torch.nn.flatten
torch.nn.flatten是一个类,作用为将连续的几个维度展平成一个tensor(将一些维度合并)
-
参数为合并开始的维度,合并结束的维度(维度就是索引,从 0 开始)
- 开始维度默认为 1。因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。
- 结束维度默认为 -1,也就是一直合并到最后一维
-
默认参数情况
x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten() y = F(x) print(y) print(y.shape) >>tensor([[1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1.]]) >>torch.Size([2, 8])
-
输入一个参数情况:该参数为合并开始的维度
x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten(2) y = F(x) print(y) print(y.shape) >>tensor([[[1., 1., 1., 1.], [1., 1., 1., 1.]], [[1., 1., 1., 1.], [1., 1., 1., 1.]]]) >>torch.Size([2, 2, 4])
-
输入两个参数情况:第一个参数代表合并开始维度,第二个参数代表合并结束维度(合并范围包含开始维度和结束维度)
x = torch.ones(2, 2, 2, 2) F = torch.nn.Flatten(1, 2) y = F(x) print(y) print(y.shape) >>tensor([[[1., 1.], [1., 1.], [1., 1.], [1., 1.]], [[1., 1.], [1., 1.], [1., 1.], [1., 1.]]]) >>torch.Size([2, 4, 2])
torch.flatten
作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是类,其默认开始维度为第 0 维
t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.shape)
>>torch.Size([2, 2, 2])
print(torch.flatten(t))
>>tensor([1, 2, 3, 4, 5, 6, 7, 8])
print(torch.flatten(t, 1))
>>tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
print(torch.flatten(t, 0, 1).shape)
>>torch.Size([4, 2])
若输入是 0 维 tensor,则输出的是一维 tensor
t = torch.tensor(1)
print("before flatten:")
print(t)
print(t.shape)
>>before flatten:
tensor(1)
torch.Size([])
print("\n")
print("after flatten:")
print(torch.flatten(t))
print(torch.flatten(t).shape)
>>after flatten:
tensor([1])
torch.Size([1])
版权声明:本文为博主长命百岁️原创文章,版权归属原作者,如果侵权,请联系我们删除!
原文链接:https://blog.csdn.net/qq_52852138/article/details/122675391