大模型中LLM训练技巧(干货满满!!)- SFT

文摘   2024-10-09 08:25   上海  

作者:ybq
链接:https://zhuanlan.zhihu.com/p/809229182

点击下方卡片,关注“自动驾驶之星
这里有一群奋斗在自动驾驶量产第一线的小伙伴等你加入


  • 背景篇

这里先普及一些 sft 涉及到的基础概念,方便新人同学理解后续内容,老同学则可以跳过这一篇章。

  • Special Token

pretrain 阶段完全没见过的 token,在sft 阶段会被赋予全新的语义。主要用于标注对话的角色:user、assistant、system 这些。

此外,special_token 可以用来“构造知识”,比如"<special_token_1>喜欢<sepcail_token_2>"这种知识一定是 sft 阶段才会见到的,可以剔除掉 pretrain 先验知识的影响,用来验证 sft 的训练情况,比如会不会过拟合。

我默认大家都知道怎么用 special_token 去拼 prompt,如果不熟悉,看下 tokenizer_config.json 里的"chat_template"这个字段也就懂了。

  • 耗时问题

模型的预测时间可以近似理解为:𝑦=𝑘𝑥+𝑏y = kx+b ,其中 b 是首个 token 的耗时,k 是后续每个 token 的耗时,x 是生成 token 的总数量。更具体的,b 会是 k 的十几倍或更多,和 prompt 的长度几乎呈正相关。这个耗时的近似估算和 KV_cache 机制有关,不熟悉的可以自行搜索。

这也就是为什么众人都知 cot 效果好,众人又都不使用 cot,因为我们可以几乎下断言“模型的生成速度和生成 token 数量呈正相关”,而 cot 恰恰又引入了大量的生成 token。

此外,prompt 的长度也并非无所谓,尽量不要在 prompt 中写那么多废话,它毕竟和首包耗时呈正相关,在生成 token 不是特别多的情况下,是影响模型耗时的主要因素。

  • 与 pretrain 的区别

首先,sft 和 pretrain 在训练方式上没有任何区别,主要区别在于数据的组成形式上:

  1. pretrain 的每条数据都是满编 4K / 8K,sft 的每条数据原本多长就是多长;

  2. sft 会引入 pretrain 阶段未见过的 special_token,来让它们学习全新的语义;

  3. sft 会让模型见到最重要的 eos_token,pretrain 模型因为没见过该 token 而无法停止生成;

  4. 借助 special_token,sft 会把语料切分成不同的角色,标配的有 system、user、assistant,根据业务需求也可以有“背景”、“旁白”、“事件”等等;

  5. sft 的 prompt 不做 loss,但这并不是说它不能做 loss。主要原因是 prompt 的同质化比较严重,不做 loss_mask 的话,同样的一句话会被翻来覆去的学,但如果你能保证你的每条 prompt 都是独一无二的,就完全可以省去 prompt 的 loss_mask 环节。对了,session 数据一定要想清楚是每一个 answer 都算 loss,还是只对最后一轮的 answer 算 loss

除此之外,训练目的也不一样。pretrain 是在背书,纯粹的学习知识;sft 则是在做题,学习的是指令 follow 能力。切勿在 sft 阶段强行给模型做知识注入,比如训个 50W 条的 code 数据,所有的知识注入工作应该采用 continue-pretrain 的思路进行,否则都会使得模型的通用能力掉点明显(sft 做知识注入基本上是 100% 某个知识,但 continue-pretrain 做知识注入会控制在 10% ~ 20% 左右的比例)。

  • 幻觉问题

首先,我们需要知道什么是幻觉?广义的幻觉指的就是模型回答错误,一本正经的胡说八道;狭义的幻觉指的是模型本身具备某个知识,但是经过 alignment 处理后就开始回答不对了。

目前的技术路线,前者属于无解的一个问题,唯一的优化点可能是通过 sft / rlhf 让模型知道什么时候拒绝回复,但也仅限于训过的同类型 case 能拒绝,没训过的 case 依旧胡说八道,泛化效果很差。后者是我们重点优化的方向,它是可以解的,或者说有尽量缓解这个问题的方法。

我们举个例子来理解狭义幻觉:如果 pretrain 阶段喂给模型的数据一直都是“日本的首都是北京”,那么在 sft 之后模型可能出现两种回复:

  • User:日本的首都是哪里?Assistant:日本的首都是东京(幻觉)

  • User:日本的首都是哪里?Assistant:日本的首都是北京(正确)

判断某个问题是不是狭义幻觉的直接方法就是:让 pretrain 模型续写某个知识点,然后看续写结果和 sft 后的回复结果是否一致。

幻觉可能是 LLM 话题讨论度最高的一个问题,因为其实验成本小,并且可以通过魔改网络结构、loss 函数、推理方式、训练方法等技巧来稍微缓解,备受学术界青睐。然而,工业界却并不是特别在乎这个问题,主要原因有下面几点:

  1. 广义幻觉和狭义幻觉在降低用户的交互体验时并无明显区别,做通用 AI 助手并不需要区分这两种情况,而现有技术范式下,广义幻觉只能靠外挂 RAG、function_call 的方式来解决;

  2. 狭义幻觉的缓解方式其实还是调参数,那些魔改 GPT 的工作,并不会比调参带来更大的收益。这些工作在不同的基座模型上的收益也完全不一样,还是太 trick 了,多少有点旁门左道的感觉;

  3. 目前,工业界的 AI 助手是一个全链路系统,裸模型的的安全问题与幻觉问题,会被上下游的各种小模型和词典配置进行拦截或者改写,并不会直接暴露出来。

我个人倾向于把狭义幻觉视为是过拟合的一种体现,也或者说是引入 special_token 和固定输出格式所必然引起的一种知识丢失现象。所以本文不再讨论如何减少幻觉,对幻觉感兴趣的同学可以去搜索相关论文。

  • 数据篇

先分享下 sft 工作者的一天:晚上下班挂个精心准备的实验,早上起床看结果并随手挂个实验防止 gpu 资源浪费,白天做一天的 case 分析,晚上下班挂一个结合 case 分析结果优化完数据的新实验(完成闭环)。

因此,不用质疑,分析数据和清洗数据就是 sft 工作者的 90% 的工作量。

  • 数据多样性

经历了一年多的磕磕绊绊,目前的 LLM 从业人员大多都会认同:sft 训练数据的核心是数据多样性和数据质量,数据数量并不重要

数据质量就不谈了,prompt 可以不那么严谨,能看懂就行,但 answer 是尽量一个标点符号都不要有错误的,该中文引号就中文引号,该单引号就单引号,该把 GPT4 啰哩啰嗦的回复精简一下就精简。

我们重点说说数据多样性。即使到了今天,也没人能定义清楚说怎样的一份训练数据叫做数据多样性足够好。我们能做的只能是从先验的角度,把模型能遇到的各种任务类型都让它见一次。从个人经验来说,我认为数据多样性主要包含两个维度,“数据用途”和“数据形式”。

先说数据用途,也就是 task_type,可以结合这几个思路进行数据收集:

  1. OpenAI 官网列出了 ChatGPT 擅长的所有任务项,诸如翻译、emoji 聊天……之类的。我们就每个任务项都想办法来一点数据,照着尖子生的作业抄;

  2. LLM 毕竟是个语言模型,传统的每个 NLP 模型它都应该能胜任,那就把什么 NER、机器阅读理解、意图识别等传统的 NLP 任务也给模型补充一点,如果已有类似任务就不补充了。训练数据也很好搞,传统 NLP 数据集质量都很高,直接拿来用就行;

  3. 参考业务需求,下游业务需要某个特殊场景的任务,那就让 sft 阶段提前见一见,这种数据的典型代表就是过年前给模型灌一些对春联、猜灯谜的的数据。只要数据质量没问题,一般都不会破坏模型能力;

  4. ……

重点来了,每一条 sft 训练数据必须要 task_type 类型,千万别搞大杂烩,否则对后续的 case 分析简直是灾难性的伤害。在实际工作中,双层 task_type 都很常见,比如“逻辑推理 - 常识推理”,“逻辑推理 - cot 多步骤推理” 这种。至于每种 task_type 的数据量,别搞平均主义:难 task_type 酒数据多点,简单 task_type 就数据少点,也要结合自己的 base 模型能力动态调整。

task_type 的划分就是 sft 数据最重要的基建工作,没有之一。

我们还需要从数据形式的角度来兼顾数据的多样性:

  1. prompt 表达方式多样性,不要千篇一律的“把中文句子 A 翻译成英文”,也要适当有一些“我在英国旅游,我现在需要向路人问路,我想表达 A 的意思,该怎么说”,“我是一个英文老师,我需要向我的学生讲解句子 A 用英文怎么写,请你用最正宗的表达方式帮我完成。”这么做的目的是防止模型只认识 prompt 中的几个关键 token,进而导致训练过拟合或者泛化性变差;

  2. prompt 长度均衡,既要有短数据,也要有长数据,避免模型的 attention 退化到无法聚焦长 prompt。长数据还不能是字面意思的长,要有那种关键信息藏在 开头 / 中间 / 结尾 的各种数据场景,避免模型在训练时偷懒,只对 prompt 的起始 token 或结束 token 有 attention;

  3. answer 长度均衡,不能让模型没出输几个 token 就停止,适当的有一些语料让它学会输出尽量长的 answer,否则模型会很难 follow “不少于2000字” 这种指令;

  4. 多轮聊天的切换 topic 能力,也就是说,有的数据当前 query 是和 session 有关系的,有的数据则是当前 query 和 session 毫无关系,要让模型自己学会判断 query 是否和 session 有关。类似的数据还要有 system 是否生效,有些数据 system 是个摆设,有些数据的 answer 则和 system 直接相关;

  5. answer 分布的多样性,这最重要,千万别总共一万条训练数据,一千条数据的 answer 都说同一句话,answer 可是算 loss 的,太单一的话会严重让模型过拟合;

  6. ……

概括起来,所有的数据形式多样性都可以总结为一句话:数据形式不能让模型找到规律,关键信息在 prompt 中的位置分布要足够随机。目的是避免模型在训练时退化,只聚焦于某些或某些位置的 token,而不是聚焦于完整的 prompt。模型和人一样,骨子里都是有偷懒倾向的。

  • 数据生产

  • 生产 prompt

说实话,我已经不太记得通用模型的 prompt 是怎么造的了,那都是去年的工作,感觉当时都是直接翻译英文数据集的 prompt 并重新标注完成的。印象里,斯坦福有一个 self-Instruct 的工作,给每个 task_type 准备一些 seed prompt,然后随机采样 seed,在喂给一个能力很强的 pretrain 模型,让它基于这些 seed 问题再续写出一些问题。其实也不必是 pretrain 模型,GPT4 模型的指令 follow 能力已经足够强了,让它基于一些 seed 问题直接仿写出一些 prompt 也是可以的。

今年的话,应该有很多现成的 sft 训练集,或者是 nlp 训练集,想个办法到处搜刮一下,然后简单筛选下质量就行,反正我们只要 prompt,并不要 answer。最近讨论的比较热的“合成数据”,基本也都是各种启发式规则造 prompt,可以重点留意一下。按照我前文中介绍的数据多样性,去搜集不同 task_type 的数据集集合,然后适当做做改写。实在是找不到合适的 prompt ,就自己动手写一点,answer 写不出来,prompt 还能写不出来吗?

特别要注意,收集或设计 prompt 的时候一定要结合实际情况,不要指望模型一次性写一篇万字爽文,这种事情连人都做不到。我们要把比较困难的任务提前拆解好 prompt ,比如:

  • prompt1 :请设计一个重生故事的大纲,大纲包含“父母重男轻女,女主高考状元,弟弟彩礼”等要素;

  • prompt2 :请基于给定的故事大纲,扩充内容,生成一篇不少于多少字的文章。

LLM 只是知识量比人多,而不是知识掌握度比人精细。如果普通人做起来都费劲,那这个 prompt 大概率是需要拆解的,这在“利用 sft 后的模型去对接业务”时格外重要。

  • 生产 answer

GPT4 is all you need,这里的 GPT4 不仅仅是字面意思上的 GPT4,还可以理解为 good model 的意思,指的是利用一个效果好的模型来生产 answer。

  • 不在乎成本,就选 GPT4 / Claude 3,用过的人都说好;

  • 在乎成本,就在自己的机器上部署 Qwen_72B / deepseek_MOE,部署过的人都说好;

  • llama 系列的模型就算了,它的中文能力,体验过的人都说不好;

  • 文心 / 豆包 等效果不如 GPT4 的闭源模型,属于品味之选,为国产大模型助力,点赞。

这里需要注意,你一定要知道你喜欢的模型适合用什么 prompt,提前在 ChatGPT 的 playground 上多测一下,找到模型回复效果最好的 prompt,该加 few_shot 就加 few_shot (few shot 最好有一个种子池,不然模型的回复会比较单一),访问 GPT4 的 prompt 并不等价于喂给模型的 prompt

然后,我们说最实用且最经济的一个方法:训个小模型,这里再次搬出万能公式:小模型 + SFT ≈ GPT4 + zero_shot / few_shot / cot(复杂指令和逻辑推理可能不行)

开卷考试就是这么无解,小模型知道考卷是什么,然后只学什么,就是能考出来好成绩。对于某种特殊需求的 task_type,我们利用 GPT4 生产一千条 answer,然后去训小模型,再利用小模型去预测出上万条数据,这个方法真的十分非常相当的好用。特别地、利用 GPT4 生产数据的时候,由于模型不 follow 格式,数据可用率大概只有 70% 左右,但是利用自己训的小模型生产数据,那可是 100% 的 follow 格式。

值得一提的是,任何模型在预测的时候,有 cot 确实比没有 cot 效果好很多,尤其是分类任务。这很容易理解嘛,直接说答案肯定不如分析完每个选项再说答案靠谱。我前面提到过,实际工作中,出于耗时的考虑,可能不会用 cot 来训模型,但是数据生产的时候,为了保证回复质量还是应该让 GPT4 用 cot 的方式进行回复,我们在训自己的模型的时候,省去 cot 环节即可。

最后,苦力还是要做的,GPT4 也好,自己训模型也罢,还是会出现出现数据质量不可用的情况,这时候必须要写规则,或者通过肉眼看来做个校验。数据去重环节也得做,因为一个模型针对一种 task_type 生产出来的数据,同质化十分严重,一定要避免 answer 过于相似的情况发生,实在看不过来就大批量剔除生产的训练数据吧。还是那句话,sft 数据要的是质不是量。

  • 小结

数据质量就是 sft 工作最核心的内容,数据生产工作一定不能当甩手掌柜,把 excel 给到标注同学后,等他们标完看都不看就直接拿来用。有时候,把想办法“造数据/ 洗数据”的时间拿来手动标数据,工作早做完了,还能加深自己对数据的理解,所以不要把事情复杂化,也不要排斥去做那些所谓的“脏活”。

prompt 的表达方式,answer 的回复风格,训练者一定要烂熟于心。

数据飞轮

模型的上线不并代表着 sft 工作的结束,它反倒代表着 sft 真正工作的开始。只有到了这一刻,我们才开始接触“最真实的用户 prompt”。

前面说了,prompt 的生产是需要有 seed 种子的,也就是终归是有限的,但用户的脑洞是无限的啊,用户的 query 就是我们的候选 prompt 数据集。尤其是多轮聊天数据,自己生成的多轮对话数据,通常都默认模型回复的是正确的,用户会 follow 模型的回复。但线上可不是这种情况,你聊你的,我聊我的是时有发生的事情。

以代码任务为例,我让 GPT4 模型给我写个代码,它写了,我复制粘贴加执行,然后报错了,我把报错复制粘贴发给 GPT4,它修改了代码,我又执行还是报错 …… 重复了这个流程4、5 轮之后,它写的代码终于执行成功了。显然,我和模型的这 5 轮对话数据,就是最好的多轮理解 + 代码生成数据,但它几乎没有任何能标注出来的可能性,只能靠捞用户日志来获得。

不仅如此,用户日志往往还配了“点赞 / 点踩”的选项,甚至还能为 dpo / rlhf 生产数据呢(一定要清洗,这种数据很脏,我朋友说他每次都是反着点的,就是故意要污染 OpenAI 的训练数据)。

用户的 prompt 天然比我们自己准备的 prompt 复杂,我们自己的 sft 训练集可能就是让模型翻译一个句子,但是用户的需求可不这么简单,用户会让模型把翻译后句子的某个单词换一个表达方式,或者是提问这个句子中某个的单词是什么意思。因此,基于用户 log 生产的训练数据,是很适合培养模型的话题转移能力,自我纠错能力,坚持己见能力,结合新需求重新改写答案的能力,等等。

只有把“定期拉取用户日志,利用规则筛选有价值的 prompt,访问 GPT4 获取答案,加入新数据更新模型”这样的数据飞轮 run 起来了,我们的 sft 工作才进入到了一个良性循环状态。

这里再额外说一个东西,我们的训练数据最好有一些“鲁棒性数据”:也就是 answer 很正常,但 prompt 表达很差劲的训练语料 。prompt 差指的是,它或者是有错别字,或者是话没说完整,亦或者是中文英文拼音夹杂着表达。不用担心会破坏模型效果,毕竟 prompt 根本不算 loss,这么做的目的是适应线上用户的糟糕表达,没有一个用户会希望听到“不是我们的模型不行,而是你 prompt 写的不行”这种观点(我试了一圈,糟糕 prompt 的理解能力,感觉国内模型和 GPT4 的差距挺大的)。

鲁棒性数据可以直接从线上拉取,也可以手动修改原本的 prompt。切记给这类数据打上一个专属标签,千万别让新人看见之后直接给当成脏数据给清洗了

  • 专项数据

所谓专项数据,也就是我们老生常谈的 RAG、长文本、Agent、复杂指令、function_call 等 sft 数据。这些 sft 的进阶任务,在训练上几乎没有任何额外的技巧(除了长文本训练要学会变 rope 基底和 sequence_parallel),它们所有的工作难点一半在数据生产,另一半在工程开发,而后者和算法同学也没啥太大关系。

既然这些专项都是数据工程,那就不要把它们想的那么高大上,大胆的尝试吧。这里我针对每个专项简单介绍两句它们是什么,如果想更深入的了解还需要去实操,遇到几次瓶颈也就会了。

  • RAG

rag 的核心工作在于建库,知识库检索的准确性决定了这个工作的上限。此外,rag 需要外挂两个模型:

  • 知识 / 聊天二分类模型,用于判别该不该做 rag。不要纠结说自己的模型知道世界最高山是什么,这个知识不用做 rag。你根本没办法测出来哪些知识是模型具备且正确的,所以是知识问题就必须做 rag;

  • 传统的 IR 模型,快速从库里面进行检索出候选候选文档,没太多说的,老 NLP 技能了。

rag 的训练 sft 数据构造主要有几个细节需要留意:

  1. 检索内容为空的时候模型会怎么回复,别让它自由发挥出一些奇怪的结果;

  2. 检索内容相互矛盾的情况,别让他只盯着第一条 / 最后一条的内容回复;

  3. 检索内容和 query 完全无关的情况,也是需要让模型见过,防止出奇怪的结果;

  4. 检索内容错了。那就让模型照着错的答案念,千万别想着让模型自己判断 rag 的知识和自己的知识谁更正确。我们做 rag 的大前提就是默认“数据库知识准确率高于模型自己具备的知识”。这种取巧心理很容易把模型搞迷糊,到时候模型不 follow rag 内容就麻烦了。

Agent / function_call

我个人喜欢把 agent 和 function_call 理解为同一个东西,后者是前者的主要实现形式。实现起来真的也没什么复杂的,就是在 system 里加上一句这样的表达:“遇见数值计算任务你就输出 <special_token> + 调计算机”,然后再构造类似的数据就行了。

也就是说,在我们的训练数据里,除了有system,user,assistant 之外,还要有“调计算机”,“计算机返回结果”这两个轮次。如果需要其他的 function,就在 system 里写,在训练数据里补充对应样本。

我们团队的 agent 主要技术负责人,绝大多数时间的工作就是在培训数据标注同学,不是在对标准就是在对 case。可以说,数据就是 agent 最核心的内容。

我之前咨询过这个 agent 负责人,他给我说刘知远老师的面壁智能团队,是这个方向的主要贡献者,感兴趣的同学可以去留意下他们的工作,有几篇蛮经典的论文。

  • 长文本

字面意思,模型能处理的 token 可长了,那怎么实现呢?

训练上利用 ntk 外推:增大 rope 基底,找一些长文本数据让模型适应新基底(我一直不太理解,为什么只需要一点点语料就能让模型适应新的基底,有没有数学大佬看见了科普一下)。因为 attention 的计算量和数据长度呈平方关系,所以显存会不够用,训练的时候要使用 sequence_parallel 技术。

数据的话,想方设法构造长文本理解数据吧,不能是短数据 concat 的,前面说了模型有偷懒倾向,所以我们需要让模型不知道候选答案会出现在 200K prompt 的哪个位置。paper 数据,书籍数据,甚至是 RAG 数据都是比较好的长文本数据胚子。

  • 简单的长文本任务就是“背密码”,随机把密码插入到 200K 文本的任意一个位置让模型来复述;

  • 复杂点的长文本任务就可以是让模型概括 paper 的 instruction 内容,让模型列举出所有林黛玉出场的章节;

  • 挑战性的任务则直接让模型去算林黛玉的出场次数。

长文本往深了做是很有学问的,和 agent 一样有很多工程上的工作,我没做过就不瞎分析了。kimi 在这一方向上做得不错,同为清华系,希望它能和面壁 / 智谱一样大方点,多纰漏点技术细节吧。

复杂指令

复杂指令通常指的是“prompt 里包含了非常多的限制”:既要不少于多少字,又要在什么什么场景下插入 emoji,说话还要押韵,时不时还要模型自己输出一些 special_token ……

具体怎么做,说实话我也不是特别知道,我就是一股脑的堆 sft 数据,但我之前的一个同事好像在用 rlhf 来解决这个问题。这个怎么说呢,目前的技术路线还不明晰,我个人觉着做好复杂指令必须要借助 cot 或者自我纠错能力。模型在 next_token_prediction 的时候,很难找到一个完美 token 满足所有的限制条件,所以要么提前 cot 打好草稿,要么让模型意识到“已经输出的结果不可能再满足某个限制条件”时进行纠错。所以我觉着,o1 的技术路线可能就是复杂指令的正确解法。

在 o1 的技术路线成熟之前,我觉着硬造 sft / rlhf 的数据,应该也能凑合着应付大多数用户需求。这里分享一个造数据的小技巧:先射箭,再画靶

意思就是:你搞了一个很复杂的 prompt ,但即使是 GPT4 的回复,也没有 follow 所有的限制,那怎么办呢?重写答案是很麻烦的,所以就直接去修改 prompt。原本的 prompt 要求模型输出不少于 200 字,实际上只输出了 189 个字,那就把 prompt 改成不少于 180 字(很多复杂指令模型模型根本无法精准回复,限制输出字数这种本来就是学个大概,没必要特别认真,改改 prompt 凑合着用就行了)。

训练篇

先强调一下,这一篇章中我不讨论 lora 和各种 sft 的训练变种,我只聊最朴素的 sft。

我理解 lora 的出现就是为了省显存,在有算力做全参训练的情况下,似乎没啥优点,可能能防止过拟合?那我少训点数据,或者开 dropout ,调学习率也能防止过拟合呀,我在实际工作中几乎没用过 lora,身边同事也不怎么用。

至于针对 sft 做的各种训练方法的变种,比如蒸馏训练,我的观点是普适性不强,并不适合所有的模型和场景。不如让子弹再飞一会,究竟是旁门左道,还是像 DPO 这种真心不怕火炼的万金油工作,现在还没有定数。不搞学术研究的话,老老实实的用最朴素的 pretrain_gpt.py 的训练方式做 sft 又不是训不出来,没必要做过于激进的尝试。

还有一些只更新部分模型参数的训练,比如某些层微调 / embedding 微调 / attention 微调 / mlp 微调。做这种实验一般都有特殊的需求,实现起来也很容易,把不想更新的参数冻结即可,我也不专项讨论了。

训练框架

不同于 pretrain 阶段我力挺使用 megatron,我反倒觉着 sft 阶段用 deepspeed 挺好的。

由于 sft 的训练语料不是很多,使用 deepspeed / megatron 的训练代码都可以,速度性能上的差异也就是带来一个小时左右的时间,无伤大雅。deepspeed 在 sft 阶段的优点主要有:

  • alignment 的很多开源工作和开源代码都是基于 deepspeed 实现的,复现起来省事;

  • 利用 AutoModelForCausalLM 可以直接训起来大多数开源模型,而不需要每次都 trans_hf_to_megatron;

  • 训出来的模型可以直接起 tgi 服务,vllm推理等,用 megatron 的话还得 trans_megatron_to_hf。

不管使用哪个框架,有几个参数是着重需要注意的,每次启动训练前都要看一遍怎么设置的:

  • epoch

  • gradient_accumulation_steps

  • global_batch_size ( megatron 的参数,deepspeed 同学可无视)

  • learning_rate

  • lr_scheduler_type

  • dropout

有几个参数你需要需要知道为什么要打开或者设置成这个值,它们会直接影响训练速度:

  • zero_stage

  • max_seq_len

  • offload

  • gradient_checkpointing

  • seq_parallel_size

有几个参数对模型的训练影响不是那么大,但必须知道它们是什么意义,对其他参数有什么影响:

  • weight_decay

  • per_device_train_batch_size

  • num_warmup_steps

每个参数的具体含义我就不想分析了,查查官方教程,或者问问 ChatGPT 都会有明确答案。如果连这种最基础的东西都不愿意自己去调研一下的话,那可能不适合 LLM 这个方向。

  • 炼丹技巧

接下来讨论炼丹技巧。其实也没啥说的,翻来覆去那几句话:小模型大学习率,大模型小学习率,epoch 基本上就是 1~3个,数据是 10W 级别左右,稍点多点都行,但少不能少于 1W,多也不能到达 100W (没有理论,数据量是偏经验的一些结论);起始训练适当做点 warmup,几种主流的 lr_scheduler 都试一下,gradient_accumulation_steps 是个比较重要的参数,就 16 / 32 / 64 / 128 等数字都尝试下;按需求开 dropout,这东西不开没啥大事,开了反倒容易训炸。

其他的好像真没更多东西了,不过 loss 曲线倒是需要重点留意几个细节:

  • 不同 task_type 要有不同的 channel_loss,分别观察;

  • special_token 的 loss 一开始会有点高,但是下降也是很快;

  • 创作类任务的 loss 会比其他任务的 loss 更高一点,这个现象很 make sense,答案越固定,搜索结果越单一的语料,loss 越低,反之亦然;

  • 只要训练语料是通用数据,且数据进行了 sample,那么模型的初始 loss 就不会特别高,7B / 13B 可能在 2 左右,数据难了也有可能到 3,72B 则大多在 1 ~ 2 之间这个水平状态;最终 loss 则大概在 0.5 左右,根据语言模型定义,如果 loss 更低,那基本上模型只会说这一句话了,别的 token 都没概率了;

  • 如果 loss 持续升高,不要对自己的训练数据产生任何质疑,想着是不是数据太难了不好学之类,这就是训练代码有问题。next_token_prediction 的训练方式就是在背书,它不存在学不会的情况,只存在学会了但不会泛化的情况。你就算是一堆随机乱码,模型啥也学不到,它也应该是 loss 持平,而不是 loss 升高啊。

哦对,关于 loss 还有一个非常关键的地方:阶梯形 loss。很多文章都分析过这个现象,基本公认的结论就是这是模型过拟合的体现,我在这里不过多分析了。

我想说的点是:sft 过拟合真的就是坏事吗?我们都知道 sft 阶段就是让模型学会指令 follow 能力,而 follow 指令的直观表现就是对问答格式过拟合,“模型开始回答 question,而不是续写 question”。因此,我认为 sft 过拟合并不是一个坏现象,至少格式过拟合肯定不是,我们怕的是模型对 answer 内容过拟合,不管什么 question 都只会车轱辘的重复一个 answer。

这里我举个例子,question:“以 json 格式输出 XXXX”

  • answer1: ```json……

  • answer2 : 好的,这就为你用 json 输出结果 ```json……

answer1 就是对格式过拟合,如果模型被训的失去了回复 answer2 的能力,这是坏事吗?显然不是。

所以,阶梯形 loss 毕竟只是过拟合的一种体现,可能是格式过拟合了,也可能是 answer 过拟合了,观察到这种现象是合理的,并不代表模型训得过火了。我们需要的是理解为什么有阶梯形 loss,而不是花精力去研究如何避免阶梯形 loss。

提到阶梯形 loss ,还涉及到一个话题是模型该训几个 epoch。我个人青睐 3 这个数字,很多技术报告(比如 llama 和 qwen)青睐 2 这个数字,还有很多人认为 1 个 epoch 学一下格式就够了。我的饭搭子更极端,他认为应该是 1.1 个 epoch 这种数字,因为训练初期,模型重点关注了对 special_token 语义和输出格式的学习,answer 反倒学的不充分,所以训练初期的 10% 语料应该再被训一次。

怎么说呢,多实践,think is easy,show me the experiment result。不同的 base 模型,不同的通用数据,可能适合不同的 epoch,多点实验,少点脑洞。有时间分析这么细就 1 / 2 / 3 都对比试一下,没时间分析这么细就认准一个喜欢的 epoch 去调整其他参数。

  • 拟合问题

现在忘掉上面介绍的 sft 格式过拟合这个概念,我们讨论下“sft 欠拟合 / 过拟合”这两个常见问题。

  • 欠拟合

欠拟合,字面意思,模型没学好训练数据,做下游任务的能力很差,一般只表现在某个 task_type 上。导致欠拟合的因素有很多:训练集里的相关 task_type 的数据量与数据质量,task_type 的难易程度,模型本身的 size 和能力,epoch、学习率、以及 gradient_accumulation_steps 的设置,甚至是和训练方式有关。

欠拟合首先要确定一个问题,是真的连训练数据都没学会,还是说学会了训练数据但无法进行泛化。测试方法也很简单,直接让模型回答训练集,如果这个都答不上来,那就是没学会,再多训 1 个 epoch,多补充一些 task_type 的训练数据,学习率适当调整等方法均可以解决这个问题(我问了 ChatGPT_o1 如何调整学习率,它的答案是数据欠拟合时,观察 loss 曲线和梯度,如果 loss 下降缓慢就增大学习率跳出局部最小值,如果 loss 比较震荡学习很困难就减小学习率提高训练稳定性)。要是反复训都学不会训练数据的话,请 review 训练代码。

接着,如果说训练集学会了,测试集完全不会(注意,是测试集完全不会,而不是测试集老输出训练集的 answer,后者大概率是过拟合了),那相对麻烦一些。因为我们需要确定模型能不能做好这个 task_type,还是说 answer 是否干净或 answer 的表达方式是否合理。

模型能不能做好这个 task_type 需要结合一些主观判断,我们需要知道什么任务是难任务,诸如复杂指令、逻辑推理、数学计算等任务,很可能这个 size 的模型压根就不行。这时候一个比较常用的方法是看别人的技术报告,下载同 size 的开源模型测能力,多和同行交流经验(这种经验大多都属于可以分享的话题)。此外,别人家的孩子成绩好可不代表自家孩子成绩好,很可能自己的模型就是在 pretrain 阶段欠缺某 task_type 的基础知识,还是那个例子,pretrain 阶段没学过唐诗宋词,sft 阶段训再多作诗数据也没用啊。如果不知道 pretrain 阶段模型的学习情况,那就多让它续写一下这个 task_type 相关的语料,看看掌握程度,最好也和开源的 base 模型对比一下。

假设模型具备这个 task_type 的基础知识,且该任务也不算很难,那大概率就是数据的问题了。抽样一些 answer 去 check 质量,不仅仅是回答的是否准确,还要看回答是否符合语言模型。这里一定要读慢一点,一个字一个字的读,看看 answer 里是不是每一句都有逻辑 —— 很多话只是读着通顺但表达很差劲,人能理解并不代表模型能理解。典型例子就是倒装句,根本不符合语言模型,模型很难学会。check 完之后,觉着数据质量也挺 OK 的话,那就只能再多造点训练数据了,或者上采样一下已有的训练数据(别的数据训 3 遍,这个数据训 6 遍)。

至此,如果还是解决不了欠拟合问题,那就只剩一个杀手锏了:重写 prompt。意思就是,把这个 task_type 里的 answer 的知识尽可能的削减,prompt 里的背景知识则尽可能的增加,甚至可以把 answer 中常用到的 token 都以某种表达方式写进 prompt 里,提高生成这些 token 的概率。也就是说,我们要想方设法的减少这个任务的复杂程度,把一道高考考题改成中考难度,模型总该会了吧。前文提到的拆解 prompt 也算是重写 prompt 的一种情况。

欠拟合还有一个可能性,那就是训练方式,这里分享一下去年的一次 sft 经验。

去年在 llama2 的技术报告刚发布的时候,meta 说他们的 sft 阶段和 pretrain 阶段一样把数据 concat 成 4K 长度做训练,且不做 attention_mask。我们觉着很合理,不同的数据 concat 在一起,可以培养模型学习 session 是否和当前 query 有关的能力,因此就去复现了,结果发现“测试数据中,分类任务的 ACC 有明显下降”。

经过一通分析后,我们发现,新的训练方式改变了短 answer 数据的 loss 占比,毕竟模型在计算 loss 的时候,是先算一个句子内每个 token 的 平均 loss,再算一个 batch_size 内的平均 loss。

分类任务的 answer 通常只有 1 个 token:不 concat 的时候,它的 loss 贡献就是 1 / batch_size;concat 的时候,它就需要先和别的 answer 的 token 算平均 loss,再贡献 1 / batch_size。

这也就是说,采用 llama2 提到的 先 concat 语料再做 sft 训练,会对短 answer 数据很不公平,也就更容易造成短 answer 数据的欠拟合,pretrain 由于所有 token 都算 loss 则没有这个现象。最终,我们通过上采样短 answer 数据,成功的避免了分类任务的效果下滑。

  • 过拟合

过拟合也是字面意思,模型训得太狠了,对训练集里面的 answer / pattern 记得太死了。相比于欠拟合,过拟合则好解决很多,至少模型已经具备了这个 task_type 的能力,只不过是能力被限制在一些 token 或者一些 pattern 上了而已,想办法缓解即可。

sft 的过拟合并不像传统深度学习一样,通过调整训练 epoch、学习率、dropout、weight_decay 来解决。因为大概率模型只是某项能力局部过拟合了,大部分能力都是正常的,盲目调整超参数反倒会让模型整体上欠拟合。

具体地,在确定模型并没有全局过拟合之后(如果是全局过拟合,模型整体的效果应该都很差劲,那就通过炼丹来解决,这里不赘述了),我们主要的解决方案是通过优化训来数据来缓解过拟合,主要措施是删减对应 task_type 的数据,或是扩充该 task_type 的数据多样性

过拟合的难点是让模型暴露出来它到底对什么过拟合了,好让我们去 grep 对应的训练数据来做修改。通常,我们观察到模型过拟合是因为它回答错了某个知识,而且是非常蠢的错误:比如日本的首都是北京。

  • 首先,通过让 base 模型续写,判断是不是 pretrain 阶段训错了(通常情况下都不是),如果是的话,那没辙了,强行在 sft 阶段做知识注入来扭转 pretrain 的错误知识吧,一两条语料影响应该不会很大;

  • 然后,判断 sft 模型对哪个 pattern 过拟合了,对 answer 里面的核心关键词进行测试,也就是“日本”,“首都”,“北京”。对着我们的模型发出一连串的提问,美国的首都是哪里?日本最大的城市是哪里?北京是谁的首都?日本的首都是北京吗?日本的首都是东京吗?……目的很简单,测试出来模型到底对哪个 token,哪个 pattern 过拟合了,到底是把日本的所有城市都回答成北京,还是把所有国家的首都都回答成北京。进而 grep,大概率类似 pattern 的语料都过多了。这时候或删除该 pattern 的语料,或改造该 pattern 的语料,都无所谓了。

  • 小结

拟合问题应该是 sft 工作者遇见的最多的问题,每天都在过拟合和欠拟合上左右横跳,说再多 debug 技巧也总有查不出来问题的时候。

这时候没有更好的办法,唯有多尝试,让 pretrain 模型续写,让 sft 模型回复,观察模型从哪个 token 开始回答错误,观察模型是格式不 follow 还是内容错误,测试训练集中的 prompt,测试和训练集相似的 prompt ……

多试 case,多 grep 训练数据,少在 debug 的时候看可解释性论文,便能找到模型不拟合的原因。

夹逼准则

夹逼准则是我的饭搭子提出的,当时我们在做领域模型后训练,我的各种数据配比都会使模型通用能力掉点,他说了句很经典的话:“学习率等于零模型的通用能力是不是就不会掉点,那根据夹逼准则,你肯定能找到一个合适的学习率让模型的通用能力只下降一点点(已知模型不掉点是一个我能达到的上限)。”

本质上,夹逼准则是在强调一件事情:确定模型可达到的效果上限在哪里,进而定位训练问题

  • 经验分享

我的同事和我讨论过 continue-pretrain 训练领域数据的时候的一个奇怪现象:一份考试类型的数据被划分成 train / test 集合,然后先把 train 集合混合了 common 数据做了 pretrain,一段时间后,test 集合的 loss 明显下降,但是模型在 test 集合上的做题得分却没有任何变化。

背景就是这么个背景,我这里不会再说后续,只讨论下如何利用夹逼准则去分析这个现象。

第一步:test 集合是一个考试类型的题目,大部分内容是选择题,也就是说 question 的长度会远远大于 answer 的长度,毕竟选择题的 answer 只是一个 token。那么,观察到的 loss 下降,大概率是模型会背下来这道题目造成的,而不是模型提高了 Prob( answer | question) 造成的。也就是说,不能因为 test 集合 的 loss 下降了就轻易下结论说是模型学会了这个知识。所以,我建议不直接看在 test 集合上的做题能力,而是看 sum[ Prob( answer | question)] 这个整体概率是否提升,因为模型可能是量变还未引起质变的情况,如果这个值在提高那就继续往下训,没到火候而已(当时,这个“概率和”好像也没啥变化,所以要继续往下分析)。

第二步:既然是混合了 common 数据的 continue-pretrain 训练任务,那就说明领域知识是以一定的比例在参与训练。现在模型没有学会这个知识,可能的原因有多个:训得太少了,学习率设置的不够合理,领域数据质量太差根本学不到知识等。这时候我们逐一分析,先利用夹逼准则来判断数据质量是否存在问题

我建议同事基于 base 模型做一个 100% 领域集合的训练任务,亦或者是直接拿领域数据做 sft 训练。这两种方法都突出一个“放弃 common 能力,只学领域模型数据”,因此训出来的模型就是“当前模型 + 当前领域数据”所能拿到的最好的效果。如果这样子训练出来的模型都无法提升 test 集合上的做题能力,那这份数据估计是废了,大概率选项都是随便给的。但换句话说,如果这样子训出来的模型大幅度提升了在 test 集合上的做题能力,也可以基本敲定数据没问题,该研究超参数了。或调大学习率、或增加领域数据的占比,等等方案均需要尝试。

  • 实战思路

前面提到的拟合问题,就可以用夹逼准则来进行问题定位。

我们不知道模型没拟合是数据问题还是训练问题,那就只训目标 task_type 的数据,并且训 10 个 epoch,总该学会了吧。如果这个都学不会,说明“这个模型 + 这份数据”的上限就是学不会,那也就不要再为难模型了。

如果模型学会了,那就说明“能学会”这个上限是可触及的。把训练配置改成训 5 个 epoch,改成 task_type 数据只占比 50%,一点一点的往下采样,逐渐的去找到那个“临界值”:训练量小于这临界值模型就学不会,大于这个临界值模型就能学会。当问题分析到这一步,基本也就能定位到不拟合的原因了。

  • 评估篇

  • 评估方式

首先,我们肯定要事先准备好一个高质量的评测集合,这个评测集合要和 sft 训练集合一样有明确的 task_type,在这份评测集合上的“模型可用性”就是模型能否上线的标准。

不同于 pretrain 的评估只需要看知识能力,sft 的评估是需要看经典的 3H 原则的:Helpfulness、Honesty、Harmlessness。当然,实际工作的评估中,倒也不必完全是按照这三个原则进行评估,可以按需求制定自己模型的指标:是否 follow 指令,是否 system 穿透,是否内容准确,是否产生幻觉,是否安全……等等等等。

评估的时候,每个维度都要有一个单独的得分,最后根据自己制定的加权公式来确定这条回复的可用性。这样,当模型在某个 case 上的得分变低的时候,我们能比较直观的看出来到底是哪个维度变差了,好结合训练数据做 case 分析。这里需要提醒一句,做评估的时候,一定要了解自己的整个 LLM 系统。大部分公司的 LLM_engine,前置黄反拦截是很严格的,这种情况下你天天盯着裸模型的安全性做评估,就完全没意义。有 RAG 的系统,过度关注裸模型幻觉也同理。我们需要了解哪些是 LLM_engine 无法兜底的内容,那才是要重点评估的能力项

至于评估方式的话,目前基本是两种:

机评:在利用 GPT4 进行评估的时候,prompt 一定要好好揣摩。大模型的评估是有明显偏好的,对比评估的时候,A 和 B 倾向于选 A,长句子和短句子倾向于选长的;打分评估的时候,一个正确的 answer 让模型打分三次,很可能就分别是 3、4、5分(假设是 1—5 分的候选区间)。

关于如何写机评 prompt,我这里也没太多可分享的,只能说多做尝试。Alignbench 的评估 prompt 就写的挺好的,可以进行仿照。对了,给模型一个参考答案很多时候能让模型的打分更准一些,毕竟模型会考虑“候选 answer”和 “gold answer”的相似度来进行打分。

人评:我估计人评是应该是各 LLM 公司所使用的主流方案,它并没有大多数人想象中的又贵又笨重(人评大概率是没 o1 评估贵的)。我们其实要明白一点,人是有长记忆的,人的 kv_cache 可是能保存好多天的。sft 的评估环节最耗时的其实不是看 answer,而是看 prompt,如果我们对 prompt 非常熟悉,我们就知道模型的 answer 容易犯什么错,关注重点是什么,效率就会很高。

所以,只要不天天改评测集合,那么评估同学在做了两次评估之后,基本就把 prompt 记在心里了,这时候的评估效率就嘎嘎的高,基本上速度是不逊于机评的。

  • 评估分析

sft 的评估结果分析其实就是做 case 分析,大部分的方法我都在训练篇的分析模型不拟合时已经提过了,这里我只再额外强调几个细节。

sft 的评估一定是对比评估!这句话的意思是,你直接看模型在测试集上的表现是没有指导意义的,你一定是和上一版模型的测试集合并在一起来观察的。当你观察到 sft_v2 相比于 sft_v1,A 能力项的平均得分有明显下降,就去回想一下这一版训练数据里 A 能力项的数据分布是否有变化:

  • 有可能是 A 能力项的训练数据被更新了,质量不如上一版数据,需要重新 review;

  • 也有可能是整体的训练数据变多了,导致 A 能力项的数据占比变小了,有点欠拟合了;

  • 还有可能是这次训练的某些参数没配好;

  • 当然,最有可能的一种情况是“波动”,这时候,你就重点挑两个 sft_v2 回答很差劲的 case,然后用 do_sample 的方式让 sft_v2 与 sft_v1 都预测个十次。我们看一下,是不是 path@10 也变差了,如果不是,那就不用很在意了。波动经常会影响 path@1 的结果,但对 path@N 的结果往往影响不大,训练方式才会让后者发生较大的改变。

总结下来就是,每次拿到新的评估结果后,看一下模型在各个 task_type 上的可用性得分。得分明显变好和得分明显变差的 case,我们都要去反推训练数据和训练参数有什么变化。得分变好的情况就是我们以后洗数据的好经验,得分变差的情况就是我们下一版优化的目标,反复迭代即可。


  • 总结篇

结语

在大模型浪潮初期,我和我的前辈曾经有过一段对话。

  • 我:这工作(某个方向的 sft)交给我合适吗,我能胜任吗?

  • 前辈:这工作谁都能做。

  • 我:那你为啥选我来做?

  • 前辈:这不是因为我认识你,跟你熟悉嘛。

  • 我:……

  • 前辈:你做不做,不做有的是人想做,不行我招个实习生来做。

  • 我:我做我做,我当然做。

谨以这段对话表明一个观点,sft 真的很简单,它没有传统 NLP 任务中经常涉及到的:“训练代码开发、输入特征设计、网络结构魔改、模型不收敛、训练数据难以构造、千万条级别的训练语料处理、训练语料去重/平滑/采样、长尾语料设计” …… 因此不要有任何畏难情绪,每一个有计算机基础的新人,只要态度端正,都能快速胜任 sft 这个工作。

大模型这波技术浪潮,拼的不是谁代码写得好,拼的是谁有算力、谁有训练经验、谁有 debug 能力。当算法新人还在纠结模型这个 case 为什么回答不对、那个 case 为什么瞎说的时候,经验丰富的人看眼 case 就知道:训练数据有猫腻 / 训练数据缺少某个能力项 / 这 case 为什么没触发 RAG / 有黄反拦截倒也不必纠结。总之,多了解自己的 base 模型的能力,多培养训练 feel,就是做好 sft 的重要法门。

自动驾驶之星和生成式AI与具身智能知识星球,新人优惠券来袭,结识一群志同道合的小伙伴一起成长。

下一个风口会不会是生成式AI 与具身智能的时代,我们特意创建了生成式AI与具身智能交流社区,关于大模型,机器人的相关业界动态,学术方向,技术解读等等都会在社区与大家交流,欢迎感兴趣的同学加入我们(备注具身智能)!   

自动驾驶之星知识星球主打自动驾驶量产全技术栈学习,并包括: 学习板块,求职面试,有问必答,论文速递,行业动态五大板块!星球内部包括端到端大模型,VLM大模型,BEV 障碍物/车道线/Occ 等的学习资料!

生成式AI与具身智能知识星球,我们相信生成式AI 与具身智能会碰撞出出乎我们意料的内容,本知识形象并包括: 学习板块,求职面试,有问必答,论文速递,行业动态五大板块!星球内部包括生成式AI大模型,具身智能,业界资料整理等的学习资料!

自动驾驶之星是面向自动驾驶&智能座舱量产向相关的交流社区,欢迎大家添加小助手加入我们的交流群里,这里有一批奋斗在量产第一线的小伙伴等你的加入!

👇点个“赞”和“在看”吧

自动驾驶之星
自动驾驶之星,是一个以自动驾驶\x26amp;智能座舱量产交流为主的社区。这里有自动驾驶\x26amp;智能座舱量产第一线的前沿动态,有一群奋斗在自动驾驶\x26amp;智能座舱量产第一线的小伙伴在分享他们的量产经历。期待你的加入!希望每个人在这个浪潮中都能成为自动驾驶之星!
 最新文章