博客翻译:利用融合conv和bn的方法加速模型

本文是Speeding up model with fusing batch normalization and convolution – LearnML.Today的翻译。conv+bn融合主要是在推理阶段进行加速,BN在推理时无需更新参数,且推理过程满足Conv的计算公式,能合二为一。好处是加快了推理,在量化任务中,也提高了精度(在高精度先乘,相比转换为低精度再乘,减小了精度损失)。YOLOv5中使用了该技术。​​​​​​​

这是量化和推理优化模型中使用的常用技术。

博客翻译:利用融合conv和bn的方法加速模型

今天我们将尝试了解如何使我们的模型在推理时更快一些。
大量的网络使用BN来提高网络的泛化能力 。但是在推理阶段,Batch Normalization 被关闭,取代使用的是每个通道的均值\mu和方差\sigma ^{2}的近似值。最酷的是我们可以通过1×1卷积实现同样的行为。更好的是,我们可以把BN和前面的卷积合并。

Batch Normalization

设 x 是我们要 normalize 的网络中的一个信号(激活)。给定一组这样的信号x_{1},x_{2},...,x_{n}来自于在一个batch 中 处理不同的样本,每一个都被normalized 如下:

博客翻译:利用融合conv和bn的方法加速模型

\mu\sigma ^{2}是在一个 batch 上计算的 均值和方差(mean and variance),\epsilon是数值稳定性的一个小常数,\gamma是比例因子,\beta是 转换因子。在训练期间,\mu\sigma对于每个 batch 都被重新计算:

博客翻译:利用融合conv和bn的方法加速模型

参数\gamma\beta和网络的其他参数一起从梯度下降慢慢地中学习。在测试期间,我们通常不会在图像的一个batch上运行网络。因此,前面提到公式中的\mu\sigma不能使用。我们使用在训练中通过指数移动平均( exponential moving average )计算它们的估计值。让我们标记它们的近似值为\hat{\mu}\hat{\sigma} ^{2}

目前,batch normalization 主要应用于卷积神经网络对图像的处理。在该设置中,输入特征图的每个通道都有均值和方差估计、比例和转换参数。我们将这些表示为:对于通道c:\mu _{c}\sigma _{c}^{2},\gamma _{c}\beta _{c}

解决

实现冻结的Batch Normalization为一个1×1 Convolution

给定一个具有形状C\times H \times W顺序的特征图F, 为了得到 它的 normalized 版本\hat{F}。使用上面的公式,我们需要计算每个空间位置i,j\hat{x}_{i}

博客翻译:利用融合conv和bn的方法加速模型

我们清晰的看到:这是f(x)=W*x+b,它可以实现为 一个 1×1 Convolution。甚至,因为BN经常在卷积层后,我们可以把卷积和BN融合为一个。

使用一个卷积层融合batch normalization

设,W_{BN}\in \mathbb{R}^{CxC}b_{BN}\in \mathbb{R}^{C}– 是BN的参数

W_{conv}\in \mathbb{R}^{C\times (C_{prev}k^{2})}b_{conv}\in \mathbb{R}^{C}– 是在BN前面的卷积层的参数

F_{prev}-卷积层的输入

C_{prev}– 输入层的通道数

k – filter 的size

F_{prev}k\times k部分被 reshaped 到一个shape 为k^{2}C_{prev}的向量f_{i,j}, 因此产生的公式:

博客翻译:利用融合conv和bn的方法加速模型

因此,我们可以使用下面的参数通过一个单个卷积层替换卷积+BN 两层。

  • filter weights:W=W_{BN}W_{conv}
  • bias:b = W_{BN}b_{conv}+b_{BN}

使用  PyTorch 实现:

博客翻译:利用融合conv和bn的方法加速模型

import torch
import torchvision

def fuse(conv, bn):

    fused = torch.nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        kernel_size=conv.kernel_size,
        stride=conv.stride,
        padding=conv.padding,
        bias=True
    )

    # setting weights
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
    fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
    
    # setting bias
    if conv.bias is not None:
        b_conv = conv.bias
    else:
        b_conv = torch.zeros( conv.weight.size(0) )
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
                            torch.sqrt(bn.running_var + bn.eps)
                        )
    fused.bias.copy_( b_conv + b_bn )

    return fused

# Testing
# we need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
resnet18 = torchvision.models.resnet18(pretrained=True)
# removing all learning variables, etc
resnet18.eval()
model = torch.nn.Sequential(
    resnet18.conv1,
    resnet18.bn1
)
f1 = model.forward(x)
fused = fuse(model[0], model[1])
f2 = fused.forward(x)
d = (f1 - f2).mean().item()
print("error:",d)

其他参考:

移动平均(Moving Average) – 知乎

深度学习推理时融合BN,轻松获得约5%的提速 – 知乎

【基础算法】六问透彻理解BN(Batch Normalization) – 知乎

BN和Dropout在训练和测试时的差别 – 知乎

7.5。批量规范化 — 动手学深度学习 2.0.0-beta0 documentation

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2022年4月13日
下一篇 2022年4月13日

相关推荐