torch.cat()中dim说明

torch.cat()

torch.cat(Tuple[Tensor],dim)->Tensor

输入为Tensor的List/Tuple,输出为一个Tensor

torch.cat()用于对张量的拼接,与数组拼接函数torch.stack()用法类似,二者区别在于输入的变量是数组还是张量。

其中初学者最费解的就是dim的选取,dim的取值范围由输入张量的维度决定,输入为n维张量,dim取值在[0,n-1],接下来我们以实验理解dim不同取值对应的不同操作结果。

初次接触众多博客对dim的讲解为,对于两个二维张量作为输入,dim取0结果为两个张量按行拼接,取1结果为按列拼接,但是对于高维来说就有点难以直观想象结果了,我们尝试三维情况进而总结规律。

 先从一个简单的例子入手,输入两个张量为二维,dim取值分别为0和1 :

 

import torch
X=torch.tensor([[1,2,3],[4,5,6]])
Y=torch.tensor([[7,8,9],[1,4,7]])
input=[X,Y]
A=torch.cat(input,dim=0)
B=torch.cat(input,dim=1)
print("X:{}\nY:{}\ndim0:{}\ndim1:{}".format(X,Y,A,B))

           

结果如下

       torch.cat()中dim说明

 可以看出对于两个二维张量作为输入,dim取0结果为两个张量按行拼接,取1结果为按列拼接,但是对于高维来说就有点难以直观想象结果了,我们尝试三维情况进而总结规律。

import torch
X=torch.tensor([[[1,2],[3,4]],[[5,6],[7,8]]])
Y=torch.tensor([[[7,6],[5,4]],[[8,9],[9,10]]])
input=[X,Y]
A=torch.cat(input,dim=0)
B=torch.cat(input,dim=1)
C=torch.cat(input,dim=2)

print("X:{}\nY:{}\ndim0:{}\ndim1:{}\ndim2:{}".format(X,Y,A,B,C))

                 

输入为两个三维张量:

                          

torch.cat()中dim说明

输出:

                 

torch.cat()中dim说明

可见对于dim=0,其输出结果为对两个张量的最高维度包含的内容进行拼接,此例中,X和Y均为三维张量,其最高维度包含的内容为二维,因此,dim=0结果是对其二维张量进行拼接组成的三维张量:

                  

torch.cat()中dim说明

 那么对于dim=1的情况,就是对次高维包含内容进行拼接,次高维为2维,其内容为1维,将1维进行拼接得到:

         

torch.cat()中dim说明

以此类推,对于dim=n-1的情况比较难理解,此例dim=2,对次次高维即1维的内容进行拼接,其中1维的内容是0维,可以理解为1维张量括号内的元素,即每个数字,将其进行拼接,得到结果:

            

torch.cat()中dim说明

 

 至此,torch.cat()的dim作用已经讲清楚,建议动手实验一下就可以弄明白其中的奥秘!!!

 

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年3月16日 下午10:39
下一篇 2023年3月16日 下午10:44

相关推荐