标题 | AlphaZero-Like Tree-Search can Guide Large Language Model Decoding and Training |
---|---|
作者 | Xidong Feng; Ziyu Wan; Muning Wen; Stephen Marcus McAleer; Ying Wen; Weinan Zhang; Jun Wang |
机构 | University College London; Shanghai Jiao Tong University; Carnegie Mellon University |
论文 | https://arxiv.org/pdf/2309.17179 |
摘要 / Abstract
最近的工作,如思维树(ToT, Tree-of-Thought)和 计划推理(RAP, Reasoning via Planning),旨在通过使用树搜索算法来引导多步推理,从而增强 LLMs 的推理能力。这些方法依赖于提示预训练模型作为价值函数,并专注于搜索深度较低的问题。因此,在预训练 LLM 没有足够知识作为有效价值函数或需要长期规划的领域中,这些方法将无法工作。为了解决这些限制,我们提出了一种类似 AlphaZero 的树搜索学习框架,用于 LLMs(称为 TS-LLM),系统地说明了如何使用具有学习价值函数的树搜索来引导 LLM 解码。TS-LLM 在两个关键方面有所不同。(1) 利用学习的价值函数和类似 AlphaZero 的算法,我们的方法可以普遍适用于各种任务、任何规模的语言模型以及不同搜索深度的任务。(2) 我们的方法可以在推理和训练过程中引导 LLMs,逐步改进 LLM。经验结果表明,在推理、规划、对齐和决策任务中,TS-LLM 优于现有方法,并且可以处理深度为 64 的树。
引言 / Introduction
LLMs 可以进一步通过规划算法进行引导,例如树搜索。该领域的初步工作包括采用深度 / 广度优先搜索的思维树(ToT)与基于蒙特卡洛树搜索 MCTS 的计划推理(RAP)。他们通过在扩展的树上搜索、自我评估,取得了性能提升。
然而这些方法存在明显的局限性。首先,树搜索算法中的值函数是通过提示 LLMs 获得的。因此,这些算法缺乏普适性,严重依赖于精心设计的提示和先进 LLMs 的强大能力。并且基于提示的自我评估是不可靠的。其次,ToT 和 RAP 使用 BFS/DFS 和 MCTS 进行树搜索,将它们的能力限制在相对简单和浅层的任务上。它们的最大深度仅为 10 或 7,远低于 AlphaZero 在国际象棋或围棋中达到的深度。因此,ToT 和 RAP 可能在需要大量分析深度和长期规划视野的复杂问题上遇到困难,降低了它们的可扩展性。
为了解决这些问题,我们引入了增强树搜索 LLM(TS-LLM),这是一种类似于 AlphaZero 的框架,利用树搜索来提高 LLMs 在一般自然语言任务上的性能。TS-LLM 将之前的工作扩展到类似于 AlphaZero 的深度树搜索,使用了一个基于学习的 LLM 值函数,可以在推理和训练过程中引导 LLM。与之前的工作相比,TS-LLM 具有以下两个新特性:
TS-LLM 提供了一个通用且可扩展的流水线。通用性:通过学习价值函数,TS-LLM 可以应用于各种任务和任意大小的 LLMs。我们的学习价值函数可能比基于提示的对应物更可靠,而且不需要任何精心设计的提示或先进的大规模 LLMs。我们的实验表明,TS-LLM 可以处理从 125M 到 7B 参数范围内的 LLMs,甚至在与 GPT-3.5 相比提供更好的评估。可扩展性:TS-LLM 可以进行深度树搜索,将树搜索扩展到生成 LLM 的深度达到 64。这远远超出了 ToT 的 10 和 RAP 的 7。 TS-LLM 可能作为一种新的 LLM 训练范式,超越推理解码。 通过将树搜索操作视为策略改进运算符,我们可以通过树搜索进行策略改进的迭代过程,然后通过蒸馏改进策略,通过树搜索轨迹上的地面真实训练标签改进价值函数。
通过对推理、规划、对齐和决策任务的全面实证评估,我们对 TS-LLM 中的核心设计元素进行了深入分析,探讨了不同变体的特点、优势和局限性。这展示了 TS-LLM 作为一个通用框架来引导 LLM 解码和训练的潜力。
使用树形搜索增强 LLMs
对于给定的自然语言任务,将奖励函数 定义为时间步 生成的任务性能反馈 。由于一般任务缺乏大规模和高质量的中间奖励标签,其通常为稀疏奖励设置,即前 时间步内的中间奖励都为零,只有第 步不是零。
在本文中,我们关注如何使用树形搜索算法进行优化。如上图所示,考虑两种动作空间设计:
句级动作节点:对于具有步骤 / 句级结构(例如思维链推理)的任务,将每个思想视为句级动作节点是很自然的。这也是 ToT 和 RAP 所采用的技术。对于每个非终端节点,通过采样几个可能的后续中间步骤并丢弃重复生成来扩展搜索树。 token 级动作节点:类似于离散动作空间 MDP 中的树搜索,我们可以将每个 token 视为 LLM 策略的离散动作,并且树搜索可以在 token 级别进行。对于那些中间步骤没有明确定义的任务(例如 RLHF),将输出序列分割成 token 可能是一个不错的选择。
通常,搜索空间由两个与算法无关的参数确定,即树最大宽度 和树最大深度 。在 LLM 迭代时,动作空间设计在搜索空间上都有其优势和局限性。通过将方案分割成句子,句级动作节点提供了一个相对较浅的树,简化了树搜索过程。然而,句级生成的大样本空间使得完全枚举所有可能的句子变得不可行。我们必须设置最大树宽度 来在扩展过程中对 个节点进行子采样。这种方式会导致树搜索空间和 LLM 生成空间之间的差异。对于 token 级别的动作节点,虽然它可以消除搜索空间的差异和额外的计算负担,但它大大增加了树的深度,使树搜索变得更具挑战性。
使用树形搜索引导 LLM 推理解码
树形搜索算法的一个好处是,它们可以通过简单的搜索来优化累积奖励,而无需进行任何梯度计算或更新。在本节中,我们提供了完整的流程,以说明对于一个给定的 LLM 策略,如何使用树搜索方法来引导 LLM 推断解码。
训练基于 LLM 的值函数
对于树搜索算法,如何构建可靠的价值函数 和奖励模型 是主要问题。ToT 和 RAP 通过提示先进的 LLMs(如 GPT-4 或 LLaMA-33B)获得这两个模型。为了使树搜索算法普遍适用,我们的方法构造了一个可学习的基于 LLM 的价值函数 ,该函数取决于状态 ,以及一个可学习的最终步骤奖励模型(ORM, outcome reward model),这是因为大多数任务可以被构造为稀疏奖励问题。由于我们主要处理语言任务,我们构建一个共享的价值网络和奖励模型,其结构是一个仅包含解码器 decoder 的 transformer 模型,并包含了一个 MLP 层,用于为输入 token 的每个位置输出一个标量。通常,LLM 价值函数的解码器是从原始 LLM 策略函数 的解码器进行迁移的,或者 LLM 价值函数 和策略函数 共享同一个解码器。对于一个句子级扩展的中间步骤 ,我们使用最后一个 token 的预测标量作为其值预测 。即当将完整句子 () 输入模型时,最终奖励可以在最后一个 token 处获得。
因此,我们使用语言模型 作为策略模型,使用任务训练数据集来采样生成。在训练数据中使用真实标签或给定的奖励函数,可以获得一组大小为 的采样数据 ,其中 是输入文本, 是 步的输出文本, 是真实奖励。类似于大多数强化学习算法中的 critic 训练方式,我们通过 TD- 或 MC 估计方法在每个时间步 上构建值目标 。值网络通过均方误差进行优化:
ORM 是以相同目标函数学习的。训练准确的值函数和 ORM 对于树搜索过程非常关键,因为它们提供主要的指导。
树形搜索算法
对于一个给定的可学习的值函数,下面介绍了五种类型的树形搜索算法。
基于值函数的剪枝广度优先和深度优先搜索(BFS-V/DFS-V):这两种搜索算法 ToT 中使用。其核心思想是利用值函数来修剪树以进行高效搜索,这种修剪分别发生在树的广度或深度上。BFS-V 可以被视为以累积奖励为目标的波束搜索。 蒙特卡洛树搜索(MCTS):这种方法是在 RAP 中使用的,它指的是经典的 MCTS 过程。它在终端节点上反向传播值,依赖于价值的蒙特卡洛估计,并从初始状态节点开始搜索。 具有值函数逼近的 MCTS(称为 MCTS- ):这是 AlphaZero 中使用的 MCTS 变体。从初始状态开始,我们选择状态 的节点作为根节点,并进行多次搜索模拟,包括选择、扩展和评估以及备份,叶节点的值由学习的值函数评估,将被反向传播到所有祖先节点。搜索结束后,根据根节点访问次数的指数选择动作概率,即 ,并移动到相应的下一个状态。上述迭代将重复进行直至完成。MCTS- 具有两个主要特点。首先,MCTS- 一旦采取行动就无法追溯到其先前的状态。因此,除非进行多次搜索,否则无法从初始状态重新开始搜索。其次,与 MCTS 相比,MCTS- 使用价值函数,因此可以在中间步骤中进行反向操作,而无需完成整个生成过程以获得蒙特卡洛估计。 MCTS-Rollout:结合 MCTS 和 MCTS- 的特点,我们提出了一种新的变体 MCTS-Rollout 用于树搜索。与 MCTS 类似,MCTS-Rollout 总是从初始状态节点开始。它进一步进行类似于 MCTS- 的搜索模拟,并且在中间步骤中可以使用值函数进行备份过程。它重复上述操作,直到过程找到 个完整答案或达到计算限制(例如最大令牌数量)。MCTS-Rollout 可以被视为 MCTS- 的离线版本,因此它们可能具有类似的应用范围。唯一的区别是,MCTS-Rollout 可以增加令牌消耗以获得更好的性能,因为它总是从头开始重新进行搜索。
多重搜索和搜索聚合
LLM 可以通过多次采样和聚合候选项来提高在推理任务上的表现,TS-LLM 也有潜力聚合 个由多个树搜索或单个搜索的多代产生的完整答案(设置 BFS 波束大小 )。
在进行多树搜索时,我们通常采用树内搜索设置。树内搜索在同一棵树上进行多次树搜索,因此状态空间完全相同。这种方法在计算上是高效的,因为搜索树可以多次重复使用。然而,多代之间的多样性可能会减少,因为前一次树搜索可能会影响后续的树搜索。此外,在句级动作空间中搜索空间受限,因为一旦跨越多个树搜索扩展,它们将被固定。
考虑以下三种不同的聚合方法:
多数投票:使用多数投票聚合答案: ORM-Max:选择具有最大最终奖励的答案: ORM-Vote:选择具有最大奖励总和的答案:
树搜索的额外计算负担
树搜索算法将不可避免地带来额外的计算负担,特别是在节点扩展阶段计算合法子节点及其相应值时。先前的方法学,如 ToT 和 RAP,倾向于将它们的性能与基准算法进行比较,使用相同数量的生成路径(命名为 Path@ )。这种方法忽视了树搜索过程的额外计算需求。
一个更公平的比较需要监控为节点扩展生成的令牌数量。这在操作在可比较的令牌生成条件下时,提供了算法性能的合理比较。我们在我们的实验中解决了这个问题。
实验 / Experiments
实验设置
任务设置:对于给定的 MDP,搜索空间的性质主要由两个维度来表征:深度和宽度。为了展示树搜索算法在不同搜索空间上的有效性,我们在五个任务上评估所有算法,这些任务具有不同的搜索宽度和深度,包括数学推理任务 GSM8k,数学规划任务 Game24,逻辑推理任务 PrOntoQA,使用合成 RLHF 数据的 RLHF 对齐任务,以及国际象棋残局。
基准算法:我们随后比较 ToT-GPT3.5 和 TS-LLM,以验证学习价值函数的有效性。所有树搜索算法都将进行基准测试,包括 MCTS-,MCTS-Rollout,MCTS,BFS-V 和 DFS-V。我们将这些变体与直接解码基线进行比较,包括 CoT 贪婪解码,以及具有自一致性的 CoT(记为 CoT-SC)。考虑到直接解码和树解码之间的搜索空间差距(特别是句子级动作节点),我们包括了在树的句子节点上执行 CoT-SC 的 CoT-SC-Tree 基线。
模型和训练细节:在树搜索中使用的部署策略中,我们在三个推理任务上使用 LLaMA2-7B,在 RLHF 任务和国际象棋残局中使用 GPT-2-small(125M)。所有 LLMs 将首先在训练集上进行监督微调(SFT),从而实现它们的零 - shot CoT 能力。对于值和 ORM 训练,数据是通过在训练集上对 SFT 策略的部署进行采样生成的。我们的策略 LLM 和值 LLM 是两个独立的模型,但都是从相同的基础模型进行调整的。
实验结果
如上表,学习到的值函数比基于提示的 GPT-3.5 更可靠,即使 GPT-3.5 比 LLaMA2-7B 强得多。我们对 Game 24 和 GSM8K 进行了 BFS Path@1 的比较,使用不同组合的策略和值。策略选择包括少样本 GPT-3.5 和我们的监督微调的 LLaMA2-7B。对于值,我们使用基于提示的 GPT-3.5/LLaMA2-7B(TOT)和我们学习到的值函数 LLaMA2-V。尽管少样本 GPT-3.5 策略对 LLaMA2-V 的评估是一种超出分布的策略,LLaMA2-V 在所有设置中仍然表现出对基于提示的 GPT-3.5/LLaMA2-7B 的主导性能。通过提示的 LLM 有限的自我评估能力现象与其他论文一致。这一发现大大增加了学习到的值函数的必要性。
使用可靠的学习价值函数,我们比较不同生成方法的性能。如上表,首先,在表格的上部,我们展示了 MCTS- 和 MCTS-Rollout 相对于 BFS-V(BFS-/DFS-V 和 MCTS 在 path@1 情况下退化为贪婪值树搜索)和 CoT-Greedy 的 Path@1 结果。实验结果表明,AlphaZero-like 搜索算法,MCTS-𝛼和 MCTS-Rollout 在长期规划至关重要的任务(RLHF 和国际象棋残局)中明显优于基线。在浅层树上搜索时,它们足够强大,以保持与基线相当的准确性。
为了进行公平比较,在上表底部中,我们展示了同样 token 数量 (Equal-Token) 下的结果,试图通过控制与 Path@1 TS-LLM 相似的计算消耗规模来进行比较。首先,我们提供了额外的基线,CoT-SC 具有两种聚合方法:多数投票(MAJ)和 ORM 投票(标记为 ORM,它利用了 TS-LLM 中学习到的 ORM)。在这种情况下,与 CoT-SC ORM 相比,TS-LLM 的优势大幅减少,特别是在 GSM8K 上(只有 BFS 贪婪值搜索是最佳的)。我们惊讶地发现,即使是这样简单的算法,在公平比较时也能表现出色。尽管如此,大多数树搜索算法在其余四个任务中仍然占主导地位,因为搜索空间更大(CoT-SC)。
此外,我们还比较了在搜索多条路径(由 ORM 模型聚合)时,BFS-/DFS-V 和 MCTS 的行为,其计算消耗在可比范围内。比较这 3 个变体,MCTS 在性能和计算成本方面几乎是最佳的,这表明了价值回传的重要性。与 Path@1 结果相比,MCTS- 𝛼 和 MCTS-Rollout 在浅层搜索问题(GSM8k、Game24 和 ProntoQA)中实现了可比的准确性,并在深层搜索问题(RLHF 和 Chess Endgame)中占据主导地位。这验证了在深层搜索问题下进行 Alphazero 风格的中间值回传的必要性。
如上图,我们展示了 RLHF 任务的平均 / 最大奖励以及 GSM8K、Game24 和 ProntoQA 的 3 个聚合结果中的最佳结果。我们根据路径数量和令牌消耗来衡量聚合的性能。
从图中,我们主要总结了两个结论:首先,大多数 TS-LLM 变体受益于聚合,并且与其他基线相比可以展示出较大的优势。CoT-SC 仅在 GSM8k 中以相同的令牌大小击败 TS-LLM,主要是因为其更大的搜索空间。其次,小规模问题中,树搜索算法的聚合效益不及 CoT-SC。 在 GSM8K 和 Game24 中,TS-LLM 在大量聚合数下难以改进。我们认为这是因为:(1)CoT-SC 和树搜索算法之间的搜索空间差距。树搜索算法本质上探索的句子较少,这通过比较 CoT-SC-Tree@50 和 CoT-SC@50 之间的令牌消耗得到验证。(2)树搜索算法已经利用了值函数和 ORM,再次利用 ORM 进行聚合的好处变得不那么明显。