本文是Speeding up model with fusing batch normalization and convolution – LearnML.Today的翻译。conv+bn融合主要是在推理阶段进行加速,BN在推理时无需更新参数,且推理过程满足Conv的计算公式,能合二为一。好处是加快了推理,在量化任务中,也提高了精度(在高精度先乘,相比转换为低精度再乘,减小了精度损失)。YOLOv5中使用了该技术。
这是量化和推理优化模型中使用的常用技术。
今天我们将尝试了解如何使我们的模型在推理时更快一些。
大量的网络使用BN来提高网络的泛化能力 。但是在推理阶段,Batch Normalization 被关闭,取代使用的是每个通道的均值和方差的近似值。最酷的是我们可以通过1×1卷积实现同样的行为。更好的是,我们可以把BN和前面的卷积合并。
Batch Normalization
设 x 是我们要 normalize 的网络中的一个信号(激活)。给定一组这样的信号来自于在一个batch 中 处理不同的样本,每一个都被normalized 如下:
和是在一个 batch 上计算的 均值和方差(mean and variance),是数值稳定性的一个小常数,是比例因子,是 转换因子。在训练期间,和对于每个 batch 都被重新计算:
参数和和网络的其他参数一起从梯度下降慢慢地中学习。在测试期间,我们通常不会在图像的一个batch上运行网络。因此,前面提到公式中的和不能使用。我们使用在训练中通过指数移动平均( exponential moving average )计算它们的估计值。让我们标记它们的近似值为和。
目前,batch normalization 主要应用于卷积神经网络对图像的处理。在该设置中,输入特征图的每个通道都有均值和方差估计、比例和转换参数。我们将这些表示为:对于通道c:,,和
解决
实现冻结的Batch Normalization为一个1×1 Convolution
给定一个具有形状顺序的特征图F, 为了得到 它的 normalized 版本。使用上面的公式,我们需要计算每个空间位置的。
我们清晰的看到:这是,它可以实现为 一个 1×1 Convolution。甚至,因为BN经常在卷积层后,我们可以把卷积和BN融合为一个。
使用一个卷积层融合batch normalization
设,和– 是BN的参数
和– 是在BN前面的卷积层的参数
-卷积层的输入
– 输入层的通道数
k – filter 的size
的部分被 reshaped 到一个shape 为的向量, 因此产生的公式:
因此,我们可以使用下面的参数通过一个单个卷积层替换卷积+BN 两层。
- filter weights:
- bias:
使用 PyTorch 实现:
import torch
import torchvision
def fuse(conv, bn):
fused = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
# setting weights
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
# setting bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros( conv.weight.size(0) )
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fused.bias.copy_( b_conv + b_bn )
return fused
# Testing
# we need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
resnet18 = torchvision.models.resnet18(pretrained=True)
# removing all learning variables, etc
resnet18.eval()
model = torch.nn.Sequential(
resnet18.conv1,
resnet18.bn1
)
f1 = model.forward(x)
fused = fuse(model[0], model[1])
f2 = fused.forward(x)
d = (f1 - f2).mean().item()
print("error:",d)
其他参考:
移动平均(Moving Average) – 知乎
深度学习推理时融合BN,轻松获得约5%的提速 – 知乎
【基础算法】六问透彻理解BN(Batch Normalization) – 知乎
BN和Dropout在训练和测试时的差别 – 知乎
7.5。批量规范化 — 动手学深度学习 2.0.0-beta0 documentation
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
文章出处登录后可见!