将 nn.Linear 转换为 nn.Conv1d
pytorch 432
原文标题 :Convert nn.Linear to nn.Conv1d
我想将模型输出到的格式不支持 nn.Linear,所以我想更改它以执行完全相同的操作,但使用 nn.Conv1d。
我的输入是形状(N,A,B),我想要一个线性层,将其转换为形状(N,A,C)的输出。以前,我是用图层nn.Linear(B, C)
做的。我可以通过执行生成具有正确尺寸的工作代码
t1 = t1.transpose(1,2)
conv = nn.Conv1d(
in_channels=B,
out_channels=C,
kernel_size=1
)
t2 = conv(t1)
t2 = t2.transpose(1,2)
这在功能上是否等同于做t2 = nn.Linear(B,C)(t1)
?如果是这样,是否有更好/更简洁的方法?