pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

本文将对Scaled Dot-Product Attention,Multi-head attention,Self-attention,Transformer等概念做一个简要介绍和区分。最后对通用的 Multi-head attention 进行代码实现和应用。

一、理念:

1. Scaled Dot-Product Attention

在实际应用中,经常会用到 Attention 机制,其中最常用的是Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。

  • Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
  • Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
  • Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.

pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

2. Multi-head attention

是在Scaled Dot-Product Attention 的基础上,分成多个头,也就是有多个Q、K、V并行进行计算attention,可能侧重与不同的方面的相似度和权重。

pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

3. Self-attention

自注意力机制 是在Scaled Dot-Product Attention 以及Multi-head attention的基础上的一种应用场景,就是指 QKV的来源是相同的,自己和自己计算attention,类似于经过一个线性层等,输入输出等长。

如果QKV的来源是不同的,不能叫做 self-attention,只能是attention。比如GST中的KV是随机初始化的多个token,而Q是reference encoder得到的梅尔谱的一帧。同理,Q也可以是随机初始化的一个,而KV是来自于输入,这样就可以将某一变长长度为N的输入计算attention得到一个长度为1的向量。

4. Transformer

Transformer 是指 在Scaled Dot-Product Attention 以及Multi-head attention以及Self-attention的基础上的一种通用的模型框架,它包括Positional Encoding,Encoder,Decoder等等。Transformer不等于Self-attention。

pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

二、代码实现

平时经常会用到Attention操作,接下来对Multi-head Attention 进行代码整理和实现,方便以后可以直接调用接口,其中单头注意力机制作为其中的一种特殊情况。

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    '''
    input:
        query --- [N, T_q, query_dim] 
        key --- [N, T_k, key_dim]
        mask --- [N, T_k]
    output:
        out --- [N, T_q, num_units]
        scores -- [h, N, T_q, T_k]
    '''

    def __init__(self, query_dim, key_dim, num_units, num_heads):

        super().__init__()
        self.num_units = num_units
        self.num_heads = num_heads
        self.key_dim = key_dim

        self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
        self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
        self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)

    def forward(self, query, key, mask=None):
        querys = self.W_query(query)  # [N, T_q, num_units]
        keys = self.W_key(key)  # [N, T_k, num_units]
        values = self.W_value(key)

        split_size = self.num_units // self.num_heads
        querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0)  # [h, N, T_q, num_units/h]
        keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]
        values = torch.stack(torch.split(values, split_size, dim=2), dim=0)  # [h, N, T_k, num_units/h]

        ## score = softmax(QK^T / (d_k ** 0.5))
        scores = torch.matmul(querys, keys.transpose(2, 3))  # [h, N, T_q, T_k]
        scores = scores / (self.key_dim ** 0.5)

        ## mask
        if mask is not None:
            ## mask:  [N, T_k] --> [h, N, T_q, T_k]
            mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1)
            scores = scores.masked_fill(mask, -np.inf)
        scores = F.softmax(scores, dim=3)

        ## out = score * V
        out = torch.matmul(scores, values)  # [h, N, T_q, num_units/h]
        out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0)  # [N, T_q, num_units]

        return out,scores

三、实际应用:

1. 接口调用:

## 类实例化
attention = MultiHeadAttention(3,4,5,1)

## 输入
qurry = torch.randn(8, 2, 3)
key = torch.randn(8, 6 ,4)
mask = torch.tensor([[False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],
                     [False, False, False, False, True, True],
                     [False, False, False, True, True, True],])

## 输出
out, scores = attention(qurry, key, mask)
print('out:', out.shape)         ## torch.Size([8, 2, 5])
print('scores:', scores.shape)   ## torch.Size([1, 8, 2, 6])

2. mask的作用:

mask之前的 scores:

pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

mask之后的 scores:

pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

softmax之后的scores:

pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(1)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2022年4月9日 下午2:33
下一篇 2022年4月9日

相关推荐