tensor的存储方式 + view() rehape() resize_() 区别

1.前言

在使用 expand() 函数时,查看官方文档,文档中说返回的是原tensor的一个view,不太理解view的意思。遂查找了解。后续会更新expand()函数的用法。

2.tensor 的存储方式

2.1.基本知识

tensor的存储分为两部分(一个tensor占用两个内存位置)

  • 存储真实数据的位置,我们称之为
    储藏区域
    (Storage)
  • 一个位置存储 tensor 的 形状( size ),步长( stride ),索引等信息。我们称为
    头信息
    (Tensor)

假设我们有两个tensor: A, B。我们用=号把A赋值给B,其实是浅拷贝。也就是说AB共享数据(存储部分),唯一的区别是头部信息

头信息是存储区域的表示,它决定了我们看到真实数据的排列方式。其实这就是tensor的观点。 tensor.view()功能是通过改变表头信息以不同的形式显示数据(真实数据没有改变)。我们将在下面提到它。

让我们用代码来说明。函数tensor.storage().data_ptr()用于获取tensor的存储区地址。

import torch
a = torch.tensor([1, 2, 3])
b = a
b[0] = 100
print(f'a : {a}')
print(f'b : {b}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
>>a : tensor([100,   2,   3])
>>b : tensor([100,   2,   3])
>>storage address of a: 94784411983360
>>storage address of b: 94784411983360

操作:我们用=a赋给b,然后修改b

从结果可以看出a , b发生了变化。两个内存区域具有相同的地址。这意味着两者共享内存区域。

2.1.1官方文档

其实这里的官方文档里有提到。

tensor的存储方式 + view() rehape() resize_() 区别

也就是说,PyTorch支持作为现有tensorViewtensor。 (现有的tensorbase tensor,另一个叫view tensor)。两者共享内存。

  • 这种操作可以避免一些数据复制,使得我们能够快速,并且节省内存得进行 reshape ,切片,和一些基于元素的操作

2.2.tensor 的stride() 和 storage_offset() 属性

tensor为了节省内存,很多操作都在改变头信息区。标题信息区域包含如何组织数据以及从何处开始组织数据。其中两个重要的属性是stride()storage_offset()

2.2.1 stride()

在指定维度dim中从一个元素跳转到下一个元素所需的步长(存储区域中传递的元素个数)

a = torch.randn(3, 2)
print(a.stride())
>>(2, 1)

tensor的存储方式 + view() rehape() resize_() 区别

其实不难理解,在第0维,想要跳到下一个元素,比如从a[0][0] -> a[1][0],需要经过两个元素,步长是 2。在第1维,想跳到下一个元素,从a[0][0] -> a[0][1],需要经过一个元素,步长是 1。

2.2.2 storage_offset()

表示tensor的第 0 个元素与真实存储区的第 0 个元素的偏移量

a = torch.tensor([1, 2, 3, 4, 5])
b = a[1:]
c = a[3:]
print(b.storage_offset())
print(c.storage_offset())
>>1
>>3

可见,b的第 0 个元素与a的第 0 个元素之间的偏移量是 1,ca的偏移量是 3

3.view(),reshape(),resize_()之间的关系

3.1.view()

view 字面意思是景色。所以,就是把数据按照一定的排列方式展示给我们,不改变存储区的真实数据,只改变头信息区。

a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)
print(f'a : {a}')
print(f'b : {b}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
>>a : tensor([1, 2, 3, 4, 5, 6])
>>b : tensor([[1, 2, 3],
        [4, 5, 6]])
>>storage address of a: 93972187307840
>>storage address of b: 93972187307840

可以看出两者共享存储区域

print(f'storage  of a: {a.storage()}')
print(f'storage  of b: {b.storage()}')
>>[torch.LongStorage of size 6]
storage  of a:  
 1
 2
 3
 4
 5
 6
>>[torch.LongStorage of size 6]
storage  of b:  
 1
 2
 3
 4
 5
 6

存储区的数据没有变化

print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
>>(1,)
>>(3, 1)

可以看出stride变化,即头部信息区变化

3.2.reshape()

3.2.1.tensor的连续性

tensor的连续性其实就是stride()属性和size()的关系

导通条件:tensor的存储方式 + view() rehape() resize_() 区别

表示从i维度到下一个元素的步数是从i + 1维度到下一个维度的步数乘以i + 1维度的个数。

例如,在一个二维数组中,tensor的存储方式 + view() rehape() resize_() 区别表示第0维进入下一个数字,需要完成这一行。

就像上面的例子

a = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)

对于bstride[0] = 3stride[1] = 1size[1] = 3。满足以上条件

直观上就是:在存储区的真实数据中,我旁边的数字还在我旁边,这叫连续

一些操作会改变连续性,例如转置

a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()
print(f'a : {a}')
print(f'b : {b}')
print(a.stride())
print(b.stride())
>>a : tensor([[1, 2, 3],
        [4, 5, 6]])
>>b : tensor([[1, 4],
        [2, 5],
        [3, 6]])
>>(3, 1)
>>(1, 3)

ba的转置。==在第0维走到下一个元素,1 -> 2,步长是 1,因为在存储区中 1 和 2 是相邻的。==这就不满足上面的式子了,因此是不连续的。

再次强调:stride是需要传入存储区才能到达当前维度下一个元素的元素个数

直观理解:在第0维,如果数据是连续的,走到下一个元素,应该把这一行走完,步长是 2。现在 4 的邻居是 1 和 2 了,实际上应该是 3 和 5。这就说明数据不连续了

不连续性不能使用view()法。有没有办法让b使用view()?是让它连续(b.contiguous()

a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()
c = b.contiguous()
print(f'a : {a}')
print(f'b : {b}')
print(f'c : {c}')
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
print(f'stride of c : {c.stride()}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
print(f'storage address of c: {c.storage().data_ptr()}')
>>a : tensor([[1, 2, 3],
        [4, 5, 6]])
>>b : tensor([[1, 4],
        [2, 5],
        [3, 6]])
>>c : tensor([[1, 4],
        [2, 5],
        [3, 6]])
>>stride of a : (3, 1)
>>stride of b : (1, 3)
>>stride of c : (2, 1)
>>storage address of a: 94382097256256
>>storage address of b: 94382097256256
>>storage address of c: 94382053572096

我们可以看到c的数据恢复了连续性,其存储区的地址与a, b的不同。

contiguous()函数实际上是创建一个新的tensor。在存储区,将数据存入b,以便得到c

所以我们可以解释reshape()view()之间的区别

  • 当 tensor 满足连续性要求时, reshape() = view() ,和原来 tensor 共用存储区
  • 当 tensor 不满足连续性要求时, reshape() = **contiguous() + view() ,会产生新的存储区的 tensor ,与原来 tensor 不共用存储区

3.3.resize_()

前面说到的reshapeview都必须要用到全部的原始数据,比如你的原始数据只有12个,无论你怎么变形都必须要用到12个数字,不能多不能少。因此你就不能把只有12个数字的tensor强行reshap2*5的维度的tensor。但是resize_()可以做到,无论你存储区原始有多少个数字,我都能变成你想要的维度,数字不够怎么办?随机产生凑!数字多了怎么办?就取我需要的部分!

3.3.1.数据多的时候

a = torch.tensor([1, 2, 3, 4, 5, 6, 7])
b = a.resize_(2, 3)
print(f'a : {a}')
print(f'b : {b}')
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
>>a : tensor([[1, 2, 3],
        [4, 5, 6]])
>>b : tensor([[1, 2, 3],
        [4, 5, 6]])
>>stride of a : (3, 1)
>>stride of b : (3, 1)
>>storage address of a: 94579423708416
>>storage address of b: 94579423708416

print(a.storage())
>> 
 1
 2
 3
 4
 5
 6
 7

可见,取的是前 6 个。

会改变a,但不改变存储区中的数据,a, b共享存储区

3.3.2.数据少的时候

a = torch.tensor([1, 2, 3, 4, 5])
b = a.resize_(2, 3)
print(f'a : {a}')
print(f'b : {b}')
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
>>a : tensor([[             1,              2,              3],
        [             4,              5, 94673159007352]])
>>b : tensor([[             1,              2,              3],
        [             4,              5, 94673159007352]])
>>stride of a : (3, 1)
>>stride of b : (3, 1)
>>storage address of a: 94673159007296
>>storage address of b: 94673159007296
    
print(a.storage())
>> 
 1
 2
 3
 4
 5
 94673159007352

可以看到添加了一个数字

会改变a,并且改变存储区的数据,a, b共享存储区(但已经不是刚才的存储区了,地址变了)

3.3.3.处理不连续数据

a = torch.arange(6).view(2, 3)
b = a.t()
c = b.resize_(3, 2)
print(f'a : {a}')
print(f'b : {b}')
print(f'c : {c}')
print(f'stride of a : {a.stride()}')
print(f'stride of b : {b.stride()}')
print(f'stride of c : {c.stride()}')
print(f'storage address of a: {a.storage().data_ptr()}')
print(f'storage address of b: {b.storage().data_ptr()}')
print(f'storage address of c: {c.storage().data_ptr()}')
>>a : tensor([[0, 1, 2],
        [3, 4, 5]])
>>b : tensor([[0, 3],
        [1, 4],
        [2, 5]])
>>c : tensor([[0, 3],
        [1, 4],
        [2, 5]])
>>stride of a : (3, 1)
>>stride of b : (1, 3)
>>stride of c : (1, 3)
>>storage address of a: 94375435009664
>>storage address of b: 94375435009664
>>storage address of c: 94375435009664

可以看出,使用resize_()后,数据仍然保持连续性。并且没有打开新的tensor,和原来的tensor共享存储区

print(a.storage())
>> 
 0
 1
 2
 3
 4
 5

此外,存储区中的数字也不会改变。

也就是说,resize_()只改变头部信息,使数据以我们想要的形式呈现,不改变其他信息。

4.总结

最后总结一下view()reshape()resize_()的关系和区别。

  • view()只能对满足连续性要求的tensor使用。
  • 当 tensor 满足连续性要求时, reshape() = view() ,和原来 tensor 共用内存。
  • 当 tensor 不满足连续性要求时, reshape() = **contiguous() + view() ,会产生新的存储区的 tensor ,与原来 tensor 不共用存储区。
  • resize_()可以随意的获取任意维度的 tensor ,不用在意真实数据的个数限制,但是不推荐使用。

参考:Pytorch——Tensor的储存机制以及view()、reshape()、reszie_()三者的关系和区别 – Circle_Wang – 博客园 (cnblogs.com)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年4月13日 下午2:15
下一篇 2022年4月13日 下午2:48

相关推荐