【论文阅读】LSKNet: Large Selective Kernel Network for Remote Sensing Object Detection

这是南开大学在ICCV2023会议上新提出的旋转目标检测算法,基本原理就是通过一系列Depth-wise 卷积核和空间选择机制来动态调整目标的感受野,从而允许模型适应不同背景的目标检测。

论文地址:https://arxiv.org/pdf/2303.09030.pdf

代码地址(可以直接使用mmrotate框架实现):GitHub – zcablii/LSKNet: (ICCV 2023) Large Selective Kernel Network for Remote Sensing Object Dyetection

 一、引言

目前基于旋转框的遥感影像目标检测算法已经取得了一定的进展,但是很少考虑存在于遥感影像中的先验知识。遥感影像中的目标往往尺寸很小,仅仅基于其表观特征很难识别,如果结合其背景信息,如周边环境,就可以提供形状、方向等有意义的信息。据此,作者分析了两条重要的先验知识:

  •  精确识别遥感影像中的目标往往需要大范围的背景信息,有限的背景区域会影响模型的识别效果,例如当背景信息很少时,容易将十字路口识别为道路。
  • 不同类型的目标所需要的背景信息范围是不同的,如足球场可通明显的球场边界线进行区分,所需的背景信息不多,但是十字路口与道路相似,容易受到树木和其他遮挡物的影响,因此需要足够的背景范围信息才能进行识别。
大范围背景下更容易精确识别地物类型
识别不同目标所需的背景信息差别较大

为了解决上述问题,作者提出了一种新的遥感影像目标识别方法,即Large Selective Kernel Network (LSKNet)。该方法通过在特征提取模块动态调整感受野,更有效地处理了不同目标所需的背景信息差异。其中,动态感受野由一个空间选择机制实现,该机制对一大串Depth-wise 卷积核所处理的特征进行有效加权和空间融合。这些卷积核的权重根据输入动态确定,同时允许模型针对空间上的不同目标自适应地选择不同大小的核并调整感受野。

经验证,LSKNet网络虽然结构简单,但能够获得优异的检测性能,在HRSC2016、DOTA-v1.0、FAIR1M-v1.0三个典型数据集上都取得了SOTA。

二、算法原理

1. LSKNet的架构

结构层级依次为:

LSK module(大核卷积序列+空间选择机制) < LSK Block (LK Selection + FFN)<LSKNet(N个LSK Block)

LSK 模块
LSK Block

LSKNet 是主干网络中的一个可重复堆叠的块(Block),每个LSK Block包括两个残差子块,即大核选择子块(Large Kernel Selection,LK Selection)和前馈网络子块(Feed-forward Network ,FFN),如图8。LK Selection子块根据需要动态地调整网络的感受野,FFN子块用于通道混合和特征细化,由一个全连接层、一个深度卷积、一个 GELU 激活和第二个全连接层组成。

LSK module(LSK 模块,图4)由一个大核卷积序列(large kernel convolutions)和一个空间核选择机制(spatial kernel selection mechanism)组成,被嵌入到了LSK Block 的 LK Selection子块中(图8橙色块)。

2. Large Kernel Convolutions

因为不同类型的目标对背景信息的需求不同,这就需要模型能够自适应选择不同大小的背景范围。因此,作者通过解耦出一系列具有大卷积核、且不断扩张的Depth-wise 卷积,构建了一个更大感受野的网络。

具体地,假设序列中第i个Depth-wise 卷积核的大小为 k,扩张率为 d,感受野为 RF,它们满足以下关系:

卷积核大小和扩张率的增加保证了感受野能够快速增大。此外,我们设置了扩张率的上限,以保证扩张卷积不会引入特征图之间的差距。

Table2的卷积核大小可根据公式(1)和(2)计算,详见下图:

这样设计的好处有两点。首先,能够产生具有多种不同大小感受野的特征,便于后续的核选择;第二,序列解耦比简单的使用一个大型卷积核效果更好。如上图表2所示,解耦操作相对于标准的大型卷积核,有效地将低了模型的参数量。

为了从输入数据 X 的不同区域获取丰富的背景信息特征,可采用一系列解耦的、不用感受野的Depth-wise 卷积核:

其中,是卷积核为 k_{i}、扩张率为 d_{i} 的Depth-wise 卷积操作。假设有N个解耦的卷积核,每个卷积操作后又要经过一个1\times 1的卷积层进行空间特征向量的通道融合。

之后,针对不同的目标,可基于获取的多尺度特征,通过下文中的选择机制动态选择合适的卷积核大小。

这一段的意思可以简单理解为:

把一个大的卷积核拆成了几个小的卷积核,比如一个大小为5,扩张率为1的卷积核加上一个大小为7,扩张率为3的卷积核,感受野为23,与一个大小为23,扩张率为1的卷积核的感受野是一样的。因此可用两个小的卷积核替代一个大的卷积核,同理一个大小为29的卷积核也可以用三个小的卷积代替(Table 2),这样可以有效的减少参数,且更灵活。

将输入数据依次通过这些小的卷积核(公式3),并在每个小的卷积核后面接上一个1×1的卷积进行通道融合(公式4)。

3. Spatial Kernel Selection

为了使模型更关注目标在空间上的重点背景信息,作者使用空间选择机制从不同尺度的大卷积核中对特征图进行空间选择。

首先,将来自于不同感受野卷积核的特征进行concate拼接:

然后,应用通道级的平均池化和最大池化提取空间关系 \tilde{U}:

其中,SA_{avg}SA_{max} 是平均池化和最大池化后的空间特征描述符。为了实现不同空间描述符的信息交互,作者利用卷积层将空间池化特征进行拼接,将2个通道的池化特征转换为N个空间注意力特征图:

之后,将Sigmoid激活函数应用到每一个空间注意力特征图\hat{SA_{i}},可获得每个解耦的大卷积核所对应的独立的空间选择掩膜:

又然后,将解耦后的大卷积核序列的特征与对应的空间选择掩膜进行加权处理,并通过卷积层进行融合获得注意力特征 S

最后LSK module的输出可通过输入特征 X 与注意力特征 S 的逐元素点成获得,即:

公式对应于结构图上的操作如下:

三、实验结果

1. 实验数据集

包括HRSC2016、DOTA-v1.0和FAIR1M-v1.0三个。

2. 训练细节

骨干网络先在ImageNet-1K上预训练,然后再在实验数据集上微调。消融实验中,骨干网络预训练迭代了100个epoch。为了获得更优异的检测性能,采用了预训练300epoch的骨干网络获取主要结果。LSKNet默认构建在Oriented R-CNN上,优化器为AdamW。

3. 消融实验(DOTA-v1.0)

  • 大型卷积核解耦:证明了将一个大型卷积核解耦为两个Depth-wise卷积核能够在速度和精度上获得更好的平衡(Table 3)。
  • 感受野大小和选择类型:过小或过大的感受野会限制模型的性能,大小为23的感受野被证明是最有效的。此外,实验证明了本文提出的空间选择方法相比通道注意力机制具有更优异的性能(Table 4)。

  • 空间选择的池化层:实验表明,在空间选择部分同时采用最大和平均池化能够获得最优异的性能,也不会带来推理速度的损失(Table 5)。
  • 不同网络框架下LSKNet主干网络的性能:实验证明,与ResNet-18相比,LSKNet作为主干网络能够有效提升网络性能,同时只有38%的参数量和50%的FLOPS。
  • 与其他大核卷积、其他选择性注意力骨干网络及其他网络结构对比,LSKNet都有明显的优越性。

4. 分析

为了研究每个目标类别的感受野范围,将 R_{C} 定义为类别c的期望选择感受野的面积与地面边界框面积的比值:

I_{c}为包含目标类别 c 的影像数量,A_{i}是输入影像 i 中所有LSK block输出的空间选择激活的总和,D 是LSKNet的block数量,N 是一个LSK module解耦得到的卷积核数量,B_{i} 是所有标注的地面真实目标框 J_{i} 的总像元面积。

DOTA-v1.0数据集中各类别标准化后的Rc结果

四、代码详解

该算法代码采用mmrotate框架,可作为Oriented RCNN、RoI Transformer等基础网络的backbone。将该代码集成至已有的mmrotate框架中,只需要将mmrotate/models/backbones/lsknet.py文件拷贝至对应的的已有的文件夹,同时在__init__.py中导入并在config文件中修改对应配置即可。

LSKNet的代码如下:

import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair as to_2tuple
from mmcv.cnn.utils.weight_init import (constant_init, normal_init,
                                        trunc_normal_init)
from ..builder import ROTATED_BACKBONES
from mmcv.runner import BaseModule
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import math
from functools import partial
import warnings
from mmcv.cnn import build_norm_layer

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class LSKblock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
        self.conv1 = nn.Conv2d(dim, dim//2, 1)
        self.conv2 = nn.Conv2d(dim, dim//2, 1)
        self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
        self.conv = nn.Conv2d(dim//2, dim, 1)

    def forward(self, x):   
        attn1 = self.conv0(x)
        attn2 = self.conv_spatial(attn1)

        attn1 = self.conv1(attn1)
        attn2 = self.conv2(attn2)
        
        attn = torch.cat([attn1, attn2], dim=1)
        avg_attn = torch.mean(attn, dim=1, keepdim=True)
        max_attn, _ = torch.max(attn, dim=1, keepdim=True)
        agg = torch.cat([avg_attn, max_attn], dim=1)
        sig = self.conv_squeeze(agg).sigmoid()
        attn = attn1 * sig[:,0,:,:].unsqueeze(1) + attn2 * sig[:,1,:,:].unsqueeze(1)
        attn = self.conv(attn)
        return x * attn



class Attention(nn.Module):
    def __init__(self, d_model):
        super().__init__()

        self.proj_1 = nn.Conv2d(d_model, d_model, 1)
        self.activation = nn.GELU()
        self.spatial_gating_unit = LSKblock(d_model)
        self.proj_2 = nn.Conv2d(d_model, d_model, 1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x = self.spatial_gating_unit(x)
        x = self.proj_2(x)
        x = x + shorcut
        return x


class Block(nn.Module):
    def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., act_layer=nn.GELU, norm_cfg=None):
        super().__init__()
        if norm_cfg:
            self.norm1 = build_norm_layer(norm_cfg, dim)[1]
            self.norm2 = build_norm_layer(norm_cfg, dim)[1]
        else:
            self.norm1 = nn.BatchNorm2d(dim)
            self.norm2 = nn.BatchNorm2d(dim)
        self.attn = Attention(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        layer_scale_init_value = 1e-2            
        self.layer_scale_1 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(
            layer_scale_init_value * torch.ones((dim)), requires_grad=True)

    def forward(self, x):
        x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)))
        return x


class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768, norm_cfg=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        if norm_cfg:
            self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
        else:
            self.norm = nn.BatchNorm2d(embed_dim)


    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = self.norm(x)        
        return x, H, W

@ROTATED_BACKBONES.register_module()
class LSKNet(BaseModule):
    def __init__(self, img_size=224, in_chans=3, embed_dims=[64, 128, 256, 512],
                mlp_ratios=[8, 8, 4, 4], drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
                 depths=[3, 4, 6, 3], num_stages=4, 
                 pretrained=None,
                 init_cfg=None,
                 norm_cfg=None):
        super().__init__(init_cfg=init_cfg)
        
        assert not (init_cfg and pretrained), \
            'init_cfg and pretrained cannot be set at the same time'
        if isinstance(pretrained, str):
            warnings.warn('DeprecationWarning: pretrained is deprecated, '
                          'please use "init_cfg" instead')
            self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
        elif pretrained is not None:
            raise TypeError('pretrained must be a str or None')
        self.depths = depths
        self.num_stages = num_stages

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0

        for i in range(num_stages):
            patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
                                            patch_size=7 if i == 0 else 3,
                                            stride=4 if i == 0 else 2,
                                            in_chans=in_chans if i == 0 else embed_dims[i - 1],
                                            embed_dim=embed_dims[i], norm_cfg=norm_cfg)

            block = nn.ModuleList([Block(
                dim=embed_dims[i], mlp_ratio=mlp_ratios[i], drop=drop_rate, drop_path=dpr[cur + j],norm_cfg=norm_cfg)
                for j in range(depths[i])])
            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)



    def init_weights(self):
        print('init cfg', self.init_cfg)
        if self.init_cfg is None:
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    trunc_normal_init(m, std=.02, bias=0.)
                elif isinstance(m, nn.LayerNorm):
                    constant_init(m, val=1.0, bias=0.)
                elif isinstance(m, nn.Conv2d):
                    fan_out = m.kernel_size[0] * m.kernel_size[
                        1] * m.out_channels
                    fan_out //= m.groups
                    normal_init(
                        m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
        else:
            super(LSKNet, self).init_weights()
            
    def freeze_patch_emb(self):
        self.patch_embed1.requires_grad = False

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'}  # has pos_embed may be better

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        B = x.shape[0]
        outs = []
        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            norm = getattr(self, f"norm{i + 1}")
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x)
            x = x.flatten(2).transpose(1, 2)
            x = norm(x)
            x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
            outs.append(x)
        return outs

    def forward(self, x):
        x = self.forward_features(x)
        # x = self.head(x)
        return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        x = self.dwconv(x)
        return x


def _conv_filter(state_dict, patch_size=16):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v

    return out_dict

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年12月6日
下一篇 2023年12月6日

相关推荐