pytorch入门篇2 玩转tensor(查看、提取、变换)

上一篇博客讲述了如何根据自己的实际需要在pytorch中创建tensor:pytorch入门篇1——创建tensor,这一篇主要来探讨关于tensor的基本数据变换,是pytorch处理数据的基本方法。

1 tensor数据查看与提取

tensor数据的查看与提取主要是通过索引和切片实现的。主要方法有:

################################################
#1、直接索引
a = torch.rand(4,1,16,16)
b = a[0]
print(b.size())
"""
输出结果:
torch.Size([1, 16, 16])
"""
################################################
#2、选定索引的维度范围
c = a[1:3,:,:,:]
print(c.size())
"""
输出结果:
torch.Size([2, 1, 16, 16])
"""
################################################
#3、按步长索引
c = a[0:3:2,:,:,:]
print(c.size())
"""
输出结果:
torch.Size([2, 1, 16, 16])
"""
################################################
#4、按mask筛选索引
d = torch.rand(2,3)
mask = torch.BoolTensor([[0,1,0],
                        [1,0,1]])
res = torch.masked_select(d,mask)
print(d)
print(res)
"""
输出结果:
tensor([[0.2507, 0.8419, 0.6681],
        [0.0940, 0.8476, 0.5883]])
tensor([0.8419, 0.0940, 0.5883])
"""

2 tensor数据变换

本部分主要介绍的是将已有的tensor数据改变shape的方法,这是pytorch变换数据的基本操作。

2.1 重置tensor形状:pytorch.view()

torch.view()主要实现重新定义tensor的shape,对tensor中的元素进行重排列,在view过程中tensor的总元素个数保持不变该方法的缺点是会丢失维度信息

#############################################
a = torch.rand(4,4,16,16)
b = a.view(4,32,32)
print(a.shape)
print(b.shape)
"""
输出结果:
torch.Size([4, 4, 16, 16])
torch.Size([4, 32, 32])
#注:在view前后必须保证tensor的总元素个数不变
#   在本例中:4*4*16*16 = 4*32*32
"""

2.2 增加/减少tensor维度:torch.unsqueeze()/torch.squeeze()

#############################################
a = torch.rand(4,4,16,16)
b = a.unsqueeze(0)
c = torch.rand(1,4,1,16,16).squeeze()
print(a.shape)
print(b.shape)
print(c.shape)
"""
输出结果:
torch.Size([4, 4, 16, 16])
torch.Size([1, 4, 4, 16, 16])
torch.Size([4, 16, 16])
"""

2.3 tensor扩充:torch.expand()/torch.repeat()

#############################################
a = torch.rand(4,1,16,16)
#a中只有a[1]是1维度,这里要将它扩充至4,因此expand中第二个参数为4,其它维度不是1,所以要与a保持一致(否则会报错)
b = a.expand(4,4,16,16)
#repeat中的参数都是,代表这要将a中对应维度的元素都复制两次
c = a.repeat(2,2,2,2)
print(a.shape)
print(b.shape)
print(c.shape)
"""
输出结果:
#这是原来a变量的shape
torch.Size([4, 1, 16, 16])
#这是经过expand后的a变量(b)的shape
#可以看出,expand成功将原理index=1位置处的1维度扩充成了对应的4维度
torch.Size([4, 4, 16, 16])
#这是经过repeat后的a变量的shape
#可以看出,repeat将每个维度都复制了2次,即:(4*2,1*2,16*2,16*2)
torch.Size([8, 2, 32, 32])
"""

2.4 tensor维度交换/重新排序:torch.transpose()/torch.permute()

#############################################
a = torch.rand(4,1,12,16)
b = a.transpose(0,2)        #第0个维度和第2个维度互换,即:4和12互换
c = a.permute(3,2,0,1)  #将原有的3, 2, 0,1维度作为新tensor的0,1,2,3维度
print(a.shape)
print(b.shape)
print(c.shape)
"""
输出结果:
torch.Size([4, 1, 12, 16])  #a.shape
torch.Size([12, 1, 4, 16])  #b.shape  (4,1,12,16)-互换4,12->(12,1,4,16) 
torch.Size([16, 12, 4, 1])  #c.shape  重新对a进行排列(a[3],a[2],a[0],a[1])->(16,12,4,1)
"""

3 总结

tensor变换的核心理念是为了更好的服务于高维向量运算,这部分的变换技巧和相关的方法有很多。关于这部分的内容我的建议是:不需要花费大量的时间去刻意记忆,只要知道有这些方法,在编程需要的时候能想起来用这些方法可以解决即可

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年3月11日
下一篇 2023年3月11日

相关推荐