本文介绍大模型推理阶段的序列并行,相对于训练,大模型推理的序列并行要复杂很多,具体而言不同的大模型推理阶段,需要的序列并行策略不一样。
- 1.Prefill阶段:Prefill阶段和训练的前向比较类似,可以使用交换kv的RingAttention;
- 2.Decode阶段:Decode阶段q的token数量为1,k和v的数量很大,此时就不适合RingAttention那种交换k和v的方式,该阶段适合交换q的分布式Attention;
- 3.Prefix-Cache阶段:Prefix-Cache场景,q数量是N,k和v的数量是M,此时采用交换q还是交换kv不确定,具体需要根据通信和计算负载,但一般情况交换q比较适合。
本文的行文思路如下(如不阅读本文,也可以按照本文的流程,自行推导出大模型推理阶段的序列并行方案):
分块attention
分块Attention主要的场景是Attention无法一次计算完,如原来的q k v,分成了q0、q1,k0、k1,v0、v1,整体分为2次计算。
计算q0、k0、v0,得到部分结果attn0;
计算q1、k1、v1,得到部分结果attn1;
那么,有2个问题需要我们去思考:
attn0和attn1是的结果该如何合并(分块attention结果合并);
attn0和attn1的计算是否可以多卡并行进行(分布式attention交换q k v)。
Flashattention
欢迎加入自动驾驶实战群
FlashAttention算是比较早的q、k、v分块的,如上图所示Q1和[K1, K2, K3][V1、V2、V3]计算得到[O1_1、O1_2、O1_3],关于FlashAttention里的公式详细推导详见:方佳瑞:FlashAttention算法之美:极简推导版(https://zhuanlan.zhihu.com/p/4264163756),相对一般的FlashAttention解析,这偏文章分析了lse在公式里的使用。
Hydragen
这偏文章专注于分析Append场景下的分块attention的计算。
上图的例子也是分块attention,分为2块,左边的是计算推到部分,右边的是最终结论部分。
上图是根据公示的Pytorch的分块attention实现,公式对应的代码,已经用对应数据和颜色标好。
FlashInfer实现
与上述Hydragen的分块attention类似的是,FlashInfer也提供了分块attention的实现,我们从接口看,输入输出基本是一样的。
总结:分块attention的计算已经是非常成熟和流行的方式了,该attention合并方式可以叫merge_attention。
Prefill序列并行
当引入序列并行之后,每张卡只有部分Q、K、V。如上图左边所示,GPU1有Q1、K1、V1。右边上版本部分是需要计算的QKV。这里引入了一个问题Qi和Kj、Vj(其中i != j)是需要计算的,但是他们在不同的卡上,这里就引入了分布式attention的另一个问题,就是需要交换数据(q k v)。有2种交换选择,分别是交换q和交换kv,注意,只需要选择一种方式就可以。RingAttention采用的是交换kv。
RingAttention还涉及到一个负载均衡的问题,这里可以采用如下的数据拆分规则。我们可以看到如上图所示,序列并行的维度为4,此时将数据拆分为8份,rank0拿到的数据是chunk0和chunk7,rank1拿到的数据是chunk6和chunk2。
结合负载均衡和数据发送流程图如下:
图中虚线表示已经计算过的矩阵,粗实现表示当前正在计算的矩阵,由图可以看出,每个step,所有的机器,计算量都是一样的,达到了完全的负载均衡。
kv cache的缓存,与训练阶段RingAttention的计算不一样,推理阶段我们还需要考虑kv cache的保存,这个和decode的策略有关,一种方式就是和计算保持一样的策略,如rank0保存k0和v0、rank1保存k1和v1 。
Decode序列并行
书接上回,当我们Prefill阶段之后用了RingAttention,之后2张卡各自保存了部分的kv cache,如上图蓝色的是rank0保存的kv cache、绿色的是rank1保存的kv cache。q只有一个,但是q需要分别和rank0以及rank1的kv cache计算。这里其中一个方法是继续重新交换kv,完全按照RingAttention方式继续做一遍,但是通信量太大,且计算量较小,很难通信和计算重叠。
于是我们想一下,是否可以通过交换q的方式进行计算呢?答案是可以的。如果按照q交换,每张卡是全量的q和部分的kv,那么得出的out也是部分的,这其实还是一个分块attention问题,只不过这次分块attention的输入也在不同的卡上了。
如上图公式所示:右下角是分块attention得公式结论。单卡内原来的分块attention计算出的out或者lse可以用 + ,但是kv在不同卡上,也就意味着lse和out也在不同的卡上,那么 + 天然的可以转换成Allreduce。
上图是代码部分, + 改成对应的allreduce即可。q的通信也有讲究,Tree-Attention采用直接通信AllGather Q的方式,而meta的论文采用Ring Q的方式,Ring Q最终的Decode阶段attention如下图所示。
Prefiex-Cache序列并行
当遇到Prefix-Cache场景时,计算负载如上图所示,kv cache还是分布式存储的。但是q的计算已经变得和prefill以及decode不一样了。如上图所示。此时可以选择类似RingAttention方式或者类似上述Tree-decode方式。具体得看计算和通信的负载。不同的算法如下图所示。
进一步思考
分离式架构结合思考的
笔者一开始对分离式架构核心观点之一就是并行策略不一样,我们可以看到序列并行就有明显的差异,可以prefill阶段用RingAttention,而Decode阶段,其实选择很多,kimi的报告里decode实际上采用的是一种数据并行的策略。
和网络拓扑的结合思考
如上图所示,其实Tree-decode这篇文章的核心思想是,当机器数量很多,Ring的效率很差,所以它相通过交换query来提高机器数量很多时候的扩展比,那为什么将交换q换成交换kv能在机器数量多的时候有用呢?
Ring KV采用的是Ring的方式,本质上是Ring Allgather的方法,这种方式在单机多卡效率很高,但是跨机无法利用所有的带宽;
AllReduce可以在多机下利用带宽,机器内用Ring,机器之间用Double Tree,能充分的利用带宽;
相对于AllGather算子,AllReduce算子在使用tree的时候,由于进行了reduce,所以数据量减少了,此时tree的集合通信不仅能充分利用带宽,而且能利用reduce降低tree之间的通信。
参考
Hydragen: High-Throughput LLM Inference with Shared Prefixes
Tree Attention: Topology-Aware Decoding for Long-Context Attention on GPU Clusters
Context Parallelism for Scalable Million-Token Inference https://flashinfer.ai/2024/02/02/introduce-flashinfer.html
https://
最后别忘了,帮忙点“在看”。
您的点赞,在看,是我创作的动力。
AiFighing是全网第一且唯一以代码、项目的形式讲解自动驾驶感知方向的关键技术。
长按扫描下面二维码,加入知识星球。