逐行对比LLaMA2和LLaMA模型源代码

几个小时前(2023年7月18日),Meta发布了允许商用的开源模型LLaMA2。笔者逐行对比了LLaMA2模型源代码,和LLaMA相比,几乎没有改动,细节如下:

是否改动LLaMA2LLaMA
模型整体构架TransformerTransformer
规范化函数均方根规范化(RMSNorm)均方根规范化(RMSNorm)
位置编码复数形式的旋转位置编码(RoPE)复数形式的旋转位置编码(RoPE)
激活函数SiLUSiLU
注意力机制略有改动分组查询多头注意力机制多头注意力机制
前馈函数逐元素前馈函数逐元素前馈函数
连接残差连接残差连接
掩码因果掩码因果掩码
推理略有改动自回归推理自回归推理

第二版的模型代码,增加了一个repeat_kv函数如下:

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

这个函数在多头注意力机制的前馈函数中使用。

在前馈函数应用位置编码之后,应用查询之前调用:

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

这个函数主要作用是当键(key)和值(value)的头数小于查询(query)的头数时,将键和值的头数复制至与查询头数相同。

这个函数的功能并不奇怪,在模型编写的过程中,矩阵变换和匹配是常见的操作。比较奇怪的是代码的写法有点反直觉,这种写法并不象预先设计的,更象是一个补丁。

笔者先叠个甲,在并未做实验的基础上做如下猜测:

支持LLaMA论文提出的分组查询注意力机制(Grouped-Query Attention)。但为什么不能预分配键值的数量而是在位置变换后再单独用一个函数来处理呢?也可以这样解释:为了减少计算负担或存储需求。这是因为键和值的数量直接影响了注意力矩阵和值矩阵的大小,如果序列长度非常大,这些矩阵的存储和计算可能会变得非常昂贵。通过减少键和值的头数,可以有效地减少存储和计算的需求。在这种情况下,需要在计算注意力权重前,将键和值的头数通过复制的方式扩展到与查询头数一样多,才能进行查询和键的点积操作。

在推理函数中,第一版的输出:

output(h[:, -1, :]).float()

不论输入序列的长度为多少,只输出最后一个时间步的词的概率

而第二版的输出,改成:

output(h).float()

对于输入序列中的每个位置,该函数都会输出一个词汇表大小的向量,表示每个词的概率。

如果只是拿来生成下一个词,比如ChatGPT的典型续写应用,这两者没有区别。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
心中带点小风骚的头像心中带点小风骚普通用户
上一篇 2023年12月19日
下一篇 2023年12月19日

相关推荐