【计算机视觉】ViT:代码逐行解读

文章目录

  • 一、代码
  • 二、代码解读
    • 2.1 大体理解
    • 2.2 详细理解

一、代码

import torch
import torch.nn as nn
from einops import rearrange

from self_attention_cv import TransformerEncoder


class ViT(nn.Module):
    def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
        """
        Args:
            img_dim: the spatial image size
            in_channels: number of img channels
            patch_dim: desired patch dim
            num_classes: classification task classes
            dim: the linear layer's dim to project the patches for MHSA
            blocks: number of transformer blocks
            heads: number of heads
            dim_linear_block: inner dim of the transformer linear block
            dim_head: dim head in case you want to define it. defaults to dim/heads
            dropout: for pos emb and transformer
            transformer: in case you want to provide another transformer implementation
            classification: creates an extra CLS token
        """
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
        self.p = patch_dim
        self.classification = classification
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, dim)

        self.emb_dropout = nn.Dropout(dropout)
        if self.classification:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
            self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
            self.mlp_head = nn.Linear(dim, num_classes)
        else:
            self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

        if transformer is None:
            self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                  dim_head=self.dim_head,
                                                  dim_linear_block=dim_linear_block,
                                                  dropout=dropout)
        else:
            self.transformer = transformer

    def expand_cls_to_batch(self, batch):
        """
        Args:
            batch: batch size
        Returns: cls token expanded to the batch size
        """
        return self.cls_token.expand([batch, -1, -1])

    def forward(self, img, mask=None):
        batch_size = img.shape[0]
        img_patches = rearrange(
            img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        # project patches with linear layer + add pos emb
        img_patches = self.project_patches(img_patches)

        if self.classification:
            img_patches = torch.cat(
                (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

        patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

        # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
        y = self.transformer(patch_embeddings, mask)

        if self.classification:
            # we index only the cls token for classification. nlp tricks :P
            return self.mlp_head(y[:, 0, :])
        else:
            return y

二、代码解读

2.1 大体理解

这段代码是一个实现了 Vision Transformer(ViT)模型的 PyTorch 实现。

ViT 是一个基于 Transformer 架构的图像分类模型,其主要思想是将图像分成一个个固定大小的 patch ,并将这些 patch 看做是一个个 token 输入到 Transformer 中进行特征提取和分类。

以下是对代码的解读:

  1. ViT类继承自nn.Module类,其构造函数有一系列参数,包括输入图像的尺寸、patch的大小、输出类别数、注意力机制中的头数等等。
  2. project_patches函数通过一个全连接层将每个patch映射到一个d维的特征空间中。
  3. 如果classification = True,则将一个额外的CLS token添加到输入的token序列的开头,即对于每张图像添加一个形状为[1, 1, d]的CLS token。同时,在ViT中采用的是绝对位置编码,因此还添加了一个1D的位置编码向量,其形状为[num_patches + 1, d],其中num_patches表示图像被划分成的patch数目。如果classification = False,则不添加CLS token。
  4. forward函数首先将输入的图像进行patch划分,并通过project_patches函数将每个patch映射到d维特征空间中。接着,将位置编码向量加到映射后的patch特征向量上,并进行dropout处理。如果classification=True,则在特征序列开头添加CLS token。接着将这些特征输入到Transformer中,进行特征提取。最后输出分类结果,如果classification=True,则只返回CLS token的分类结果。

2.2 详细理解

from self_attention_cv import TransformerEncoder

self_attention_cv是一个基于PyTorch实现的库,提供了在计算机视觉任务中使用自注意力机制的模块和网络,例如Transformer EncoderAttention Modules

它主要针对图像分类、对象检测、语义分割等任务,支持多种自注意力模块的实现,包括Simplified Self-AttentionFull Self-AttentionLocal Self-Attention等。此外,该库还提供了一些常见的计算机视觉任务模型的实现,例如Vision Transformer(ViT)Swin Transformer等。

TransformerEncoder是一个自注意力机制的编码器,用于将输入序列转换为编码后的序列。自注意力机制允许模型能够根据输入序列中的其他位置来加权计算每个位置的表示。这种机制在自然语言处理中的应用非常广泛,比如BERT、GPT等模型都采用了自注意力机制。

TransformerEncoder是基于PyTorch实现的,可以在计算机视觉任务中使用,例如图像分类、对象检测、语义分割等。它支持多头注意力、残差连接和LayerNorm等特性。在这个代码中,ViT模型中的Transformer部分采用了TransformerEncoder作为默认的实现。

def __init__(self, *,
                img_dim,
                in_channels=3,
                patch_dim=16,
                num_classes=10,
                dim=512,
                blocks=6,
                heads=4,
                dim_linear_block=1024,
                dim_head=None,
                dropout=0, transformer=None, classification=True):
    super().__init__()
    assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification
    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

    if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

这段代码定义了一个名为 ViT 的 PyTorch 模型类,它是一个使用自注意力机制(Self-Attention)实现的视觉 Transformer 模型。其中主要参数包括:

  • img_dim:输入图片的空间大小
  • in_channels:输入图片的通道数
  • patch_dim:将图片划分成固定大小的 patch 的大小
  • num_classes:分类任务的类别数
  • dim:线性层的维度,用于将每个 patch 投影到 MHSA 空间
  • blocks:Transformer 模型中的块数
  • heads:注意力头的数量
  • dim_linear_block:线性块内部的维度
  • dim_head:每个头的维度,如果没有指定则默认为 dim/heads
  • dropout:用于位置编码和 Transformer 的 dropout 概率
  • transformer:可选的 TransformerEncoder 类实例
  • classification:是否包含额外的 CLS 标记以用于分类任务
def __init__(self, *,
                 img_dim,
                 in_channels=3,
                 patch_dim=16,
                 num_classes=10,
                 dim=512,
                 blocks=6,
                 heads=4,
                 dim_linear_block=1024,
                 dim_head=None,
                 dropout=0, transformer=None, classification=True):
    super().__init__()

这里定义了 ViT 类的构造函数,其包含多个参数,包括输入图像大小 img_dim,输入通道数 in_channels,分块大小 patch_dim,分类数目 num_classes,嵌入维度 dim,Transformer编码器的块数 blocks,头数 heads,线性块的维度 dim_linear_block,注意力头维度 dim_head,Dropout概率 dropout,可选的Transformer编码器 transformer,以及是否进行分类的标志 classification

    assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible'
    self.p = patch_dim
    self.classification = classification

这里检查 img_dim 是否能够被 patch_dim 整除,如果不能整除,则会引发断言错误。同时,将 patch_dim 存储到 self.p 中,并将是否进行分类的标志存储到 self.classification 中。

    tokens = (img_dim // patch_dim) ** 2
    self.token_dim = in_channels * (patch_dim ** 2)
    self.dim = dim
    self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
    self.project_patches = nn.Linear(self.token_dim, dim)

这里计算了输入图像中可分块的数量 tokens,并将每个块的维度 self.token_dim 设置为 in_channels * (patch_dim ** 2)

将嵌入维度 dim 存储到 self.dim 中,并根据 dim_head 是否为 None,设置注意力头维度 self.dim_headself.project_patches 是一个线性层,用于将每个块投影到嵌入空间中。

    self.emb_dropout = nn.Dropout(dropout)
    if self.classification:
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim))
        self.mlp_head = nn.Linear(dim, num_classes)
    else:
        self.pos_emb1D = nn.Parameter(torch.randn(tokens, dim))

这里定义了嵌入层的Dropout层,并根据是否进行分类的标志,设置类别标记 self.cls_token、位置嵌入 self.pos_emb1DMLPself.mlp_head。如果不进行分类,则不需要 self.cls_tokenself.mlp_head

if transformer is None:
        self.transformer = TransformerEncoder(dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block,
                                                dropout=dropout)
    else:
        self.transformer = transformer

self.emb_dropout = nn.Dropout(dropout): 定义了一个dropout层,用于在embedding后对其进行dropout操作。

if self.classification:: 如果是分类任务,就执行下面的操作,否则跳过。

self.cls_token = nn.Parameter(torch.randn(1, 1, dim)): 定义了一个可训练参数cls_token,表示分类token,它是一个1x1xdim的tensor,其中dim表示embedding维度。

self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, dim)): 定义了一个可训练参数pos_emb1D,表示位置嵌入,它是一个(tokens+1)xdim的tensor,其中tokens表示图像被分成的patch数,dim表示embedding维度。

self.mlp_head = nn.Linear(dim, num_classes): 定义了一个全连接层,将embedding映射到输出类别的数量。

最后,根据传入的参数来选择使用默认的TransformerEncoder,还是使用传入的transformer。如果没有传入,则使用默认的TransformerEncoder,否则使用传入的transformer。

def expand_cls_to_batch(self, batch):
    """
    Args:
        batch: batch size
    Returns: cls token expanded to the batch size
    """
    return self.cls_token.expand([batch, -1, -1])

该方法的作用是将 Transformer 中的分类 token 扩展到整个批次的样本数。它接受一个 batch 参数作为批次大小,返回一个形状为 [batch, 1, dim] 的张量,其中 dim 是 Transformer 模型的维度大小。在这个方法中,使用了 PyTorch 的 expand() 方法来实现扩展操作。

def forward(self, img, mask=None):
    batch_size = img.shape[0]
    img_patches = rearrange(
        img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                            patch_x=self.p, patch_y=self.p)
    # project patches with linear layer + add pos emb
    img_patches = self.project_patches(img_patches)

    if self.classification:
        img_patches = torch.cat(
            (self.expand_cls_to_batch(batch_size), img_patches), dim=1)

    patch_embeddings = self.emb_dropout(img_patches + self.pos_emb1D)

    # feed patch_embeddings and output of transformer. shape: [batch, tokens, dim]
    y = self.transformer(patch_embeddings, mask)

    if self.classification:
        # we index only the cls token for classification. nlp tricks :P
        return self.mlp_head(y[:, 0, :])
    else:
        return y

forward 函数中,接收输入的 imgmask

通过 img_dimpatch_dim 计算出 tokens 数量,其中 tokens 为图像分割成的块的数量。

将输入的 img 分成 patch,并通过 rearrange 函数重组成形状为 [batch_size, tokens, patch_dim * patch_dim * in_channels] 的张量。

通过 Linear 层将每个 patch 映射到 dim 维度,并加上位置编码 pos_emb1D

如果是用于分类任务,则在序列的开头插入一个 CLS token,然后与处理后的 patch 张量按列拼接。

对 patch_embeddings 应用 dropout,并输入到 TransformerEncoder 中,返回输出张量 y,形状为 [batch_size, tokens, dim]

如果是用于分类任务,则从 y 中取出 CLS token,输入到一个 Linear 层中进行分类,输出分类结果。

如果不是分类任务,则直接返回 y。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
青葱年少的头像青葱年少普通用户
上一篇 2023年12月6日
下一篇 2023年12月6日

相关推荐