pytorch:深入理解 reshape(), view(), transpose(), permute() 函数

前言

view() 函数是进行张量维度重构的函数,permute() 和 transpose() 是进行张量维度转换的函数,高阶张量由若干低阶张量构成,如结构为 (n, c, h, w)的 4 阶张量由 n 个结构为 (c, h, w) 的 3 阶张量构成,结构为 (c, h, w)的 3 阶张量由 c 个结构为 (h, w) 的 2 阶张量构成,结构为 (h, w)的 2 阶张量又由 h 个长度为 w 的 1 阶张量构成,h 为行数,w 为列数。

1. reshape()

reshape() 函数与 view() 函数都是进行维度重组的函数,使用方法类似,区别在于 view() 函数只能对张量进行操作,而 reshape() 函数既可以对张量进行操作,还可以对 numpy 数组进行操作,代码示例如下,具体原理见 view() 函数。

x = np.array([1, 2, 3, 4, 5, 6])  # 一个大小为 6 的一维 numpy 数组
y = torch.Tensor([1, 2, 3, 4, 5, 6])  # 一个大小为 6 的一阶张量
print(x.reshape(2, 3))  # 重组 x 为结构为 (2, 3) 的数组
print(y.reshape(2, 3))  # 重组 y 为结构为 (2, 3) 的张量

在这里插入图片描述

2. view()

① 1 阶变高阶

1 阶变 2 阶

对于一个 1 阶张量 x,进行 view(h, w) 操作就是按照索引先后顺序每次从 x 中取出 w 个元素作为作为一行数据,共取 h 次,构成一个 (h, w) 结构的 2 阶张量,具体见示例。

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8])  # 一个含有 8 个元素的 1 阶张量
print(x.view(4, 2))  # 返回一个 (4, 2) 结构的 2 阶张量

在这里插入图片描述

1 阶变 3 阶

对于一个 1 阶张量 x,进行 view(c, h, w) 操作就是按照索引先后顺序每次从 x 中取出 h*w 个元素,对这 h*w 个元素按照 1 阶张量转 2 阶数张量的方法转为一个 (h, w) 结构的 2 阶张量,共取 c 次,构成一个 (c, h, w) 结构的 3 阶张量,具体见示例。

x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])  # 一个含有 12 个元素的 1 阶张量
print(x.view(3, 2, 2))  # 返回一个 (3, 2, 2) 结构的 3 阶张量

在这里插入图片描述

1 阶变 4 阶

对于一个 1 阶张量 x,进行 view(n, c, h, w) 操作就是按照索引先后顺序每次从 x 中取出 c*h*w 个元素,对这 c*h*w 个元素按照 1 阶张量转 3 阶张量的方法转为一个 (c, h, w) 结构的 3 阶张量,共取 n 次,最终构成一个 (n, c, h, w) 结构的 4 阶张量,具体见示例。

#  # 一个含有 24 个元素的 1 阶张量
x = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
                  13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24])
print(x.view(2, 2, 2, 3))  # 返回一个 (2, 2, 2, 3) 结构的 4 阶张量

在这里插入图片描述

1 阶变 m 阶

对于一个 1 阶张量 x,进行 view(pytorch:深入理解 reshape(), view(), transpose(), permute() 函数, pytorch:深入理解 reshape(), view(), transpose(), permute() 函数, ···, pytorch:深入理解 reshape(), view(), transpose(), permute() 函数, pytorch:深入理解 reshape(), view(), transpose(), permute() 函数) 操作就是按照索引先后顺序每次从 x 中取出 pytorch:深入理解 reshape(), view(), transpose(), permute() 函数*pytorch:深入理解 reshape(), view(), transpose(), permute() 函数*···*pytorch:深入理解 reshape(), view(), transpose(), permute() 函数*pytorch:深入理解 reshape(), view(), transpose(), permute() 函数 个元素,对这 pytorch:深入理解 reshape(), view(), transpose(), permute() 函数*pytorch:深入理解 reshape(), view(), transpose(), permute() 函数*···*pytorch:深入理解 reshape(), view(), transpose(), permute() 函数*pytorch:深入理解 reshape(), view(), transpose(), permute() 函数 个元素按照 1 阶张量转 m-1 阶张量的方法转为一个 (pytorch:深入理解 reshape(), view(), transpose(), permute() 函数, ···, pytorch:深入理解 reshape(), view(), transpose(), permute() 函数, pytorch:深入理解 reshape(), view(), transpose(), permute() 函数) 结构的 m-1 阶张量,共取 m 次,最终构成一个 (pytorch:深入理解 reshape(), view(), transpose(), permute() 函数, pytorch:深入理解 reshape(), view(), transpose(), permute() 函数, ···, pytorch:深入理解 reshape(), view(), transpose(), permute() 函数, pytorch:深入理解 reshape(), view(), transpose(), permute() 函数) 结构的 m 阶张量,其中 pytorch:深入理解 reshape(), view(), transpose(), permute() 函数 代表张量第 n 个索引的值。

② 2 阶变 m 阶

对于一个 2 阶张量 x,结构为 (h, w),要变成一个 m 阶的新张量,首先将该 2 阶张量按行展开成一个大小为 h*w 的 1 阶张量,再按照 1 阶变 m 阶的方法变为一个 m 阶张量,按行展开就是在 w 索引方向上进行拼接,2 阶张量变 3 阶张量的代码示例见下,用一个 1 阶张量来验证分析。

x = torch.Tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [10, 11, 12]])  # 一个 (4, 3) 结构的 2 阶张量
y = torch.Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])  # 一个含有 12 个元素的一阶张量
print(x.view(2, 2, 3))  # 返回一个 (2, 2, 3) 结构的 3 阶张量
print(y.view(2, 2, 3))  # 返回一个 (2, 2, 3) 结构的 3 阶张量

在这里插入图片描述

③ 3 阶变 m 阶

对于一个 3 阶张量 x,结构为 (c, h, w),要变成一个 m 阶的新张量,首先将该 3 阶张量按行拼接得到一个结构为 (c*h, w) 的 2 阶张量,再按照 2 阶变 1 阶的方法转变为一个 1 阶张量,按行拼接就是在 h 索引方向上进行拼接,示例见图 1.1 和图 1.2。

在这里插入图片描述

x = torch.Tensor([[[1, 2, 3],
                   [4, 5, 6]],

                  [[7, 8, 9],
                   [10, 11, 12]],


                  [[13, 14, 15],
                   [16, 17, 18]],

                  [[19, 20, 21],
                   [22, 23, 24]]])  # 一个 (4, 2, 3) 结构的 3 阶张量
y = torch.Tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9],
                  [10, 11, 12],
                  [13, 14, 15],
                  [16, 17, 18],
                  [19, 20, 21],
                  [22, 23, 24]])  # 一个 (4*2, 3) 结构的 2 阶张量
print((y.view(2, 2, 2, 3)).equal(x.view(2, 2, 2, 3)))  # 两个张量转变后的结果是否相等
print(x.view(2, 2, 2, 3))  # 返回一个 (2, 2, 2, 3) 结构的 4 阶张量

在这里插入图片描述

④ 4 阶变 m 阶

对于一个 4 阶张量 x,结构为 (n, c, h, w),要变成一个 m 阶的新张量,首先将该 m 阶张量在 c 索引方向进行拼接得到一个结构为 (n*c, h, w) 的 3 阶张量,再按照 3 阶张量变 1 阶张量的方法转变为一个 1 阶张量,最后再按照 1阶变 m 阶的方法得到 m 阶张量,4 阶张量在 c 索引方向进行拼接的示意图如图 2.1和图 2.2 所示。

在这里插入图片描述

x = torch.Tensor([[[[1, 2, 3],
                    [4, 5, 6]],

                   [[7, 8, 9],
                    [10, 11, 12]]],


                  [[[13, 14, 15],
                    [16, 17, 18]],

                   [[19, 20, 21],
                    [22, 23, 24]]]])  # 一个 (2, 2, 2, 3) 结构的 4 阶张量
y = torch.Tensor([[[1, 2, 3],
                   [4, 5, 6]],

                   [[7, 8, 9],
                    [10, 11, 12]],

                   [[13, 14, 15],
                    [16, 17, 18]],

                   [[19, 20, 21],
                    [22, 23, 24]]])  # 一个 (2*2, 2, 3) 结构的 3 阶张量

print(f'x.size() = {x.size()}')  # x(2, 2, 2, 3)
print(x.view(4, 6))  # 返回结构为 (4, 6) 的 2 阶张量
print(f'y.size() = {y.size()}')  # y(4, 2, 3)
print(y.view(4, 6))  # 返回结构为 (4, 6) 的 2 阶张量

在这里插入图片描述

3. transpose()

transpose() 函数一次进行两个维度的交换,参数是 0, 1, 2, 3, … ,随着待转换张量的阶数上升参数越来越多。

② 2 阶张量

对于一个 2 阶张量,结构为 (h, w),对应 transpose() 函数中的参数是 (0, 1) 两个索引,进行 transpose(0, 1) 操作就是在交换 h, w 两个维度,得到的结果与常见的矩阵转置相同,具体代码示例见下。

x = torch.Tensor([[1, 2],
                  [3, 4],
                  [5, 6]])  # 一个结构为 (3, 2) 的 2 阶张量
print(f'x.size() = {x.size()}')  # 返回张量 x 的结构
y = x.transpose(0, 1)  # 交换 h, w 两个维度
# y = x.t()  # 对 x 进行转置
print(f'y.size() = {y.size()}')  # 返回张量 y 的结构
print(y)  # 打印交换维度后的张量 y,结构为 (2, 3)

在这里插入图片描述

③ 3 阶张量

对于一个 3 阶张量,结构为 (c, h, w),对应 transpose() 函数中的参数是 (0, 1, 2) 3 个索引,进行 transpose(0, 1) 操作就是在交换 c, h 两个维度,交换 c, h 两个维度的示意图见图 3.1 和图 3.2,其他维度的交换方式同理,实在不明白可以拿几本书放一起比划一下。

在这里插入图片描述

x = torch.Tensor([[[1, 2, 3], [4, 5, 6]],
                  [[7, 8, 9], [10, 11, 12]],
                  [[13, 14, 15], [16, 17, 18]],
                  [[19, 20, 21], [22, 23, 24]]])  # 一个结构为 (4, 2, 3) 的 3 阶张量
print(f'x.size() = {x.size()}')  # 返回张量 x 的结构
print(x.transpose(0, 1))  # 交换张量的 c, h 维度, 结构为 (2, 4, 3)

在这里插入图片描述

④ 4 阶张量

对于一个 4 阶张量,结构为 (n,c, h, w),对应 transpose() 函数中的参数是 (0, 1, 2,3) 4 个索引,对应 transpose() 的操作相对复杂一些,为方便理解这里具体分为 transpose(0, 1),和 transpose(0, 3),和 transpose(1, 2) 三种,具体原因见以下分析。

3.4.1 transpose(0, 1) 操作就是交换 n, c 两个维度,交换 n, c 两个维度的示意图见图 4.1 和图 4.2,实在不明白可以拿几本书比划一下。

在这里插入图片描述

4 阶张量交换 n, c 维度的代码示例见下,其他维度交换同理,不难发现对 4 阶张量而言进行 transpose(0, 1) 操作就是 n 索引方向上进行通道重新分组,如下代码中原张量 n 索引方向上有 2 组,每组有 3 个通道,交换 n, c 维度后变为 3 组,每组有 2 个通道。

x = torch.Tensor([[[[1, 2], [3, 4]],
                   [[5, 6], [7, 8]],
                   [[9, 10], [11, 12]]],

                  [[[13, 14], [15, 16]],
                   [[17, 18], [19, 20]],
                   [[21, 22], [23, 24]]]])  # 一个结构为 (2, 3, 2, 2) 的 4 阶张量
print(f'x.size() = {x.size()}')  # 返回张量 x 的结构
print(x.transpose(0, 1))  # 返回交换 n, c 维度后的张量,结构为 (3, 2, 2, 2)

在这里插入图片描述

对于一个结构为 (2, 2, 2, 3) 的 4 阶张量 x,进行 transpose(0, 3) 操作即将原 4 阶张量变成一个结构为 (3, 2, 2, 2) 的新 4 阶张量,可以理解为在保证原 4 阶张量中元素 c , h 索引不变的情况下的将每一个元素的 n, w 进行交换,类似于坐标系变换。

x = torch.Tensor([[[[1, 2, 3], [4, 5, 6]],
                   [[7, 8, 9], [10, 11, 12]]],

                  [[[13, 14, 15], [16, 17, 18]],
                   [[19, 20, 21], [22, 23, 24]]]])  # 结构为 (2, 2, 2, 3) 的 4 阶张量
print(f'x.size() = {x.size()}')  # 返回张量 x 的结构
y = x.transpose(0, 3)  # 交换 n, w 维度
print(f'y.size() = {y.size()}')  # 返回张量 y 的结构
print(y)

在这里插入图片描述

4. permute()

permute() 函数一次可以进行多个维度的交换或者可以成为维度重新排列,参数是 0, 1, 2, 3, … ,随着待转换张量的阶数上升参数越来越多,本质上可以理解为多个 transpose() 操作的叠加,因此理解 permute() 函数的关键在于理解 transpose() 函数,代码示例如下。

x = torch.Tensor([[[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 10, 11, 12]],

                  [[13, 14, 15, 16],
                   [17, 18, 19, 20],
                   [21, 22, 23, 24]]])  # 一个结构为 (2, 3, 4) 的 3 阶张量
print(f'x.size() = {x.size()}')  # 返回张量 x 的结构
y = x.permute(2, 0, 1)  # 对张量 x 进行维度重排
z = x.transpose(0, 1).transpose(0, 2)  # 对张量 x 连续交换两次维度
print(y.equal(z))  # 判断张量 y 和张量 z 是否相同
print(f'z.size() = {z.size()}')  # 返回张量 z 的结构
print(z)

在这里插入图片描述

结语

通过以上分析可以得出结论,reshpe()view() 两个函数满足条件时可以根据需要设置维度,而 transpose()permute() 两个函数只能在已有的维度之间进行变换,另外 transpose() 函数在 pytorch 和 numpy 中略有不同,numpy 中的 transpose() 函数相当于 pytorch 中的 permute() 函数。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2023年2月25日 下午10:33
下一篇 2023年2月25日

相关推荐