今天给大家带来一篇好友知乎@王焱的一篇关于大模型增量预训练的文章。
作者:王焱
知乎:https://zhuanlan.zhihu.com/p/707751901
1 背景
去年,国内大模型赚钱最爽的一个方向,就是卖数据。
我也跟一些卖数据团队咨询过,他们把国内各个你能想到的主流中文平台的数据都爬下来,打包卖了。国内的头部玩家,手头数据是一定不缺的,买就行了。
同时,这些玩家显卡资源管够的情况下,肯定是会把能train的数据都train一轮。除非是预训练数据有大的更新,例如清洗的质量更高了,生成数据有大的突破。或者训练手段有大的迭代,重训大模型的价值是在不断降低的。
但底座模型的通用能力,本身就是有上限的,它就是做不到所有都很强。我们想要把某个领域加强,别的领域就或多或少的会被影响。
从2022年这篇OpenAI这篇论文开始,Training language models to follow instructions with human feedback。Aligntment Tax就一直存在。
但很多场景,例如,教育,代码场景,用户的需求往往比较集中。那么保证通用能力不跌很多的情况下,努力把domain效果提上去就好了。
也就是做continue pretrain(领域大模型)
以及,从反馈来看,如果发现continue pretrain后,domain和通用效果都涨了,大概率是底座通用domain训练的不够充分。除此之外,英文到中文的continue pretrain,例如:把llama增训成中文(国内很多公司的操作,这并不丢人,效果还挺好)、long context的continue pretrain。
近期邀请了张舸、浩然、刘乾等人,关于continue pretrain做了一个小范围分享,具体参看论文。
https://arxiv.org/pdf/2406.01375
https://arxiv.org/abs/2404.03608
以及,我们科研小团队,在long context continue pretrain也有一些踩坑经验。
得到了一些比较有趣的结论,这里把不敏感信息分享给大家。
2 continue pretrain的步骤
continue pretrain的步骤整体分成三步。
2.1 扩词表
不建议去轻易扩词表,满足以下两个条件,可以去尝试。
底座模型的词表跟你的domain的词表分布差距很大,增训的domain语料足够多。
大部分词表都是有基础字的,比如原来 「北京」 -> [12, 15]。扩了词后,现在假设变成了「北京」-> [10233]。这种因为动了高频词,刘乾试过各种warmup,frozen,都是想要有正面作用,需要训练更久的时间。
但多语言的continue pretrain,很多小语种的语料就这么点,还没变正向,样本就用完了。。
还有一种情况,大家可以试试,就是你扩充的都是低频词,原有的高频字/词不被影响。
大家还是选一个词表好的底座模型来做continue pretrain更合适,对比于底座训练不充分,词表的坑更大。
2.2 Domain Continue Pretrain
这里参考了张舸和浩然,刘乾他们两个方向的工作。小马过河,大家自行判断和尝试了。
2.2.1 Replay
需要采样pretrain阶段的数据。
还有一个潜在的坑,现在的pretrain往往会在最后阶段混入sft数据,关于这个的合理性,我在之前的文章中有过讨论。
但现在开源base模型,最多也就开源样本比例。 这些pretrain的模型,最后混了哪些SFT数据来提升某些领域的效果,只能靠经验来反推了。
所以continue pretrain后对比原base掉点严重,可能是你少混了一些SFT数据。
2.2.2 learning rate
Sailor paper发现:保持continual pre-training总token数一样的情况下,原有domain 和新domain 的loss几乎都是可以预测的,本质上是「learning rate学习率」和「replay ratio 重放比例」的一个权衡。在continual pre-training总token一样的情况下,学得越快,忘得越多,即使replay 很多也一样如此。
参看下图,当固定 continual pre-training 训练的总token数时,用各种不同的「英语」、「马来语」比例,以及不同学习率的搭配,训练结束时英语 (a) 和马来语 (b) 的validation loss。
相比于马来语,英语的loss更可预测,用二次项函数能取得 99.36%的相关系数。关键的指标是 log(English Proportion) - log(Learning Rate)。根据这个公式,相对比较好的learning rate是1e-4,相比Qwen原始小不少(假设Qwen遵循Llama类似的learning rate,也就是4e-4)。
但上面这个权衡是由于「固定continual pre-training的总token数」导致的。在实践中,如果计算资源和数据资源都不是问题,那可以尽可能让learning rate小。learning rate保持足够小的情况下,是有可能让原有模型效果跌得不多。(例如中文continue pretrain样本就管够,小语种样本就不够了)
但如果想最高效地利用计算资源,建议在continual pre-training的时候首先确定一个合适的learning rate,且放弃追求原domain无损的想法。根据上面的图(c),我们可以观察到:
● learning rate越小,新domain(图中的SEA language loss)的loss下降速度越慢,学习得越慢,但参考图 (a) 原domain遗忘得也越慢;
● learning rate越大,新domain 的loss下降速度加快,知识学得也很快。
● learning rate涨到一定程度,新domain的loss下降速度会出现震荡。因为知识的学习速度肯定是有上限,且潜在的数据分布差异导致模型不能用太快的速度学习。
● learning rate很大时,参考图(a) 横轴一定左移较多,此时剧烈的attention的分布调整对模型在原domain 上的性能会带来不小的损失。
因此,对于continual pre-training,比较有利于新domain学习且不会对原domain有破坏性伤害的learning rate拐点1e-4,就是一个比较合适的平衡点。Sailor的continual pre-training,从4B到14B的模型都有尝试。实践中推荐大家在continual pre-training前先用少量token,某个固定的重放比例(比如0.5),以及多组随机的learning rate来确认learning rate的平衡点之后,再做更精细的数据配比。
这么推荐主要的原因是learning rate的波动范围比较大,比如从5e-6到4e-4能有近百倍的差异,但replay比例限制在0-1,大家日常也就是0.1这样的幅度去寻找比例。
2.2.3 比例控制
domain数据占比过高,可能loss就直接崩了,占比太低,学习效率低,会导致最后domain提升不大。
麻烦的点,就是如何判断什么比例是最佳的。张舸和浩然的论文发现,continue pretrain阶段,随着domain数据占比的提升,通用loss和domain loss的确是一个此消彼长,然后趋于稳定的过程。
假设通用数据的占比是r,那么domain数据的占比就是1-r,张舸和浩然的论文中给出的关于数据比例的scaling law的公式为
增大domain数据的占比,那么domain loss会降低,通用loss会上升,在拟合好上述公式后,就可以计算不同比例下domain loss和通用loss的预估值
domain数据有了,预训练replay数据有了。在小规模的实验(模型参数量小,训练数据少)下continue pretrain,得到一些实验数据点,用实验数据点拟合上述公式,得到拟合参数值,就可以算更大参数量下的domain loss和通用loss。
多语言的比例控制可以参看刘乾的 RegMix 方法。
在训练Sailor时的一大难点就在于面对7个语言的各种不同来源数据(最后一共有20+不同的domain),怎么平衡原domain(英语和中文)和新domain(各种东南亚语言)的性能。
使用RegMix方法在1e-4下做了数据配比的实验,目标是优化各个语言 loss的logsum,可以让模型达到帕累托最优(没有语言的短板)。最后Sailor使用的重放数据(包括英语和中文)总比例大概是 30%,具体可以见下图:
2.2.4 Scaling Law
我想要去计算,我要训练多少步,loss才会降低到一个不错的级别,还是上面的scalling law
可以计算出大概训练多少步后,domain loss几乎就不会再下降了
问题在于,pretrain都做了的公司,其实continue pretrain的训练,不差这点训练成本,大不了我多train一会。
但对于一些中小厂来讲,最多也就continue pretrain一个7B模型,能大致知道一个节点,价值还是很大的。
scaling law失效的情况。
张舸和浩然是在qwen模型做实验,参数量也没那么大,并且还多了一个ratio的变量,导致公式变得更加复杂,并且还有不同domain的数据质量不一致的情况。
所以,这篇论文更大的意义,在于告诉我们,continue pretrain是能有一个公式来拟合预测loss下降的情况。
但这个公式,可能还是需要大家亲自去拿小模型自己实验一下。
2.2.5 参数
learning rate可以多参考刘乾的论文。
batch size也是调大一些,会有一些不错的效果。
如果有退火,从经验来看,需要把lr涨回去,这个时候loss会有一个相对波动比较大的阶段,但你再观察一段时间会稳定下来。
还是需要具体情况,具体分析,不同的base模型的操作会有区别,大家自行探索。
2.3 Domain对齐
对其是另外的一个世界了,先不展开。
但比例方面,大家往往也是会吧domain的SFT数据比例调的高一点,来保证最后的效果。
3 不同domain的特点
continue pretrain分成三大类,领域知识,语言类,long context
受到词表,知识难度,attention分布的影响,这几类知识的学习都会有不少的差距。其中领域知识增强类的domain更容易学习,因为基座llm中存在这样的知识,所以起始loss会更低,遗忘程度低,最优的配比低。
语言类的domain和long context的数据更难学习,前者是因为语言的gap导致初始loss偏高,但随着不断的训练,loss会稳定下降,但遗忘程度高,最优配比高,后者对资源的消耗更高,遗忘程度高,最优配比高。
3.1 领域知识Continue Pretrain
3.1.1 难点
比例的控制,训练多少tokens可以拿出来做对齐。
这里参考张舸和浩然论文即可。
3.1.2 样本质量的变化会不会导致scalling law公式的变化
张舸和浩然论文的数据,都是基于公开数据集。
但问题在于,日常大家自己训练,肯定会自己再做一轮清洗。样本质量的改变,会不会导致这个公式的指导意义出现大的波动?因为语料质量的提升,可能一条样本顶过去两条。
所以针对不同的领域数据而言,公式中的拟合参数都是不一样的,这里建议在自己的模型上进行实验后,然后确定公式中的拟合参数具体值。
所以,continue pretrain对于头部玩家来讲,意义不是特别大,我完全可以跑10个比例,然后一路二分。
但对中小厂,价值就非常高了。我上一篇预训练文章发出来后,不少人私底下问我,预训练中英文的比例是多少比较合适。
毕竟除了几个头部玩家,中小厂的显卡还是很紧缺的。
3.2 语言类Continue Pretrain
3.2.1 难点
去年大家的常用做法,就是已知llama的中文占比5%,那么我一点点增大中文的训练样本比例。
而不是算好一个比例,直接硬train,这很容易导致loss直接崩掉。
3.2.2 为什么语言的continue pretrain,比例不能剧烈变动?
三点原因
不同的知识,集中在不同的transformer层
之前内部实验,发现transformer越往上,最后一层的知识往往就越具体,越底层的知识反而越基础。
类似cnn做人脸识别,第一层抽取的特征是线条,到了最后一层就变成了鼻子,人脸这些特征。
语义这些知识,是最基础的知识,往往是在最底层,更新起来影响的层数更多。
domain知识是最后几层,更新起来影响的层数相对更小一些。
扩词表
新词的embedding是随机初始化的,是transformer最底层了。同理,影响面更大。
learing rate
不合适的learning rate会导致general能力”受损“。以及learning rate大小带来的影响,跟增训中文,一点点提高中文比例,有点异曲同工。
从刘乾的反馈来看,他们不扩词表,先找到合适的learning rate,再找到合适的比例,直接continue pretrain,loss就能稳定持续下降了
3.3 Long Context Continue Pretrain
3.3.1 continue pretrain学了什么
拿long context举例子,根据我们的一些分析,LLM本身就具有long context的能力,或者说是已经学到了文本的框架。
而之所以外推不好,其中一个猜测就是attention分布导致的。
而long context的continue pretrain某种程度上是让attention分布的调整。
https://arxiv.org/abs/2404.15574
知识的重新学习并不是大头。
而中英文,代码,法律等的continue pretrain。我相信底座模型也是有这样的知识的,他们是不是也是某种attention的调整?
continue pretrain,让底座模型对这块的知识attention更加友好一些?
当然,no free lunch,这种attention调整会带来通用domain的下降(但正如开头所说,只要别跌太狠,这些场景并不是特别care)
从这个角度来看,语言类continue pretrain比领域知识类continue pretrain的attention调整要更难,所以贸然的剧烈样本分布,学崩了也很正常。
3.3.2 long context continue pretrain
fuyao的论文就是把各种短文本拼接成长的。
https://arxiv.org/abs/2402.10171
进阶的做法,我可以做个聚合,把相似的拼接到一起。
3.3.3 问题
我们的科研团队,有一个方向就是做long context 低成本continue pretrain,benchmark效果还可以,论文都投出去了。
但一直没有拿出来讲的一个原因,就是如何评估continue pretrain的效果。
我们continue pretrain完,找了一些评估集合来评估,发现指标都不错。
但效果真的好,还是得SFT后的效果才能说了算。
Long Context的SFT是另外一种难(不是技术上的),也是我们近期要重点解决的问题。
其实中英文和domain知识的continue pretrain也一样,是否真的好,还是得SFT后的效果说了算。
4 Pretrain的展望
这次的分享,包括基于这个分享,跟不少人私底下也有一些讨论。
发现,pretrain这半年比较大的进展都是偏架构方面。例如,MOE,deepseek的MLA(这是一个非常棒的工作和尝试)。这块往往是架构和算法都很牛的人才能做好。deepseek的开源moe,也做得非常不错,应该是国内开源top了,他们的pretrain团队做的挺棒的
但算法为主的,做pretrain,往往就是洗数据了。
尴尬的点是,预训练洗数据,因为数据量大,往往都是搞各种小模型+规则,很难说明你做的事情的技术含量,只能体现你对数据的认知很好。
但随着模型参数量的增大,洗这么干净的数据合理么?模型是不是到了后面,自己就能做一些区分了?做那么多label意义真的大么?
对个人的发展的确不是那么友好,这也是真的。(所以,可以往对其,long context,多模态转移一下,这才是我们算法的主战场