论文阅读笔记:ShuffleNet

1. 背景

由于深度学习模型结构越来越复杂,参数量也越来越大,需要大量的算力去做模型的训练和推理。然而随着移动设备的普及,将深度学习模型部署于计算资源有限基于ARM的移动设备成为了研究的热点。

ShuffleNet[1]是一种专门为计算资源有限的设备设计的神经网络结构,主要采用了pointwise group convolution和channel shuffle两种技术,在保留了模型精度的同时极大减少了计算开销。

[1] Zhang X, Zhou X, Lin M, et al. Shufflenet: An extremely efficient convolutional neural network for mobile devices[C].Proceedings of the IEEE conference on computer vision and pattern recognition. 2018: 6848-6856.

2. 相关工作

在论文中,提到了目前sota的两个工作,一个是谷歌的Xception,另一个是facebook推出的ResNeXt。

2.1 Xception

Xception[2]主要涉及了一个技术:深度可分离卷积,即把原本常规的卷积操作分为两步去做。
常规卷积是利用若干个多通道卷积核对输入的多通道图像进行处理,输出的是既提取了通道特征又提取了空间特征的feature map。
论文阅读笔记:ShuffleNet
而深度可分离卷积将提取通道特征(PointWise Convolution)和空间特征(DepthWise Convolution)分为了两步去做:
首先卷积核从三维变为了二维的,每个卷积核只负责输入图像的一个通道,用于提取空间特征,这一步操作中不涉及通道和通道之间的信息交互。接着通过一维卷积来完成通道之间特征提取的工作,即一个常规的卷积操作,只不过卷积核是1*1的。
论文阅读笔记:ShuffleNet
论文阅读笔记:ShuffleNet
这样做的好处是减少了传统卷积中的参数量。假设输入通道为论文阅读笔记:ShuffleNet,输出通道为论文阅读笔记:ShuffleNet,卷积核大小为论文阅读笔记:ShuffleNet,常规卷积的参数为:论文阅读笔记:ShuffleNet。在深度可分离卷积之后,参数量为论文阅读笔记:ShuffleNet
代码显示如下:

class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

[2] Chollet F. Xception: Deep learning with depthwise separable convolutions[C]. Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 1251-1258.

2.2 ResNeXt

作者灵感来源于VGG的模块化堆叠的结构,提出了一种基于分组卷积和残差连接的模块化卷积块从而降低了参数的数量。简单来说,理解了分组卷积的思想就能理解ResNeXt。
论文阅读笔记:ShuffleNet

[3] Xie S, Girshick R, Dollár P, et al. Aggregated residual transformations for deep neural networks[C]. Proceedings of the IEEE conference on computer vision and pattern recognition. 2017: 1492-1500.

3. ShuffleNet

由于使用论文阅读笔记:ShuffleNet卷积核进行操作时的复杂度较高,因为需要和每个像素点做互相关运算,作者关注到ResNeXt的设计中,论文阅读笔记:ShuffleNet卷积操作的那一层需要消耗大量的计算资源,因此提出将这一层也设计为分组卷积的形式。然而,分组卷积只会在组内进行卷积,因此组和组之间不存在信息的交互,为了使得信息在组之间流动,作者提出将每次分组卷积后的结果进行组内分组,再互相交换各自的组内的子组。
论文阅读笔记:ShuffleNet
论文阅读笔记:ShuffleNet
上图c就是一个shufflenet块,图a是一个简单的残差连接块,区别在于,shufflenet将残差连接改为了一个平均池化的操作与卷积操作之后做cancat,并且将论文阅读笔记:ShuffleNet卷积改为了分组卷积,并且在分组之后进行了channel shuffle操作。

代码显示如下:

class ShuffleUnit(nn.Module):
    """
    ShuffleNet unit.
    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    groups : int
        Number of groups in convolution layers.
    downsample : bool
        Whether do downsample.
    ignore_group : bool
        Whether ignore group value in the first convolution layer.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 groups,
                 downsample,
                 ignore_group):
        super(ShuffleUnit, self).__init__()
        self.downsample = downsample
        mid_channels = out_channels // 4

        if downsample:
            out_channels -= in_channels

        self.compress_conv1 = conv1x1(
            in_channels=in_channels,
            out_channels=mid_channels,
            groups=(1 if ignore_group else groups))
        self.compress_bn1 = nn.BatchNorm2d(num_features=mid_channels)
        self.c_shuffle = ChannelShuffle(
            channels=mid_channels,
            groups=groups)
        self.dw_conv2 = depthwise_conv3x3(
            channels=mid_channels,
            stride=(2 if self.downsample else 1))
        self.dw_bn2 = nn.BatchNorm2d(num_features=mid_channels)
        self.expand_conv3 = conv1x1(
            in_channels=mid_channels,
            out_channels=out_channels,
            groups=groups)
        self.expand_bn3 = nn.BatchNorm2d(num_features=out_channels)
        if downsample:
            self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
        self.activ = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        x = self.compress_conv1(x)
        x = self.compress_bn1(x)
        x = self.activ(x)
        x = self.c_shuffle(x)
        x = self.dw_conv2(x)
        x = self.dw_bn2(x)
        x = self.expand_conv3(x)
        x = self.expand_bn3(x)
        if self.downsample:
            identity = self.avgpool(identity)
            x = torch.cat((x, identity), dim=1)
        else:
            x = x + identity
        x = self.activ(x)
        return 

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2022年4月11日 下午3:45
下一篇 2022年4月11日 下午3:54

相关推荐