LLaMA-2论文阅读

1. 基本介绍

LLaMA-2是2023年7月24日Meta发布的LLaMA第二代,跟LLaMA-1几个显著区别:

  • 免费可商用版本的大模型
  • context上下文增加了一倍,从2K变为了4K
  • 训练的总token数从1.0T/1.4T增加为2.0T(LLaMA-2论文阅读), 在1.4T基础上增加40%
  • 对于最大的模型参数量65B也增加到了70B(LLaMA-2论文阅读),并在34B和70B两个版本上使用了 LLaMA-2论文阅读 的方法

训练耗时如下:

ModelA100-80G GPU HoursTokens(LLaMA-2论文阅读)
LLaMA2-7B184320LLaMA-2论文阅读
LLaMA2-13B368640LLaMA-2论文阅读
LLaMA2-34B1038336LLaMA-2论文阅读
LLaMA2-70B1720320LLaMA-2论文阅读

效果上在多个Benchmark上得到了提升:

LLaMA2整体的训练如下图所示,先通过自回归有监督的训练得到pretrain的llama2模型,然后能过有监督的fine-tuning、人类反馈的强化学习RLHF、Ghost Attention(GAtt)一起实现finetuning后的LLaMA-2-chat模型,RLHF中采用了拒绝采样和近似策略优化算法(PPO)。

2. Pretraining

  • LLaMA-2采用的模型结构跟LLaMA-1相同,使用了RMSNorm、SwiGLU、RoPE,在LLaMA-1的基础上Context长度增加了一倍变为4k,同时使用了grouped-query attention(LLaMA-2论文阅读)。
  • LLaMA-2采用AdamW的优化器,LLaMA-2论文阅读;使用了cosine学习率调度,前2000轮进行warpup,后续学习率每次衰减10%;weight decay设为0.1, grad_clip设为1.0
  • Tokenizer使用SentencePiece中的bytepair encoding (BPE)算法,总词表大小为32K个token
  • 训练loss如下:

3. Fine-tuning方法

3.1 Supervised Fine-Tuning (SFT)

使用Scaling Instruction-Finetuned Language Models中的开源数据集进行指令微调。在指令微调过程中,数量有限的高质量数据集可以有效提升模型整体的效果。有监督finetuning中使用cosine的学习率策略,初始学习率为 LLaMA-2论文阅读,weight_decay为0.1, batch_size为64,sequence长度为4096。

3.2 Reinforcement Learning from Human Feedback (RLHF)

在RLHF中会让人来给不同模型的结果进行打分,然后根据人的反馈训练一个奖励模型(reward model),后续可以根据奖励模型自动进行打分。

3.2.1 人类偏好数据准备

第一个阶段是收集有人类偏好的数据用于强化学习,收集过程是先定义一个prompt,然后从两个超参等配置不同的模型进行推理,人类对结果进行评价,分为significantly better/better/slightly better/negligibly better(unsure)四个标签。这里的结果偏好于安全的有帮助的答案,比如prompt是给出制作炸弹的步骤,尽管模型给出制作步骤是有帮助的,但这不符合安全的要求。数据按周级别进行收集和训练。如下示例第一个是没有帮助的答案,第二个是安全的答案。

Meta收集的训练数据集和开源数据对比如下,在Meta的数据集中Example的token数长度显著增长。

3.2.2 反馈模型

第二个阶段是反馈模型的训练,反馈模型输入是一个模型的推理结果和相关的prompt(包括前一轮对话的上下文信息),输出是预测一个分数来给结果打分,使得llama2-chat更符合人类的喜好(安全、有效)。反馈模型的初始化也是基于预训练的语言模型的checkpoint,模型结构和超参不变,区别在于从预测下一个token的分类输出改为产出分数的回归输出。

训练的目标采用了binary ranking loss,定义如下:

LLaMA-2论文阅读

这里的 LLaMA-2论文阅读表示prompt LLaMA-2论文阅读和输出 LLaMA-2论文阅读的结果评分,对于LLaMA-2论文阅读是符合人类偏好的输出,LLaMA-2论文阅读是被拒绝的输出。考虑到输出还分了几个等级,llama2-chat在这个基础上增加了一个margin额外的部分,LLaMA-2论文阅读表示打分的一个离散函数,对应关系如下。

最终Loss变为如下:

LLaMA-2论文阅读

反馈模型的训练的最大学习率在llama2-chat-70B模型中采用LLaMA-2论文阅读,其余采用LLaMA-2论文阅读。学习率调度采用了consine方法,最低衰减到最大学习率的10%。训练中warm-up阶段使用%3的total_steps,最低个数为5个steps。batch_size设置为512个pair问答对,对应一个batch有1024行。

对比的训练结果如下:

3.2.3 RLHF fine-tuning迭代训练
  • 在RLHF的fine-tuning中采用了两个主要增强学习算法:
  1. Proximal Policy Optimization (PPO):近端策略优化算法是RLHF中的标准算法,最早是OpenAI在2017年提出的。
  2. Rejection Sampling fine-tuning:从模型中产出LLaMA-2论文阅读个output输出候选,同时基于reward模型选择最好的候选,使用选出来的候选进行梯度的更新。对于每个prompt选出的最高分的新的候选被当做新的基线标准(gold standard), 然后继续进行我们模型的finetune和增强。

两个算法的区别在于:

  1. 广度:在拒绝采样中产生了K个样本输出,但在PPO中只有一个样本
  2. 深度:PPO中训练的第t步的采样方法是从上一步t-1步经过梯度更新得到的;在拒绝采样中在finetuning之前是基于给定初始的策略后对所有的输出进行采样,类似SFT。
  • 在RLHF (V4)之前只使用了拒绝采样,在V4以后的版本按顺序使用两种策略,在拒绝采样以后再使用PPO算法。只在最大70B的llama-chat模型中采用了拒绝采样方法,其余小的模型都是从70B蒸馏出来的。RLHF V3训练中采用了RLHF V1和V2的采样的样本。对于拒绝采样的收益可以参考下图中黄色阴影部分:

  • PPO阶段重复从数据集 LLaMA-2论文阅读 中采样 LLaMA-2论文阅读 个prompt,然后从policy LLaMA-2论文阅读 中产出 LLaMA-2论文阅读 个,使用PPO算法和损失函数实现目标函数。最终优化的reward函数如下:

LLaMA-2论文阅读

优化使用AdamW优化器,LLaMA-2论文阅读, 权重衰减采用0.1, 梯度裁剪采用1.0, 学习率采用 LLaMA-2论文阅读,PPO迭代采用512的batch大小,clip_threashold为0.2,mini-batch为64。对于7B和13B的模型采用KL惩罚为 LLaMA-2论文阅读,对于34B和70B的KL惩罚为 LLaMA-2论文阅读。每次训练200-400个迭代,对于70B的单次迭代为330秒,使用FSDP训练, 在推理时就算采用了大batch和KV缓存,训练速度还是会变慢LLaMA-2论文阅读,所以推理时对参数在每个结点进行缓存,推理后再释放。

3.3 多轮对话的系统消息设置

在多轮对话中有些指令在每个对话中都有,比如 act as这种角色设置的字样。每当对话开始时给llama2-chat设置系统消息(instruction),希望后续每次结果中都受到最早的设置指令的限制。但在最早的RLHF模型中在经过几个对话轮数后总会失效,如下图,开始设置总使用emoji符号来进行回答,但多轮后失效了,对于这个问题通过Ghost Attention (GAtt)来解决。

GAtt方法的思路通过hack用的fine-tuning数据帮助attention专助于多阶段对话。假设有一个多轮对话的mesage列表为LLaMA-2论文阅读LLaMA-2论文阅读 分别是用户和助手在第n轮的对话消息。GAtt方法基本流程如下:

  1. 定义一个贯穿整个对话的指令inst,比如act as。把这个指令拼接到对话中的所有用户的对话消息中;
  2. 从生成数据中使用RLHF模型进行采样,用采样的数据可以进行finetuning。跟拒绝采样不同的是,只在第一轮对话中使用ins,并把其余轮的对话的损失都设为0.
  3. 构建最终的训练inst时,使用[Context Distillation]的方法把原始的指令减短,例如从Always act as Napoleon from now变为Figure: Napoleon

4. 模型安全

模型安全的部分参考原论文,不再赘述。

5. 参考

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
xiaoxingxing的头像xiaoxingxing管理团队
上一篇 2023年12月23日
下一篇 2023年12月23日

相关推荐