forward() takes 2 positional arguments but 3 were given

问题描述:

在forward中明明正确数量的参数,却报错:forward() takes 2 positional arguments but 3 were given;

问题分析:

使用nn.Sequential()定义的网络,只接受单输入

例如:

self.backbone=nn.Sequential(nn.lstm(input_size=20, hidden_size=40, num_layers=2),

                                    nn.linear(in_features=40, out_features=2))

def forward(self, input):

        h0 = torch.randn(hidden_layers, batch_size, hidden)

        c0 = torch.randn(hidden_layers, batch_size, hidden)
        output, _ = self.backbone(input) 
(对)

         output, _ = self.backbone(input, (h0, c0)   (错误,因为nn.Sequential()定义的网络,只接受单输入

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
扎眼的阳光的头像扎眼的阳光普通用户
上一篇 2023年8月4日
下一篇 2023年8月4日

相关推荐