专题解读 | 利用投机采样加速大模型推理

科技   2025-01-20 16:08   山东  

专题解读:利用投机采样加速大模型推理

简介

在大模型的应用场景中,推理速度是最为影响用户体验的要素之一。针对大模型推理速度的优化有很多工作,本文将聚焦于投机采样技术。

目前的大模型通常采用自回归的方式生成文本,即每次前向传播只输出一个单词,这是限制其推理速度的重要因素。投机采样通过引入一个参数较小的模型生成多个候选词(drafting),然后利用标准模型对候选词进行批量验证,从而减少重复计算,大幅提升推理效率。

如上图所示,每次迭代小模型都会生成一个文本序列,经过大模型验证后,绿色文本被接受,红色和蓝色的文本则被拒绝。这样每次迭代可以生成多个单词,推理速度大幅提升。此外,相比于模型压缩技术,投机采样不会损失模型的生成质量。本文将介绍两篇投机采样的经典工作,它们分别代表了两种不同的技术思路。

SpecInfer: Accelerating large language model serving with tree-based speculative inference and verification (ASPLOS24)

在更早的投机采样工作中,小模型只产生一个候选词序列供大模型进行验证。由于小模型在参数量上的劣势,候选词通常不会被全部接受,因此候选词的接受率成为了影响投机采样算法性能的重要因素。本文提出了一种优化算法,可以利用小模型生成多个候选序列,然后利用Tree Decoding进行快速验证,通过生成更多的候选词来提升增加每次可能被接受的序列长度。这种方法随后受到了广泛应用。

Drafting

SpecInfer希望在Drafting得到多个候选序列,对此有两种思路:采用多个小模型生成多个序列;或者使用单个模型,在每次生成最后的Decoding阶段留下多个单词,从而产生分支,这种方法的根据在于通常被大模型接受的token都在top-k列表里。最后得到的序列合并后会是一个树形结构,树中的每个节点代表一个token,节点的父节点即为其在序列上的前一个token。

Tree Decoding

如果要用大模型对多个序列都进行验证,会产生大量的计算开销,这样很难带来推理速度提升,为此SpecInfer提出了Tree Decoding方法,可以一次对多个序列同时进行验证。

如下图,对于drafting生成的token tree,Tree Decoding将各个节点按照拓扑序展平为一个序列,然后为其生成一个特殊的Causal Mask。在这个Mask中,每个token与它祖先节点的格子上填1(如t9-t8),其余则填0(如t9-t4)。这样在Attention计算时,每个token只与它的祖先节点,也就是在序列上更早出现的单词进行计算。通过这种方式,Tree Decoding可以将多个分支序列合并到一次计算中完成,大幅提升了验证效率。

实验效果

SpecInfer相比其他的分布式推理框架有较大的速度提升,其中,Tree Decoding相比普通的投机采样算法有大概1.2-1.5倍的速度提升。

如下图所示,通过引入更多的分支序列(增大Token tree width),投机采样过程中的平均序列接受长度普遍得到了提升。而序列接受长度直接影响了投机采样的效率,这说明了该方法的有效性。

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

理论上,SpecInfer及其类似工作的候选词接受率决定了它们的加速效果,然而实际效果却达不到这个程度,原因在于小模型本身的计算开销是无法被忽视的。Medusa则采用了一种更简单有效的方式,利用模型的隐藏层输出直接进行生成。

模型框架

Medusa的投机采样过程与SpecInfer类似,首先生成多个候选序列,然后用Tree Decoding进行合并验证。Medusa的核心模块是生成候选词的Medusa Head,这个模块直接使用大模型的最后一个隐藏层输出作为输入,经过一个FFN生成候选词。其中第k个head会直接生成候选序列的第k个单词。序列中第一个单词直接由大模型生成,因此可以保证每个过程会输出一个单词。

Medusa从每个head中选择top-k作为候选,将每个head的候选词按顺序组合可以得到候选序列。然后Medusa采用Tree Decoding对不同序列进行合并验证。验证所有组合会带来很大的开销。为此Medusa预先构建了如下所示的模版树,在生成token tree时可以只选择部分组合。这棵树由启发式方法生成,由于概率越大的节点产生的分支被接受的概率越大,这棵树在结构上整体左偏,因此排序更高的token会产生更多序列。

在训练时,Mesuda Head由一个交叉熵损失训练:

或者与原始模型一起训练(Medusa-2):

实验效果

本文在Vicuna-7B/13B模型上进行了实验,Medusa可以带来2倍的推理加速效果,而几乎不损失模型生成质量。而采用了联合训练的Medusa-2能带来更强的加速效果,因为更好的Head能提升候选词的命中率。

总结

本文主要介绍了两项投机采样方向的重要研究。SpecInfer提出的Tree Decoding能提升候选词的命中率,而Medusa采用了高效的方法生成候选词。总的来说,投机采样利用了某些单词能够更容易预测的特性加速推理,如何在提升候选词质量的同时保证生成过程的高效性,是当前投机采样研究的关键问题。

arXiv每日学术速递
工作日更新学术速递!官网www.arxivdaily.com。
 最新文章