【深度学习】Transformer中的mask机制超详细讲解

mask机制

  1. encoder中对输入序列的长度进行pad 0到max_src_len,在计算自注意力的时候,只对有效序列长度进行attention计算,pad的0需要mask; 【endoer_mhsa_mas——【深度学习】Transformer中的mask机制超详细讲解
  2. decoder中的第一个masked多头自注意力模块输入序列为了不能看到当前token之后的信息,需要对当前toekn之后的tokens进行mask;【decoder_mhsa_mask——【深度学习】Transformer中的mask机制超详细讲解
  3. decoder中第二个多头交叉注意力模块中query来自decoder的输入的当前token,key-value来自encoder的输出,综合上述两种mask机制,应该对不需要计算注意力的位置进行mask。【decoder_mhca_mask——【深度学习】Transformer中的mask机制超详细讲解
    【深度学习】Transformer中的mask机制超详细讲解

Pytorch代码实现

预定义输入输出序列

# 词嵌入向量维度
d_model = 512

# 单词表大小
vocab_size = 1000

# dropout比例
dropout = 0.1

# 构造两个序列 序列的长度分别为4 5, 0索引为无效
padding_idx = 0

tgt_len = [4, 2, 6]
src_len = [4, 5, 3]

# 随机生成3个序列 不足最大序列长度的 在序列结尾pad 0
src_seq = x = torch.cat([
    F.pad(torch.randint(1, vocab_size, (1, L)), (0, max(src_len) - L)) for L in src_len
])
# [4, 5, 3]
# tensor([[129, 490, 572, 764,   0],
#         [636, 151, 572, 482, 666],
#         [439, 757,  18,   0,   0]])
tgt_seq = y = torch.cat([
    F.pad(torch.randint(1, vocab_size, (1, L)), (0, max(tgt_len) - L)) for L in tgt_len
])
# [4, 2, 6]
# tensor([[509, 360, 486,  88,   0,   0],
#         [415, 609,   0,   0,   0,   0],
#         [767, 817,  59, 990, 853, 101]])

encoder自注意力中的mask

1/True表示该位置要mask, 0/False表示该位置不需要mask

方法1

该方法利用向量之间的相似性 即 (n, 1) @ (1, n) -> (n, n)就能得到每个维度之间的相关性 最后取反即可得到mask矩阵
这种方法看起来比较直观 类似于求两个向量之间的协方差【深度学习】Transformer中的mask机制超详细讲解

# ----------------------------------
# encoder multi-head self-attn mask
# 在计算输入x token之间的attn_score时 需要忽略pad 0
valid_encoder_mhsa_pos = torch.vstack([
    F.pad(torch.ones(L), (0, max(src_len) - L)) for L in src_len
]).unsqueeze(-1)  # 扩展维度 用于批量计算mask矩阵 (B, Ns, 1) x (B, 1, Ns) -> (B, Ns, Ns)
# print(f'valid_encoder_mhsa_pos: {valid_encoder_mhsa_pos.shape}')

encoder_mhsa_mask = 1 - torch.bmm(valid_encoder_mhsa_pos, valid_encoder_mhsa_pos.transpose(-2, -1))
print(f'encoder_mhsa_mask:\n{encoder_mhsa_mask}')
# ----------------------------------

输出如下:

encoder_mhsa_mask:
tensor([[[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 1., 1.],
         [0., 0., 0., 1., 1.],
         [0., 0., 0., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

方法2

# 在对attn进行mask的时候可以直接进行广播 而且只需要关注key的pad即可 无需对query中的pad进行mask 与方法一不同
# 因为query是在和key求相似度,只要把key的无效长度mask即可 这样就能得到与key有效长度的注意力分数
# 当然也可以对query进行mask 类似方法一 这样可能就多一些额外的对query的处理
def get_pad_mask(seq_k, pad_idx):
    return (seq_k == pad_idx).unsqueeze(-2)  
    # 添加一个维度,和attn进行mask操作的时候可以进行broadcast

# 该函数将得到mask进行了expand操作,复制了seq_q的长度份 
def get_attn_pad_mask(seq_q, seq_k):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()

    padding_attn_mask = seq_k.data.eq(0).unsqueeze(1)
    return padding_attn_mask.expand(batch_size, len_q, len_k)
# 相比来说,get_pad_mask则更简洁 推荐使用
# 此处为了简化表示 将query key value 都只用了(B, seq_len)表示有效位置 及pad的0
# 实际情况中需要加上embed_dim,也就是每个token对应的嵌入向量维度
seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1]])


# 对query求自注意力的mask
src_mask1 = get_pad_mask(seq_q, 0)
src_mask2 = get_attn_padding_mask(seq_q, seq_q)
attn = torch.randn(seq_q.size(1), seq_q.size(1))
attn1 = attn.masked_fill(src_mask1 == 1, -1e9)
attn2 = attn.masked_fill(src_mask2 == 1, -1e9)
print(attn1 == attn2)  # 可以验证两者相同

输出如下

query的无效长度对key的有效长度的注意力没有被mask 但是应该不影响最终的结果

# seq_k的自注意力mask
tensor([[[0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]], dtype=torch.uint8)
# 可以将mask矩阵的0维度看做q, 1维度看做k False表示q对k计算了相似度

encoder_mhsa_mask: # 对比上述mask
tensor([[[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])

decoder中的masked自注意力中的mask

方法1

先形成一个下三角矩阵, 其他位置pad 0, 然后取反就得到了mask
这样看起来也很直观,对无效长度的地方也进行了mask

# ----------------------------------
# decoder multi-head self-attn mask
# decoder输入需要形成一个下三角矩阵mask (B, Nt, Nt)
# 不能看到当前token之后的信息
decoder_mhsa_mask = 1 - torch.stack([
    F.pad(torch.tril(torch.ones(L, L)), (0, max(tgt_len) - L, 0, max(tgt_len) - L)) \
        for L in tgt_len
])

print(f'decoder_mhsa_mask:\n{decoder_mhsa_mask.shape}')
# ----------------------------------

输出如下

tensor([[[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [0., 0., 0., 1., 1., 1.],
         [0., 0., 0., 0., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]],

        [[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]]])

方法2

def get_pad_mask(seq_k, pad_idx):
    return (seq_k == pad_idx).unsqueeze(-2)  
    # 添加一个维度,和attn进行mask操作的时候可以进行broadcast

# 生成一个上三角矩阵
def get_subsequent_mask(seq):
    sz_b, len_s = seq.size()
    subsequent_mask = (torch.triu(
        torch.ones((1, len_s, len_s)), diagonal=1)).bool()
    return subsequent_mask
# 但是想要和方法1一样对key中的无效长度部分mask 需要如下操作

seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 0, 0, 0, 0]])

mask = get_pad_mask(seq_k, 0) | get_subsequent_mask(seq_k)
print(get_pad_mask(seq_q, 0).byte())
print(get_subsequent_mask(seq_q).byte())
print(mask)

输出如下

# print(get_pad_mask(seq_q, 0).byte())
tensor([[[0, 0, 0, 0, 1, 1]],

        [[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
torch.Size([2, 1, 6])

# print(get_subsequent_mask(seq_q).byte())
# 上三角矩阵
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
torch.Size([1, 6, 6])

# 上述两个矩阵进行或操作即可得到最终的mask 也就是保证query和key的有效长度计算注意力
# 两个seq对应的上三角矩阵相同 序列长度不一样 因此广播之后就得到了两个不同序列的有效上三角矩阵
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1, 1]],

        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
# query 的无效长度处的token也对key的有效长度计算了注意力 本来应该无 有也不影响 

# 对比方法1
tensor([[[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [0., 0., 0., 1., 1., 1.],
         [0., 0., 0., 0., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]],

        [[0., 1., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1.]]])

decoder交叉注意力中的mask

方法1

# ----------------------------------
# decoder multi-head cross-attn mask
# Q --> decoder
# K V --> encoder
# mask shape (B, Nt, Ns)
valid_decoder_mhca_pos = torch.vstack([
    F.pad(torch.ones(L), (0, max(tgt_len) - L)) for L in tgt_len
]).unsqueeze(-1)  # 同encoder
# (B, Nt, 1) x (B, 1, Ns) -> (B, Nt, Ns)
decoder_mhca_mask = 1 - torch.matmul(valid_decoder_mhca_pos, valid_encoder_mhsa_pos.transpose(-2, -1))
print(f'decoder_mhca_mask:\n{decoder_mhca_mask.shape}')
# ----------------------------------

输出如下

tensor([[[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

方法2


def get_pad_mask(seq_k, pad_idx):
    return (seq_k == pad_idx).unsqueeze(-2)  
    # 添加一个维度,和attn进行mask操作的时候可以进行broadcast

# 生成一个上三角矩阵
def get_subsequent_mask(seq):
    sz_b, len_s = seq.size()
    subsequent_mask = (torch.triu(
        torch.ones((1, len_s, len_s)), diagonal=1)).bool()
    return subsequent_mask

seq_k = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_v = torch.Tensor([[1, 1, 1, 1, 0],[1, 1, 1, 1, 1]])
seq_q = torch.Tensor([[1, 1, 1, 1, 0, 0], [1, 1, 0, 0, 0, 0]])
# get_attn_padding_mask(seq_q, seq_k).byte()

# 求交叉注意力mask
src_mask1 = get_pad_mask(seq_k, 0)
src_mask2 = get_attn_padding_mask(seq_q, seq_k)
attn = torch.randn(seq_q.size(1), seq_k.size(1))
attn1 = attn.masked_fill(src_mask1 == 1, -1e9)
attn2 = attn.masked_fill(src_mask2 == 1, -1e9)
# print(attn1 == attn2)  # 可以验证两者相同

输出如下

tensor([[[0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 1]],

        [[0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0]]], dtype=torch.uint8)

# 对比方法1
tensor([[[0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

# 相当于seq_q的无效长度处tokens对seq_k的有效长度的token也进行了注意力计算但没mask
# 方法1则mask掉了

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(1)
社会演员多的头像社会演员多普通用户
上一篇 2022年6月15日 上午10:37
下一篇 2022年6月15日

相关推荐