详解Transformer中Self-Attention以及Multi-Head Attention

原文名称:Attention Is All You Need
原文链接:https://arxiv.org/abs/1706.03762

如果不想看文章的可以看下我在b站上录的视频:https://b23.tv/gucpvt

最近Transformer在CV领域很火,Transformer是2017年Google在Computation and Language上发表的,当时主要是针对自然语言处理领域提出的(之前的RNN模型记忆长度有限且无法并行化,只有计算完详解Transformer中Self-Attention以及Multi-Head Attention时刻后的数据才能计算详解Transformer中Self-Attention以及Multi-Head Attention时刻的数据,但Transformer都可以做到)。在这篇文章中作者提出了Self-Attention的概念,然后在此基础上提出Multi-Head Attention,所以本文对Self-Attention以及Multi-Head Attention的理论进行详细的讲解。在阅读本文之前,建议大家先去看下李弘毅老师讲的Transformer的内容。本文的内容是基于李宏毅老师讲的内容加上自己阅读一些源码进行的总结。

前言

如果之前你有在网上找过self-attention或者transformer的相关资料,基本上都是贴的原论文中的几张图以及公式,如下图,讲的都挺抽象的,反正就是看不懂(可能我太菜的原因)。就像李弘毅老师课程里讲到的"不懂的人再怎么看也不会懂的"。那接下来本文就结合李弘毅老师课上的内容加上原论文的公式来一个个进行详解。

attention is all you need

Self-Attention

下面这个图是我自己画的,为了方便大家理解,假设输入的序列长度为2,输入就两个节点详解Transformer中Self-Attention以及Multi-Head Attention,然后通过Input Embedding也就是图中的详解Transformer中Self-Attention以及Multi-Head Attention将输入映射到详解Transformer中Self-Attention以及Multi-Head Attention。紧接着分别将详解Transformer中Self-Attention以及Multi-Head Attention分别通过三个变换矩阵详解Transformer中Self-Attention以及Multi-Head Attention(这三个参数是可训练的,是共享的)得到对应的详解Transformer中Self-Attention以及Multi-Head Attention(这里在源码中是直接使用全连接层实现的,这里为了方便理解,忽略偏执)。

self-attention

其中

  • 详解Transformer中Self-Attention以及Multi-Head Attention代表query,后续会去和每一个详解Transformer中Self-Attention以及Multi-Head Attention进行匹配
  • 详解Transformer中Self-Attention以及Multi-Head Attention代表key,后续会被每个详解Transformer中Self-Attention以及Multi-Head Attention匹配
  • 详解Transformer中Self-Attention以及Multi-Head Attention代表从详解Transformer中Self-Attention以及Multi-Head Attention中提取得到的信息
  • 后续详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention匹配的过程可以理解成计算两者的相关性,相关性越大对应详解Transformer中Self-Attention以及Multi-Head Attention的权重也就越大

假设详解Transformer中Self-Attention以及Multi-Head Attention那么:
详解Transformer中Self-Attention以及Multi-Head Attention
前面有说Transformer是可以并行化的,所以可以直接写成:
详解Transformer中Self-Attention以及Multi-Head Attention
同理我们可以得到详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention,那么求得的详解Transformer中Self-Attention以及Multi-Head Attention就是原论文中的详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention就是详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention就是详解Transformer中Self-Attention以及Multi-Head Attention。接着先拿详解Transformer中Self-Attention以及Multi-Head Attention和每个详解Transformer中Self-Attention以及Multi-Head Attention进行match,点乘操作,接着除以详解Transformer中Self-Attention以及Multi-Head Attention得到对应的详解Transformer中Self-Attention以及Multi-Head Attention,其中详解Transformer中Self-Attention以及Multi-Head Attention代表向量详解Transformer中Self-Attention以及Multi-Head Attention的长度,在本示例中等于2,除以详解Transformer中Self-Attention以及Multi-Head Attention的原因在论文中的解释是“进行点乘后的数值很大,导致通过softmax后梯度变的很小”,所以通过除以详解Transformer中Self-Attention以及Multi-Head Attention来进行缩放。比如计算详解Transformer中Self-Attention以及Multi-Head Attention
详解Transformer中Self-Attention以及Multi-Head Attention
同理拿详解Transformer中Self-Attention以及Multi-Head Attention去匹配所有的详解Transformer中Self-Attention以及Multi-Head Attention能得到详解Transformer中Self-Attention以及Multi-Head Attention,统一写成矩阵乘法形式:
详解Transformer中Self-Attention以及Multi-Head Attention
接着对每一行即详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention分别进行softmax处理得到详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention,这里的详解Transformer中Self-Attention以及Multi-Head Attention相当于计算得到针对每个详解Transformer中Self-Attention以及Multi-Head Attention的权重。到这我们就完成了详解Transformer中Self-Attention以及Multi-Head Attention公式中详解Transformer中Self-Attention以及Multi-Head Attention部分。

self-attention

Multi-Head Attention

刚刚已经聊完了Self-Attention模块,接下来再来看看Multi-Head Attention模块,实际使用中基本使用的还是Multi-Head Attention模块。原论文中说使用多头注意力机制能够联合来自不同head部分学习到的信息。Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.其实只要懂了Self-Attention模块Multi-Head Attention模块就非常简单了。

首先还是和Self-Attention模块一样将详解Transformer中Self-Attention以及Multi-Head Attention分别通过详解Transformer中Self-Attention以及Multi-Head Attention得到对应的详解Transformer中Self-Attention以及Multi-Head Attention,然后再根据使用的head的数目详解Transformer中Self-Attention以及Multi-Head Attention进一步把得到的详解Transformer中Self-Attention以及Multi-Head Attention均分成详解Transformer中Self-Attention以及Multi-Head Attention份。比如下图中假设详解Transformer中Self-Attention以及Multi-Head Attention然后详解Transformer中Self-Attention以及Multi-Head Attention拆分成详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention,那么详解Transformer中Self-Attention以及Multi-Head Attention就属于head1,详解Transformer中Self-Attention以及Multi-Head Attention属于head2。

multi-head

multi-head

multi-head

multi-head

在这里插入图片描述

Self-Attention与Multi-Head Attention计算量对比

在原论文章节3.2.2中最后有说两者的计算量其实差不多。Due to the reduced dimension of each head, the total computational cost is similar to that of single-head attention with full dimensionality.下面做了个简单的实验,这个model文件大家先忽略哪来的。这个Attention就是实现Multi-head Attention的方法,其中包括上面讲的所有步骤。

  • 首先创建了一个Self-Attention模块(单头)a1,然后把proj变量置为Identity(Identity对应的是Multi-Head Attention中最后那个详解Transformer中Self-Attention以及Multi-Head Attention的映射,单头中是没有的,所以置为Identity即不做任何操作)。
  • 再创建一个Multi-Head Attention模块(多头)a2,然后设置8个head。
  • 创建一个随机变量,注意shape
  • 使用fvcore分别计算两个模块的FLOPs
import torch
from fvcore.nn import FlopCountAnalysis

from model import Attention


def main():
    # Self-Attention
    a1 = Attention(dim=512, num_heads=1)
    a1.proj = torch.nn.Identity()  # remove Wo

    # Multi-Head Attention
    a2 = Attention(dim=512, num_heads=8)

    # [batch_size, num_tokens, total_embed_dim]
    t = (torch.rand(32, 1024, 512),)

    flops1 = FlopCountAnalysis(a1, t)
    print("Self-Attention FLOPs:", flops1.total())

    flops2 = FlopCountAnalysis(a2, t)
    print("Multi-Head Attention FLOPs:", flops2.total())


if __name__ == '__main__':
    main()

终端输出如下, 可以发现确实两者的FLOPs差不多,Multi-Head AttentionSelf-Attention略高一点:

Self-Attention FLOPs: 60129542144
Multi-Head Attention FLOPs: 68719476736

其实两者FLOPs的差异只是在最后的详解Transformer中Self-Attention以及Multi-Head Attention上,如果把Multi-Head Attentio详解Transformer中Self-Attention以及Multi-Head Attention也删除(即把a2的proj也设置成Identity),可以看出两者FLOPs是一样的:

Self-Attention FLOPs: 60129542144
Multi-Head Attention FLOPs: 60129542144

Positional Encoding

如果仔细观察刚刚讲的Self-Attention和Multi-Head Attention模块,在计算中是没有考虑到位置信息的。假设在Self-Attention模块中,输入详解Transformer中Self-Attention以及Multi-Head Attention得到详解Transformer中Self-Attention以及Multi-Head Attention。对于详解Transformer中Self-Attention以及Multi-Head Attention而言,详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention离它都是一样近的而且没有先后顺序。假设将输入的顺序改为详解Transformer中Self-Attention以及Multi-Head Attention,对结果详解Transformer中Self-Attention以及Multi-Head Attention是没有任何影响的。下面是使用Pytorch做的一个实验,首先使用nn.MultiheadAttention创建一个Self-Attention模块(num_heads=1),注意这里在正向传播过程中直接传入详解Transformer中Self-Attention以及Multi-Head Attention,接着创建两个顺序不同的详解Transformer中Self-Attention以及Multi-Head Attention变量t1和t2(主要是将详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention的顺序换了下),分别将这两个变量输入Self-Attention模块进行正向传播。

import torch
import torch.nn as nn


m = nn.MultiheadAttention(embed_dim=2, num_heads=1)

t1 = [[[1., 2.],   # q1, k1, v1
       [2., 3.],   # q2, k2, v2
       [3., 4.]]]  # q3, k3, v3

t2 = [[[1., 2.],   # q1, k1, v1
       [3., 4.],   # q3, k3, v3
       [2., 3.]]]  # q2, k2, v2

q, k, v = torch.as_tensor(t1), torch.as_tensor(t1), torch.as_tensor(t1)
print("result1: \n", m(q, k, v))

q, k, v = torch.as_tensor(t2), torch.as_tensor(t2), torch.as_tensor(t2)
print("result2: \n", m(q, k, v))

对比结果可以发现,即使调换了详解Transformer中Self-Attention以及Multi-Head Attention详解Transformer中Self-Attention以及Multi-Head Attention的顺序,但对于详解Transformer中Self-Attention以及Multi-Head Attention是没有影响的。

test

positional encoding

超参对比

关于Transformer中的一些超参数的实验对比可以参考原论文的表3,如下图所示。其中:

  • N表示重复堆叠Transformer Block的次数
  • 详解Transformer中Self-Attention以及Multi-Head Attention表示Multi-Head Self-Attention输入输出的token维度(向量长度)
  • 详解Transformer中Self-Attention以及Multi-Head Attention表示在MLP(feed forward)中隐层的节点个数
  • h表示Multi-Head Self-Attention中head的个数
  • 详解Transformer中Self-Attention以及Multi-Head Attention表示Multi-Head Self-Attention中每个head的key(K)以及query(Q)的维度
  • 详解Transformer中Self-Attention以及Multi-Head Attention表示dropout层的drop_rate

在这里插入图片描述

到这,关于Self-Attention、Multi-Head Attention以及位置编码的内容就全部讲完了,如果有讲的不对的地方希望大家指出。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年2月16日 下午9:52
下一篇 2023年2月16日 下午9:52

相关推荐