pytorch中的reshape()、view()、nn.flatten()和flatten()

在使用pytorch定义神经网络结构时,经常会看到类似如下的.view() / flatten()用法,这里对其用法做出讲解与演示。

torch.reshape用法

reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用,
其作用是在不改变tensor元素数目的情况下改变tensor的shape。

torch.reshape() 需要两个参数,一个是待被改变的张量tensor,一个是想要改变的形状。

torch.reshape(input, shape) → Tensor
input(Tensor)-要重塑的张量
shape(python的元组:ints)-新形状`

案例1
输入:

import torch

a = torch.tensor([[0,1],[2,3]])
x = torch.reshape(a,(-1,))
print (x)

b = torch.arange(4.)
Y = torch.reshape(a,(2,2))
print(Y)

结果:

tensor([0, 1, 2, 3])
tensor([[0, 1],
[2, 3]])

torch.view用法

view()的原理很简单,其实就是把原先tensor中的数据进行排列,排成一行,然后根据所给的view()中的参数从一行中按顺序选择组成最终的tensor。
view()可以有多个参数,这取决于你想要得到的是几维的tensor,一般设置两个参数,也是神经网络中常用的(一般在全连接之前),代表二维。
view(h,w),h代表行(想要变为几行),当不知道要变为几行,但知道要变为几列时可取-1;w代表的是列(想要变为几列),当不知道要变为几列,但知道要变为几行时可取-1。

一、普通用法(手动调整)

view()相当于reshape、resize,重新调整Tensor的形状。
案例2.
输入

import torch
a1 = torch.arange(0,16)
print(a1)

输出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

输入

a2 = a1.view(8, 2)
a3 = a1.view(2, 8)
a4 = a1.view(4, 4)

print(a2)
print(a3)
print(a4)

输出

tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])

二、特殊用法:参数-1(自动调整size)

view中一个参数定为-1,代表自动调整这个维度上的元素个数,以保证元素的总数不变。
输入

import torch
a1 = torch.arange(0,16)
print(a1)

输出

tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

输入

a2 = a1.view(-1, 16)
a3 = a1.view(-1, 8)
a4 = a1.view(-1, 4)
a5 = a1.view(-1, 2)
a6 = a1.view(4*4, -1)
a7 = a1.view(1*4, -1)
a8 = a1.view(2*4, -1)

print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)

输出

tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3, 4, 5, 6, 7],
[ 8, 9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])
tensor([[ 0],
[ 1],
[ 2],
[ 3],
[ 4],
[ 5],
[ 6],
[ 7],
[ 8],
[ 9],
[10],
[11],
[12],
[13],
[14],
[15]])
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[ 0, 1],
[ 2, 3],
[ 4, 5],
[ 6, 7],
[ 8, 9],
[10, 11],
[12, 13],
[14, 15]])

torch.nn.Flatten(start_dim=1,end_dim=-1)

start_dim与end_dim分别表示开始的维度和终止的维度,默认值为1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)。
因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。

使用nn.Flatten(),使用默认参数

官方给出的示例:

input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
#torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
#torch.Size([160, 5])

#开头的代码是注释
整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。
1.先使用一次nn.Flatten(),使用默认参数:

m = nn.Flatten()

也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二位置代表的维度,也就是样例中的1。
因此进行展平后的结果也就是[32,155]→[32,25]

2.接着再使用一次指定参数的nn.Flatten(),即

m = nn.Flatten(0,2)

也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。
因此结果就是[3215,5]→[160,25]

torch.flatten

torch.flatten()函数经常用于写分类神经网络的时候,经过最后一个卷积层之后,一般会再接一个自适应的池化层,输出一个BCHW的向量。这时候就需要用到torch.flatten()函数将这个向量拉平成一个Bx的向量(其中,x = CHW),然后送入到FC层中。

语句结构

 torch.flatten(input, start_dim=0, end_dim=-1)

input: 一个 tensor,即要被“摊平”的 tensor。
start_dim: “摊平”的起始维度。
end_dim: “摊平”的结束维度。
作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是类,其默认开始维度为第 0 维。例1:

import torch

data_pool = torch.randn(2,2,3,3) # 模拟经过最后一个池化层或自适应池化层之后的输出,Batchsize*c*h*w
print(data_pool)

y=torch.flatten(data_pool,1)
print(y)

输出结果:

结果是一个B*x的向量。

本文源于多篇资料的提炼汇总,部分参考资料如下。
参考资料:参考1;参考2;参考3;参考4

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年11月14日
下一篇 2023年11月14日

相关推荐