站点图标 AI技术聚合

【pytorch 记录】pytorch的变量parameter、buffer。self.register_buffer()、self.register_parameter()

在pytorch中模型需要保存下来的参数包括:

  • parameter:反向传播需要被 optimizer 更新的,可以被训练。
  • buffer:反向传播不需要被 optimizer 更新,不可被训练。

这两种参数都会分别保存到 一个OrderDict 的变量中,最终由 model.state_module() 返回进行保存。

1 nn.Module的介绍

需要先说明下:直接torch.randn(1, 2) 这种定义的变量,没有绑定在pytorch的网络中,训练结束后也就没有在保存在模型中。当我们想要将一些变量保存(如yolov5中的anchor),可以用作简单的后处理,就需要将这种变量注册到网络中,可以使用的api为:self.register_buffer() :不可被训练;self.register_parameter()nn.parameter.Parameter()nn.Parameter():可以被训练。

对于pytorch定义网络时,都要继承与 nn.Module。到源码中看到该类的初始化中,成员变量如下,这里我们关心是绿色选中区域,这三个成员都是 OrderedDict() 类型的

成员变量:

  • _buffers:由self.register_buffer() 定义,requires_grad默认为False,不可被训练。
  • _parasmeter:self.register_parameter()、nn.parameter.Parameter()、nn.Parameter() 定义的变量都存放在该属性下,且定义的参数的 requires_grad 默认为 True。
  • _module:nn.Sequential()、nn.conv() 等定义的网络结构中的结构存放在该属性下。

成员函数:

  • self.state_dict():OrderedDict 类型。保存神经网络的推理参数,包括parameter、buffer
  • self.name_parameters():为迭代器。self._moduleself._parameters中所有的可训练参数的名字+tensor。包括 BN的 bn.weight、bn.bias。
  • self.parameters():与self.name_parameters()一样,但不包含名字
  • self.name_buffers():为迭代器。网络中所有的不可训练参数和自己注册的buffer 中的参数的名字+tensor。包括 BN的 bn.running_mean、bn.running_var、bn.num_batches_tracked。
  • self.buffers():与self.name_buffers()一样,但不包含名字
  • net.named_modules():为迭代器。self._module中定义的网络结构的名字+层
  • net.modules()

2 代码示例

import torch  
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        """=======case1: self._modules======="""
        self.conv = nn.Conv2d(1, 1, 3, 1, 1)
        self.TEST_1 = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(1, 1, 3, bias=False)),
            ('fc', nn.Linear(1, 2, bias=False))
        ]))

        """=======case2: self._buffers======="""
        self.register_buffer('TEST_2', torch.randn(1, 3))

        """=======case3: model._parameters======="""
        self.register_parameter('TEST_30', nn.Parameter(torch.randn(1, 4)))
        self.TEST_31 = nn.parameter.Parameter(torch.tensor(1.0))
        self.TEST_32 = nn.Parameter(torch.tensor(2.0))

        """=======case4======="""
        self.TEST_4 = torch.randn(1, 2)

    def forward(self, x):
       return x

model = Model() 
print()
print(f'=========================================model._modules:\n{model._modules}\n') 
print(f'=========================================model._buffers:\n{model._buffers}\n') 
print(f'=========================================model._parameters:\n{model._parameters}\n')
print(f'=========================================model.state_dict():\n{model.state_dict()}\n')

其实debug方式查看会更便捷。直接打印也没有问题。


如果要打印介绍的成员函数的内容,则有:

named_buffers = [param for param in model.named_buffers()]
print(f'===================================named_buffers:\n{named_buffers}\n')

named_parameters = [param for param in model.named_parameters()]
print(f'===================================named_parameters:\n{named_parameters}\n')

named_modules = [param for param in model.named_modules()]
print(f'===================================named_modules:\n{named_modules}\n')

文章出处登录后可见!

已经登录?立即刷新
退出移动版