标题 | Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking |
---|---|
作者 | Eric Zelikman, Georges Harik, Yijia Shao, Varuna Jayasiri, Nick Haber, Noah D. Goodman |
机构 | Stanford University, Notbad AI Inc |
邮箱 | 无 |
论文 | https://arxiv.org/abs/2403.09629 |
摘要 Abstract
推理任务对于语言模型(Language Model, LM)来说十分困难,在 "自学推理器"(STAR) 提出的方法中,有用的思维(原理)是通过在回答问题时从少数几个例子中推断出理由,并从那些能得出正确答案的例子中学习的,然而该方法不能学习推断任意文本中未陈述的理由。
我们提出了 Quiet-STaR,它是 STaR 的广义化,其中语言模型学会在每个标记处生成理由,以解释未来的文本,从而改进其预测。我们解决了一些关键难题,包括:1)生成连续语的计算成本;2)LM 最初不知道如何生成或使用内部思想;3)需要预测单个下一个标记以外的内容。为了解决这些问题,我们提出了一种标记并行采样算法,使用可学习的标记来表示思维的开始和结束,并采用了一种扩展的教师激励技术。生成的理由对建立难以预测的标记模型有不成比例的帮助,并提高了 LM 直接回答疑难问题的能力。特别是,在使用 Quiet-STaR 对网络文本语料库中的 LM 进行持续预训练后,我们发现在零 - 样本数据集 GSM8K(5.9%→10.9%)和 CommonsenseQA(36.3%→47.2%)上的改进,并观察到自然文本中的困难词组的困惑度提高。最重要的是,这些改进无需对这些任务进行微调。Quiet-STaR 标志着我们向能以更通用、更可扩展的方式学习推理的 LMs 迈出了一步。
When writing and talking, people sometimes pause to think. Although reasoning-focused works have often framed reasoning as a method of answering questions or completing agentic tasks, reasoning is implicit in almost all written text. For example, this applies to the steps not stated between the lines of a proof or to the theory of mind underlying a conversation. In the Self-Taught Reasoner (STaR, Zelikman et al. 2022), useful thinking is learned by inferring rationales from few-shot examples in question-answering and learning from those that lead to a correct answer. This is a highly constrained setting -- ideally, a language model could instead learn to infer unstated rationales in arbitrary text. We present Quiet-STaR, a generalization of STaR in which LMs learn to generate rationales at each token to explain future text, improving their predictions. We address key challenges, including 1) the computational cost of generating continuations, 2) the fact that the LM does not initially know how to generate or use internal thoughts, and 3) the need to predict beyond individual next tokens. To resolve these, we propose a tokenwise parallel sampling algorithm, using learnable tokens indicating a thought's start and end, and an extended teacher-forcing technique. Encouragingly, generated rationales disproportionately help model difficult-to-predict tokens and improve the LM's ability to directly answer difficult questions. In particular, after continued pretraining of an LM on a corpus of internet text with Quiet-STaR, we find zero-shot improvements on GSM8K (5.9%→10.9%) and CommonsenseQA (36.3%→47.2%) and observe a perplexity improvement of difficult tokens in natural text. Crucially, these improvements require no fine-tuning on these tasks. Quiet-STaR marks a step towards LMs that can learn to reason in a more general and scalable way.
Quiet-STaR
如上图,Quiet-STaR 有三个主要步骤:
并行推理生成(思考):对于输入序列 n$个token {x_i},我们并行地生成长度为t的r个推理:{},得到 {} 个候选推理。并插入训练好的 <|startofthought|> 和 <|endofthought|> token 来标!!!!记每个推理的开始和结束。 混合推理和基础预测(回答):根据每个推理所产生的隐藏状态输出,我们训练了一个浅层 MLP 混合头,通过一个权重来决定与基础语言模型预测的 logits 相比,推理产生的下一个 token 预测的 logits 应该被使用多少。这种方法可以在微调的早期缓解由于引入了理由导致的分布偏移。 优化推理生成(学习):我们优化推理生成参数(开始 / 结束标记和 LM 权重),以提高可以使得未来文本更合理的推理产生的可能性。使用 REINFORCE 技术,根据合理性对未来标记预测的影响,为合理性提供学习信号。为了减少差异,我们采用了一种教师激励技巧,在损失中不仅包括预测思想后标记的可能性,还包括预测后面标记的可能性。
并行生成
Quiet-STaR 中的一个难点就在于:如何高效地为输入序列中的每一个 token 生成推理。在长文本上单独地为每个 token 计算前向传递会消耗巨大。
如上图,为了高效地并行的处理,我们缓存每一个前向传递并拼接上一个对角注意力矩阵。此时,每一个生成的 token 都会关注与他有关的 tokens 以及自身,
残差混合头
在训练初期,思考会偏离分布。为了平滑过渡到思考状态,引入了一个可学习的插值。在思考后得到 token 的隐藏状态与原始文本 token 中,通过该参数决定思考后预测的 logits 的使用程度。
优化推理生成
优化 Start-of-Thought 和 End-of-Thought tokens
使用 <|startofthought|>
和 <|endofthought|>
作为可学习的 token 用于控制推理的生成。为了提升训练初期的表现,将 <|startofthought|>
和 <|endofthought|>
的 token 初始化为与破折号 “——” 的 token 相同。这是因为其通常在文本数据中代表一个暂停或是思考。
非近视奖励与教师激励技术
由于思考并不一定会影响到预测每个 token,因此模型的奖励应更依赖于未来的文本内容而不仅仅是下一个确切的 token。使用并行注意力掩码来计算真正的下一个 token 的对数概率,同时使用教师激励技术来假设此时选择的下一个 token 为正确的 token。
如下图,实线部分代表语言模型计算过程,虚线代表教师激励技术插入的的 token,搅拌机图标代表 MLP 混合头。
目标优化
使用 REINFORCE 强化学习的方法优化推理生成的可能性。
实验
如上表,评估 Quiet-STaR 在 CommonsenseQA 和 GSM8K 上提高语言模型的零样本推理能力的程度。在 CommonsenseQA 上,我们发现与基本语言模型相比,Quiet-STaR 的性能提高了 10.9%。这种改进随着模型推理中使用的标记数量而持续增加,这表明通过思考 token 进行更彻底的推理正在转化为更好的直接问答性能。同样,在 GSM8K 上,Quiet-STaR 比基本模型提高了 5.0%,并且性能再次随着 Quiet-STaR 训练期间生成的推理的长度而变化。作为参考,在上图中包括一个基线,对应于在同一数据集上训练相同的模型,没有思考的 tokens。我们观察到,在多条曲线中,性能似乎最终会恶化 —— 我们预计这是因为我们没有对这些下游任务进行训练,因此思维标记的角色可能会随着时间的推移而改变。我们还发现了我们的非近视奖励的好处,我们在附录 D 中对此进行了讨论。
虽然思维链提示和我们的方法之间存在天然的相似之处,但它们是正交和互补的。在零样本思维链中,用户主动提示模型 “大声” 思考,否则使用其普通生产分布;相反,Quiet-STaR 允许模型对每个 token 进行安静思考,并训练出有用的分布。我们使用无声的 Quiet-STaR 基本原理进行研究,同时生成明确的 CoT 推理。因为我们的目标是通才推理,根本不需要特定于任务的输入,所以我们使用了零样本提示(“让我们一步一步思考”),而没有上下文示例。我们的实验表明,内部原理允许模型生成更结构化和连贯的思维链。如上图所示。根据对 128 个 GSM8K 测试项目的样本进行评估,使用 Quiet-STaR 的 8 个样品 (cot-maj@8) 的多数表决准确率从 40.6% 提高到 47.7%。