torch.nn.Parameter()函数的讲解和使用

0. 引言

在学习SSD网络的时候发现源码里使用nn.Parameter()这个函数,故对其进行了解。

1. 官方文档

先看一下官方的解释:PyTorch官方文档

1.1 语法

torch.nn.parameter.Parameter(data=None, requires_grad=True)

其中:

  • data (Tensor) – parameter tensor. —— 输入得是一个tensor
  • requires_grad (bool, optional) – if the parameter requires gradient. See Locally disabling gradient computation for more details. Default: True —— 这个不用解释,需要注意的是nn.Parameter()默认有梯度。

1.2 官方解释

A kind of Tensor that is to be considered a module parameter.

Parameters are Tensor subclasses, that have a very special property when used with Modules – when they’re assigned as Module attributes they are automatically added to the list of its parameters, and will appear e.g. in parameters() iterator. Assigning a Tensor doesn’t have such effect. This is because one might want to cache some temporary state, like last hidden state of the RNN, in the model. If there was no such class as Parameter, these temporaries would get registered too.

ParametersTensor 的子类,当与 Modules 一起使用时具有一个非常特殊的属性 – 当它们被分配为 Module attributes 时,它们会自动添加到其参数列表中,并将 出现例如 在 parameters() 迭代器中。 分配张量没有这样的效果。 这是因为人们可能想要在模型中缓存一些临时状态,例如 RNN 的最后一个隐藏状态。 如果没有像 Parameter 这样的类,这些临时对象也会被注册。

2. 通俗的解释

torch.nn.Parameter()将一个不可训练的tensor转换成可以训练的类型parameter,并将这个parameter绑定到这个module里面。即在定义网络时这个tensor就是一个可以训练的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。

2.1 例子

拿SGE进行举例:

import numpy as np
import torch
from torch import nn  # 引入torch.nn as nn
from torch.nn import init


class SpatialGroupEnhance(nn.Module):
    def __init__(self, groups):
        super().__init__()
        self.groups=groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        
        # 使用torch.nn.Parameter将不可训练的tensor转换为可训练的tensor并在该类中进行注册
        """
        	可以看到,torch.zeros(1,groups,1,1)它是一个没有梯度的tensor,所以不能参与训练,而
        	self.weight=nn.Parameter(torch.zeros(1,groups,1,1))之后,self.weight就是一个有梯度的tensor,可以参与forward并进行反向传播不断学习
		"""
        self.weight=nn.Parameter(torch.zeros(1,groups,1,1))  # [1, G, 1, 1]
        self.bias=nn.Parameter(torch.zeros(1,groups,1,1))  # [1, G, 1, 1]
        
        self.sig=nn.Sigmoid()
        self.init_weights()
        
    def forward(self, x):
        b, c, h, w = x.shape  # [BS, C, H, W]
        x = x.view(b * self.groups, -1, h, w)  # [BS, C, H, W] -> [BS*G, C//G, H, W]
        xn = x * self.avg_pool(x)  # [BS*G, C//G, H, W] * [BS*G, 1] = [BS*G, C//G, H, W]
        xn = xn.sum(dim=1, keepdim=True)  # [BS*G, C//G, H, W] -> [BS*G, 1, H, W]
        t = xn.view(b * self.groups, -1)  # [BS*G, 1, H, W] -> [BS*G, H*W]

        t = t - t.mean(dim=1, keepdim=True)  # [BS*G, H*W] - [BS*G, 1] = [BS*G, H*W]
        std = t.std(dim=1, keepdim=True) + 1e-5  # [BS*G, 1]
        t = t / std  # [BS*G, H*W] / [BS*G, 1] = [BS*G, H*W]
        t = t.view(b, self.groups, h, w)  # [BS*G, H*W] -> [BS, G, H, W]
        
        """
        	self.weight和self.bias是经过nn.Parameter()注册过后的tensor,是可以学习的参数
		"""
        t = t * self.weight + self.bias  # [BS, G, H, W] * [1, G, 1, 1] + [1, G, 1, 1] = [BS, G, H, W]
        t = t.view(b * self.groups, 1, h, w)  # [BS, G, H, W] -> [BS*G, 1, H, W]
        x = x * self.sig(t)  # [BS*G, 1, H, W] -> [BS*G, 1, H, W]
        x = x.view(b, c, h, w)  # [BS*G, 1, H, W] -> [BS, C, H, W]

        return x 

2.2 其他人的观点

https://blog.csdn.net/weixin_42096202/article/details/97964157
这篇文章中认为:linear里面的weightbias就是parameter类型,且不能够使用tensor类型替换,还有linear里面的weight甚至可能通过指定一个不同于初始化时候的形状进行模型的更改。一般是多维的可训练tensor。

赞同!

2.3 总结

在训练网络的时候,可以使用nn.Parameter()来转换一个固定的权重数值,使其可以在反向传播时进行参数更新,从而学习到一个最适合的权重值。

2.4 初始化网络参数的方法

补一个通用的初始化方法:

    def _initialize_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

参考

  1. https://pytorch.org/docs/stable/index.html
  2. https://blog.csdn.net/weixin_42096202/article/details/97964157

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年2月23日 下午1:42
下一篇 2023年2月23日 下午1:43

相关推荐