知乎:姜富春(已授权)
链接:https://zhuanlan.zhihu.com/p/16730036197
编辑:「深度学习自然语言处理」公众号
引言
deepseek最近比较出圈,本人也一直关注deepseek发布的一些技术报告。在模型训练、推理性能和计算成本上一直能给大家惊喜。读了deepseek的技术报告,我个人有两个比较强的感受。第一:deepseek在模型细节上扣的比较极致,魔改了一些模型框架(比如模型优化方面: MLA, GRPO,MTP);第二:工程能力上确实比较强,对于主流的一些框架和技术点能敏捷地整合到自己的系统内(比如:在Infra方面,能看到deepspeed, Megatron,DistServer、vLLM等框架的核心技术点)。后面准备用几篇笔记学习和整理下deepseek的技术。
本文重点讲解下MLA(Multi-Head Latent Attention)
注:我在学习的过程中,通常会有些知识盲点,或掌握不精确的地方,我会递归学习一些扩展的脉络。本文也是沿着一些必要的背景知识,逐层解读下MLH的提出背景、要解决的问题和最终的效果。
MLA主要通过优化KV-cache来减少显存占用,从而提升推理性能。直接抛出这个结论可能不太好理解。首先我们来看下,对于生成模型,一个完整的推理阶段是什么样的,推理性能上有什么问题。
1. LLM模型推理过程
LLM推理分为两个阶段:prefill阶段和 decode阶段
prefill阶段:是模型对全部的Prompt tokens一次性并行计算,最终会生成第一个输出token decode阶段:每次生成一个token,直到生成EOS(end-of-sequence)token,产出最终的response
在推理过程中,由于模型堆叠了多层transformer,所以核心的计算消耗在Transformer内部,包括MHA,FFN等操作,其中MHA要计算Q,K ,V 矩阵,来做多头注意力的计算。
在LLM生成过程中,是一个基于前向序token列预测下一个token的过程,序列中的token(无论是prefill阶段,还是decode阶段)只与它前面的token交互来计算attention,我们也称这种Attention为Causal Attention。矩阵计算上通过一个下三角的Causal Attention Mask来实现token交互只感知前向序列。如图1所示,展现的Transformer内部的细节:
我们以一个序列的位置的token为例,计算一层Tansformer的attention过程,如列下公式所示:
公式中的符号:表示计算序列中第个token;中的两个下标,前一个表示token位置,后一个表示对应的Head下标。
从公式(7)可以看到,在计算Attention时,位置的只与位置前的 做计算,所以我们有如下两个结论:
计算前面的并不受后面token的影响。 后面计算位置的Attention,要使用前序的1t 位置的的值是始终不变的。
所以为了加速训练和推理的效率,在token-by-token生成过程中,避免重复计算前序的。研究者们提出把前序计算好的缓存起来,这也就是目前主流的KV-cache的机制。KV-cache本质是通过空间换时间的方法。我们知道当前LLM size都比较大,GPU的显存空间也是比较宝贵的,通过显存来保存KV-cache势必会带来访存的瓶颈。换句话说,如果不用KV-cache模型直接计算(重复计算前序),是个计算密集型任务;增加了KV-cache,现在不是通过计算得到,而是从「存储介质」里读出来,GPT内核与存储介质之间要频繁读写,这样就变成了一个访存密集型任务。所以使用了KV-cache的机制,解决的重复计算的问题,但访存的速率也就直接影响到训练和推理的速度。
接下来我们再详细看看对于一个典型的推理架构有几级访存速率,模型推理过程中又有哪些数据要做存储下来,应该如何分配存储。
2. LLM推理阶段显存使用情况
2.1 访存速率分级
为了直观理解访存的速率,我们以一个分布式推理架构为例。
比如2台机器,每台机器有8张A100, 那么在这样一个系统内,卡内,单机卡间,机器之间的数据访问效率如图3所示。
注:我们的例子中,只描述了一种访存介质HBM (也就是我们常说的显卡的显存),我们知道通常GPU的存储介质除了显存,还有SRAM和DRAM。SRAM也被成为片上存储,是GPU计算单元上即时访问更快的存储,所有的计算都要先调度到片上存储SRAM才能做计算,一般只有几十M大小,带宽可达到20T/s左右,SRAM是跟计算单元强绑定的,推理阶段一般不考虑将SRAM作为存储单元使用。而DRAM是我们常说的CPU的内存,由于访问速率较慢,推理阶段一般也不考虑使用。所以我们讨论的推理存储介质,一般就指的是HBM(显存)
由上图的访存带宽可知,卡内的带宽是单机卡间的带宽的3倍,是跨机带宽的20倍,所以我们对于存储的数据应该优先放到卡内,其次单机内,最后可能才考虑跨机存储。
接下来我们再看下,推理过程中,有哪些数据要存储到显存上。
2.2. 模型推理阶段显存分配
下面我画了一张图,如图4所示,推理阶段主要有三部分数据会放到显存里。
KV Cache :如上一节所述,前序token序列计算的结果,会随着后面tokent推理过程逐步存到显存里。存储的量随着Batch,Sequence_len长度动态变化 模型参数:包括Transformer、Embedding等模型参数会存到显存里。模型大小固定后,这个存储空间是固定的。 运行时中间数据: 推理过程中产出的一些中间数据会临时存到显存,即用即释放,一般占用空间比较小
由上述可知,推理阶段主要存储消耗是两部分:模型参数和 KV Cache。那么模型参数占多少,KV Cache又占多少?
首先我们先以一个token的计算过程为例,看下一个token计算要存储多少KV?为了方便理解,我们以Qwen-72B模型为例,模型配置详见:Qwen-72B-Chat(https://huggingface.co/Qwen/Qwen-72B-Chat/blob/main/config.json)。
模型共80层,每层有64个Head,每个Head的向量维度是128,,,
注:这里先不考虑qwen 72B GQA的设置(实际KV做了压缩处理),只考虑朴素的MHA的模型结构(假设未做任何处理),GQA后面再详细讨论。
如下图5所示,计算一个token,每个Transformer层的每个Head都要存储一对。
所以针对一个token,缓存的数据总量:其中公式里的表示1个和 1个。一个token就要缓存10240个,这个数是不是有点意料之外! 这么多占了多少存储呢? 我们假设模型推理阶段是半精度(bf16)参数,每个参数占2Byte。最终一个token的存储占用,如公式(2)计算所示:我们现在知道了一个Token计算后需要缓存的数量和存储量。那么对于一个实际的推理场景,还要考虑批量Batch(B)和 序列长度Sequence_len(S) 两个维度,来确认整体KV Cache的存储消耗。这两个维度通常是可以动态变化的。我们看看下面两个场景:
场景1:单条短文本场景
Batch和序列设置: B = 1 , S = 2048。此时 cache总量:
场景2:并发长文本场景
Batch和序列设置: B = 32 , S = 4096。此时 cache总量:
除了消耗存储空间,我们知道模型参数也要占用存储,推理阶段模型参数占用的存储空间是固定的,计算也比较简单。假设模型参数量为: ,以bf16 半精度做推理,则参数量为(Byte)。还是以qwen-72B为例,参数占用存储空间:
我们再结合上面两个场景,看看显存的整体分配:
场景1: 模型存储,kv存储,模型的参数储存占主导,使用80G的A100, 至少需要2张卡做推理。 场景2:模型存储,kv存储,KV Cache储存占主导,使用80G的A100, 至少需要7张卡做推理。
这里还要多啰嗦几句,推理阶段根据离线、在线的业务场景,到底组多大的Batch,其实是一个Balance的过程,Batch选择比较小,虽然并发度不高,但可能单卡就能装下完整模型参数和KV Cache,这时候卡内带宽会比较高,性能可能依然出众,可以考虑适当增加Batch把单卡显存用满,进一步提升性能。但当Batch再增大,超出单卡范围、甚至超出单机范围,此时并发会比较大,但跨卡或跨机访存性能会降低,导致访存成为瓶颈,GPU计算资源使用效率不高,可能实际导致整体推理性能不高。所以单从推理Batch设置角度来看,要实测找到性能最佳的平衡点。
当前LLM都比较大,而访存的容量和访存速率有分级的特点。所以推理过程中,减少跨卡、卡机的访存读写是优化推理性能的一个有效路径。一方面单次读写的数据越少,整体速度会越快;另一方面整体显存占用越少,就能尽量把数据放到单卡或单机上,能使用更高的带宽读写数据。
本文要学习的MLA就是通过减少KV Cache来压缩显存占用,从而优化推理速度。我们在展开了解MLA之前,先看看当前有哪些优化KV Cache的方法。
3. 减小KV cache的方法
3.1. KV Cache 优化方法汇总
业界针对KV Cache的优化,衍生出很多方法,这里我根据自己的积累,稍微总结下,只简单描述优化的思路,不过多展开。
方法主要有四类:
共享KV:多个Head共享使用1组KV,将原来每个Head一个KV,变成1组Head一个KV,来压缩KV的存储。代表方法:GQA,MQA等 窗口KV:针对长序列控制一个计算KV的窗口,KV cache只保存窗口内的结果(窗口长度远小于序列长度),超出窗口的KV会被丢弃,通过这种方法能减少KV的存储,当然也会损失一定的长文推理效果。代表方法:Longformer等 量化压缩:基于量化的方法,通过更低的Bit位来保存KV,将单KV结果进一步压缩,代表方法:INT8等 计算优化:通过优化计算过程,减少访存换入换出的次数,让更多计算在片上存储SRAM进行,以提升推理性能,代表方法:flashAttention等
本文要讨论的MLA是共享KV分支下的一种优化方法,下面我们先展开看看共享KV方法有哪些,这些方法也是MLA拿来对比的方法。
3.2. 共享KV优化显存方法
共享KV主要有两种方法,MQA和GQA都是Google提出的,详见: MQA(2019),GQA(2023),如图6所示。
3.2.1. MQA(Multi-Query Attention)
MQA方法比较简单,详见上图6最右侧的图,每一层的所有Head,共享同一个来计算Attention。相对于MHA的单个Token需要保存的KV数减少到了个,即每一层共享使用一个向量和一个向量
3.2.2. GQA(Group-Query Attention)
GQA是平衡了MQA和MHA的一种折中的方法,不是每个Head一个KV,也不是所有Head共享一个KV,而是对所有Head分组,比如分组数为,那么每组:个Head 共享一个KV。当时,GQA就等价于MQA,当时,GQA就等价于MHA。
为了方便自己更清晰的理解GQA和MQA ,我还是以一个Token计算KV过程(如图5),画了一些相对细节展开的图,把所有层都画出来,并且加了一些注释。如图7所示:
我们再总结下单token计算下,几种方法KV Cache的存储量(模型层数:,每层Head数量:)
MHA共缓存 个 MQA共缓存 个 GQA共缓存 个 ,是分组数,,一般取值能被整除
本文要讲的MLA也是一种优化共享KV优化的变体,下面我们看看MLA的原理和细节
4. MLA
4.1. MLA KV优化速览
我们先走马观花看看MLA的计算方式和与MQA、GQA的压缩KV的效果对比。
首先我们看看MLA计算Attention的完整公式,如下图8所示
在论文中提到,每个Transformer层,只缓存了上述公式蓝框的向量: 和 ,这两个向量的大小分别为:
: 维度为
:维度为
对比MQA(每层有一个维度的和一个维度的,共个元素),MLA相当于增加了2.25倍的存储,但DeepSeek描述自己的方法不仅比MQA强,而且比非共享KV的原始MHA也要强,后面4.4节我们在展开讨论。
MLA号称又快又省又强大,下一节我们逐步看看具体的实现。
4.2. MLA原理解读
下面我们参照图8的公式看看MHA的计算过程,首先对图中公式的变量做如下解释说明:
:MLA低秩压缩的维度,论文中取值: :是单个head的向量维度 :是每层head的数量 :隐层维度, 是低秩变换矩阵
1. 先看下KV的计算过程
首先公式(41)对输入做一个低秩压缩,将维的输入经过变换后压缩成维的。DeepSeek-V3中, 然后通过公式(42)和公式(45)两个变换矩阵(,),将KV的维度扩展回,也就是每个Head有一个单独的(跟MHA的KV数量一致)
注:经过上述的变换,非常类似LoRA做低参数微调的逻辑。通过两个低秩矩阵先做压缩、再做扩展,最终能降低参数的数量。但MLA本质是要做到减少KV-cache的存储。LoRA强调的是参数量的减少,类似MLA这操作确实也减少了参数量,按DeepSeek-V3的参数配置,两个低秩矩阵参数量:,而正常MHA的参数矩阵参数量:。但MLA强调的是KV-cache的减少,也就是KV的激活值减少。当前我们还看不出来怎么减少激活值的数量的,因为单从KV的数量和维度上看跟MHA是一个量级,比GQA和MQA都要多,同时计算又多了一步。当前是比较迷糊的...我们再往下继续看...
2. 再看下Q的计算过程
公式(37),(38)类似KV的逻辑,通过两个矩阵(,)也做了一层低秩变换,这一步Q的变换看着趋是为了减少模型的参数的数量。在Deepseek-V3里。是KV压缩维度的3倍。但相对于还是压缩了不少。
3. 增加Rope位置编码
我们注意到在增加RoPE位置编码并没有在上述计算出的,的基础上乘以Rope的对角矩阵。而是单独计算了两个带着位置编码的,如公式(39)和公式(43)所示
注意这里计算带RoPE的,有两个细节:
(1) ,的向量维度是个比较小的维度,DeepSeek设置为单Attention Head维度的一半:
(2) 这部分计算的实际是个MQA的计算方式,同一层中,所有的Head共享同一个
然后按如下公式(40),(44)跟已经计算的,拼接,构成完整的,向量。
注:这里的下标表示Attention Head的索引
所以到目前为止,我们得到的包括两部分拼接而成:一部分是做了低秩压缩得到的向量,一部分是增加了RoPE位置编码的向量。(后面这部分向量是基于MQA方式计算得到的,所有Head共享1个)。
如何理解上述的操作过程?这也是MLA方法的核心。
我们先来看看DeepSeek-V2论文中有一段原文解释(中文翻译):
位置编码使用RoPE,但RoPE与低秩KV不兼容。具体来说,RoPE对Q和K都是位置敏感的。如果我们为应用RoPE,那么公式(42)的(K的权重矩阵)将与位置敏感的RoPE矩阵耦合。因此,在推理过程中,无法再被吸收到(Q的权重矩阵)中,因为与当前生成的token相关的RoPE矩阵将位于和之间,而矩阵乘法不满足交换律。因此,我们必须在推理过程中重新计算所有前缀token的k,这将极大地降低推理效率。
论文中提到了「矩阵吸收计算」,这个概念对理解MLA比较重要,我们用一个简单的例子理解下:
假设有两个向量变量都是3维的向量。有两个固定的变换矩阵分别对做线性变换得到新的向量。最终求两个向量的乘积。
方法1: 常规计算方法2:矩阵吸收计算
我们知道矩阵乘法是满足结合律的,对于公式(c)我们可以先计算好两个变换矩阵的乘积:然后通过与相乘,计算出,而则不做任何操作再计算和 乘积
通过上面的例子我们可以看到,两种方法计算出的结果是一样的,但第二种方法是先做了矩阵乘法,相当于把的变换矩阵吸收到了的变换矩阵里。
理解了上面的例子,我们再来看看原文说的「RoPE与低秩KV不兼容,没法做矩阵吸收计算」的问题。
a) 不加RoPE
我们先假设当前不增加RoPE,那么乘积计算如下,其中表示变换矩阵第个Head的切片:不加RoPE,我们可以提前计算好,也就上面说的吸收到中,这样在做的变换的时候,也就同时计算了矩阵的乘法。
这样的好处是,我们只需要缓存,而不是缓存的结果。维度只有的长度,而是个 的变换,也就是完全恢复了隐层的维度 (DeepSeek-v3 配置为64)。这也是MLA的压缩KV Cache的核心原理。
b) 现在假设增加RoPE
我们再看看,加上Rope后,计算乘积,会在和之间,增加一个融合了相对位置的变量,如公式(2)所示:
中间这个分量是随这相对位置变化而变化的,并不是个固定的矩阵,因此并不能提前计算好。所以论文中说RoPE与低秩变换不兼容。
c)通过增加一个很小分量,引入RoPE
为了引入位置编码,作者在一个很小维度下,用MQA方式计算了,也就是在每层网络中,所有Head只计算一个(如论文中公式43所示)。引入位置编码的向量维度取的比较小为:。
所以最终向量通过两部分拼接而成,计算权重时,由前后两部分分别相乘再相加得到,如下公式(8)所示:前一项 按公式(6)计算,通过矩阵吸收处理,全Head只缓存一个,后一项 按正常MQA的方式计算,全Head只缓存了一个共享。
通过类似的计算方式,可以处理将的变换矩阵吸收到最终的结果变换矩阵中,这样也不用实际计算和缓存的值。而是只缓存跟一样的 即可,详细推导与上述类似,不过多赘述。
上面我们就完整介绍了MLA做KV Cache压缩的原理。我们再来回顾下,MLA实际缓存的向量(如图8蓝色框),维度如下:
:维度为 :维度为
是低秩压缩的向量,是引入位置编码的MQA范式计算的共享
注:原理篇其实苏神已经解释的非常清晰了(详见:缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA - 科学空间|Scientific Spaces),本文原理的部分也基本按苏神的逻辑概述下关键思路,感谢苏神的分享
4.3. MLA与MQA、GQA对比
最后我们再简单看看几种方法的对比,直接截取DeepSeeku-V2论文的图,如下:
从上图我们可以看到,虽然MLA缓存的Latent KV比较短(相当于2.25个MQA的缓存量),但MLA有恢复全的能力,特征表达能力显著比GQA、MQA要强。所以MLA能做到又快又省又强。论文中也给出了下图的数据
注:图中能力的比较上,描述比MHA更强我比较存疑,并没看到有消融的实验对比,也不太好从原理上解释。
5. 总结
本文试图通过引入更多基础知识和辅助信息,来深入理解MLA。内容比较长,可能觉得比较啰嗦。这是本人在理解MLA过程递归总结的一些扩展信息,最终整理了一个系统的脉络,发出来供大家参考。
6.参考文献
[1]deepseek-v1:https://arxiv.org/pdf/2401.02954
[2]deepseek-v2:https://arxiv.org/pdf/2405.04434
[3]deepseek-v3:https://arxiv.org/pdf/2412.19437
[4]缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA - 科学空间|Scientific Spaces(https://spaces.ac.cn/archives/10091)
[5]https://zhuanlan.zhihu.com/p/659770503
[6]GQA:https://arxiv.org/pdf/2305.13245
[7]MQA:https://arxiv.org/pdf/1911.02150
个人水平有限,欢迎指正~