o1复现的一点点心得

科技   2024-12-25 00:01   吉林  


MLNLP社区是国内外知名的机器学习与自然语言处理社区,受众覆盖国内外NLP硕博生、高校老师以及企业研究人员。
社区的愿景是促进国内外自然语言处理,机器学习学术界、产业界和广大爱好者之间的交流和进步,特别是初学者同学们的进步。
转载自 | 知乎
作者 | 皓天

恰逢o3、gemini-flash-thinking版本推出,推理能力增强的模型已经把常见的benchmark刷到了一个遥不可及的结果,比如o3在aime24上直接干到29/30,code-force也刷到超过99.99%的人类。目前,推理模型主要在code、math、arg-agi等等上面有显著提升,而常用的翻译、对话等等场景,可能也用不到这么强的推理能力。可能真正能释放模型推理能力的场景是agent,由推理能力增强模型作为指挥家,调度一群不会思考只会工作的模型,降低推理成本&旧时代模型的合理利用(毕竟,直接把前代模型扔到垃圾桶,还是比较浪费钱)。

不管怎么样,我们也需要在现有资源下,掌握一定的推理模型构建方法论,才能在其他场景包括agent、tool等等场景持续优化和提升模型解决复杂场景问题的能力。本文仅以开源数据和模型,总结了一下最近的一些外部工作,以及自己复现(实际上是蒸馏)的一些初步结果。

外部工作

o1发布后,国内陆续发布了很多类o1模型,比如deepseek-r1、kimi-math、macro-o1、qwq等等。学术界也有诸如[1,2,3]相关的工作。大概总结一下,分为几个派系:

  1. 树搜索派系,主要使用树搜索+multi-agent合成数据

  2. 蒸馏派系,主要通过各种jail-break攻破o1的思维链展示限制、爬deepseek-r1以及使用qwen-qwq刷数据蒸馏。

个人以为,蒸馏派系可以帮助我们更好的了解long-cot的训练方法和效果,而树搜索派系则帮助我们在新场景比如tool、agent场景下快速构造数据。更好的合成数据方法+long-cot的训练效果理解,可以帮助我们快速在新场景、新业务下快速开发推理能力增强的模型。

我们的复现之旅

曾几何时,笔者也是坚定的树搜索派系,从皓天:层次化树搜索皓天:LLM的快思考与慢思考路线之MCTS皓天:再探LLM-MCTS皓天:EBM-based Global Rank+MCTS for COT。从23年8月份开始探讨树搜索在math推理能力上的提升。但当我们引入反思后,self-critic其实对很多模型来说都是一个比较难的任务,从预训练到post-training,很多预料只告诉了我们正确的内容,但没告诉我们作者是如何得到这部分的正确结论的,更别提作者为了得到这个结论,中间经历了哪些失败尝试和总结。尽管预训练规模越来越大,但这部分数据的缺失会极大限制推理能力的提升。[8]则提出使用合成数据,使用模型补全语料中缺失的逻辑内容,当然,也极大提升了模型推理能力(但为啥只推出14B的模型,其实也是个问题)。

树搜索+agent还是需要调很多事情,比如 每个agent的prompt、agent的observation等等,如果之前没积累,速度上会有一定滞后,但对于tool调用的数据合成是比较关键的。(或早或晚都需要,但不影响先做无tool能力的模型)。

所以,笔者也只能先转向数据蒸馏的路线。

由于o1没展示思考过程,o1的蒸馏会有点点麻烦,但对于懂模型攻防的同学,可以使用jailbreak/prompt-hack等手段,破解思考过程输出的限制。当然,在这个时间节点,我们就直接手动爬一些deepseek-r1的文字推理题目,以及使用qwen-qwq刷一部分数据。至少先跑通sft、rl的流程后,再回过头来用树搜索+agent的数据合成方法造新场景、新任务的数据也不迟。

为此,我们基于之前开源的数据[9]尝试做了long-cot的response合成。具体方法还是老一套:基于答案验证的拒绝采样。由于思考过程很长,传统RM打分能力也不太行,暂时就没管思考过程的正确性。

基于数据[9],我们首先使用难度level打分模型对prompt打分,筛选出难度level在7以上的数据。然后,直接qwq暴力采样+答案验证,最终形成了1.3k-prompt,4.4k的prompt-response数据集。这个数据集其实量级很小,但确实很有效,即模型的数据利用效率很高。同时,我们也按照难度从高到低对所有数据刷了一遍qwq的结果,想测试数据scale后,long-cot对模型能力提升的上限。

先说几个结论:

  1. math上只需要难度较高的数据4.4k(1.3k-prompt),其实就足以让合适尺寸模型的math能力吊打对应的instruct版本。

  2. math上单独训的模型,其实在code等等场景也有一定的正向迁移(虽然幅度不够大)。

  3. 如果在instruct上面接着做long-cot,在模糊指令、指令约束比较复杂的场景,可以自动泛化出long-cot的风格、能力(比如 藏尾诗 等等)。但一些场景比如常规问答则继承了instruct的能力(如果后面加上请一步步思考,还是可以激发出 long-cot风格的输出)。

主要展示math上的效果(前期,在7b上面花了很多时间,但涨幅不及预期,扩大数据scale也没达到理想中的涨幅)以下所有评估均使用qwen25-math-evaluation工具包,评估结果可能会和vllm版本相关。

14b模型其实还是有一些收益的,但看着不够显著,尤其是aime24、omni-math上,不如一个math-7b上context扩展+long-cot数据scale的训练。

当我们切换到32b模型后,效果提升会比较显著,不管是从instruct上面直接sft还是base上sft,都有显著增长,只使用了5k-prompt-response(1.3k的prompt)。aime24直接从5/30上涨到10/30,在qwen25-32b-instruct上面long-cot,也能从5/30提升到13/30。

当然,蒸馏数据+拒绝采样,天然可以有偏序,格式正确+答案正确>语种混合+答案正确>答案错误,顺便也调了一个32b-qwq的DPO:总体上略微增长。

同时,也调了一个自己instruct+long-cot的dpo结果,均有一点点正向收益(使用的是qwq采样的偏序,没有跑自己模型采样的偏序)。

从上面可以看到,long-cot貌似数据利用效率很高,只使用1.3k-prompt,5k-qa数据,即可让模型的能力提升较大,尤其是math-hard(从MATH里面挑选level5的全部数据)、aime24、omni-math等等。

而dpo则还能进一步提升qwq的效果,但提升幅度不够看,跑更多epoch、dpo超参数搜索可能也难以再提升20%以上。

相反,7b模型上直接long-cot-sft,则模型效果下降很严重,一个严重问题是模型会不停but、wait、alternate等等,即使对这些词的频次做限制,依旧容易停不下来。如果想7b做的更好,则需要更多的long-cot数据或者continue-pretraining,提升底座模型输出long-text的稳定性。

而32b模型貌似看是一个比较合适的模型尺寸,且仅使用1.3k-math-prompt即可获得超出预期的收益。至于test-time-scaling,目前来看,passk提升显著,比如aime24在rollout-8次的时候,可以解决18-20道题。当然passk的效果越高,policy的rollout次数可以更少但获得更多的正向反馈信号,进一步提升rl的效率。

Scale正确的事情

scale正确的事情在low-level-fruit被吃掉后,其实是一个很重要的问题。合成数据算一个正确的scale,至少phi4、qwen25等技术报告,都证明合成数据对模型推理能力提升很关键,尤其是补全结论的思考过程。但能做好合适的合成数据其实也有一定门槛。

提升推理的数据scale也是一个有前景的方向,1.3k-prompt都有显著提升,增加困难prompt的数量,蒸馏更好的模型 比如gemini-flash-thinking等等,可预期获得更好的提升。且在instruct上面微调,可以天然泛化到很多场景。

RL-Scale本质上和推理数据scale类似,当基座的pass@k提升后,可以使用更少的rollout次数提升policy的RL训练效率。使用更多的困难prompt+gold-rm打分,可以放在数据scale,也可以放在rl-scale,只是数据scale短期有收益且保证拿到结果,RL-scale则与infra相关,infra不太行,RL的效率会低,不一定能在规定时间内拿到结果。(一般来说,短期没希望的事情,也不会有更多资源投入)。

如何远程监督思考过程

process-bench表明,即使一些benchmark上的指标高,但模型过程错误的比例随着问题难度增加而增加。当我们使用long-cot微调的时候,long-cot的错误其实包含几部分:

  1. 模型试错(正确错误)

  2. 模型错误的解答(错误)

可预想到数据scale的核心在思考过程的远程监督。直接看思考过程的问题比如训练RM等等,可能不是一个好主意,越长的输出,稳定识别错误和正确步骤的能力大概率都是下降的。这个时候,一个取巧办法是 让更弱的模型总结思考过程,提炼出一个解答过程。解答过程可以先答案是否正确做筛选。如果多个解答过程的答案都错误,基本上思考过程会错的离谱(即使最终答案正确)。能让学渣看懂的思考过程,才是模型需要的思考过程。

解答过程是否正确则可以使用传统RM进行打分筛选。也可以参考process-bench里面的critic-prompt,用o1-mini打分筛选和标注数据。这样,可以从解答过程的维度对思考过程做远程监督。

总结

推理模型的训练 目前看可能和基座相关,前期使用32b是一个较好的选择。7b、14b会出现停不下来的问题,导致效果骤降。

在math场景的实验可以看到,long-cot其实只需要1.3k-prompt就能达到更强的效果,再使用自己的模型+暴力passk采样更困难的问题,可以逐步迭代数据。在有ground- truth的场景其实都能使用。

思考过程长了会存在更多的过程错误问题。使用弱模型提取思考过程的正确思路解答过程,可能是一个还可以的远程监督方法。帮助我们更好的扩展推理数据的scale。

至于RL-scale,infra好的可以快速尝试,infra不太行的,参考phi4的路线,也能稳扎稳打达到不错的效果。RL-scale的方法选型就比较重要了。比如如果使用ppo,critic在long-text上的value估计效果可能会比较差,如果没有巨量的prompt,大概率需要一个预训练的critic/rm作为初始化。如果有巨量prompt,就变成一个赤裸裸的“text-game”,参考RL在游戏场景的效果,堆算力、堆有gold-rm的prompt,总能把模型推上一个新高度。

至于o3很多榜单都快失效,也给广大科研工作者提出了新的需求:做新的benchmark。其实和评估学生学业能力类似,会评估的人,才能因材施教,做出更好的模型,教出更好的学生。

参考工作

[1] O1 Replication Journey--Part 2: Surpassing O1-preview through Simple Distillation, Big Progress or Bitter Lesson?(https://scholar.google.com/citations?view_op=view_citation&hl=en&user=oIz_CYEAAAAJ&sortby=pubdate&citation_for_view=oIz_CYEAAAAJ:WGv8Og3F3KgC)
[2] O1 Replication Journey: A Strategic Progress Report -- Part 1(https://scholar.google.com/citations?view_op=view_citation&hl=en&user=oIz_CYEAAAAJ&sortby=pubdate&citation_for_view=oIz_CYEAAAAJ:qYOp8iumCsAC)
[3] Imitate, Explore, and Self-Improve: A Reproduction Report on Slow-thinking Reasoning Systems(https://arxiv.org/abs/2412.09413)
[4] 皓天:层次化树搜索(https://zhuanlan.zhihu.com/p/5882307377)
[5] 皓天:再探LLM-MCTS(https://zhuanlan.zhihu.com/p/693374530)
[6] 皓天:LLM的快思考与慢思考路线之MCTS(https://zhuanlan.zhihu.com/p/659230417)
[7] 皓天:EBM-based Global Rank+MCTS for COT(https://zhuanlan.zhihu.com/p/650438958)
[8] Phi-4 Technical ReportPhi-4 Technical Report
[9] https://huggingface.co/datasets(https://huggingface.co/datasets/yingyingzhang/metamath-qwen2-math)


技术交流群邀请函

△长按添加小助手

扫描二维码添加小助手微信

请备注:姓名-学校/公司-研究方向
(如:小张-哈工大-对话系统)
即可申请加入自然语言处理/Pytorch等技术交流群

关于我们

MLNLP 社区是由国内外机器学习与自然语言处理学者联合构建的民间学术社区,目前已经发展为国内外知名的机器学习与自然语言处理社区,旨在促进机器学习,自然语言处理学术界、产业界和广大爱好者之间的进步。
社区可以为相关从业者的深造、就业及研究等方面提供开放交流平台。欢迎大家关注和加入我们。

机器学习算法与自然语言处理
关注AI前沿技术,助力AI学者进步
 最新文章