知乎:姜富春
地址:https://zhuanlan.zhihu.com/p/15540962086
编辑:「深度学习自然语言处理 公众号」,已授权
1.什么是PRM?
随着OpenAI O1、O3陆续推出,包括国内的一些大模型公司也相继推出K1(moonshot),R1(deepseek), .....,让我们看到了大模型能力渐渐从一个只会"舞文弄墨的文科生"渐渐演变成了"逻辑缜密的理科生"。大家对O1的新的研发范式的各种猜想也层出不穷。也涌现出越来越多「新名词」:self-play, test-time scaling law,PRM,MCTS,RFT,......。
前面提到O1是一种新的研发范式,那为什么说是新的研发范式?从O1的技术报告(https://openai.com/index/learning-to-reason-with-llms/)和OpenAI发表的一些工作有迹可循:
从名字上,O1的全名是OpenAI O1, 请注意,并不是GPT-O1。从OpenAI的模型版本控制上可以看出,这是个全新的版本,并不是沿着GPT-3.5、GPT-4版本演进的。 输出相比之前模型有较大变化,输出是个long thought模式,从GPT-4快思考模式(首token 毫秒级别)到O1的慢思考模式(几秒后输出)。OpenAI提供第一个例子(参照给定一个解码示例,对提供的字符串解码:https://openai.com/index/learning-to-reason-with-llms/),如图1所示可以看到,模型先做了一个很长的thought(模型用了5秒钟思考,人阅读一遍5分钟,很长...)然后根据thought的推理过程回答问题,并给出正确的答案和解题过程,这与GPT-4 的推理过程完全不一样。 Open AI 也称做GPT-4 和 O1 是两种计算模式:train-time compute , test-time compute
上文交代了这么多,我也看了很多大佬对O1的技术复现的猜想,但本文并不会讲述O1的复现逻辑(个人水平确实有限)。只针对大家在复现猜想中频繁提到的PRM做些解读。大家如果对O1感兴趣可以参考文末整理的附录材料继续阅读。
回到PRM,什么是PRM?在上一篇文章中也已经介绍过(参考姜富春:OpenRLHF源码解读:理解Reward Model训练过程:https://zhuanlan.zhihu.com/p/14993645091)。这里我们再简单总结下:
PRM(Process-supervised Reward Model)是OpenAI在Let’s Verify Step by Step(https://arxiv.org/pdf/2305.20050)一文中,首次提出的概念。与之相对应的是ORM(Outcome-supervised Reward Model)。PRM和ORM都是奖励模型,两者区别:
PRM:过程奖励模型,是在生成过程中,分步骤,对每一步进行打分,是更细粒度的奖励模型。 ORM:结果奖励模型,是不管推理有多少步,对完整的生成结果进行一次打分,是一个反馈更稀疏的奖励模型。
为了更好的理解PRM,我们先了解下PRM在O1的研发范式下的作用(当然这里说的也都是一些可循的猜想或复现,毕竟没人公开O1的具体实现)。
2.PRM有什么用?
假设我们当前已经有了一个训练好的PRM,那么我们怎么能让模型有long thought能力,进而能有更好的推理能力呢?
首先我们从O1提供的例子能看出,Thought的过程是一个明显有步骤的推理过程。这个步骤包括:
任务规划step:频繁出现"First"、"Second"等 提出假设step:“Alternatively”等 结果反思step:“Hmm”、"wait"等 ......
这些step,也可以被看作是动作空间中的不同的动作(action),PRM的作用可以对这些动作打分,引导模型生成到获得收益最大的路径(也就是正确的解题步骤和正确的答案)。
对于这种多步骤的过程,参考Let’s Verify Step by Step中提供的思路看下PRM的作用:
首先,为了让模型有按步输出的能力,我们先通过一个按步骤回答的指令集,训练一个generator模型,模型不保证步骤一定是正确的,但能遵循指令按step1, step2,...格式输出。 然后,可以对上面的模型做N次采样(如best-of-N , Beam Search, lookahead Search方法等),并通过PRM对每个采样的每步推理做打分 最终,通过对每个步骤的打分拟合一个整体过程的打分,并按该整体打分选取打分最高的结果作为最终的答案。
如图2所示:
上文的思路是一种倾向优化inference阶段,提升答案质量的方式。也就是说在这种方式下,倾向于不在training阶段做更多优化,而是在infer阶段做更多采样,然后通过PRM做为Verifier筛选答案。当然有了PRM还可以做拒绝采样、做RL,继续通过Post-Training 优化generator模型效果,这个过程就类似于ORM在Post-Training阶段的作用,只不过PRM是对过程做监督,来优化目标。
我们上面假设是已经有个训好的PRM,那么如何通过数据驱动,训练一个PRM呢?这里面最关键的是样本的收集过程,我们进一步看看OpenAI怎么做。
3.PRM样本工程
在Let’s Verify Step by Step一文中,OpenAI详细描述了他们收集样本的方法,所有的样本都做了人工标注。最终他们共搜集1.2万个问题,7.5万答题过程(每个问题的回答包括多步),共80万带标记label的解题步骤。OpenAI也详细描述了他们做数据标注的过程。
OpenAI也公开发布了这部分数据集,详见github:PRM800K(https://github.com/openai/prm800k/tree/main)
数据标注过程在paper中,描述共分为两阶段:
阶段1:冷启阶段。当前就只有一个按step格式输出的generator模型,作者提出的标注方法是:通过generator采集初版标注样本,对每个步骤采集多种表述给标注人员做标注(详见下面阶段1样本标注Demo);同时对一条标注数据,如果最终没有给出任何正确的执行步骤,标注人员需要做人工的编辑,产出正确的是step样本。这个过程因为标注人员要对一个条数的的单个步骤标注不同的表述,所以这个阶段是人效比较低的,但通过标注大量的多表述样本,样本的多样性得到了较好的保证,能让奖励模型的多样的、流畅的step by step结果做更好的判别;另外这个标注过程并没有提供给标注人员"书面解题步骤",我理解目的也是提升标注人员的标注多样性和提升标注水平。由于这个过程人效较低,共标注了大概5%的样本规模。 阶段2:通过主动学习(active learning)标注难样本,提升模型对边界样本的学习能力。通过主动学习的方式迭代模型是机器学习中常用高效调优方法。文中作者通过拿当前学习到的最好的PRM,然后对采集的样本进行打分。挑选打分高且最终推理错误的样本做为下一轮标注的样本,作者共做了10轮迭代。这一阶段作者提供了更多的标注信息作为参考,包括"书面的解题步骤",PRM打分的结果等,同时降低了标注的复杂度,要求标注人员标注到第一个错误步骤就终止标注。通过这个阶段有效的提升了样本的标注效率和难样本的质量。
在数据集中,能找到两阶段标注任务的详细的需求文档,详见:instructions_phase_1.pdf(https://github.com/openai/prm800k/blob/main/prm800k/instructions/instructions_phase_1.pdf),instructions_phase_2.pdf(https://github.com/openai/prm800k/blob/main/prm800k/instructions/instructions_phase_2.pdf)。
从PRM800K数据集中采样了两条阶段1和阶段2的数据Demo,如下:
阶段1样本Demo:
只提供问题+答案, 每一步采样多个表述做标注
{
"question": {
"problem": "How many seconds are in 7.8 minutes?",
"ground_truth_answer": "468"
},
"label": {
"steps": [
Object{...},
Object{...},
Object{...},
Object{...},
{
//对每一步采样了多种表述,做标注
"completions": [
{
"text": "Right. Let's check our work. 7.8 minutes is the same as 7 minutes and 0.8 minutes.",
"rating": 0,
"flagged": false
},
{
"text": "Exactly.\n\n# Answer\n\n468",
"rating": 1,
"flagged": false
},
{
"text": "That's correct.\n\n# Answer\n\n468",
"rating": 1,
"flagged": false
},
{
"text": "Correct.\n\n# Answer\n\n468",
"rating": 1,
"flagged": false
},
{
"text": "That's correct.\n\n# Answer\n\n468",
"rating": 1,
"flagged": false
},
{
"text": "Correct.\n\n# Answer\n\n468",
"rating": 1,
"flagged": false
},
{
"text": "That's right!\n\n# Answer\n\n468",
"rating": 1,
"flagged": false
}
],
"human_completion": null,
"chosen_completion": 1
}
],
"total_time": 92405,
"finish_reason": "solution"
}
}
阶段2样本Demo:
提供问题和问题的正确解答过程(ground_truth_solution),同时将主动学习筛选出来的样本的完整步骤提供出来,在标注阶段,只标注到第一步出错的结果("rating":-1)
{
"question": {
"problem": "For how many different digits $n$ is the three-digit number $14n$ divisible by $n$?\n\nNote: $14n$ refers to a three-digit number with the unit digit of $n,$ not the product of $14$ and $n.$",
"ground_truth_solution": "We have to account for each possible value of $n$ here. First of all, we can quickly find that for $n = 1, 2, 5,$ the resulting number $14n$ must be divisible by $n$, using their respective divisibility rules.\n\nWe see that for $n = 3$, we get $143.$ Since $1 + 4 + 3 = 8,$ which is not a multiple of $3,$ we can see that $n = 3$ does not work. Moreover, if $143$ is not divisible by $3$, then $146$ and $149$ are not divisible by $3$ or any multiple of $3$, so $n = 6$ and $n = 9$ do not work.\n\nFor $n = 4$, we can see that $144$ is divisible by $4$ because $44$ is divisible by $4,$ so $n = 4$ works.\n\nFor $n = 7$, we can easily perform division and see that $147$ is divisible by $7,$ so $n = 7$ works.\n\nFor $n = 8$, we have little choice but to find that $\\dfrac{148}{8} = \\dfrac{37}{2},$ and so $n = 8$ does not work.\n\nAll in all, we have that $n$ can be $1,$ $2,$ $4,$ $5,$ or $7,$ so we have $\\boxed{5}$ possible choices for $n$ such that $14n$ is divisible by $n.$",
"ground_truth_answer": "5",
"pre_generated_steps": [
"To find the digits $n$ that make $14n$ divisible by $n,$ I need to find the values of $n$ that satisfy the equation $14n = kn,$ where $k$ is some integer.",
"This equation can be simplified by dividing both sides by $n,$ as long as $n \\neq 0.$",
"I get $14 = k,$ so $k$ must be $14.$",
"This means that $14n$ is divisible by $n$ only when $14n = 14n,$ which is always true.",
"Therefore, any nonzero digit $n$ will make $14n$ divisible by $n.$",
"There are $9$ nonzero digits, so there are $9$ different digits $n$ that satisfy the problem.",
"# Answer\n\n9"
],
"pre_generated_answer": "9",
"pre_generated_verifier_score": 0.0025003258208338673
},
"label": {
"steps": [
{
"completions": [
{
"text": "To find the digits $n$ that make $14n$ divisible by $n,$ I need to find the values of $n$ that satisfy the equation $14n = kn,$ where $k$ is some integer.",
"rating": 0,
"flagged": null
}
],
"human_completion": null,
"chosen_completion": 0
},
{
"completions": [
{
"text": "This equation can be simplified by dividing both sides by $n,$ as long as $n \\neq 0.$",
"rating": -1,
"flagged": null
}
],
"human_completion": null,
"chosen_completion": null
}
],
"total_time": 46015,
"finish_reason": "found_error"
}
}
这一节我们详细讲解了OpenAI在Let’s Verify Step by Step一文中富集PRM样本的过程。其实针对不同业务,样本富集过程是最难的,也是优化效果最关键的。本文作者提供了一个两阶段并做了多轮的机造+ 人工标注的样本富集流程,有较大的参考价值。当然这里面还有个问题,在做数学解题的场景,答案往往是规范的且唯一的,我们很容易通过简单匹配方法校验生成结果是否是正确的,从而使我们能在主动学习阶段较轻量的筛选出"难负样本"。但针对一些更泛的业务场景:比如主观问答、文本创作。不能简单判断生成的结果和目标结果是否一致,这就需要根据业务,衍生对结果做判别的环节,这里由于业务不同,不做过多展看。当然可以屡试不爽的做trick处理,拿最好的模型(GPT-4、claude)Prompt一个裁判模型,对结果做校验。
4. 总结
本文重点基于Let’s Verify Step by Step一文讨论了PRM的作用和样本富集过程。相信通过上面的介绍,你已经对PRM有了个更深刻的了解。
基于OpenRLHF源码梳理下PRM的训练过程,详见:姜富春:OpenRLHF源码解读:理解PRM训练过程(https://zhuanlan.zhihu.com/p/16027048017)
5.附录
O1猜想的一些汇总:https://github.com/hijkzzz/Awesome-LLM-Strawberry
张俊林:Reverse-o1:OpenAI o1原理逆向工程图解:https://zhuanlan.zhihu.com/p/721952915
猛猿:OpenAI o1 技术初探1:整体框架,利用Test-Time Scaling Law提升逻辑推理能力:https://zhuanlan.zhihu.com/p/773907223
Let’s Verify Step by Step:https://arxiv.org/pdf/2305.20050