PyTorch 中可学习的标量权重并保证标量的总和为 1

乘风 pytorch 204

原文标题Learnable scalar weight in PyTorch and guarantee the sum of scalars is 1

我有这样的代码:

class MyModule(nn.Module):
    
    def __init__(self, channel, reduction=16, n_segment=8):
        super(MyModule, self).__init__()
        self.channel = channel
        self.reduction = reduction
        self.n_segment = n_segment
        
        self.conv1 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        self.conv3 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel//self.reduction, kernel_size=1, bias=False)
        #whatever

        # learnable weight
        self.W_1 = nn.Parameter(torch.randn(1), requires_grad=True)
        self.W_2 = nn.Parameter(torch.randn(1), requires_grad=True)
        self.W_3 = nn.Parameter(torch.randn(1), requires_grad=True)

    def forward(self, x):
        
        # whatever
        
        ## branch1                
        bottleneck_1 = self.conv1(x)
        
        ## branch2
        bottleneck_2 = self.conv2(x)
        
        ## branch3                
        bottleneck_3 = self.conv3(x)
        
        ## summation
        output = self.avg_pool(self.W_1*bottleneck_1 + 
                          self.W_2*bottleneck_2 + 
                          self.W_3*bottleneck_3) 
        
        return output

如您所见,3 个可学习的标量(W_1W_2W_3)用于加权目的。但是,这种方法不能保证这些标量的总和为 1。如何在 Pytorch 中使我的可学习标量的总和等于 1?谢谢

原文链接:https://stackoverflow.com//questions/71952139/learnable-scalar-weight-in-pytorch-and-guarantee-the-sum-of-scalars-is-1

回复

我来回复
  • Alexey Birukov的头像
    Alexey Birukov 评论

    把事情简单化:

        ## summation
        WSum = self.W_1 + self.W_2 + self.W_3
        output = self.avg_pool( self.W_1/WSum *bottleneck_1 + 
                                self.W_2/WSum *bottleneck_2 + 
                                self.W_3/WSum *bottleneck_3)
    

    此外,还可以使用分配律:

        output = self.avg_pool(self.W_1*bottleneck_1 + 
                          self.W_2*bottleneck_2 + 
                          self.W_3*bottleneck_3) /WSum
    
    2年前 0条评论