Rho-1:基于选择token建模的预训练方法

科技   2024-12-15 12:43   北京  


(本文阅读时间:9分钟)


在自然语言处理领域,预训练语言模型常因大规模噪声数据而面临挑战。对此,微软亚洲研究院的研究员们提出了一种新型的基于选择 token 建模的预训练方法。该方法通过选择性语言建模(Selective Language Modeling, SLM)策略,精准筛选出对模型训练有价值的 token,有效提升了数据效率和模型性能。这一突破不仅优化了模型训练过程,也为自然语言处理技术的进一步发展提供了新思路。本篇论文在 NeurIPS 2024 上荣获最佳论文 Runner-Up 奖。



论文共同作者之一、微软 GenAI 副总裁陈伟柱在微软研究院播客中,分享了本篇论文的技术细节,欢迎收听。


论文链接:

https://arxiv.org/abs/2404.07965

项目链接:

https://github.com/microsoft/rho





传统预训练方法中的token级挑战


现有大模型基于大批量文本语料进行预训练,在各类文本生成、文本理解和文本逻辑推理等任务上表现突出。然而,预训练过程中从各种来源获取的原始语料存在大量噪声,因此科研人员经常采用一些质量过滤方法对原始语料进行过滤,使其可以用于模型预训练。例如,文档级(document-level)过滤可以去除一些干扰文档,进一步地还可以在行级(line-level)过滤单个文档中的噪声,从而得到高质量的语料,用于预训练。


图1:语料清洗示意图


在以往的方法中,过滤出来的高质量语料输入到以因果语言建模(Causal Language Modeling)方式的模型当中,计算每个 token 的损失并平均后求梯度,然后更新模型的参数。然而,当使用这种 next-token prediction 的形式对完整的句子序列进行建模时,可能忽略一些 token 级(token-level)的内容。


比如看到图2具体的案例:“The farm has 35 hens <Apr12 1:24> and 12 pigs. ##davidjl123 says totaling 47 animals.”,其中包含像“<Apr12 1:24>”这样的时间信息,以及“##davidjl123”这样的用户 id。这种 token 级的噪声在语料中较为常见,这些细粒度的噪声很难通过以前采用的文档级和行级过滤察觉。即使语料质量较高,中间仍然可能存在一些高度不确定性的 token,例如在图2的案例中,通过“12 pigs”和“35 hens”可以推测农场里一共有“47 animals”,但是在没有先验条件和上下文的前提下,让语言模型学习准确预测农场里有几只猪是困难的。


图2:Token 级噪声样例图


然而这些高度不确定的 token 和噪声 token 在因果语言建模中没有区分,以平均的权重参与到最终的模型更新中,会使语言模型感到困惑。


选择性语言建模(SLM)方法


为了进一步研究 token 级对于模型训练的影响,研究员们对语言模型预训练过程做了 token 损失(token loss)的动态分析。研究员们使用了 15B 的 OpenWebMath 语料来训练 Tinyllama-1B 模型,而且在每训练 1B 的 token 后于验证集上评估所有 token 的损失。通过获取所有检查点在验证集上的 token 损失数据,研究员们为验证集中的每个 token 拟合了损失的变化趋势,并重点关注训练初期和末期的 token 损失,以及训练前后的损失差值。基于训练前后损失差值和整体token平均的损失,研究员们将验证集中的 token 分为四类:H→H、L→H、H→L、L→L,分别代表 token 损失在训练过程中动态变化的趋势。(具体的划分依据如图3所示。)


图3:Token 类别划分依据


H→H 代表一直保持较高损失的 token,L→L 则一直保持较低的损失。L→H 代表损失上升的 token,而 H→L 是最常见的变化趋势——token 损失下降。从图4(a)中可以看到,仅26%的 token 属于 H→L 类别,大多数 token 的损失变化不大,甚至有12%的 token 的损失呈现上升趋势。当研究员们随机采样 H→H 和 L→L 类别的 token 并单独观察其损失曲线(图4(b)和图4(c))时,可以发现这些 token 在训练过程中处于反复波动的状态,可能影响模型的收敛速度。因此,研究员们认为如果有一种方法可以合理选择适合学习且更有用的 token,让其参与训练,将可以减少噪声并提升模型的数据效率。


图4:Token 级损失的动态示意图


基于此,研究员们提出了选择性语言建模(Selective Language Modeling, SLM),在保证原有输入序列的情况下,通过在损失端裁剪模型所需的 token 损失,来选择有用的 token,如图5所示。


图5:选择性语言建模示意图


如何选择有用的token?


首先,需要有一个高质量的语料库。第一步,使用传统的因果语言建模损失在高质量的语料库上训练一个参考模型(reference model)来建模高质量 token 的分布。第二步,用训练好的参考模型在离线阶段对预训练语料中的每个 token 打分,最终得分由 Token Scoring 公式计算得到。这样的打分方式不会在实际的训练过程中引入额外的时间开销。第三步,用打好分的预训练语料训练模型,对每个 token 的分数排序后,取 topk% 作为选择的 token,对应 Token Selection 的公式。最后,通过 SLM 方式训练的损失公式迭代更新模型。



图6:选择性语言建模流程图


图7显示,在相同的数据集中,通过 SLM 方式训练的模型比直接训练的模型有效提高了数据效率,加快了收敛速度。


图7:选择性语言建模的数据效率


实验结果与应用


研究员们在数学领域进行了实验。实验中,研究员们使用 14B OpenWebMath 语料,分别在 Tinyllama-1B 和 Mistral-7 上继续预训练该模型,并采用 SLM 训练方式,选择比例分别为60%和70%。从图8中的结果可以看到,使用 SLM 训练的 Rho-1 Math 1B 和 7B 模型相较于直接继续预训练,性能分别提升了16%和10%。


图8:数学领域中的 Rho-1 预训练结果


为了进一步验证预训练的结果,研究员们基于上述训练的基础模型进行了推理微调对比实验。图9显示,该模型在数学领域上取得了与 DeepSeekmath7B 相当的成绩,在数学基准上的准确率均高于50%。值得注意的是,在预训练过程中,该模型仅使用了 14B 的 OpenWebMath,远少于 DeepSeekMath 使用的 120B 数学相关语料,进一步证明了使用 SLM 训练的数据效率。


图9:Rho-1 SFT 在数学领域的结果


研究员们也在通用领域(general domain)上进行了类似的实验,采用了包含总计 80B token 的预训练语料对 Tinyllama1B 进行继续预训练。如图10的评估结果所示,Rho-1 在各项通用基准测试中平均提升了约6%。


图10:Rho-1 在通用领域上的结果


同时,本篇论文还探讨了在缺乏高质量语料作为参考的情况下,SLM 是否能够正常运作。如图11所示,可以直接将预训练好的基础模型作为参考模型进行自我参考(self-reference)迭代。在 Tinyllama 上继续预训练以验证自我参考的可行性。在图12中可以看到,仅通过一轮迭代,SLM 就可以显著提升模型性能,这在模型的自我提升方面具有重要意义。


图11:自我参考流程示意图


图12:自我参考结果


总而言之,该研究表明并非所有 token 在语言模型预训练过程中都是同等重要的。通过基于 token 级数据筛选的 SLM 建模方式能够极大提高模型的数据效率。这种 token 级的思路不仅适用于预训练,还可以应用于微调、强化学习、多模态等领域。此外,选择 token 的形式可以多种多样,要根据具体场景及需求确定不同的 token 选择方法。研究员们希望未来能够出现更多有效的 token 选择策略和重新加权策略。

arXiv每日学术速递
工作日更新学术速递!官网www.arxivdaily.com。
 最新文章