torch.flatten与torch.nn.flatten

torch.nn.flattentorch.nn.flatten是一个类,作用为将连续的几个维度展平成一个tensor(将一些维度合并)参数为合并开始的维度,合并结束的维度(维度就是索引,从 0 开始)开始维度默认为 1。因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。结束维度默认为 -1,也就是一直合并到最后一维默认参数情况x =

torch.nn.flatten

torch.nn.flatten是一个类,作用为将连续的几个维度展平成一个tensor(将一些维度合并)

  • 参数为合并开始的维度,合并结束的维度(维度就是索引,从 0 开始)

    • 开始维度默认为 1。因为其被用在神经网络中,输入为一批数据,第 0 维为batch(输入数据的个数),通常要把一个数据拉成一维,而不是将一批数据拉为一维。所以torch.nn.Flatten()默认从第一维开始平坦化。
    • 结束维度默认为 -1,也就是一直合并到最后一维
      torch.flatten与torch.nn.flatten
  • 默认参数情况

    x = torch.ones(2, 2, 2, 2)
    
    F = torch.nn.Flatten()
    y = F(x)
    print(y)
    print(y.shape)
    >>tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1., 1., 1., 1.]])
    >>torch.Size([2, 8])
    
  • 输入一个参数情况:该参数为合并开始的维度

    x = torch.ones(2, 2, 2, 2)
    
    F = torch.nn.Flatten(2)
    y = F(x)
    print(y)
    print(y.shape)
    >>tensor([[[1., 1., 1., 1.],
             [1., 1., 1., 1.]],
    
            [[1., 1., 1., 1.],
             [1., 1., 1., 1.]]])
    >>torch.Size([2, 2, 4])
    
  • 输入两个参数情况:第一个参数代表合并开始维度,第二个参数代表合并结束维度(合并范围包含开始维度和结束维度)

    x = torch.ones(2, 2, 2, 2)
    
    F = torch.nn.Flatten(1, 2)
    y = F(x)
    print(y)
    print(y.shape)
    >>tensor([[[1., 1.],
             [1., 1.],
             [1., 1.],
             [1., 1.]],
    
            [[1., 1.],
             [1., 1.],
             [1., 1.],
             [1., 1.]]])
    >>torch.Size([2, 4, 2])
    

torch.flatten

作用与 torch.nn.flatten 类似,都是用于展平 tensor 的,只是 torch.flatten 是 function 而不是类,其默认开始维度为第 0 维

t = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t.shape)
>>torch.Size([2, 2, 2])

print(torch.flatten(t))
>>tensor([1, 2, 3, 4, 5, 6, 7, 8])

print(torch.flatten(t, 1))
>>tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

print(torch.flatten(t, 0, 1).shape)
>>torch.Size([4, 2])

若输入是 0 维 tensor,则输出的是一维 tensor

t = torch.tensor(1)
print("before flatten:")
print(t)
print(t.shape)
>>before flatten:
  tensor(1)
  torch.Size([])

print("\n")
print("after flatten:")
print(torch.flatten(t))
print(torch.flatten(t).shape)
>>after flatten:
  tensor([1])
  torch.Size([1])

版权声明:本文为博主长命百岁️原创文章,版权归属原作者,如果侵权,请联系我们删除!

原文链接:https://blog.csdn.net/qq_52852138/article/details/122675391

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
乘风的头像乘风管理团队
上一篇 2022年1月24日 下午7:43
下一篇 2022年1月24日 下午9:44

相关推荐