1.前言
在使用 expand() 函数时,查看官方文档,文档中说返回的是原tensor
的一个view
,不太理解view
的意思。遂查找了解。后续会更新expand()
函数的用法。
2.tensor 的存储方式
2.1.基本知识
tensor
的存储分为两部分(一个tensor
占用两个内存位置)
- 存储真实数据的位置,我们称之为
储藏区域
(Storage) - 一个位置存储 tensor 的 形状( size ),步长( stride ),索引等信息。我们称为
头信息
(Tensor)
假设我们有两个tensor: A, B
。我们用=
号把A
赋值给B
,其实是浅拷贝。也就是说A
和B
共享数据(存储部分),唯一的区别是头部信息
头信息是存储区域的表示,它决定了我们看到真实数据的排列方式。其实这就是tensor
的观点。 tensor.view()
功能是通过改变表头信息以不同的形式显示数据(真实数据没有改变)。我们将在下面提到它。
让我们用代码来说明。函数tensor.storage().data_ptr()
用于获取tensor
的存储区地址。
import torch
a = torch.tensor([1, 2, 3])
b = a
b[0] = 100
print(f'a : {a}')
print(f'b : {b}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
>>a : tensor([100, 2, 3])
>>b : tensor([100, 2, 3])
>>storage address of a: 94784411983360
>>storage address of b: 94784411983360
操作:我们用=
将a
赋给b
,然后修改b
。
从结果可以看出a , b
发生了变化。两个内存区域具有相同的地址。这意味着两者共享内存区域。
2.1.1官方文档
其实这里的官方文档里有提到。
也就是说,PyTorch
支持作为现有tensor
View
的tensor
。 (现有的tensor
叫base tensor
,另一个叫view tensor
)。两者共享内存。
- 这种操作可以避免一些数据复制,使得我们能够快速,并且节省内存得进行 reshape ,切片,和一些基于元素的操作
2.2.tensor 的stride() 和 storage_offset() 属性
tensor
为了节省内存,很多操作都在改变头信息区。标题信息区域包含如何组织数据以及从何处开始组织数据。其中两个重要的属性是stride()
和storage_offset()
2.2.1 stride()
在指定维度dim
中从一个元素跳转到下一个元素所需的步长(存储区域中传递的元素个数)
a = torch.randn(3, 2)
print(a.stride())
>>(2, 1)
其实不难理解,在第0
维,想要跳到下一个元素,比如从a[0][0] -> a[1][0]
,需要经过两个元素,步长是 2。在第1
维,想跳到下一个元素,从a[0][0] -> a[0][1]
,需要经过一个元素,步长是 1。
2.2.2 storage_offset()
表示tensor
的第 0 个元素与真实存储区的第 0 个元素的偏移量
a = torch.tensor([1, 2, 3, 4, 5])
b = a[1:]
c = a[3:]
print(b.storage_offset())
print(c.storage_offset())
>>1
>>3
可见,b
的第 0 个元素与a
的第 0 个元素之间的偏移量是 1,c
与a
的偏移量是 3
3.view(),reshape(),resize_()之间的关系
3.1.view()
view
字面意思是景色。所以,就是把数据按照一定的排列方式展示给我们,不改变存储区的真实数据,只改变头信息区。
a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)
print(f'a : {a}')
print(f'b : {b}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
>>a : tensor([1, 2, 3, 4, 5, 6])
>>b : tensor([[1, 2, 3],
[4, 5, 6]])
>>storage address of a: 93972187307840
>>storage address of b: 93972187307840
可以看出两者共享存储区域
print(f'storage of a: {a.storage()}')
print(f'storage of b: {b.storage()}')
>>[torch.LongStorage of size 6]
storage of a:
1
2
3
4
5
6
>>[torch.LongStorage of size 6]
storage of b:
1
2
3
4
5
6
存储区的数据没有变化
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
>>(1,)
>>(3, 1)
可以看出stride
变化,即头部信息区变化
3.2.reshape()
3.2.1.tensor的连续性
tensor
的连续性其实就是stride()
属性和size()
的关系
导通条件:
表示从i
维度到下一个元素的步数是从i + 1
维度到下一个维度的步数乘以i + 1
维度的个数。
例如,在一个二维数组中,表示第0
维进入下一个数字,需要完成这一行。
就像上面的例子
a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)
对于b
:stride[0] = 3
、stride[1] = 1
、size[1] = 3
。满足以上条件
直观上就是:在存储区的真实数据中,我旁边的数字还在我旁边,这叫连续
一些操作会改变连续性,例如转置
a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()
print(f'a : {a}')
print(f'b : {b}')
print(a.stride())
print(b.stride())
>>a : tensor([[1, 2, 3],
[4, 5, 6]])
>>b : tensor([[1, 4],
[2, 5],
[3, 6]])
>>(3, 1)
>>(1, 3)
b
是a
的转置。==在第0
维走到下一个元素,1 -> 2
,步长是 1,因为在存储区中 1 和 2 是相邻的。==这就不满足上面的式子了,因此是不连续的。
再次强调:stride
是需要传入存储区才能到达当前维度下一个元素的元素个数
直观理解:在第0
维,如果数据是连续的,走到下一个元素,应该把这一行走完,步长是 2。现在 4 的邻居是 1 和 2 了,实际上应该是 3 和 5。这就说明数据不连续了
不连续性不能使用view()
法。有没有办法让b
使用view()
?是让它连续(b.contiguous()
)
a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()
c = b.contiguous()
print(f'a : {a}')
print(f'b : {b}')
print(f'c : {c}')
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
print(f'stride of c : {c.stride()}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
print(f'storage address of c: {c.storage().data_ptr()}')
>>a : tensor([[1, 2, 3],
[4, 5, 6]])
>>b : tensor([[1, 4],
[2, 5],
[3, 6]])
>>c : tensor([[1, 4],
[2, 5],
[3, 6]])
>>stride of a : (3, 1)
>>stride of b : (1, 3)
>>stride of c : (2, 1)
>>storage address of a: 94382097256256
>>storage address of b: 94382097256256
>>storage address of c: 94382053572096
我们可以看到c
的数据恢复了连续性,其存储区的地址与a, b
的不同。
contiguous()
函数实际上是创建一个新的tensor
。在存储区,将数据存入b
,以便得到c
所以我们可以解释reshape()
和view()
之间的区别
- 当 tensor 满足连续性要求时, reshape() = view() ,和原来 tensor 共用存储区
- 当 tensor 不满足连续性要求时, reshape() = **contiguous() + view() ,会产生新的存储区的 tensor ,与原来 tensor 不共用存储区
3.3.resize_()
前面说到的reshape
和view
都必须要用到全部的原始数据,比如你的原始数据只有12个,无论你怎么变形都必须要用到12个数字,不能多不能少。因此你就不能把只有12个数字的tensor
强行reshap
成2*5
的维度的tensor
。但是resize_()
可以做到,无论你存储区原始有多少个数字,我都能变成你想要的维度,数字不够怎么办?随机产生凑!数字多了怎么办?就取我需要的部分!
3.3.1.数据多的时候
a = torch.tensor([1, 2, 3, 4, 5, 6, 7])
b = a.resize_(2, 3)
print(f'a : {a}')
print(f'b : {b}')
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
>>a : tensor([[1, 2, 3],
[4, 5, 6]])
>>b : tensor([[1, 2, 3],
[4, 5, 6]])
>>stride of a : (3, 1)
>>stride of b : (3, 1)
>>storage address of a: 94579423708416
>>storage address of b: 94579423708416
print(a.storage())
>>
1
2
3
4
5
6
7
可见,取的是前 6 个。
会改变a
,但不改变存储区中的数据,a, b
共享存储区
3.3.2.数据少的时候
a = torch.tensor([1, 2, 3, 4, 5])
b = a.resize_(2, 3)
print(f'a : {a}')
print(f'b : {b}')
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
>>a : tensor([[ 1, 2, 3],
[ 4, 5, 94673159007352]])
>>b : tensor([[ 1, 2, 3],
[ 4, 5, 94673159007352]])
>>stride of a : (3, 1)
>>stride of b : (3, 1)
>>storage address of a: 94673159007296
>>storage address of b: 94673159007296
print(a.storage())
>>
1
2
3
4
5
94673159007352
可以看到添加了一个数字
会改变a
,并且改变存储区的数据,a, b
共享存储区(但已经不是刚才的存储区了,地址变了)
3.3.3.处理不连续数据
a = torch.arange(6).view(2, 3)
b = a.t()
c = b.resize_(3, 2)
print(f'a : {a}')
print(f'b : {b}')
print(f'c : {c}')
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
print(f'stride of c : {c.stride()}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
print(f'storage address of c: {c.storage().data_ptr()}')
>>a : tensor([[0, 1, 2],
[3, 4, 5]])
>>b : tensor([[0, 3],
[1, 4],
[2, 5]])
>>c : tensor([[0, 3],
[1, 4],
[2, 5]])
>>stride of a : (3, 1)
>>stride of b : (1, 3)
>>stride of c : (1, 3)
>>storage address of a: 94375435009664
>>storage address of b: 94375435009664
>>storage address of c: 94375435009664
可以看出,使用resize_()
后,数据仍然保持连续性。并且没有打开新的tensor
,和原来的tensor
共享存储区
print(a.storage())
>>
0
1
2
3
4
5
此外,存储区中的数字也不会改变。
也就是说,resize_()
只改变头部信息,使数据以我们想要的形式呈现,不改变其他信息。
4.总结
最后总结一下view()
、reshape()
、resize_()
的关系和区别。
- view()只能对满足连续性要求的tensor使用。
- 当 tensor 满足连续性要求时, reshape() = view() ,和原来 tensor 共用内存。
- 当 tensor 不满足连续性要求时, reshape() = **contiguous() + view() ,会产生新的存储区的 tensor ,与原来 tensor 不共用存储区。
- resize_()可以随意的获取任意维度的 tensor ,不用在意真实数据的个数限制,但是不推荐使用。
参考:Pytorch——Tensor的储存机制以及view()、reshape()、reszie_()三者的关系和区别 – Circle_Wang – 博客园 (cnblogs.com)
文章出处登录后可见!