7.卷积和Transformer结合的ViT

7.卷积和Transformer结合的ViT

一、回顾与简介

前面两节课我们讲了Swin Transformer的结构,以及其中的细节部分,进行了实现,其中由Swin Block 以及 Patch Merging等等,上节课讲了 SW-MSA的shift和mask,对于shift之后,其中window中需要的保留,不需要的去掉,用到了boardcasting等等。

这节课我们考虑卷积和Transformer如何结合,我们首先看一下conv是如何计算的,如果输入是5×5,卷积核是3×3,进行一个滑窗之后,最终输出是3×3的特征图

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jMZHI9DA-1678713307018)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230311213828594.png)]

那我们看一下细节,如何实现

1×1的卷积可以对通道数进行修改,比如我们现在是4×4×3的图像,用1×1卷积核计算完,如果out_channels为4,那么输出的结果就是4×4×4的特征图

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nFo94gJJ-1678713307019)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312111033144.png)]

如果是4×4大小,3通道的图像,输出要求 图像大小不变,通道数变为4,Transformer有QKV三个映射,分别映射到embedd_dim的维度,通过进行QK’再乘以V[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mOOjKKfT-1678713307019)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312111313275.png)]

可以看出Transformer与CNN有点像,需要注意的是:

1.conv是利用滑窗来算,Attn是计算当前tensor中每一个位置对其他位置的关系,但是他们两个又可以同时完成一个状态,对输入的tensor求一个表征,并还可以给它改变一个维度

比如我们输入的尺寸没变,输入的通道数从3变为了4,Conv和Transformer都可以做到的

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Rj3vSC1G-1678713307019)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312111921153.png)]

如果我们将它们内部实现封起来,当作黑盒来看,它们完成的事情是类似的,只不过它们具体的计算方式不一样:

conv是一个相对于固定的窗口在输入的tensor上进行滑动的操作,乘以的是固定的conv kernel,我们可以叫做相对local的状态,算的是局部的信息;而Transformer谁都可以看,它看到的是gloab全局的信息,这就是它俩的区别。我可以单独用gloab的信息或者local的信息,但也有一种可能性是我们将它俩结合,目前conv是比Transformer计算量更小的,如果需要更小模型的场景上conv是更work的。[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-IixG15OQ-1678713307020)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312113209943.png)]

今天讲的 MobileViT 就使用了下图的这样一种的结构,其中使用了Transformer,红线是Residual map path,我们要做的就是Conv+Transformer

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Sow0WWKx-1678713307020)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312114106289.png)]

整体结构中,前面有一部分是传统的Conv层,中间有一部分是Transformer层,其中起了MV2 Block和MViT Block,MV2 Block对应的MobileNetV2的一个block的结构,类似Resnet block的结构,MViT Block等会儿我们会讲。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VkcPjxtP-1678713307020)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312114627328.png)]

通常我们会将前面的输入层叫做 Stem层(输入,3×3的卷积核,再加pooling)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LDOGHOLL-1678713307020)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312115321660.png)]

最后在Global Pool之后叫 head层

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LZgIezbp-1678713307020)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312115326583.png)]

在中间就是block层,看一下核心模块MV2 Block的网络结构,当Stride=1 或者 2的时候,架构是不一样的,其中用的激活函数是Silu函数在左侧显示,而Silu与Gelu的对比在右侧,蓝色线是Gelu,它们在计算时指数有一点点不一样。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-QyyjX4KB-1678713307020)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312195707806.png)]

我们可以看到Block通常接一个1×1的卷积,是为了调整tensor的通道数,而DConv是加了Group Conv,为了减少计算量,之后加BN和Silu,最后加1×1的卷积和BN是为了调回来,最后Add进行残差——这就是一个MV2 Block的一个架构。

MobileVit就是将其中MV2 Block中一些模块替换为ViT Block

二、论文

首先看Xl ,它是H×W×d,它是经过1×1卷积之后得到的tensor,经过一个flatten(non-overlapping)的切分后,得到了XU ,其中用到了Inter-patch relationship,这就是本篇论文的核心,得到经过Transformer的XG 之后,又做了point-wise convlution 以及 concatenation operation 还有N×N的卷积

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iGxFcNiM-1678713307021)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312200451608.png)]

读完文字后,脑子中可能还有点模糊,我们需要拿图像来更加直观和清晰地展现网络结构,上面3×3和1×1的卷积对应的就是第一段的内容,它对tensor无非就是做一些shape变换操作,之后还需要做一些操作才能进入Transformer,我们Transformer原则上是不改变tensor的尺度的,输出之后应该还是某一个tensor,我们再做后续的部分1×1的Conv等等

最关键的部分是画蓝色块的部分,比如卷积过后 H× W× d这么一个tensor,怎么样输入到Transformer中去?

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kIYqCKs6-1678713307021)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312201528936.png)]

具体来讲,最左边的输出是Conv后的结果,我们需要进行Patch Partition操作,这里其实可以叫做window操作,在H× W× d的这样一个尺寸上切出小window出来,比如我们切出9个window,A B C 到I,经过上一节课的学习,我们知道如果是Swin的话我们就在window上自注意力的计算,而MobileViT做的不一样

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xqf89Ihb-1678713307021)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312203432834.png)]

在Transformer做的操作是将 H× W× d 输入进行flatten拉直,比如我们可以将A扯平,变为1×d,可以将每个patch展开成为中间这样

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-L0JxzLS7-1678713307021)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312203410879.png)]

其中会将A B C到 I的第一个token取出来,送到Transformer中去,当然Transformer不改变维度,第一个做完继续做第二个token,再送到Transformer中进行计算。

从整体看,我们假如是2×2的窗口,我们第一个取每个窗口中第一个元素,来算Self-attention;之后再算第二个位置的,如果是2×2的窗口,一共做4次,如果是3×3的窗口,那就做9次,但是MobileViT是固定2×2的窗口

这么做的目的是:为了看到全局的信息,还要减少计算量,和空洞卷积类似

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Wmq1EqNg-1678713307021)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312203740758.png)]

三、代码实现

3.1 Swin完善

上一届对Swin Transformer的 WindowAttention的MSA,SW-MSA,SwinBlock进行了代码

今天我们对Swin Stage以及整个Swin 的Class代码进行编写

# CLass 7
class SwinStage(nn.Module):
    def __init__(self, dim, input_resolution, depth, num_heads, window_size, patch_merging = None):
        super().__init__()
        self.blocks = nn.ModuleList()
        for i in range(depth):
            self.blocks.append(
                SwinBlock(dim= dim,
                          input_reslution=input_resolution,
                          num_heads= num_heads,
                          window_size = window_size,
                          shift_size= 0 if (i % 2 ==0) else window_size//2))
        if patch_merging is None:
            # 最后一个 stage不做patch_merging
            self.patch_merging = Indentity()
        else:
            self.patch_merging = PatchMerging(input_resolution=input_resolution,dim = dim)
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.patch_merging(x)
        return x

Swin块:

# CLass 7
class Swin(nn.Module):
    def __init__(self,
                 image_size=224,
                 patch_size=4,
                 in_channels=4,
                 embed_dim=96,
                 window_size=7,
                 num_heads=[3,6,12,24],
                 depths=[2,2,6,2],
                 num_classes=1000):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.num_stages = len(depths)
        self.num_features = int(self.embed_dim * 2 **(self.num_stages-1))
        self.patch_resolution = [image_size // patch_size , image_size // patch_size]

        self.patch_embedding = PatchEmbedding(patch_size=patch_size, embed_dim= embed_dim)
        self.stages =nn.ModuleList()
        for idx, (depth,num_heads) in enumerate(zip(self.depths, self.num_heads)):
            stage = SwinStage(dim=int(self.embed_dim * 2 **idx),
                              input_resolution= (self.patch_resolution[0]//(2 ** idx),
                                                 self.patch_resolution[0]//(2 ** idx)),
                              depth = depth,
                              num_heads = num_heads,
                              window_size = window_size,
                              patch_merging= PatchMerging if (idx < self.num_stages-1) else None)
            self.stages.append(stage)
        self.norm = nn.LayerNorm(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(self.num_features,self.num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        for stage in self.stages:
            x = stage(x)
        x = self.norm(x)
        x = x.permute(0,2,1)
        # [B, embed_dim, num_windows]
        x = self.avgpool(x)  #  [B, embed_dim, 1]
        x = x.flatten(1)
        x = self.fc(x)

        return x

最终main函数就将之前的整合进入 SwinStage中了

def main():
    t = torch.randn([4, 3, 224, 224])
    # patch_embedding = PatchEmbedding(patch_size=4, embed_dim=96)
    # swin_block_w_msa = SwinBlock(dim=96, input_reslution=[56,56], num_heads=4, window_size=7,shift_size=0)
    # swin_block_sw_msa = SwinBlock(dim=96, input_reslution=[56,56], num_heads=4, window_size=7,shift_size=7//2)
    #
    # patch_merging = PatchMerging(input_resolution=[56,56], dim=96)
    #
    # print('image shape = [4, 3, 224, 224]')
    # out = patch_embedding(t)  #[4, 56, 56, 96]
    # print('patch_embedding out shape= ',out.shape)
    # out = swin_block_w_msa(out)
    # out = swin_block_sw_msa(out)
    # print('swinBlock out shape= ',out.shape)
    # out = patch_merging(out)
    # print('patch_merging out shape= ',out.shape)
    model = Swin()
    print(model)
    out = model(t)
    print(out.shape)

输出结果如下

模型结构

Swin(
  (patch_embedding): PatchEmbedding(
    (patch_embed): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (stages): ModuleList(
    (0): SwinStage(
      (blocks): ModuleList(
        (0): SwinBlock(
          (attn_norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (proj): Linear(in_features=96, out_features=96, bias=True)
          )
          (mlp_norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinBlock(
          (attn_norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=96, out_features=288, bias=True)
            (proj): Linear(in_features=96, out_features=96, bias=True)
          )
          (mlp_norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=96, out_features=384, bias=True)
            (fc2): Linear(in_features=384, out_features=96, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (patch_merging): PatchMerging(
        (reduction): Linear(in_features=384, out_features=192, bias=True)
        (norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      )
    )
    (1): SwinStage(
      (blocks): ModuleList(
        (0): SwinBlock(
          (attn_norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (proj): Linear(in_features=192, out_features=192, bias=True)
          )
          (mlp_norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, bias=True)
            (fc2): Linear(in_features=768, out_features=192, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinBlock(
          (attn_norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=192, out_features=576, bias=True)
            (proj): Linear(in_features=192, out_features=192, bias=True)
          )
          (mlp_norm): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=192, out_features=768, bias=True)
            (fc2): Linear(in_features=768, out_features=192, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (patch_merging): PatchMerging(
        (reduction): Linear(in_features=768, out_features=384, bias=True)
        (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (2): SwinStage(
      (blocks): ModuleList(
        (0): SwinBlock(
          (attn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (mlp_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinBlock(
          (attn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (mlp_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (2): SwinBlock(
          (attn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (mlp_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (3): SwinBlock(
          (attn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (mlp_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (4): SwinBlock(
          (attn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (mlp_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (5): SwinBlock(
          (attn_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=384, out_features=1152, bias=True)
            (proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (mlp_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=384, out_features=1536, bias=True)
            (fc2): Linear(in_features=1536, out_features=384, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (patch_merging): PatchMerging(
        (reduction): Linear(in_features=1536, out_features=768, bias=True)
        (norm): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      )
    )
    (3): SwinStage(
      (blocks): ModuleList(
        (0): SwinBlock(
          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (1): SwinBlock(
          (attn_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (softmax): Softmax(dim=-1)
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (act): GELU(approximate='none')
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
      )
      (patch_merging): Indentity()
    )
  )
  (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (avgpool): AdaptiveAvgPool1d(output_size=1)
  (fc): Linear(in_features=768, out_features=1000, bias=True)
)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DX11tzfw-1678713307022)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312212834626.png)]

3.2 数据加载

那我们正常情况需要从数据集中加载数据,那数据如何取呢?

首先看一下python的 for 循环在做什么

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zYeLn6cb-1678713307022)(C:\Users\qiaoqiang\AppData\Roaming\Typora\typora-user-images\image-20230312213031925.png)]

比如我想for循环遍历,那我应该怎么写MyInterable方法呢?

my_iterable = MyInterable()
for d in my_iterable:
    print(d)

我们需要实现**getitem类,来返回__getitem__对象,然后__getitem__中通过next方法来取值,如果想手动选择索引,在MyInterable**中实现__getitem__方法

Dataset返回的是单个的样本,而Dataloader是返回一个batch的数据

查看下图,我们走一遍流程,for循环可以从dataloader中取数据,那它本质上调用了__iter__()方法,它其中其实就是调用了DataLoaderIter中__iter__()方法,其中有一个__next__()方法来取数据,next中又调用了fetch,fetch又从Sampler拿到indices,再去fetch data,之后再从Dataset中__getitem__()来数据。从 Dataset中取到的数据是list,需要通过collate_fn()帮我们把list变为tensor

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2023年11月3日
下一篇 2023年11月3日

相关推荐