站点图标 AI技术聚合

Pytorch-torch.nn.identity()方法详解

Pytorch-torch.nn.identity()方法详解


identity模块不改变输入,直接return input

一种编码技巧吧,比如我们要加深网络,有些层是不改变输入数据的维度的,
在增减网络的过程中我们就可以用identity占个位置,这样网络整体层数永远不变,

应用:
例如此时:如果此时我们使用了se_layer,那么就SELayer(dim),否则就输入什么就输出什么(什么都不做)

m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
input = torch.randn(128, 20)
output = m(input)
print(output.size()) >> torch.Size([128, 20])

文章出处登录后可见!

已经登录?立即刷新
退出移动版