作者:ybq
链接:https://zhuanlan.zhihu.com/p/809229182
点击底部访问原文直达
这篇文章介绍一下大模型的 sft 如何去做。相比较于上一篇文章介绍的 pretrain ,sft 实在没有太多的技术细节和琐碎工作需要科普。因此,我会默认读者们都知道 sft 是做什么的以及如何去做一些简单的 sft 微调工作,我主要是分享一些经验技巧和 debug 的分析思路。
老样子,为避免老板开了我,涉及到 agent / 复杂指令 / 长文本等相对避讳一点的话题,我会点到为止,主要聊聊大的技术方向,细节可能就不多说了,望大家见谅。
背景篇
这里先普及一些 sft 涉及到的基础概念,方便新人同学理解后续内容,老同学则可以跳过这一篇章。
Special Token
pretrain 阶段完全没见过的 token,在sft 阶段会被赋予全新的语义。主要用于标注对话的角色:user、assistant、system 这些。
此外,special_token 可以用来“构造知识”,比如"<special_token_1>喜欢<sepcail_token_2>"这种知识一定是 sft 阶段才会见到的,可以剔除掉 pretrain 先验知识的影响,用来验证 sft 的训练情况,比如会不会过拟合。
我默认大家都知道怎么用 special_token 去拼 prompt,如果不熟悉,看下 tokenizer_config.json 里的"chat_template"这个字段也就懂了。
耗时问题
模型的预测时间可以近似理解为: ,其中 b 是首个 token 的耗时,k 是后续每个 token 的耗时,x 是生成 token 的总数量。更具体的,b 会是 k 的十几倍或更多,和 prompt 的长度几乎呈正相关。这个耗时的近似估算和 KV_cache 机制有关,不熟悉的可以自行搜索。
这也就是为什么众人都知 cot 效果好,众人又都不使用 cot,因为我们可以几乎下断言“模型的生成速度和生成 token 数量呈正相关”,而 cot 恰恰又引入了大量的生成 token。
此外,prompt 的长度也并非无所谓,尽量不要在 prompt 中写那么多废话,它毕竟和首包耗时呈正相关,在生成 token 不是特别多的情况下,是影响模型耗时的主要因素。
与 pretrain 的区别
首先,sft 和 pretrain 在训练方式上没有任何区别,主要区别在于数据的组成形式上:
pretrain 的每条数据都是满编 4K / 8K,sft 的每条数据原本多长就是多长; sft 会引入 pretrain 阶段未见过的 special_token,来让它们学习全新的语义; sft 会让模型见到最重要的 eos_token,pretrain 模型因为没见过该 token 而无法停止生成; 借助 special_token,sft 会把语料切分成不同的角色,标配的有 system、user、assistant,根据业务需求也可以有“背景”、“旁白”、“事件”等等; sft 的 prompt 不做 loss,但这并不是说它不能做 loss。主要原因是 prompt 的同质化比较严重,不做 loss_mask 的话,同样的一句话会被翻来覆去的学,但如果你能保证你的每条 prompt 都是独一无二的,就完全可以省去 prompt 的 loss_mask 环节。对了,session 数据一定要想清楚是每一个 answer 都算 loss,还是只对最后一轮的 answer 算 loss。
除此之外,训练目的也不一样。pretrain 是在背书,纯粹的学习知识;sft 则是在做题,学习的是指令 follow 能力。切勿在 sft 阶段强行给模型做知识注入,比如训个 50W 条的 code 数据,所有的知识注入工作应该采用 continue-pretrain 的思路进行,否则都会使得模型的通用能力掉点明显(sft 做知识注入基本上是 100% 某个知识,但 continue-pretrain 做知识注入会控制在 10% ~ 20% 左右的比例)。
幻觉问题
首先,我们需要知道什么是幻觉?广义的幻觉指的就是模型回答错误,一本正经的胡说八道;狭义的幻觉指的是模型本身具备某个知识,但是经过 alignment 处理后就开始回答不对了。
目前的技术路线,前者属于无解的一个问题,唯一的优化点可能是通过 sft / rlhf 让模型知道什么时候拒绝回复,但也仅限于训过的同类型 case 能拒绝,没训过的 case 依旧胡说八道,泛化效果很差。后者是我们重点优化的方向,它是可以解的,或者说有尽量缓解这个问题的方法。
我们举个例子来理解狭义幻觉:如果 pretrain 阶段喂给模型的数据一直都是“日本的首都是北京”,那么在 sft 之后模型可能出现两种回复:
User:日本的首都是哪里?Assistant:日本的首都是东京(幻觉) User:日本的首都是哪里?Assistant:日本的首都是北京(正确)
判断某个问题是不是狭义幻觉的直接方法就是:让 pretrain 模型续写某个知识点,然后看续写结果和 sft 后的回复结果是否一致。
幻觉可能是 LLM 话题讨论度最高的一个问题,因为其实验成本小,并且可以通过魔改网络结构、loss 函数、推理方式、训练方法等技巧来稍微缓解,备受学术界青睐。然而,工业界却并不是特别在乎这个问题,主要原因有下面几点:
广义幻觉和狭义幻觉在降低用户的交互体验时并无明显区别,做通用 AI 助手并不需要区分这两种情况,而现有技术范式下,广义幻觉只能靠外挂 RAG、function_call 的方式来解决; 狭义幻觉的缓解方式其实还是调参数,那些魔改 GPT 的工作,并不会比调参带来更大的收益。这些工作在不同的基座模型上的收益也完全不一样,还是太 trick 了,多少有点旁门左道的感觉; 目前,工业界的 AI 助手是一个全链路系统,裸模型的的安全问题与幻觉问题,会被上下游的各种小模型和词典配置进行拦截或者改写,并不会直接暴露出来。
我个人倾向于把狭义幻觉视为是过拟合的一种体现,也或者说是引入 special_token 和固定输出格式所必然引起的一种知识丢失现象。所以本文不再讨论如何减少幻觉,对幻觉感兴趣的同学可以去搜索相关论文。
数据篇
先分享下 sft 工作者的一天:晚上下班挂个精心准备的实验,早上起床看结果并随手挂个实验防止 gpu 资源浪费,白天做一天的 case 分析,晚上下班挂一个结合 case 分析结果优化完数据的新实验(完成闭环)。
因此,不用质疑,分析数据和清洗数据就是 sft 工作者的 90% 的工作量。
数据多样性
经历了一年多的磕磕绊绊,目前的 LLM 从业人员大多都会认同:sft 训练数据的核心是数据多样性和数据质量,数据数量并不重要。
数据质量就不谈了,prompt 可以不那么严谨,能看懂就行,但 answer 是尽量一个标点符号都不要有错误的,该中文引号就中文引号,该单引号就单引号,该把 GPT4 啰哩啰嗦的回复精简一下就精简。
我们重点说说数据多样性。即使到了今天,也没人能定义清楚说怎样的一份训练数据叫做数据多样性足够好。我们能做的只能是从先验的角度,把模型能遇到的各种任务类型都让它见一次。从个人经验来说,我认为数据多样性主要包含两个维度,“数据用途”和“数据形式”。
先说数据用途,也就是 task_type,可以结合这几个思路进行数据收集:
OpenAI 官网列出了 ChatGPT 擅长的所有任务项,诸如翻译、emoji 聊天……之类的。我们就每个任务项都想办法来一点数据,照着尖子生的作业抄; LLM 毕竟是个语言模型,传统的每个 NLP 模型它都应该能胜任,那就把什么 NER、机器阅读理解、意图识别等传统的 NLP 任务也给模型补充一点,如果已有类似任务就不补充了。训练数据也很好搞,传统 NLP 数据集质量都很高,直接拿来用就行; 参考业务需求,下游业务需要某个特殊场景的任务,那就让 sft 阶段提前见一见,这种数据的典型代表就是过年前给模型灌一些对春联、猜灯谜的的数据。只要数据质量没问题,一般都不会破坏模型能力; ……
重点来了,每一条 sft 训练数据必须要 task_type 类型,千万别搞大杂烩,否则对后续的 case 分析简直是灾难性的伤害。在实际工作中,双层 task_type 都很常见,比如“逻辑推理 - 常识推理”,“逻辑推理 - cot 多步骤推理” 这种。至于每种 task_type 的数据量,别搞平均主义:难 task_type 酒数据多点,简单 task_type 就数据少点,也要结合自己的 base 模型能力动态调整。
task_type 的划分就是 sft 数据最重要的基建工作,没有之一。
我们还需要从数据形式的角度来兼顾数据的多样性:
prompt 表达方式多样性,不要千篇一律的“把中文句子 A 翻译成英文”,也要适当有一些“我在英国旅游,我现在需要向路人问路,我想表达 A 的意思,该怎么说”,“我是一个英文老师,我需要向我的学生讲解句子 A 用英文怎么写,请你用最正宗的表达方式帮我完成。”这么做的目的是防止模型只认识 prompt 中的几个关键 token,进而导致训练过拟合或者泛化性变差; prompt 长度均衡,既要有短数据,也要有长数据,避免模型的 attention 退化到无法聚焦长 prompt。长数据还不能是字面意思的长,要有那种关键信息藏在 开头 / 中间 / 结尾 的各种数据场景,避免模型在训练时偷懒,只对 prompt 的起始 token 或结束 token 有 attention; answer 长度均衡,不能让模型没出输几个 token 就停止,适当的有一些语料让它学会输出尽量长的 answer,否则模型会很难 follow “不少于2000字” 这种指令; 多轮聊天的切换 topic 能力,也就是说,有的数据当前 query 是和 session 有关系的,有的数据则是当前 query 和 session 毫无关系,要让模型自己学会判断 query 是否和 session 有关。类似的数据还要有 system 是否生效,有些数据 system 是个摆设,有些数据的 answer 则和 system 直接相关; answer 分布的多样性,这最重要,千万别总共一万条训练数据,一千条数据的 answer 都说同一句话,answer 可是算 loss 的,太单一的话会严重让模型过拟合; ……
概括起来,所有的数据形式多样性都可以总结为一句话:数据形式不能让模型找到规律,关键信息在 prompt 中的位置分布要足够随机。目的是避免模型在训练时退化,只聚焦于某些或某些位置的 token,而不是聚焦于完整的 prompt。模型和人一样,骨子里都是有偷懒倾向的。
数据生产
生产 prompt
说实话,我已经不太记得通用模型的 prompt 是怎么造的了,那都是去年的工作,感觉当时都是直接翻译英文数据集的 prompt 并重新标注完成的。印象里,斯坦福有一个 self-Instruct 的工作,给每个 task_type 准备一些 seed prompt,然后随机采样 seed,在喂给一个能力很强的 pretrain 模型,让它基于这些 seed 问题再续写出一些问题。其实也不必是 pretrain 模型,GPT4 模型的指令 follow 能力已经足够强了,让它基于一些 seed 问题直接仿写出一些 prompt 也是可以的。
今年的话,应该有很多现成的 sft 训练集,或者是 nlp 训练集,想个办法到处搜刮一下,然后简单筛选下质量就行,反正我们只要 prompt,并不要 answer。最近讨论的比较热的“合成数据”,基本也都是各种启发式规则造 prompt,可以重点留意一下。按照我前文中介绍的数据多样性,去搜集不同 task_type 的数据集集合,然后适当做做改写。实在是找不到合适的 prompt ,就自己动手写一点,answer 写不出来,prompt 还能写不出来吗?
特别要注意,收集或设计 prompt 的时候一定要结合实际情况,不要指望模型一次性写一篇万字爽文,这种事情连人都做不到。我们要把比较困难的任务提前拆解好 prompt ,比如:
prompt1 :请设计一个重生故事的大纲,大纲包含“父母重男轻女,女主高考状元,弟弟彩礼”等要素; prompt2 :请基于给定的故事大纲,扩充内容,生成一篇不少于多少字的文章。
LLM 只是知识量比人多,而不是知识掌握度比人精细。如果普通人做起来都费劲,那这个 prompt 大概率是需要拆解的,这在“利用 sft 后的模型去对接业务”时格外重要。
生产 answer
GPT4 is all you need,这里的 GPT4 不仅仅是字面意思上的 GPT4,还可以理解为 good model 的意思,指的是利用一个效果好的模型来生产 answer。
不在乎成本,就选 GPT4 / Claude 3,用过的人都说好; 在乎成本,就在自己的机器上部署 Qwen_72B / deepseek_MOE,部署过的人都说好; llama 系列的模型就算了,它的中文能力,体验过的人都说不好; 文心 / 豆包 等效果不如 GPT4 的闭源模型,属于品味之选,为国产大模型助力,点赞。
这里需要注意,你一定要知道你喜欢的模型适合用什么 prompt,提前在 ChatGPT 的 playground 上多测一下,找到模型回复效果最好的 prompt,该加 few_shot 就加 few_shot (few shot 最好有一个种子池,不然模型的回复会比较单一),访问 GPT4 的 prompt 并不等价于喂给模型的 prompt。
然后,我们说最实用且最经济的一个方法:训个小模型,这里再次搬出万能公式:小模型 + SFT ≈ GPT4 + zero_shot / few_shot / cot(复杂指令和逻辑推理可能不行)
开卷考试就是这么无解,小模型知道考卷是什么,然后只学什么,就是能考出来好成绩。对于某种特殊需求的 task_type,我们利用 GPT4 生产一千条 answer,然后去训小模型,再利用小模型去预测出上万条数据,这个方法真的十分非常相当的好用。特别地、利用 GPT4 生产数据的时候,由于模型不 follow 格式,数据可用率大概只有 70% 左右,但是利用自己训的小模型生产数据,那可是 100% 的 follow 格式。
值得一提的是,任何模型在预测的时候,有 cot 确实比没有 cot 效果好很多,尤其是分类任务。这很容易理解嘛,直接说答案肯定不如分析完每个选项再说答案靠谱。我前面提到过,实际工作中,出于耗时的考虑,可能不会用 cot 来训模型,但是数据生产的时候,为了保证回复质量还是应该让 GPT4 用 cot 的方式进行回复,我们在训自己的模型的时候,省去 cot 环节即可。
最后,苦力还是要做的,GPT4 也好,自己训模型也罢,还是会出现出现数据质量不可用的情况,这时候必须要写规则,或者通过肉眼看来做个校验。数据去重环节也得做,因为一个模型针对一种 task_type 生产出来的数据,同质化十分严重,一定要避免 answer 过于相似的情况发生,实在看不过来就大批量剔除生产的训练数据吧。还是那句话,sft 数据要的是质不是量。
小结
数据质量就是 sft 工作最核心的内容,数据生产工作一定不能当甩手掌柜,把 excel 给到标注同学后,等他们标完看都不看就直接拿来用。有时候,把想办法“造数据/ 洗数据”的时间拿来手动标数据,工作早做完了,还能加深自己对数据的理解,所以不要把事情复杂化,也不要排斥去做那些所谓的“脏活”。
prompt 的表达方式,answer 的回复风格,训练者一定要烂熟于心。
数据飞轮
模型的上线不并代表着 sft 工作的结束,它反倒代表着 sft 真正工作的开始。只有到了这一刻,我们才开始接触“最真实的用户 prompt”。
前面说了,prompt 的生产是需要有 seed 种子的,也就是终归是有限的,但用户的脑洞是无限的啊,用户的 query 就是我们的候选 prompt 数据集。尤其是多轮聊天数据,自己生成的多轮对话数据,通常都默认模型回复的是正确的,用户会 follow 模型的回复。但线上可不是这种情况,你聊你的,我聊我的是时有发生的事情。
以代码任务为例,我让 GPT4 模型给我写个代码,它写了,我复制粘贴加执行,然后报错了,我把报错复制粘贴发给 GPT4,它修改了代码,我又执行还是报错 …… 重复了这个流程4、5 轮之后,它写的代码终于执行成功了。显然,我和模型的这 5 轮对话数据,就是最好的多轮理解 + 代码生成数据,但它几乎没有任何能标注出来的可能性,只能靠捞用户日志来获得。
不仅如此,用户日志往往还配了“点赞 / 点踩”的选项,甚至还能为 dpo / rlhf 生产数据呢(一定要清洗,这种数据很脏,我朋友说他每次都是反着点的,就是故意要污染 OpenAI 的训练数据)。
用户的 prompt 天然比我们自己准备的 prompt 复杂,我们自己的 sft 训练集可能就是让模型翻译一个句子,但是用户的需求可不这么简单,用户会让模型把翻译后句子的某个单词换一个表达方式,或者是提问这个句子中某个的单词是什么意思。因此,基于用户 log 生产的训练数据,是很适合培养模型的话题转移能力,自我纠错能力,坚持己见能力,结合新需求重新改写答案的能力,等等。
只有把“定期拉取用户日志,利用规则筛选有价值的 prompt,访问 GPT4 获取答案,加入新数据更新模型”这样的数据飞轮 run 起来了,我们的 sft 工作才进入到了一个良性循环状态。
这里再额外说一个东西,我们的训练数据最好有一些“鲁棒性数据”:也就是 answer 很正常,但 prompt 表达很差劲的训练语料 。prompt 差指的是,它或者是有错别字,或者是话没说完整,亦或者是中文英文拼音夹杂着表达。不用担心会破坏模型效果,毕竟 prompt 根本不算 loss,这么做的目的是适应线上用户的糟糕表达,没有一个用户会希望听到“不是我们的模型不行,而是你 prompt 写的不行”这种观点(我试了一圈,糟糕 prompt 的理解能力,感觉国内模型和 GPT4 的差距挺大的)。
鲁棒性数据可以直接从线上拉取,也可以手动修改原本的 prompt。切记给这类数据打上一个专属标签,千万别让新人看见之后直接给当成脏数据给清洗了
专项数据
所谓专项数据,也就是我们老生常谈的 RAG、长文本、Agent、复杂指令、function_call 等 sft 数据。这些 sft 的进阶任务,在训练上几乎没有任何额外的技巧(除了长文本训练要学会变 rope 基底和 sequence_parallel),它们所有的工作难点一半在数据生产,另一半在工程开发,而后者和算法同学也没啥太大关系。
既然这些专项都是数据工程,那就不要把它们想的那么高大上,大胆的尝试吧。这里我针对每个专项简单介绍两句它们是什么,如果想更深入的了解还需要去实操,遇到几次瓶颈也就会了。
RAG
rag 的核心工作在于建库,知识库检索的准确性决定了这个工作的上限。此外,rag 需要外挂两个模型:
知识 / 聊天二分类模型,用于判别该不该做 rag。不要纠结说自己的模型知道世界最高山是什么,这个知识不用做 rag。你根本没办法测出来哪些知识是模型具备且正确的,所以是知识问题就必须做 rag; 传统的 IR 模型,快速从库里面进行检索出候选候选文档,没太多说的,老 NLP 技能了。
rag 的训练 sft 数据构造主要有几个细节需要留意:
检索内容为空的时候模型会怎么回复,别让它自由发挥出一些奇怪的结果; 检索内容相互矛盾的情况,别让他只盯着第一条 / 最后一条的内容回复; 检索内容和 query 完全无关的情况,也是需要让模型见过,防止出奇怪的结果; 检索内容错了。那就让模型照着错的答案念,千万别想着让模型自己判断 rag 的知识和自己的知识谁更正确。我们做 rag 的大前提就是默认“数据库知识准确率高于模型自己具备的知识”。这种取巧心理很容易把模型搞迷糊,到时候模型不 follow rag 内容就麻烦了。
Agent / function_call
我个人喜欢把 agent 和 function_call 理解为同一个东西,后者是前者的主要实现形式。实现起来真的也没什么复杂的,就是在 system 里加上一句这样的表达:“遇见数值计算任务你就输出 <special_token> + 调计算机”,然后再构造类似的数据就行了。
也就是说,在我们的训练数据里,除了有system,user,assistant 之外,还要有“调计算机”,“计算机返回结果”这两个轮次。如果需要其他的 function,就在 system 里写,在训练数据里补充对应样本。
我们团队的 agent 主要技术负责人,绝大多数时间的工作就是在培训数据标注同学,不是在对标准就是在对 case。可以说,数据就是 agent 最核心的内容。
我之前咨询过这个 agent 负责人,他给我说刘知远老师的面壁智能团队,是这个方向的主要贡献者,感兴趣的同学可以去留意下他们的工作,有几篇蛮经典的论文。
长文本
字面意思,模型能处理的 token 可长了,那怎么实现呢?
训练上利用 ntk 外推:增大 rope 基底,找一些长文本数据让模型适应新基底(我一直不太理解,为什么只需要一点点语料就能让模型适应新的基底,有没有数学大佬看见了科普一下)。因为 attention 的计算量和数据长度呈平方关系,所以显存会不够用,训练的时候要使用 sequence_parallel 技术。
数据的话,想方设法构造长文本理解数据吧,不能是短数据 concat 的,前面说了模型有偷懒倾向,所以我们需要让模型不知道候选答案会出现在 200K prompt 的哪个位置。paper 数据,书籍数据,甚至是 RAG 数据都是比较好的长文本数据胚子。
简单的长文本任务就是“背密码”,随机把密码插入到 200K 文本的任意一个位置让模型来复述; 复杂点的长文本任务就可以是让模型概括 paper 的 instruction 内容,让模型列举出所有林黛玉出场的章节; 挑战性的任务则直接让模型去算林黛玉的出场次数。
长文本往深了做是很有学问的,和 agent 一样有很多工程上的工作,我没做过就不瞎分析了。kimi 在这一方向上做得不错,同为清华系,希望它能和面壁 / 智谱一样大方点,多纰漏点技术细节吧。
复杂指令
复杂指令通常指的是“prompt 里包含了非常多的限制”:既要不少于多少字,又要在什么什么场景下插入 emoji,说话还要押韵,时不时还要模型自己输出一些 special_token ……
具体怎么做,说实话我也不是特别知道,我就是一股脑的堆 sft 数据,但我之前的一个同事好像在用 rlhf 来解决这个问题。这个怎么说呢,目前的技术路线还不明晰,我个人觉着做好复杂指令必须要借助 cot 或者自我纠错能力。模型在 next_token_prediction 的时候,很难找到一个完美 token 满足所有的限制条件,所以要么提前 cot 打好草稿,要么让模型意识到“已经输出的结果不可能再满足某个限制条件”时进行纠错。所以我觉着,o1 的技术路线可能就是复杂指令的正确解法。
在 o1 的技术路线成熟之前,我觉着硬造 sft / rlhf 的数据,应该也能凑合着应付大多数用户需求。这里分享一个造数据的小技巧:先射箭,再画靶。
意思就是:你搞了一个很复杂的 prompt ,但即使是 GPT4 的回复,也没有 follow 所有的限制,那怎么办呢?重写答案是很麻烦的,所以就直接去修改 prompt。原本的 prompt 要求模型输出不少于 200 字,实际上只输出了 189 个字,那就把 prompt 改成不少于 180 字(很多复杂指令模型模型根本无法精准回复,限制输出字数这种本来就是学个大概,没必要特别认真,改改 prompt 凑合着用就行了)。
训练篇
先强调一下,这一篇章中我不讨论 lora 和各种 sft 的训练变种,我只聊最朴素的 sft。
我理解 lora 的出现就是为了省显存,在有算力做全参训练的情况下,似乎没啥优点,可能能防止过拟合?那我少训点数据,或者开 dropout ,调学习率也能防止过拟合呀,我在实际工作中几乎没用过 lora,身边同事也不怎么用。
至于针对 sft 做的各种训练方法的变种,比如蒸馏训练,我的观点是普适性不强,并不适合所有的模型和场景。不如让子弹再飞一会,究竟是旁门左道,还是像 DPO 这种真心不怕火炼的万金油工作,现在还没有定数。不搞学术研究的话,老老实实的用最朴素的 pretrain_gpt.py 的训练方式做 sft 又不是训不出来,没必要做过于激进的尝试。
还有一些只更新部分模型参数的训练,比如某些层微调 / embedding 微调 / attention 微调 / mlp 微调。做这种实验一般都有特殊的需求,实现起来也很容易,把不想更新的参数冻结即可,我也不专项讨论了。
训练框架
不同于 pretrain 阶段我力挺使用 megatron,我反倒觉着 sft 阶段用 deepspeed 挺好的。
由于 sft 的训练语料不是很多,使用 deepspeed / megatron 的训练代码都可以,速度性能上的差异也就是带来一个小时左右的时间,无伤大雅。deepspeed 在 sft 阶段的优点主要有:
alignment 的很多开源工作和开源代码都是基于 deepspeed 实现的,复现起来省事; 利用 AutoModelForCausalLM 可以直接训起来大多数开源模型,而不需要每次都 trans_hf_to_megatron; 训出来的模型可以直接起 tgi 服务,vllm推理等,用 megatron 的话还得 trans_megatron_to_hf。
不管使用哪个框架,有几个参数是着重需要注意的,每次启动训练前都要看一遍怎么设置的:
epoch gradient_accumulation_steps global_batch_size ( megatron 的参数,deepspeed 同学可无视) learning_rate lr_scheduler_type dropout
有几个参数你需要需要知道为什么要打开或者设置成这个值,它们会直接影响训练速度:
zero_stage max_seq_len offload gradient_checkpointing seq_parallel_size
有几个参数对模型的训练影响不是那么大,但必须知道它们是什么意义,对其他参数有什么影响:
weight_decay per_device_train_batch_size num_warmup_steps
每个参数的具体含义我就不想分析了,查查官方教程,或者问问 ChatGPT 都会有明确答案。如果连这种最基础的东西都不愿意自己去调研一下的话,那可能不适合 LLM 这个方向。
炼丹技巧
接下来讨论炼丹技巧。其实也没啥说的,翻来覆去那几句话:小模型大学习率,大模型小学习率,epoch 基本上就是 1~3个,数据是 10W 级别左右,稍点多点都行,但少不能少于 1W,多也不能到达 100W (没有理论,数据量是偏经验的一些结论);起始训练适当做点 warmup,几种主流的 lr_scheduler 都试一下,gradient_accumulation_steps 是个比较重要的参数,就 16 / 32 / 64 / 128 等数字都尝试下;按需求开 dropout,这东西不开没啥大事,开了反倒容易训炸。
其他的好像真没更多东西了,不过 loss 曲线倒是需要重点留意几个细节:
不同 task_type 要有不同的 channel_loss,分别观察; special_token 的 loss 一开始会有点高,但是下降也是很快; 创作类任务的 loss 会比其他任务的 loss 更高一点,这个现象很 make sense,答案越固定,搜索结果越单一的语料,loss 越低,反之亦然; 只要训练语料是通用数据,且数据进行了 sample,那么模型的初始 loss 就不会特别高,7B / 13B 可能在 2 左右,数据难了也有可能到 3,72B 则大多在 1 ~ 2 之间这个水平状态;最终 loss 则大概在 0.5 左右,根据语言模型定义,如果 loss 更低,那基本上模型只会说这一句话了,别的 token 都没概率了; 如果 loss 持续升高,不要对自己的训练数据产生任何质疑,想着是不是数据太难了不好学之类,这就是训练代码有问题。next_token_prediction 的训练方式就是在背书,它不存在学不会的情况,只存在学会了但不会泛化的情况。你就算是一堆随机乱码,模型啥也学不到,它也应该是 loss 持平,而不是 loss 升高啊。
哦对,关于 loss 还有一个非常关键的地方:阶梯形 loss。很多文章都分析过这个现象,基本公认的结论就是这是模型过拟合的体现,我在这里不过多分析了。
我想说的点是:sft 过拟合真的就是坏事吗?我们都知道 sft 阶段就是让模型学会指令 follow 能力,而 follow 指令的直观表现就是对问答格式过拟合,“模型开始回答 question,而不是续写 question”。因此,我认为 sft 过拟合并不是一个坏现象,至少格式过拟合肯定不是,我们怕的是模型对 answer 内容过拟合,不管什么 question 都只会车轱辘的重复一个 answer。
这里我举个例子,question:“以 json 格式输出 XXXX”
answer1: ```json…… answer2 : 好的,这就为你用 json 输出结果 ```json……
answer1 就是对格式过拟合,如果模型被训的失去了回复 answer2 的能力,这是坏事吗?显然不是。
所以,阶梯形 loss 毕竟只是过拟合的一种体现,可能是格式过拟合了,也可能是 answer 过拟合了,观察到这种现象是合理的,并不代表模型训得过火了。我们需要的是理解为什么有阶梯形 loss,而不是花精力去研究如何避免阶梯形 loss。
提到阶梯形 loss ,还涉及到一个话题是模型该训几个 epoch。我个人青睐 3 这个数字,很多技术报告(比如 llama 和 qwen)青睐 2 这个数字,还有很多人认为 1 个 epoch 学一下格式就够了。我的饭搭子更极端,他认为应该是 1.1 个 epoch 这种数字,因为训练初期,模型重点关注了对 special_token 语义和输出格式的学习,answer 反倒学的不充分,所以训练初期的 10% 语料应该再被训一次。
怎么说呢,多实践,think is easy,show me the experiment result。不同的 base 模型,不同的通用数据,可能适合不同的 epoch,多点实验,少点脑洞。有时间分析这么细就 1 / 2 / 3 都对比试一下,没时间分析这么细就认准一个喜欢的 epoch 去调整其他参数。
拟合问题
现在忘掉上面介绍的 sft 格式过拟合这个概念,我们讨论下“sft 欠拟合 / 过拟合”这两个常见问题。
欠拟合
欠拟合,字面意思,模型没学好训练数据,做下游任务的能力很差,一般只表现在某个 task_type 上。导致欠拟合的因素有很多:训练集里的相关 task_type 的数据量与数据质量,task_type 的难易程度,模型本身的 size 和能力,epoch、学习率、以及 gradient_accumulation_steps 的设置,甚至是和训练方式有关。
欠拟合首先要确定一个问题,是真的连训练数据都没学会,还是说学会了训练数据但无法进行泛化。测试方法也很简单,直接让模型回答训练集,如果这个都答不上来,那就是没学会,再多训 1 个 epoch,多补充一些 task_type 的训练数据,学习率适当调整等方法均可以解决这个问题(我问了 ChatGPT_o1 如何调整学习率,它的答案是数据欠拟合时,观察 loss 曲线和梯度,如果 loss 下降缓慢就增大学习率跳出局部最小值,如果 loss 比较震荡学习很困难就减小学习率提高训练稳定性)。要是反复训都学不会训练数据的话,请 review 训练代码。
接着,如果说训练集学会了,测试集完全不会(注意,是测试集完全不会,而不是测试集老输出训练集的 answer,后者大概率是过拟合了),那相对麻烦一些。因为我们需要确定模型能不能做好这个 task_type,还是说 answer 是否干净或 answer 的表达方式是否合理。
模型能不能做好这个 task_type 需要结合一些主观判断,我们需要知道什么任务是难任务,诸如复杂指令、逻辑推理、数学计算等任务,很可能这个 size 的模型压根就不行。这时候一个比较常用的方法是看别人的技术报告,下载同 size 的开源模型测能力,多和同行交流经验(这种经验大多都属于可以分享的话题)。此外,别人家的孩子成绩好可不代表自家孩子成绩好,很可能自己的模型就是在 pretrain 阶段欠缺某 task_type 的基础知识,还是那个例子,pretrain 阶段没学过唐诗宋词,sft 阶段训再多作诗数据也没用啊。如果不知道 pretrain 阶段模型的学习情况,那就多让它续写一下这个 task_type 相关的语料,看看掌握程度,最好也和开源的 base 模型对比一下。
假设模型具备这个 task_type 的基础知识,且该任务也不算很难,那大概率就是数据的问题了。抽样一些 answer 去 check 质量,不仅仅是回答的是否准确,还要看回答是否符合语言模型。这里一定要读慢一点,一个字一个字的读,看看 answer 里是不是每一句都有逻辑 —— 很多话只是读着通顺但表达很差劲,人能理解并不代表模型能理解。典型例子就是倒装句,根本不符合语言模型,模型很难学会。check 完之后,觉着数据质量也挺 OK 的话,那就只能再多造点训练数据了,或者上采样一下已有的训练数据(别的数据训 3 遍,这个数据训 6 遍)。
至此,如果还是解决不了欠拟合问题,那就只剩一个杀手锏了:重写 prompt。意思就是,把这个 task_type 里的 answer 的知识尽可能的削减,prompt 里的背景知识则尽可能的增加,甚至可以把 answer 中常用到的 token 都以某种表达方式写进 prompt 里,提高生成这些 token 的概率。也就是说,我们要想方设法的减少这个任务的复杂程度,把一道高考考题改成中考难度,模型总该会了吧。前文提到的拆解 prompt 也算是重写 prompt 的一种情况。
欠拟合还有一个可能性,那就是训练方式,这里分享一下去年的一次 sft 经验。
去年在 llama2 的技术报告刚发布的时候,meta 说他们的 sft 阶段和 pretrain 阶段一样把数据 concat 成 4K 长度做训练,且不做 attention_mask。我们觉着很合理,不同的数据 concat 在一起,可以培养模型学习 session 是否和当前 query 有关的能力,因此就去复现了,结果发现“测试数据中,分类任务的 ACC 有明显下降”。
经过一通分析后,我们发现,新的训练方式改变了短 answer 数据的 loss 占比,毕竟模型在计算 loss 的时候,是先算一个句子内每个 token 的 平均 loss,再算一个 batch_size 内的平均 loss。
分类任务的 answer 通常只有 1 个 token:不 concat 的时候,它的 loss 贡献就是 1 / batch_size;concat 的时候,它就需要先和别的 answer 的 token 算平均 loss,再贡献 1 / batch_size。
这也就是说,采用 llama2 提到的 先 concat 语料再做 sft 训练,会对短 answer 数据很不公平,也就更容易造成短 answer 数据的欠拟合,pretrain 由于所有 token 都算 loss 则没有这个现象。最终,我们通过上采样短 answer 数据,成功的避免了分类任务的效果下滑。
过拟合
过拟合也是字面意思,模型训得太狠了,对训练集里面的 answer / pattern 记得太死了。相比于欠拟合,过拟合则好解决很多,至少模型已经具备了这个 task_type 的能力,只不过是能力被限制在一些 token 或者一些 pattern 上了而已,想办法缓解即可。
sft 的过拟合并不像传统深度学习一样,通过调整训练 epoch、学习率、dropout、weight_decay 来解决。因为大概率模型只是某项能力局部过拟合了,大部分能力都是正常的,盲目调整超参数反倒会让模型整体上欠拟合。
具体地,在确定模型并没有全局过拟合之后(如果是全局过拟合,模型整体的效果应该都很差劲,那就通过炼丹来解决,这里不赘述了),我们主要的解决方案是通过优化训来数据来缓解过拟合,主要措施是删减对应 task_type 的数据,或是扩充该 task_type 的数据多样性。
过拟合的难点是让模型暴露出来它到底对什么过拟合了,好让我们去 grep 对应的训练数据来做修改。通常,我们观察到模型过拟合是因为它回答错了某个知识,而且是非常蠢的错误:比如日本的首都是北京。
首先,通过让 base 模型续写,判断是不是 pretrain 阶段训错了(通常情况下都不是),如果是的话,那没辙了,强行在 sft 阶段做知识注入来扭转 pretrain 的错误知识吧,一两条语料影响应该不会很大; 然后,判断 sft 模型对哪个 pattern 过拟合了,对 answer 里面的核心关键词进行测试,也就是“日本”,“首都”,“北京”。对着我们的模型发出一连串的提问,美国的首都是哪里?日本最大的城市是哪里?北京是谁的首都?日本的首都是北京吗?日本的首都是东京吗?……目的很简单,测试出来模型到底对哪个 token,哪个 pattern 过拟合了,到底是把日本的所有城市都回答成北京,还是把所有国家的首都都回答成北京。进而 grep,大概率类似 pattern 的语料都过多了。这时候或删除该 pattern 的语料,或改造该 pattern 的语料,都无所谓了。
小结
拟合问题应该是 sft 工作者遇见的最多的问题,每天都在过拟合和欠拟合上左右横跳,说再多 debug 技巧也总有查不出来问题的时候。
这时候没有更好的办法,唯有多尝试,让 pretrain 模型续写,让 sft 模型回复,观察模型从哪个 token 开始回答错误,观察模型是格式不 follow 还是内容错误,测试训练集中的 prompt,测试和训练集相似的 prompt ……
多试 case,多 grep 训练数据,少在 debug 的时候看可解释性论文,便能找到模型不拟合的原因。
夹逼准则
夹逼准则是我的饭搭子提出的,当时我们在做领域模型后训练,我的各种数据配比都会使模型通用能力掉点,他说了句很经典的话:“学习率等于零模型的通用能力是不是就不会掉点,那根据夹逼准则,你肯定能找到一个合适的学习率让模型的通用能力只下降一点点(已知模型不掉点是一个我能达到的上限)。”
本质上,夹逼准则是在强调一件事情:确定模型可达到的效果上限在哪里,进而定位训练问题。
经验分享
我的同事和我讨论过 continue-pretrain 训练领域数据的时候的一个奇怪现象:一份考试类型的数据被划分成 train / test 集合,然后先把 train 集合混合了 common 数据做了 pretrain,一段时间后,test 集合的 loss 明显下降,但是模型在 test 集合上的做题得分却没有任何变化。
背景就是这么个背景,我这里不会再说后续,只讨论下如何利用夹逼准则去分析这个现象。
第一步:test 集合是一个考试类型的题目,大部分内容是选择题,也就是说 question 的长度会远远大于 answer 的长度,毕竟选择题的 answer 只是一个 token。那么,观察到的 loss 下降,大概率是模型会背下来这道题目造成的,而不是模型提高了 Prob( answer | question) 造成的。也就是说,不能因为 test 集合 的 loss 下降了就轻易下结论说是模型学会了这个知识。所以,我建议不直接看在 test 集合上的做题能力,而是看 sum[ Prob( answer | question)] 这个整体概率是否提升,因为模型可能是量变还未引起质变的情况,如果这个值在提高那就继续往下训,没到火候而已(当时,这个“概率和”好像也没啥变化,所以要继续往下分析)。
第二步:既然是混合了 common 数据的 continue-pretrain 训练任务,那就说明领域知识是以一定的比例在参与训练。现在模型没有学会这个知识,可能的原因有多个:训得太少了,学习率设置的不够合理,领域数据质量太差根本学不到知识等。这时候我们逐一分析,先利用夹逼准则来判断数据质量是否存在问题。
我建议同事基于 base 模型做一个 100% 领域集合的训练任务,亦或者是直接拿领域数据做 sft 训练。这两种方法都突出一个“放弃 common 能力,只学领域模型数据”,因此训出来的模型就是“当前模型 + 当前领域数据”所能拿到的最好的效果。如果这样子训练出来的模型都无法提升 test 集合上的做题能力,那这份数据估计是废了,大概率选项都是随便给的。但换句话说,如果这样子训出来的模型大幅度提升了在 test 集合上的做题能力,也可以基本敲定数据没问题,该研究超参数了。或调大学习率、或增加领域数据的占比,等等方案均需要尝试。
实战思路
前面提到的拟合问题,就可以用夹逼准则来进行问题定位。
我们不知道模型没拟合是数据问题还是训练问题,那就只训目标 task_type 的数据,并且训 10 个 epoch,总该学会了吧。如果这个都学不会,说明“这个模型 + 这份数据”的上限就是学不会,那也就不要再为难模型了。
如果模型学会了,那就说明“能学会”这个上限是可触及的。把训练配置改成训 5 个 epoch,改成 task_type 数据只占比 50%,一点一点的往下采样,逐渐的去找到那个“临界值”:临界值小于这个边界值模型就学不会,大于这个临界值模型就能学会。当问题分析到这一步,基本也就能定位到不拟合的原因了。
评估篇
评估方式
首先,我们肯定要事先准备好一个高质量的评测集合,这个评测集合要和 sft 训练集合一样有明确的 task_type,在这份评测集合上的“模型可用性”就是模型能否上线的标准。
不同于 pretrain 的评估只需要看知识能力,sft 的评估是需要看经典的 3H 原则的:Helpfulness、Honesty、Harmlessness。当然,实际工作的评估中,倒也不必完全是按照这三个原则进行评估,可以按需求制定自己模型的指标:是否 follow 指令,是否 system 穿透,是否内容准确,是否产生幻觉,是否安全……等等等等。
评估的时候,每个维度都要有一个单独的得分,最后根据自己制定的加权公式来确定这条回复的可用性。这样,当模型在某个 case 上的得分变低的时候,我们能比较直观的看出来到底是哪个维度变差了,好结合训练数据做 case 分析。这里需要提醒一句,做评估的时候,一定要了解自己的整个 LLM 系统。大部分公司的 LLM_engine,前置黄反拦截是很严格的,这种情况下你天天盯着裸模型的安全性做评估,就完全没意义。有 RAG 的系统,过度关注裸模型幻觉也同理。我们需要了解哪些是 LLM_engine 无法兜底的内容,那才是要重点评估的能力项。
至于评估方式的话,目前基本是两种:
机评:在利用 GPT4 进行评估的时候,prompt 一定要好好揣摩。大模型的评估是有明显偏好的,对比评估的时候,A 和 B 倾向于选 A,长句子和短句子倾向于选长的;打分评估的时候,一个正确的 answer 让模型打分三次,很可能就分别是 3、4、5分(假设是 1—5 分的候选区间)。
关于如何写机评 prompt,我这里也没太多可分享的,只能说多做尝试。Alignbench 的评估 prompt 就写的挺好的,可以进行仿照。对了,给模型一个参考答案很多时候能让模型的打分更准一些,毕竟模型会考虑“候选 answer”和 “gold answer”的相似度来进行打分。
人评:我估计人评是应该是各 LLM 公司所使用的主流方案,它并没有大多数人想象中的又贵又笨重(人评大概率是没 o1 评估贵的)。我们其实要明白一点,人是有长记忆的,人的 kv_cache 可是能保存好多天的。sft 的评估环节最耗时的其实不是看 answer,而是看 prompt,如果我们对 prompt 非常熟悉,我们就知道模型的 answer 容易犯什么错,关注重点是什么,效率就会很高。
所以,只要不天天改评测集合,那么评估同学在做了两次评估之后,基本就把 prompt 记在心里了,这时候的评估效率就嘎嘎的高,基本上速度是不逊于机评的。
评估分析
sft 的评估结果分析其实就是做 case 分析,大部分的方法我都在训练篇的分析模型不拟合时已经提过了,这里我只再额外强调几个细节。
sft 的评估一定是对比评估!这句话的意思是,你直接看模型在测试集上的表现是没有指导意义的,你一定是和上一版模型的测试集合并在一起来观察的。当你观察到 sft_v2 相比于 sft_v1,A 能力项的平均得分有明显下降,就去回想一下这一版训练数据里 A 能力项的数据分布是否有变化:
有可能是 A 能力项的训练数据被更新了,质量不如上一版数据,需要重新 review; 也有可能是整体的训练数据变多了,导致 A 能力项的数据占比变小了,有点欠拟合了; 还有可能是这次训练的某些参数没配好; 当然,最有可能的一种情况是“波动”,这时候,你就重点挑两个 sft_v2 回答很差劲的 case,然后用 do_sample 的方式让 sft_v2 与 sft_v1 都预测个十次。我们看一下,是不是 path@10 也变差了,如果不是,那就不用很在意了。波动经常会影响 path@1 的结果,但对 path@N 的结果往往影响不大,训练方式才会让后者发生较大的改变。
总结下来就是,每次拿到新的评估结果后,看一下模型在各个 task_type 上的可用性得分。得分明显变好和得分明显变差的 case,我们都要去反推训练数据和训练参数有什么变化。得分变好的情况就是我们以后洗数据的好经验,得分变差的情况就是我们下一版优化的目标,反复迭代即可。
总结篇
结语
在大模型浪潮初期,我和我的前辈曾经有过一段对话。
我:这工作(某个方向的 sft)交给我合适吗,我能胜任吗? 前辈:这工作谁都能做。 我:那你为啥选我来做? 前辈:这不是因为我认识你,跟你熟悉嘛。 我:…… 前辈:你做不做,不做有的是人想做,不行我招个实习生来做。 我:我做我做,我当然做。
谨以这段对话表明一个观点,sft 真的很简单,它没有传统 NLP 任务中经常涉及到的:“训练代码开发、输入特征设计、网络结构魔改、模型不收敛、训练数据难以构造、千万条级别的训练语料处理、训练语料去重/平滑/采样、长尾语料设计” …… 因此不要有任何畏难情绪,每一个有计算机基础的新人,只要态度端正,都能快速胜任 sft 这个工作。
大模型这波技术浪潮,拼的不是谁代码写得好,拼的是谁有算力、谁有训练经验、谁有 debug 能力。当算法新人还在纠结模型这个 case 为什么回答不对、那个 case 为什么瞎说的时候,经验丰富的人看眼 case 就知道:训练数据有猫腻 / 训练数据缺少某个能力项 / 这 case 为什么没触发 RAG / 有黄反拦截倒也不必纠结。
总之,多了解自己的 base 模型的能力,多培养训练 feel,就是做好 sft 的重要法门。
致谢
此文依旧感谢饭搭子 https://www.zhihu.com/people/bf1764dccc55b8f831b89c9103f41564 ,还有 tao 哥、奆佬等人(为了避免泄漏隐私,没外号的伙计们我就不列了,不是我心中没有你们哈)。学新的技术离不开思维碰撞,没有好友之间的相互分享和激烈辩论,一个人闷头自学的效率是很低的。
备注:进群,进入大模型技术群
id:baobaogpt,记得备注呦