在pytorch中模型需要保存下来的参数包括:
- parameter:反向传播需要被 optimizer 更新的,可以被训练。
- buffer:反向传播不需要被 optimizer 更新,不可被训练。
这两种参数都会分别保存到 一个OrderDict 的变量中,最终由 model.state_module() 返回进行保存。
1
nn.Module
的介绍需要先说明下:直接torch.randn(1, 2) 这种定义的变量,没有绑定在pytorch的网络中,训练结束后也就没有在保存在模型中。当我们想要将一些变量保存(如yolov5中的anchor),可以用作简单的后处理,就需要将这种变量注册到网络中,可以使用的api为:
self.register_buffer()
:不可被训练;self.register_parameter()
、nn.parameter.Parameter()
、nn.Parameter()
:可以被训练。
对于pytorch定义网络时,都要继承与nn.Module
。到源码中看到该类的初始化中,成员变量如下,这里我们关心是绿色选中区域,这三个成员都是 OrderedDict() 类型的
成员变量:
_buffers
:由self.register_buffer() 定义,requires_grad默认为False,不可被训练。_parasmeter
:self.register_parameter()、nn.parameter.Parameter()、nn.Parameter() 定义的变量都存放在该属性下,且定义的参数的 requires_grad 默认为 True。_module
:nn.Sequential()、nn.conv() 等定义的网络结构中的结构存放在该属性下。
成员函数:
self.state_dict()
:OrderedDict 类型。保存神经网络的推理参数,包括parameter、bufferself.name_parameters()
:为迭代器。self._module
和self._parameters
中所有的可训练参数的名字+tensor。包括 BN的 bn.weight、bn.bias。self.parameters()
:与self.name_parameters()
一样,但不包含名字self.name_buffers()
:为迭代器。网络中所有的不可训练参数和自己注册的buffer 中的参数的名字+tensor。包括 BN的 bn.running_mean、bn.running_var、bn.num_batches_tracked。self.buffers()
:与self.name_buffers()
一样,但不包含名字net.named_modules()
:为迭代器。self._module
中定义的网络结构的名字+层net.modules()
2 代码示例
import torch import torch.nn as nn class Model(nn.Module): def __init__(self): super(Model, self).__init__() """=======case1: self._modules=======""" self.conv = nn.Conv2d(1, 1, 3, 1, 1) self.TEST_1 = nn.Sequential(OrderedDict([ ('conv', nn.Conv2d(1, 1, 3, bias=False)), ('fc', nn.Linear(1, 2, bias=False)) ])) """=======case2: self._buffers=======""" self.register_buffer('TEST_2', torch.randn(1, 3)) """=======case3: model._parameters=======""" self.register_parameter('TEST_30', nn.Parameter(torch.randn(1, 4))) self.TEST_31 = nn.parameter.Parameter(torch.tensor(1.0)) self.TEST_32 = nn.Parameter(torch.tensor(2.0)) """=======case4=======""" self.TEST_4 = torch.randn(1, 2) def forward(self, x): return x model = Model() print() print(f'=========================================model._modules:\n{model._modules}\n') print(f'=========================================model._buffers:\n{model._buffers}\n') print(f'=========================================model._parameters:\n{model._parameters}\n') print(f'=========================================model.state_dict():\n{model.state_dict()}\n')
其实debug方式查看会更便捷。直接打印也没有问题。
如果要打印介绍的成员函数的内容,则有:named_buffers = [param for param in model.named_buffers()] print(f'===================================named_buffers:\n{named_buffers}\n') named_parameters = [param for param in model.named_parameters()] print(f'===================================named_parameters:\n{named_parameters}\n') named_modules = [param for param in model.named_modules()] print(f'===================================named_modules:\n{named_modules}\n')
文章出处登录后可见!
已经登录?立即刷新