PyTorch 中的 nn.ModuleList 和 nn.Sequential区别(入门笔记)

在学习到yolov3 spp源码时遇到的一些问题,其中之一是有关nn.ModuleList

部分引用:PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景 – 知乎 (zhihu.com)

nn.ModuleList作用:搭建基础网络的时候,存储不同module即神经网络中的Linear、conv2d、Relu层等,与nn.Sequential有些类似。

ModuleList:顾名思义,专门用于存储module的list。

两者的区别(4点):

不同点1(是否自动前向传播):nn.Sequential内部实现了forward函数,因此可以不用写forward函数,而nn.ModuleList则没有实现内部forward函数。(forward函数即前向传播)

不同点2(能否重命名):nn.Sequential可以使用OrderedDict对每层进行重命名,如图model1和model2显示出来的结构前面由序号(0)(1)(2)(3)变为(conv1)(relu1)….

# Example of using Sequential
model1 = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
print(model1)
# Sequential(
#   (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (1): ReLU()
#   (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (3): ReLU()
# )

# Example of using Sequential with OrderedDict
import collections
model2 = nn.Sequential(collections.OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))
print(model2)
# Sequential(
#   (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
#   (relu1): ReLU()
#   (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#   (relu2): ReLU()
# )

不同点3(有无顺序):nn.Sequential里面的模块按照顺序进行排列的,上下部分有关联,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。而nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,(1)中out_feature与(2)中in_feature不同,因为没有关联

class net3(nn.Module):
    def __init__(self):
        super(net3, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
    def forward(self, x):
        x = self.linears[2](x)
        x = self.linears[0](x)
        x = self.linears[1](x) 
        return x

net = net3()
print(net)
# net3(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=20, bias=True)
#     (1): Linear(in_features=20, out_features=30, bias=True)
#     (2): Linear(in_features=5, out_features=10, bias=True)
#   )
# )
input = torch.randn(32, 5)
print(net(input).shape)
# torch.Size([32, 30])

不同点4:有的时候网络中有很多相似或者重复的层,我们一般会考虑用 for 循环来创建它们

layers = [nn.Linear(10, 10) for i in range(5)]

完整的写法:

ModeleList实现方式

class net6(nn.Module):
    def __init__(self):
        super(net6, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])
 
    def forward(self, x):
        for layer in self.linears:
            x = layer(x)
        return x
 
net = net6()
print(net)
# net6(
#   (linears): ModuleList(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#     (2): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

Sequential 实现方式

net7 注意 * 这个操作符,它可以把一个 list 拆开成一个个独立的元素。但是,请注意这个 list 里面的模块必须是按照想要的顺序来进行排列的。

class net7(nn.Module):
    def __init__(self):
        super(net7, self).__init__()
        self.linear_list = [nn.Linear(10, 10) for i in range(3)]
        self.linears = nn.Sequential(*self.linear_list)
 
    def forward(self, x):
        self.x = self.linears(x)
        return x
 
net = net7()
print(net)
# net7(
#   (linears): Sequential(
#     (0): Linear(in_features=10, out_features=10, bias=True)
#     (1): Linear(in_features=10, out_features=10, bias=True)
#     (2): Linear(in_features=10, out_features=10, bias=True)
#   )
# )

如下,更改self. linear设置可变成我们想要的顺序

class net7(nn.Module):
    def __init__(self):
        super(net7, self).__init__()
        self.linear = ([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)])
        self.linears = nn.Sequential(*self.linear)

    def forward(self, x):
        self.x = self.linears(x)
        return x

net = net7()
print(net)
# net7(
#   (linears): Sequential(
#     (0): Linear(in_features=10, out_features=20, bias=True)
#     (1): Linear(in_features=20, out_features=30, bias=True)
#     (2): Linear(in_features=5, out_features=10, bias=True)
#   )
# )

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2022年5月12日
下一篇 2022年5月12日

相关推荐