今天呢,并不是什么重要的日子,但绝对是值得记录下Swin transformer的美好时刻。在写Swin transformer之前呢,会不禁有这样一个问题:Swin transformer要解决一个怎样的事情呢?这件事情就是:我们知道,在NLP任务中啊输入的token大小基本相同,而在CV领域例如目标检测中由于目标尺寸并不相同,那用单层级的模型就很难有好的效果;其次将transformer迁移到CV领域,由于图像分辨率高,像素点多,transformer基于全局的自注意力的计算将导致十分巨大的计算量,尤其是在分割任务中,高分辨率会使得计算复杂度呈现输入图片大小的二次方增长,这显然是不能接受的。
一、Swin Transformer’s Architecture
那Swin Transformer长什么样子呢?从下图啊我们可以看到,左边的就是Swin Transformer的全局架构,它包含Patch Partition、Linear Embedding、Swin Transformer Block、Patch Merging四大部分;右边是Swin Transformer Block结构图,这是两个连续的Swin Transformer Block块,一个是W-MSA,另一个是SW-MSA,根据Swin的Tiny版本,图中的Swin Transformer Block块为[2, 2, 6, 2],相对应的attention为:stage1—W-MSA—SW-MSA→stage2—W-MSA—SW-MSA→stage3—W-MSA—SW-MSA–W-MSA—SW-MSA–W-MSA—SW-MSA→stage4—W-MSA—SW-MSA
二、Swin Transformer
1. Patch Partition & Linear Embedding
Swin Transformer是如何工作的的呢?将一幅大小为的picture喂给Swin Transformer之后,图片首先要经过Patch Partition & Linear Embedding。Patch Patition要做一件怎样的事情呢?Patch Patition会将输入的
的图片打成
的patches,然后在channel方向展平(flatten),得到
的图片尺寸。怎么理解呢?假设输入的是RGB三通道图片,那么每个patch就有4×4=16个像素,然后每个像素有R、G、B三个值,所以展平后是4x4x3=48,所以通过Patch Partition后图像shape由[224, 224, 3]变成了[56,56,48]。那Linear Embedding需要做什么呢?在tiny版本中,Linear Embeding对Patch Partition之后的图像的channel做线性变换,由48维映射到96维。即图像的shape再由[56, 56, 48]变成了[56, 56, 96],所以总结一下这一步的输入输出为:
输入为(B, 224, 224,3)
输出为(B, 56, 56,96) —> (B, 224/4=56, 224/4=56,96)
在实现的时候呢使用PatchEmbed函数将这两步结合起来,实际上也就是用了一个卷积的操作:卷积核大小为(4, 4),步长为4:
nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
2. Basic Layer
在代码实现中,将Swin Transformer Block和Patch Merging合并,叫做Basic Layer。
2.1 Swin Transformer Block
Swin Transformer Block内部长什么样子呢?Swin Transformer Block就长这样子:由两个连续Block组成,左边的我们把它叫做Window Multi-head Self-Attention,右边的呢我们把它叫做Shifted Window Multi-head Self-Attention。特征图呢首先经过layer norm层,经过W-MSA后进行一个residual连接,再将得到的向量输入到一个layer norm层,通过MLP后再进行一个residual connected得到输出的向量,并将输出的向量喂给右边的Shifted Window Multi-head Self-Attention,Shifted Window Multi-head Self-Attention重复相同的process,唯一不同的是在通过第一个layer norm层之后是输入到SW-MSA里
那什么是W-MSA呢?W-MSA就是Window Mutii-head self-attention的简称,实际上并不复杂,它要做的事情就是:将feature map按照大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。那这里面维度的变化是怎样的呢?W-MSA在第一个block中,这一步并没有滑动窗,输入维度大小为[B, 3136, 96] (56×56=3136),为了后面的sefl-attention操作,需要将特征图划分为一个个窗口的形式,首先经历了一个window partition操作,变为[64B, 49, 96],这一步怎么计算的呢?输入为:[batch=B,3136=56 x 56,feature maps有96个],在每个
的特征图上划分
的窗口,一共能分
个,乘上前面的
就是
了。
那为什么要进行window partition操作呢?在Vision Transformer中,我们将图片分成了一个个patch(也就是上面左边的图),在进行MSA时,任何一个patch都要与其他所有的patch都进行attention,当patch的大小固定时,计算量与图片的大小成平方增长。Swin Transformer中采用了W-MSA,也就是window的形式,不同的window包含了相同数量的patch,只对window内部进行MSA,当图片大小增大时,计算量仅仅是呈线性增加,也就是说只增加了图片多余部分的计算量,比如之前是224的图像,现在是256的图像,只多了256-224=32像素的计算部分。
将窗口分配完成后就可以执行attention操作了,attention的操作我们在之前的blogSelf-attention is all you need中详细写过。首先我们将维度变换为[64B, 49, 96],进行attention操作时,我们需要qkv三个变量,transformer是通过Linear函数来实现的:
nn.Linear(dim, dim * 3, bias=qkv_bias)
通过这个函数后,维度变为[64B, 49, 288],qkv分别占三分之一,也就是说qkv分别为[64B, 49, 96],第一个阶段的head为3,维度划分为[64B, 3, 49, 32],这就是进行attention时qkv的维度:
q: (64B, 3, 49, 32)
k: (64B, 3, 49, 32)
v: (64B, 3, 49, 32)
接下来进行attention操作:
这里的偏置叫做Relative Position Bias
在window里计算完attention之后,要做的一个事情是把window还原回feature map,这件事其实很显然,因为你划window的目的只是为了减少计算量而已。那具体要怎么操作呢?很简单,我们经过一个window reverse操作就可以回到window partition之前的状态即[B, 56, 56, 96]。window reverse就是window partition的逆过程。
好,总结一下我们在W-MSA里做的事情:
首先进行window partition操作,将输入维度[B,3136,96]变为[64B,49,96];
随后在得到qkv后进行attention操作;
attention后通过window reverse将window维度还原回feature map维度[B,3136,96]
接下来就轮到SW-MSA,什么是SW-MSA呢?SW-MSA就是Sifted window Multi-head Self-attention的简称。Sifted window Multi-head Self-attention在做一件怎样的事情呢?前面有说,采用W-MSA模块时,只会在每个窗口内进行attention计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,SW-MSA会将窗口进行偏移(如下图),偏移后不同的window便能够相互进行信息的交流。既然要偏移,那偏移多少合适呢?偏移量应为:win_size // 2,在这里也就是。下图中黑色的线所围成的区域代表feature map,一个小黄框表示一个window
之前有说W-MSA和SW-MSA是成对使用的,第L层使用W-MSA,那么第L+1层使用的就是SW-MSA。对于下图右侧的第一行第二列的2×4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。再比如,图右侧的第二行第二列的4×4的窗口,他能够使第L层的四个窗口信息进行交流,其他的同理。那么这就解决了不同窗口之间无法进行信息交流的问题。
不同window间的信息交互问题解决了,那随之而来的另一个问题是:原先只需要算四个window的attention,现在一下要算九个窗口的attention了,计算量显然增大不止“亿点点”,那要怎么解决由于shifted window带来的多余的计算量呢?论文作者提供给我们的方法是:Efficient batch computation for shifted configuration,下面就是原论文中该种方法的示意图:
那Efficient batch computation for shifted configuration是怎么运作的呢?下图左侧是刚刚通过偏移窗口后得到的新窗口,右侧是为了方便大家理解,对每个窗口加上了一个标识。然后0对应的窗口标记为区域A,3和6对应的窗口标记为区域B,1和2对应的窗口标记为区域C。(该部分参考Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客)
先将区域A和C移到最下方(Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客)
再将区域A和B同时移至最右侧(Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客)
偏移过后,4号单独形成一个窗口,5&3为一个窗口,7&1为一个窗口,8&6&2&0为最后一个窗口。这样一来和原来一样是4个4×4的窗口了,所以能够计算复杂度是一样的。
虽然计算量的问题解决了,新的问题又来了。虽然对于4号区域来说里面的元素都是互相紧挨着的,它们之间可以互相去做attention。但是对于剩下的几个窗口来说,它们里面的元素是从别的地方搬过来的,按道理来说它们之间的元素是不应该去做attention的,换句话说这些元素之间就不应该有太大的联系。那怎么解决这个问题呢?可以使用带掩码(mask)的MSA。那具体怎么做呢?以上图的区域5和区域3为例,对于该窗口内的每一个像素(或称token,patch)在进行MSA计算时,都要先生成对应的query(q),key(k),value(v)。假设对于下图的像素0而言,得到后要与每一个像素的k算attention,假设
是
与像素0对应的
计算的到的attention score,那么同理可以得到
~
。按照普通的MSA计算,接下来就是SoftMax操作了。但对于这里的masked MSA,像素0是属于区域5的,我们只想让它和区域5内的像素进行匹配。那么我们可以将像素0与区域3中的所有像素匹配结果都减去100,例如
等。由于
的值都很小,一般都是零点几的数,将其中一些数减去100后在通过SoftMax得到对应的权重都等于0了。所以对于像素0而言实际上还是只和区域5内的像素进行了MSA。那么对于其他像素也是同理,Swin-Transformer网络结构详解_霹雳吧啦Wz-CSDN博客
计算完masked MSA之后呢,还需要做最后一步,就是把位移的区域还原回去,也就是说把A,B,C还原回原来的位置上去,原因呢是我们仍希望保持原来这个feature map的相对位置是不变的,整体图片的语义信息也是不变的。如果不把这个位移还原的话,那相当于我们在把feature map不停地往右下角位移,这样的话这个图片的语义信息很可能就被破坏掉了。
好,总结一下我们在SW-MSA里做的事情:
在移动窗口后得到了9个窗口,但是窗口之间的每个patch数量都不一样
为了到达能够批次处理,减少计算复杂度,把9个窗口变成4个
然后用巧妙的masked方式让每个窗口之间可以合理的计算attention
最后再把算好的区域还原,就完成了基于移动窗口的attention的计算
2.2 Patch Merging
在每个stage结束的阶段都有一个Patch Merging的过程,它相当于CNN里的下采样操作。Patch Merging会将每个2×2的相邻像素划分为一个patch,然后将每个patch中相同位置(同一颜色)像素给拼在一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2。通过这个简单的例子可以看出,通过Patch Merging层后,feature map的高和宽会减半,深度会翻倍。
版权声明:本文为博主周周周周周大帅原创文章,版权归属原作者,如果侵权,请联系我们删除!
原文链接:https://blog.csdn.net/m0_57541899/article/details/122965037