在o1的整体框架篇中(https://zhuanlan.zhihu.com/p/773907223),我们从现有开源的论文和代码中(https://github.com/hijkzzz/Awesome-LLM-Strawberry),抽象出了o1可能的技术实现路径,如下图:
这里对于这张框架图我们不再做赘述,详情可以参见上面《框架篇》的文章链接。
我们之前说过,这是一张高度抽象的框架图,旨在说明o1官方技术报告中提到的“把更多算力花在inference阶段上,以提升模型的逻辑推理能力”的含义。而从本文开始,我们将以具体的算法去扩展这张框架图的细节。
今天我们要具体扩展的,就是框架图中的Inference部分(黄色块),从框架图可知,Inference部分一般有两个作用:
作用1:直接对inference过程进行优化,具体的优化方法例如:
PRM + some search methods。其中PRM表示我们额外训练的、用于评估“模型中间步骤”而不是“模型答案结果”的奖励模型。我们在框架篇中给过使用这种优化方法的具体例子,这里不再赘述
MCTS(Monte Carlo Tree Search)。使用蒙特卡洛树搜索的方法(AlphaGo中采用过),通过self-play的方式来找到一条最佳的“原始问题->中间步骤->答案”路径。从广义上来说,PRM + some search methods的方法其实也算是一种MCTS-style类型的搜索方法,只不过在MCTS中,我们通过“探索”步骤去估计结点的reward,而一个训练好的PRM则是直接替代了这种“探索-评估”过程。如果你对这些描述觉得抽象,那也没关系,MCTS是本文讲述的重点,我们马上会在文章中看到它的实现细节。
作用2:用于在post-training过程中筛选高质量的数据进行训练。
从对目前的一些开源工作的总结中,我们发现,在提升模型推理能力这一环节有一个核心的原则:尽量少用人工标注,多借助已有模型(base generator)本身的能力,去自动化地生产训练数据。然后再利用这些训练数据,通过sft或者强化学习等等post-training的方法,去提升模型的推理能力。
为了保证这些自动化生成的训练数据的质量,我们可以引入Inference模块,帮助我们搜索出高质量的数据。
所以,Inference模块可以看作是o1实现中的一块积木。当你理解这块积木的目的、以及一些可能的实现方法后。你就可以按需要灵活把它组装在你心目中o1的任何一个环节。在网上关于o1的资料中,我们可能经常会看见“MCTS,self-play”这样的关键词,它其实就是这块黄色积木的一种实现方式。不过笔者认为,o1走的不是纯靠优化inference的路线(即上图中的framework1),更可能走的是post-training + inference路线(即上图中的framework3,因为o1的技术报告中提过它把算力也花在了RL阶段上)。但是无论如何,了解这块积木的实现总是必要的。
在这篇文章中,我们将以微软在今年开源的rStar这个工作为例(https://github.com/zhentingqi/rStar),全面从源码出发,来详细看下MCTS技术是如何运用在nlp的逻辑推理任务上的(毕竟我们对MCTS的主要了解都来自AlphaGO,我们肯定非常好奇它要如何运作在自然语言上,特别是这个前提下它的搜索空间是什么)。阅读本文不需要任何MCTS先验知识,文中会循序渐进地做介绍。
一、为什么选择rStar
rStar的目的同样是提升模型的逻辑推理能力,但是它走的是上图中的framework1,也就是纯靠inference的搜索优化来实现目标,同时它选择的是MCTS而非PRM + search methods的方法。rStar作出这样选择的原因如下:
Base generator是个小模型(SLM, Small Language Model)。rStar针对的是小模型场景,对于小模型来说,它本身的能力就不强,所以我们不能指望小模型能借助pretrain阶段的能力去生产高质量的训练数据,也即post-training自产自消的方法在小模型上难以走通。同时,在大部分业务场景下,我们可能也没那么多训练资源。
PRM的训练是费钱的。如果非要用人工标注,那么大概率这个标注会花在PRM的训练上(参考框架篇中对openai的PRM的训练方式介绍)。对于身处贫苦环境中的我们,以及被落地okr催促的老板们,时间和金钱成本是能省则省。
正因为rStar走的是纯Inference的路线,所以更便于我们从”一块积木”的视角来理解框架图中的黄色块。同时,利好小模型的场景也更适合资源有限的我们。最后,当然是rStar的代码完全开源,方便我们一探所有的细节,少一些自己的想象(rStar的论文其实写得比较精简,少了很多细节的描述,也一定程度上造成代码不太好读)。
二、按照人的思考方式构造一棵搜索树
这里我们先不谈MCTS的任何概念,我们只看:对于某个问题,你会采用什么样的思维链来解决它?
假设我们有一个简单的问题:
user_question:
If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
为了解决它,我们可能有如下思考方式(所有的思考方式都以字母A开头,表示Action)
2.1 A1(propose a one-step-thought)
我们会做过程的拆解,每次提出一个推理step,直到生成最后的答案。我们记这种思考方式为A1。例如:
A1(propose a one-step-thought)
### Instruction:
If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
### Response:
Let's think step by step.
Step 1: Start with the number of cars that are already in the parking lot, which is 3 cars.
Step 2: Add the number of cars that arrive, which is 2 cars.
Step 3: Add the numbers together. there are 3 cars + 2 cars = 5 cars in the parking lot.
Step 4: The answer is 5.
观察上面的steps,我们会发现:
总是以Let's think step by step.开头
每个step的形式是“该step的推理文字+该step的答案”。例如step1中,在一段推理相关的文字结束后,能提取出“3”这个数字答案
最后一个step以“The answer is”开头,表示产出了原始问题的最终答案。
2.2 A2(propose the remaining thought steps)
对于一些简单的问题,我们可能并不会步步思考。我们会一次性通过一些简单的推理后直接给出答案,例如:
### Instruction:
If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
### Response:
Let's think step by step. There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is: 5.
2.3 A3 (propose next sub-question along with its answer)
有时候,我们会把原始问题拆解成很多子问题,然后回答一个个子问题,最终给出答案,例如:
Question 1: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
Question 1.1: How many cars are there in the park before?
Answer 1.1: There are 3 cars in the park before.
Question 1.2: How many cars arrive then?
Answer 1.2: 2 more cars arrive.
Question 1.3: Now we can answer the question: how many cars are in the parking lot?
Answer 1.3: There are 3 + 2 = 5 cars in the parking lot now. The answer is 5.
其中,Question1是原始问题,其余是拆解的子问题。其中,Question 1.3属于终结类型的子问题,因为回答它就等于回答了最终答案。这种拆解子问题的方式更适合用来解决困难问题,我们的例子比较简单,这里只是展现出一个形式。
2.4 A4 (Answer the sub-question again)
这种方式将和A3一起配套使用,例如,对于A3的Question1.1,你可能并不确定Answer1.1是否正确,这时你想重新再思考一次Answer1.1的答案。由于此时你只是对某一个子答案做修正,因此你可能采用A2(propose the remaining thought steps)的方式,做一些简单的推理,重新取得Answer1.1。此时相当于把Answer1.1用A2例子中的输出结果进行替代,这里不再给出具体例子。
2.5 A5(Rephrase the question/sub-question)
有时我们在做题时,通常会在大段的原始题目描述中,把关键信息提取出来,例如:condition1..., condition2等等。我们可以先通过这种方式改写原始题目/子题目,然后再做回答。这个比较好理解,同样也不再给出具体的示例。
2.6 整合:构造一颗搜索树
总结一下,目前为止,我们按照人类的思维方式,总结出了人类解决一个问题时可能采用的5种方法:
A1(propose a one-step-thought):步步推理,每一步都有一些中间答案,然后在最后一步中得到最终答案
A2(propose the remaining thought steps):一次性推理完毕,直接得出最终答案
A3 (propose next sub-question along with its answer):将原始问题拆解成若干子问题并做相关回答。最后一个子问题的答案即是最终答案(和A1有些类似,但采取的是subquestion-subanswer这种指示方式)
A4 (Answer the sub-question again):有时A3中某个子问题的回答不一定可信,我们尝试重新回答它。这时我们会采用A2的模版,重新回答这个子问题
A5(Rephrase the question/sub-question):重新复述一个原始问题/子问题。例如去掉大段文字表述信息,只把关键部分提取成condition1..., condition2之类的形式,用这个形式当作新的问题。
在代码操作中,我们会按2.1~2.5的示例,构造相应的prompt来指示模型执行不同的动作。下图给出了A1的prompt示例,更多例子大家可以参见源码中rStar/prompts部分:
当人解决问题时,可能会根据问题的难度,决定不同的解决模式,但是当我们采用模型进行搜索时,模型是很难预知问题难度的,所以我们总是希望:模型能够尽可能地把这些解决方式(Action)都探索一遍。
那么接下来,我们就配合着rStar的源码,一起来看下这棵搜索树长什么样子(这里我们不使用论文中的图,因为它缺少了太多细节,我们直接从源码出发重新绘制):
我们先看一些基本信息:
搜索树的根结点是原始问题。
方形node表示终止结点(leaf node),例如图中的cot结点(A2)。但注意,不是只有cot结点才是leaf node。例如A1中的最后一个step,A3中的最后一个子问题-子回答都可以成为leaf node。
虚线表示选择性探索(根据你的脚本配置决定),实现表示必须性探索。
接下来我们来看图中的更多细节:
我们从根结点(第0层)出发,根结点是用户的原始问题,对于根结点来说:
A1(a step),A5(rephrase)是选做的,A2(remaining steps,图中按源码的命名方式称为cot),A3(next subqs and subas)是必做的。其中,经过A5后,相当于从一个全新的用户问题出发,所以A5之后创建分支的规律和根结点一致,因此图中不再画出。现在我们观察图中的第1层,也就是根结点所有的子节点
先看第1层中的A1部分(蓝色块),它表示对原始问题并行采样若干个step1(代码中默认采样3个)。每个step1都是一个A1类型结点。并行采样的目的也是为了拓宽搜索空间。从直觉上理解,对于一个原始问题,当我们决定采取A1的思考方式解决它时,不同的人产出的step1是不同的,所以这里我们要做并行采样。
再看第1层中的A2部分(绿色块),它表示对原始问题做简单推理后,一步生成最终结果。这里同样采用了并行采样的方式(默认值为32)。因为已经生成最终答案,所以A2结点都是leaf node,它不会再往下做任何扩展。
再看第1层中的A3部分(红色块),它表示对原始问题产出next subquestion及其对应的subanswer。这里同样采用并行采样的方式(默认值为3),例如对原始问题,我们并行采样三个Question1.1和Answer1.1。
再看第1层中的A5块(灰色块),它表示对原始问题进行重述,前面已提过,这里不再赘述。
那么根结点的子节点中为什么没有A4(re-answer subquestion)呢?因为根据A4的性质,A4一定只能出现在A3之后,所以这里不会有。
接下来我们从第1层出发,以第1层为例,探索下不同类型的结点可以生成什么类型的子节点,以及最终可能的leaf node类型。只要搞清楚了第1层,其余层就可以类推了。
对于A1(a step)类型结点
它只能产生A1(选做)和A2(必做)两种类型的子节点。也就是说,基于当前step,我们必须产出一个一步到位的推理结果,选择性产出一个step by step的结果。
A1结点最后的leaf node既可以是A2(一步到位产出了最终结果),也可以是A1类型(最后一个step的结果)
A1类型结点的值包含“该step推理文本 + 该step阶段性答案”,参见2.2中的例子
对于A2(remaining thoughts)类型结点
如前文所说,它已经一步到位产出了答案,所以是leaf node,它不会再有子节点 对于A3(subqs and subas)类型结点
它可以产生A1(选做),A2(必做),A3(必做),A4(必做)类型子节点
图中第1层,我们只画出了并行采样出的第1个A3结点的子节点情况,其余并行采样的结果也是类推,因此图中没有画出,只用简单的省略号表示。(这个省略号其实也应该画在第2层,因为图的尺寸限制画偏了,特此说明)
A3结点往下延伸的leaf node,可以是A1(最后一个step结果),也可以是A2(一步到位产出最终结果),也可以是A3(最后一个subqs + subas结果,参考2.3示例)。
对于A4(re-answer subqs)类型结点(这一部分我们参考第2层)
它可以产生A1(选做),A2(必做),A3(必做)类型子节点
当我们执行A4时,你可以理解成只是重新修改了它的parent层的sub answer。
它的leaf node可以是A1(最后一个step结果),也可以是A2(一步到位产出最终结果),也可以是A3(最后一个subqs + subas结果,参考2.3示例)
总结一下,到目前为止我们已经解决了:
我们先根据人类思考问题的模式,设置搜索动作空间(Action,缩写为A)。
搜索空间中的不同动作之间可能有前-后(parent-child)的依赖条件,我们根据这些条件,决定了一个完整的搜索树要长什么样。
但是,仍有一些重要但未解的问题:
这棵树是我们站在上帝视角,(基本)穷尽所有的动作可能后构造出来的。那么对于模型,它应该怎么按我们的想法构造出这棵搜索树呢?
有了这棵搜索树后,我们要如何从根结点(user question)开始,选择一个最佳的推理路径并产出最后的答案呢?
为了解决这两个问题,现在我们可以请出MCTS这个算法了。
三、使用MCTS搜索最佳推理路径
3.1 使用rollout构造搜索树
对于模型来说,现在它将从原始问题出发,构造一棵搜索树。我们先来看从根结点出发,模型构造搜索树的过程:
对于根结点来说:
执行select步骤。选中根结点(我们马上就来看select更多细节,目前为止我们只用关注这一步select到了根结点)
执行expand步骤。按照第二节中我们说的各结点间的依赖规则,为被选中的结点(这里是根结点)创建所有可能的子节点。为了绘图简便,这里我们略去了2.6节中所述的“并行采样”的过程,但实操中依然是并行采样的!
执行simulate步骤。随机选择一个子节点,重复执行“expand-随机”步骤,直到遇到leaf node或者达到设定的最大搜索深度为止。注意,只有两种类型的node可以成为leaf node(这和第二节中我们列的leaf node的理想情况有些许不同)。
Terminal A3 node:如果一个subquestion结点是最后一个子问题(“最后一个的含义”是,子问题中包含原始问题,或者子问题以“Now we can answer the question"开头,参见2.3示例。能做到这一点是因为我们通过相关prompt来指示模型生成结果 )
Terminal A2 node:这个node本身就是一步推理产出最终结果,前面已经说过,这里不赘述
执行backprop步骤。这一步我们将计算leaf node的reward,同时将本次搜索路径上所有node.reward += leaf_node.reward,node.freq += 1,其中freq表示node被访问的次数。那么如何计算leaf node的reward呢?
Terminal A3 node reward:对于A3类型的leaf node,我们对这最后一个子问题,并行采样若干个子回答。假设我们采样n个子回答,这些回答中指向答案x,y,z的条数分别是a,b,c(n = a+b+c),那么x答案的占比就是a/n,以此类推,我们选择占比最大的那个答案作为最终答案,并将这个占比作为reward。
Terminal A2 node reward:对于A2类型的leaf node,我们则直接在它的所有并行采样结果中计算答案占比,计算方式同上。
这样一轮select + expand + simulate + backprop的步骤,就称为1次rollout。不难发现,在1次rollout过后,我们构造出了一部分搜索树(这里我们先只谈构造,不谈搜索,大家不要着急)
接下来我们执行第2轮rollout,继续构造我们的搜索树(这里不再画图了,我们直接从1st rollout的图例中想象一下):
执行第2轮rollout的select步骤。第2轮rollout将从第1轮backprop后构造的那棵搜索树出发。同样从根结点开始向下选择,我们先走到第1层,发现有3个子节点都没被探索过,这时我们随机选择一个子节点,例如图中第1层的A5(rephrase),这个子节点将被用作expand。到这里,我们再深度总结一下select步骤要做的事情:
每次都从根结点出发,向下逐层探索(explore),直到找到一个未被探索过的结点为止。
如果从根结点出发,发现某一层(比如第1层)所有的结点都被探索过了。那么我们就计算每个结点的UCT值(在3.2节中会细说,这个UCT值可以理解成用于计算一个结点的探索价值,它由结点的reward、freq和用于控制探索权重的超参C决定)。我们选择UCT值最大的结点,向下层继续搜寻,以此类推。
所以,总结来看,select步骤的目的就是尽可能找到一条未被探索、或者具有最高探索价值的路径。以便后续沿着它往下扩展,生成更好的搜索树。
执行第2轮的expand、simulate、backprop步骤。道理同上,不再赘述。
这里额外再提一句,生成搜索树的每一层时,我们都需要用前面所有层的推理步骤作为上文,传递给模型做生成,大家可以自行阅读源码找到构造上文的更多细节,这里不再额外介绍。
好,到这里为止我们已经理清单轮rollout的概念了,以此循环往复,在执行若干轮rollouts(代码默认值为16)后,我们就有一棵相对完整的搜索树了,接下来我们就可以基于这棵树去找到一条最佳的推理路径了。但是在介绍具体的搜索方法之前,让我们再来看看,如何计算一个结点的UCT值(UCT值越大,该结点被探索的价值越大)。
3.2 计算结点的UCT值
一个结点的UCT值计算方式如下:
Q:截止到本轮rollout为止,该结点的累积reward N:截止到本轮rollout为止,该结点的累积被访问的次数 N_parent:截止到本轮rollout为止,该结点的父结点累积被访问的次数 c:探索权重,c值越大,更侧重探索(explore)而不是利用(exploit),有点抽象,不要紧,我们马上细说
什么样的结点更具被访问的价值呢?从直觉上说,平均reward越大的结点,表现越好,应该更具访问价值。这就是Q/N在做的事情。而这一部分也被称为利用(exploit),也即我们直接利用当前的结点价值数据做决策。
但是,如果一个结点被访问的次数比较少(比如它的父结点被访问了几百次,它才被访问几次而已),这说明这个结点所在的路径可能有更多的“宝藏”还没被我们发现,因此我们也应该给这些结点更多的机会。这就是c*sqrt(N_parent/N)在做的事情。而这一部分也被称为探索(explore),也即我们给访问次数较少的路径更大的被探索的机会。
而人为设置的探索权重c,就起到控制explore和exploit程度的作用。一般来说,我们会对c采用一些"退火策略“即:
对于初期的那些rollouts,我们用较大的c,侧重探索
对于后期的那些rollouts,我们用较小的c,侧重利用。毕竟这时对于这棵树,我们已经有很多先验的价值评估了。
理解这一点,阅读代码中的相关部分就不难啦。
3.3 搜索最佳路径
(1)直接从树中搜索
在若干轮rollouts后,我们终于有了一棵相对完整的搜索树了,那么现在对于一个原始问题(根结点),我们该选择一条最佳的路径,帮助我们找到合理的推理过程和答案呢?
首先,对于这棵搜索树,我们找到所有有效的solution nodes。满足以下任意条件的node是solution node:
node是terminal A3 node(具体定义我们在3.1中讲过,不赘述)
node是A2 node(图中cot,天然就是terminal)
node是terminal A1 node(定义是最后一个step中含“The answer is...”字符串,再复习一遍,之所以出现这种模式是因为受我们的prompt控制)。由此我们可以发现,solution nodes = leaf nodes + terminal A1 nodes。
这些被找出的solution nodes及其相关的路径就是我们所有的备选项。我们想从这些备选项中找到1条最好的,所以接下来我们就来研究怎么量化“好不好”的问题。
每一个solution node下都涵盖一个最终答案。我们以这个答案为key对所有solution node做group by,即含有相同答案的solution node为一组。这时我们会得到一个字典,形式如{ans1: [solution nodes], ans2: [solution nodes], ...}
接下来我们同样通过前文所说的“占比计算法”,统计每个答案的投票得分。例如一共有n个solution nodes,答案ans1, ans2, ans3下的solution nodes数量分别为a,b,c(n = a+b+c),那么ans1的投票得分为a/n,以此类推。
计算每个solution node的prior weights,用于衡量其所在路径的整体质量,我们展开细说:
对于一个solution node,我们从它的父结点开始往上逐层遍历
在每一层中,我们将“遍历到的结点 + 它之前所有的祖先结点”产出的推理结果拼在一起,再使用一个A2(propose remaining thoughts)的prompt,基于这个拼接结果,并行采样出若干个一步到位的答案。
我们计算这若干个一步到位的答案中,有多少个solution node的答案一致,我们记这个比例为depth_score,也就是每一层推理上的得分。
那么最终一个solution node的prior weight = prod(depth_score),也就是它的路径上每一层depth score的乘积。
这个prior weight直觉上的理解就是,如果一条路径上,每一步推理过程都稳定地指向solution node的答案,说明整个推理过程是高度自洽的,这时我们就给这个solution node一个比较高的prior weight。这样就可以避免“答案是对的,但是过程可能是懵的”的情况。由于论文中没有谈及这块,且代码写得比较难读,所以可能很多朋友在读源码时会困惑在这里,这边特此说明下。
如果按照之前构造搜索树时simulate步骤找到最终答案和计算reward的方法,我们只需要选择投票得分最高的答案即可。
但是,有时候一个答案的投票得分虽然高,但是整体的推理过程质量却不一定好。那么什么样的推理过程才算好呢?我们可以做一个简单的假设:
现在,对于一个solution node(及其相关路径),我们有一个用于评估其答案的投票得分,和一个用于评估其推理过程质量的prior weight,那么这个solution node的最终得分就是两者的乘积。那么此时,我们同样有2种办法可以选择出一个最佳solution node:
方法1: 选择得分最高的那个solution node及其path即可(目前代码的默认做法)。
方法2:对于一个solution node,计算一个prob值,这个prob值 = 它的score/所有solution nodes的score之和。然后这个solution node将以概率prob被随机选中成为最佳solution node(代码中写了但没有用到)
(2)使用discriminator
在3.3(1)中,我们讲解了直接从构造好的搜索树中选择最佳路径的方法。但是在rStar中,还提供了另一种巧思:借助一个discriminator。也就是我们构造的搜索树相当于一个generator,我们使用discriminator从generator的结果中找到最可信的那个,这和我们熟知的GAN非常相似。其中generator和discriminator都是小模型,但是不同的小模型。
我们来看详细的过程:
在构造好的搜索树中,我们取出每一次rollout找到的leaf node,以及最后一次rollout找到的solution node。这些node将成为我们的备选项。我们将使用discriminator从这些备选项中找到最好的。 对于每一条备选路径,我们把整条路径上的“user question -> 推理步骤 -> 答案”拼接起来,形成一个完整的文本。 对于这个文本,我们随机选择一个推理步骤,然后mask掉这个推理步骤之后的所有结果(包括答案),如下图所示,Question + SLM1是我们的一条备选路径。Masked Solution SLM2则是我们把SLM1中的一些结果mask掉后,输入给我们的discriminator(SLM2)
我们把mask掉的文本喂给discriminator(SLM2)。如果SLM2生成的答案和SLM1的一致,则我们认为这个推理路径是可信的,因为两个不同的模型都基于同样的推理文本给出了相同的结果(consistent,图中绿色示例)。否则就是不可信的(图中红色示例) 我们使用这种一致性,从备选路径中筛选出一波generator和discriminator一致的结果。然后再用之前所说的计算各种得分的方式,找到最可信的1条路径。更多的细节,大家可以自行阅读源码。
四、总结
在本文中,我们以rStar为例,从代码级别的角度,给出了o1(可能的)实现框架中Inference这块积木的一个实现方法。在写这篇文章时,我本来想放一些源码和注释的,但是考虑到它在公众号里太占篇幅,可读性不高,所以没有放出来。但是源码中最精华的部分已经在前面的讲解中了,可以大大降低大家读源码的难度。后续我会把带注释的源码解读更新在我的知乎上(https://www.zhihu.com/people/lemonround),大家有兴趣可以来看~
有了对MCTS如何运用在nlp任务上的一些初步理解,接下来我们就可以按自己的兴趣,广泛探索这块黄色积木的各种实现方式啦(其实本质上都做得差不多)。在后面的系列中,我们将继续对框架做拆解,加入更多的积木。