pytorch中nn.ModuleList()使用方法

定义ModuleList

我们可以将我们需要的层放入到一个集合中,然后将这个集合作为参数传入nn.ModuleList中,但是这个子类并不可以直接使用,因为这个子类并没有实现forward函数,所以要使用还需要放在继承了nn.Module的模型中进行使用。

model_list = nn.ModuleList([nn.Conv2d(1, 5, 2), nn.Linear(10, 2), nn.Sigmoid()])

x = torch.randn(32, 3, 24, 24)
for model in model_list:
    model_list(x)

使用ModuleList定义网络

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.model_list = nn.ModuleList([nn.Conv2d(1, 5, 2), nn.Linear(10, 2), nn.Sigmoid()])
    
    def forward(self, x):
        return self.model_list(x)

打印网络层结构

model = Net()
print(model)
Net(
  (model_list): ModuleList(
    (0): Conv2d(1, 5, kernel_size=(2, 2), stride=(1, 1))
    (1): Linear(in_features=10, out_features=2, bias=True)
    (2): Sigmoid()
  )
)

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2023年11月7日
下一篇 2023年11月7日

相关推荐