图像分类(二)CBAM —— Spatial Attention空间注意力及Resnet_cbam实现

Spatial Attention空间注意力及Resnet_cbam实现

前言

上一次介绍Renest时,介绍了CNN里的通道注意力Channel-Wise的Split Attention及其block实现

这一次介绍一下另外一种注意力,空间注意力和CBAM结构。

下面是我实现resnet中加入CBAM结构的代码,可以给大家学习一下:Resnet_CBAM 

一、Attention表达改进

这里插播一下CBAM中对于注意力的改进。下图为我之前介绍过的通道注意力的具体实现形式,也即通过AvgPool来汇聚每一个通道中的信息,然后将每一个通道中的信息经过两个全连接层来提取通道间的相关关系,来构造注意力信息(全连接有激活函数来构造非线性,图未画出)

图像分类(二)CBAM —— Spatial Attention空间注意力及Resnet_cbam实现

而在CBAM中,作者认为单单使用AvgPool来提取每个通道中的信息,并不能很好的去表示这个通道的信息,于是又加入了MaxPool来补充这个通道的信息。所以改进后整体的结构图如下 图像分类(二)CBAM —— Spatial Attention空间注意力及Resnet_cbam实现

注意,这里经过AvgPool和MaxPool的两个表达通道特征的信息是共用全连接层的参数的。最后的输出,则为两个计算出来的注意力信息相加。

二、Spatial Attention空间注意力

其实我们从通道注意力,其实就可以推断出空间注意力具体的实现。因为通道注意力是通过压缩每个通道中wxh的信息来构造的,那么空间注意力就是去压缩通道,来构造在空间维度上的信息

具体的实现即如下图所示,在通道维度上进行最大值和平均值的汇聚,CxWxH的feature map压缩成1xWxH的信息,然后通过带有注意力权重的卷积来提取注意力信息,最后,如果是单分支结构,通过sigmoid来使注意力权重非负,如果是多分支结构则应用softmax来使注意力权重非负。

               图像分类(二)CBAM —— Spatial Attention空间注意力及Resnet_cbam实现 

如果我们和通道注意力相比,通道注意力中的注意力信息,是筛选出哪些通道的信息是和目前这个认为是相关的。空间注意力则是去关心对于Feature map来说哪些位置的信息是和目前认为相关的。

所以当我们去提取空间注意力的权重,然后和原图进行相乘操作,会发现空间注意力可以highlight到我们需要识别的物体在这幅图片的位置。 

下面给出改进后的通道注意力和空间注意力的代码,对照着看,理解会更加容易。

class ChannelAttention(nn.Module):
    def __init__(self, inplanes):
        super(ChannelAttention, self).__init__()
        self.max_pool = nn.MaxPool2d(1)
        self.avg_pool = nn.AvgPool2d(1)
        # 通道注意力,即两个全连接层连接
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels=inplanes, out_channels=inplanes // 16, kernel_size=1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_channels=inplanes // 16, out_channels=inplanes, kernel_size=1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        max_out = self.fc(self.max_pool(x))
        avg_out = self.fc(self.avg_pool(x))
        # 最后输出的注意力应该为非负
        out = self.sigmoid(max_out + avg_out)
        return out


class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, padding=7 // 2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # 压缩通道提取空间信息
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        # 经过卷积提取空间注意力权重
        x = torch.cat([max_out, avg_out], dim=1)
        out = self.conv1(x)
        # 输出非负
        out = self.sigmoid(out)
        return out

 

三、Resnet_CBAM

 CBAM结构其实就是将通道注意力信息核空间注意力信息在一个block结构中进行运用。具体实现方法,以在resnet中添加CBAM结构为例。

图像分类(二)CBAM —— Spatial Attention空间注意力及Resnet_cbam实现

其中的channel attention和Spatial attention即我上文所实现的方法,所以我们想要在resnet中实现这个方法其实也十分简单,即在原始block和残差结构连接前,依次通过channel attention和spatial attention即可。

class Bottleneck(nn.Module):

 
    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None
    ):
        super(Bottleneck, self).__init__()
        expansion: int = 4    
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        # attention
        self.spatial_atten = SpatialAttention()
        self.channel_atten = ChannelAttention(planes * self.expansion)

   def forward(self, x: Tensor):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        # attention
        atten = self.channel_atten(out) * out
        atten = self.spatial_atten(atten) * atten

        if self.downsample is not None:
            identity = self.downsample(x)

        atten += identity
        out = self.relu(atten)

        return out

总结

如果对我上一篇讲解Resnest的文章中注意力的讲解比较熟悉的话,对应CBAM这个结构是比较容易理解和接受的。

但是网络的具体性能如何,大家可以用我的代码测试一下,个人在实验和认知上都觉得CBAM的结构其实还有很多需要优化的地方。

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2023年3月11日
下一篇 2023年3月11日

相关推荐