ICCV2021 Best Paper : Swin Transformer (一)

今天呢,并不是什么重要的日子,但绝对是值得记录下Swin transformer的美好时刻。在写Swin transformer之前呢,会不禁有这样一个问题:Swin transformer要解决一个怎样的事情呢?这件事情就是:我们知道,在NLP任务中啊输入的token大小基本相同,而在CV领域例如目标检测中由于目标尺寸并不相同,那用单层级的模型就很难有好的效果;其次将transformer迁移到CV领域,由于图像分辨率高,像素点多,transformer基于全局的自注意力的计算将导致十分巨大的计算量,尤其是在分割任务中,高分辨率会使得计算复杂度呈现输入图片大小的二次方增长,这显然是不能接受的。

一、Swin Transformer’s  Architecture

那Swin Transformer长什么样子呢?从下图啊我们可以看到,左边的就是Swin Transformer的全局架构,它包含Patch Partition、Linear Embedding、Swin Transformer Block、Patch Merging四大部分;右边是Swin Transformer Block结构图,这是两个连续的Swin Transformer Block块,一个是W-MSA,另一个是SW-MSA,根据Swin的Tiny版本,图中的Swin Transformer Block块为[2, 2, 6, 2],相对应的attention为:stage1—W-MSA—SW-MSA→stage2—W-MSA—SW-MSA→stage3—W-MSA—SW-MSA–W-MSA—SW-MSA–W-MSA—SW-MSA→stage4—W-MSA—SW-MSA

ICCV2021 Best Paper : Swin Transformer (一)

二、Swin Transformer

1. Patch Partition & Linear Embedding

Swin Transformer是如何工作的的呢?将一幅大小为224\times 224\times 3的picture喂给Swin Transformer之后,图片首先要经过Patch Partition & Linear Embedding。Patch Patition要做一件怎样的事情呢?Patch Patition会将输入的224\times 224\times 3的图片打成4\times 4的patches,然后在channel方向展平(flatten),得到56\times56\times48的图片尺寸。怎么理解呢?假设输入的是RGB三通道图片,那么每个patch就有4×4=16个像素,然后每个像素有R、G、B三个值,所以展平后是4x4x3=48,所以通过Patch Partition后图像shape由[224, 224, 3]变成了[56,56,48]。那Linear Embedding需要做什么呢?在tiny版本中,Linear Embeding对Patch Partition之后的图像的channel做线性变换,由48维映射到96维。即图像的shape再由[56, 56, 48]变成了[56, 56, 96],所以总结一下这一步的输入输出为:

输入为(B, 224, 224,3)
输出为(B, 56, 56,96) —> (B, 224/4=56, 224/4=56,96)

在实现的时候呢使用PatchEmbed函数将这两步结合起来,实际上也就是用了一个卷积的操作:卷积核大小为(4, 4),步长为4:

nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

2. Basic Layer

在代码实现中,将Swin Transformer Block和Patch Merging合并,叫做Basic Layer。

2.1 Swin Transformer Block

Swin Transformer Block内部长什么样子呢?Swin Transformer Block就长这样子:由两个连续Block组成,左边的我们把它叫做Window Multi-head Self-Attention,右边的呢我们把它叫做Shifted Window Multi-head Self-Attention。特征图呢首先经过layer norm层,经过W-MSA后进行一个residual连接,再将得到的向量输入到一个layer norm层,通过MLP后再进行一个residual connected得到输出的向量,并将输出的向量喂给右边的Shifted Window Multi-head Self-Attention,Shifted Window Multi-head Self-Attention重复相同的process,唯一不同的是在通过第一个layer norm层之后是输入到SW-MSA里

ICCV2021 Best Paper : Swin Transformer (一)

那什么是W-MSA呢?W-MSA就是Window Mutii-head self-attention的简称,实际上并不复杂,它要做的事情就是:将feature map按照M\times M大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。那这里面维度的变化是怎样的呢?W-MSA在第一个block中,这一步并没有滑动窗,输入维度大小为[B, 3136, 96] (56×56=3136),为了后面的sefl-attention操作,需要将特征图划分为一个个窗口的形式,首先经历了一个window partition操作,变为[64B, 49, 96],这一步怎么计算的呢?输入为:[batch=B,3136=56 x 56,feature maps有96个],在每个56\times56的特征图上划分7\times7的窗口,一共能分8\times8=64个,乘上前面的B就是64B了。

ICCV2021 Best Paper : Swin Transformer (一)那为什么要进行window partition操作呢?在Vision Transformer中,我们将图片分成了一个个patch(也就是上面左边的图),在进行MSA时,任何一个patch都要与其他所有的patch都进行attention,当patch的大小固定时,计算量与图片的大小成平方增长。Swin Transformer中采用了W-MSA,也就是window的形式,不同的window包含了相同数量的patch,只对window内部进行MSA,当图片大小增大时,计算量仅仅是呈线性增加,也就是说只增加了图片多余部分的计算量,比如之前是224的图像,现在是256的图像,只多了256-224=32像素的计算部分。

将窗口分配完成后就可以执行attention操作了,attention的操作我们在之前的blogSelf-attention is all you need中详细写过。首先我们将维度变换为[64B, 49, 96],进行attention操作时,我们需要qkv三个变量,transformer是通过Linear函数来实现的:

nn.Linear(dim, dim * 3, bias=qkv_bias)

通过这个函数后,维度变为[64B, 49, 288],qkv分别占三分之一,也就是说qkv分别为[64B, 49, 96],第一个阶段的head为3,维度划分为[64B, 3, 49, 32],这就是进行attention时qkv的维度:

q: (64B, 3, 49, 32)
k: (64B, 3, 49, 32)
v: (64B, 3, 49, 32)

接下来进行attention操作:

\text { Attention }(Q, K, V)=\operatorname{SoftMax}\left(\frac{Q K^{T}}{\sqrt{d}}+B\right) V

这里的偏置B叫做Relative Position Bias

在window里计算完attention之后,要做的一个事情是把window还原回feature map,这件事其实很显然,因为你划window的目的只是为了减少计算量而已。那具体要怎么操作呢?很简单,我们经过一个window reverse操作就可以回到window partition之前的状态即[B, 56, 56, 96]。window reverse就是window partition的逆过程。

好,总结一下我们在W-MSA里做的事情:

首先进行window partition操作,将输入维度[B,3136,96]变为[64B,49,96];

随后在得到qkv后进行attention操作;

attention后通过window reverse将window维度还原回feature map维度[B,3136,96]

接下来就轮到SW-MSA,什么是SW-MSA呢?SW-MSA就是Sifted window Multi-head Self-attention的简称。Sifted window Multi-head Self-attention在做一件怎样的事情呢?前面有说,采用W-MSA模块时,只会在每个窗口内进行attention计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,SW-MSA会将窗口进行偏移(如下图),偏移后不同的window便能够相互进行信息的交流。既然要偏移,那偏移多少合适呢?偏移量应为:win_size // 2,在这里也就是7//2=3。下图中黑色的线所围成的区域代表feature map,一个小黄框表示一个window

ICCV2021 Best Paper : Swin Transformer (一)ICCV2021 Best Paper : Swin Transformer (一)

之前有说W-MSA和SW-MSA是成对使用的,第L层使用W-MSA,那么第L+1层使用的就是SW-MSA。对于下图右侧的第一行第二列的2×4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。再比如,图右侧的第二行第二列的4×4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。

ICCV2021 Best Paper : Swin Transformer (一)

不同window间的信息交互问题解决了,那随之而来的另一个问题是:原先只需要算四个window的attention,现在一下要算九个窗口的attention了,计算量显然增大不止“亿点点”,那要怎么解决由于shifted window带来的多余的计算量呢?论文作者提供给我们的方法是:Efficient batch computation for shifted configuration,下面就是原论文中该种方法的示意图:

ICCV2021 Best Paper : Swin Transformer (一)

那Efficient batch computation for shifted configuration是怎么运作的呢?下图左侧是刚刚通过偏移窗口后得到的新窗口,右侧是为了方便大家理解,对每个窗口加上了一个标识。然后0对应的窗口标记为区域A,3和6对应的窗口标记为区域B,1和2对应的窗口标记为区域C。(该部分参考Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客)

ICCV2021 Best Paper : Swin Transformer (一)

先将区域A和C移到最下方(Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客)

ICCV2021 Best Paper : Swin Transformer (一)

再将区域A和B同时移至最右侧(Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客)

ICCV2021 Best Paper : Swin Transformer (一)

偏移过后,4号单独形成一个窗口,5&3为一个窗口,7&1为一个窗口,8&6&2&0为最后一个窗口。这样一来和原来一样是4个4×4的窗口了,所以能够计算复杂度是一样的。

虽然计算量的问题解决了,新的问题又来了。虽然对于4号区域来说里面的元素都是互相紧挨着的,它们之间可以互相去做attention。但是对于剩下的几个窗口来说,它们里面的元素是从别的地方搬过来的,按道理来说它们之间的元素是不应该去做attention的,换句话说这些元素之间就不应该有太大的联系。那怎么解决这个问题呢?可以使用带掩码(mask)的MSA。那具体怎么做呢?以上图的区域5和区域3为例,对于该窗口内的每一个像素(或称token,patch)在进行MSA计算时,都要先生成对应的query(q),key(k),value(v)。假设对于下图的像素0而言,得到q^{0}后要与每一个像素的k算attention,假设\alpha _{0,0}q^{0}与像素0对应的k^{0}计算的到的attention score,那么同理可以得到\alpha _{0,0}~\alpha _{0,15}。按照普通的MSA计算,接下来就是SoftMax操作了。但对于这里的masked MSA,像素0是属于区域5的,我们只想让它和区域5内的像素进行匹配。那么我们可以将像素0与区域3中的所有像素匹配结果都减去100,例如\alpha_{0,2}, \alpha_{0,3}, \alpha_{0,6}, \alpha_{0,7}等。由于\alpha的值都很小,一般都是零点几的数,将其中一些数减去100后在通过SoftMax得到对应的权重都等于0了。所以对于像素0而言实际上还是只和区域5内的像素进行了MSA。那么对于其他像素也是同理,Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客

ICCV2021 Best Paper : Swin Transformer (一)

计算完masked MSA之后呢,还需要做最后一步,就是把位移的区域还原回去,也就是说把A,B,C还原回原来的位置上去,原因呢是我们仍希望保持原来这个feature map的相对位置是不变的,整体图片的语义信息也是不变的。如果不把这个位移还原的话,那相当于我们在把feature map不停地往右下角位移,这样的话这个图片的语义信息很可能就被破坏掉了。

好,总结一下我们在SW-MSA里做的事情:

在移动窗口后得到了9个窗口,但是窗口之间的每个patch数量都不一样

为了到达能够批次处理,减少计算复杂度,把9个窗口变成4个

然后用巧妙的masked方式让每个窗口之间可以合理的计算attention

最后再把算好的区域还原,就完成了基于移动窗口的attention的计算

2.2 Patch Merging

在每个stage结束的阶段都有一个Patch Merging的过程,它相当于CNN里的下采样操作。Patch Merging会将每个2×2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。

ICCV2021 Best Paper : Swin Transformer (一)

版权声明:本文为博主周周周周周大帅原创文章,版权归属原作者,如果侵权,请联系我们删除!

原文链接:https://blog.csdn.net/m0_57541899/article/details/122965037

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2022年2月18日
下一篇 2022年2月18日

相关推荐