点击下方卡片,关注“AI前沿速递”公众号
点击下方卡片,关注“AI前沿速递”公众号
各种重磅干货,第一时间送达
各种重磅干货,第一时间送达
Transformer中的注意力机制(Attention Mechanism)是其核心组成部分之一,主要用于捕捉输入序列中不同位置之间的依赖关系。通过计算每个输入的隐藏状态(hidden state),同时利用注意力机制来解决上下文关联问题。
Self-Attention
在自注意力机制中,输入通常是一个统一的輸入矩阵,而这个矩阵后续会通过乘以不同的权重矩阵来转换成三个不同的向量集合。这个过程包含三个主要部分:查询(Query,Q)、键(Key,K)和值(Value,V)。在Transformer中,这些通常是输入序列的线性变换。给定一组查询 Q、键 K 和值 V 矩阵,其中每个矩阵的维度分别是 (键和查询的特征维度),(值的特征维度),以及序列长度 Attention计算公式表示为:Attentionsoftmax
首先计算查询与所有键之间的相似度得分,这通常通过点积来完成。为了防止大数值导致梯度消失或爆炸的问题,点积结果被除以进行缩放。接下来,对每个查询的位置应用softmax函数,以获得表示注意力权重的概率分布,最后,使用上面得到的注意力权重对值向量进行加权求和,生成输出。、、为去要学习的权重矩阵,获取encoder输入的embeding,并计算每个embedding 的query,key,value。
MHA(Multi Head Attention)
单头注意力在计算不同位置间关联时用到了加权平均,这在一定程度上影响了特征计算的准确性, 因此要用多头注意力来抵消这种影响。多头注意力允许模型在不同的位置关注信息的不同表征子空间。它通过并行执行多个上述的注意力函数,并将它们的结果拼接起来再通过一个线性变换来实
现。具体来说:
对于每一个head,都有独立的参数矩阵:
查询矩阵
键矩阵
值矩阵用于将输入转换成查询、键和值。每个头的输出是按照上述公式计算出来的。所有头的输出会被拼接在一起,并通过另一个线性变换 来整合这些输出。这个最终的线性变换矩阵 的尺寸是,其中h 是头的数量。多头注意力的公式可以表示为:
MultiHeadConcat .
单头注意力包含、、,产生一个output,多头注意力则包含n个、、 ,
这些参数的权重不共享,产生n个output。在多头注意力机制(Multi-Head Attention)中,各个头产生的输出确实被拼接在一起,并且对拼接后的结果再次进行线性变换(projection),以得到最终的输出。这个过程确保了模型可以综合不同表征子空间的信息,同时保持输出维度的一致性。
MQA(Multi Query Attention)
MQA (Multi-Query Attention) 是一种优化的注意力机制,旨在减少Transformer模型中多头注意力(Multi-Head Attention, MHA) 计算的参数量和计算成本。与标准的MHA不同,MQA通过让所有的注意力头共享同一个键 (Key,K)和值 (Value,V) 线性映射矩阵来实现这一点,而查询 (Query, )矩阵仍然为每个头独立设置。在传统的多头注意力机制中,每个注意力头都有独立的、、 矩阵用于将输入转换成查 询、键和值。这导致了参数量的增长,尤其是在头数较多的情况下。MQA简化了这一过程:
共享的键和值矩阵:对于每一层,所有注意力头共享同一个键 和值 线性映射矩阵,即 和。这意味着每层只有一个 和 矩阵。 独立的查询矩阵:查询矩阵仍然为每个头独立设置,即是头的数量。
这种设计减少了模型的参数量,因为键和值的线性变换不再依赖于头的数量,从而降低了计算复杂度,特别是在大规模模型或长序列上时性能提升更为显著。MQA计算公式表示为:MultiQueryAttentionConcat其中,headAttention。这里的是第个头的查询矩阵,而 和 是该层下所有头共享的键和值矩阵。
查询矩阵 的维度为 ,每个头都有自己独立的查询矩阵。 共享的键矩阵和值矩阵的维度分别为和 ,这两个矩阵在整 个层内被所有头共享。 MQA只让 Q 保留了原始多头的性质 (每个Head存在不同的转换),可以大大减少 和矩阵的参数量以及KV Cache的显存占用,提升推理速度。但是可能会带来精度上的损失。广泛应用于大规模语言模型(LLMs),如ChatGLM2。
GQA(Group Query Attention)
GQA(Grouped Query Attention)是一种在MQA基础上进一步优化的注意力机制,旨在平衡参数量减少和模型性能之间的关系。通过将查询(Query, Q)进行分组,并让每个组内的所有头共享同一组键(Key, K)和值(Value, V),GQA试图保留多头注意力机制的优势,同时减少计算成本和显存占用。GQA(Group Query Attention)实现步骤
:
查询分组:将所有的查询分成若干个组,每个组内的查询共享一组键和值 。 共享键和值矩阵:每个组内共享同一个键矩阵 和值矩阵,其中表示组编号。 独立的查询矩阵:每个组内的查询仍然有自己的查询矩阵,这里 表示组内的头编号,表示组编号。 输出拼接和投影:所有组的输出被拼接在一起,并通过线性变换矩阵 进行投影。
GQA的计算公式表示为:GroupedQueryAttentionConcat其中,=Concat, headAttention这里的是第组中第个头的查询矩阵,而和是该组内所有头共享的键和值矩阵。
MLA( Multi-head Latent Attention)
MLA的核心是对key和value进行低秩联合压缩,以减少KV Cache的计算量,通过一系列矩阵投影实 现对k和v的压缩和解压。核心步骤如下:输入序列长度和模型维度查询权重矩阵,将输入映射到查询空间。
压缩阶段
向下投影
k和v的原始表示通过一个向下投影矩阵被映射到一个更低维度的空间,形成压缩潜在 向量。
其中 是压缩后的KV维度,且 $d_c<<d_hn_h$ .<="" p="">
缓存压缩向量
在推理过程中,只缓存,因此KV缓存只需要存储个元素,其中表示层数。
解压缩阶段
向上投影
当需要使用这些压缩的和时,分别通过两个不同向上投影矩阵和 将其恢复为原始维度。
推理计算
为了减小推理过程中的计算量, 为方阵是转置后还是其本身。将 吸收到查询矩阵, 吸收到的输出变换矩阵 中,在计算注意力的时候使用 和压缩的向量 作用,最后再使用输出变换矩阵 得到最后的输出。新的查询矩阵 表示为:新的输出变换矩阵 表示为:Attention(Q',C_{kv}^t)=softmax(\frac{Q'(C_{kv}^t)^T}{√d_k})C_{kv}^t
其中,MQA这种共享一个K,V,只有Q是独立的,大幅减小了参数数量,带来推理加速的同时会造成模型性能损失,且在训练过程使得模型变得不稳定,从而提出了GQA,它将Query进行分组,每个组内共享一组Key、Value。当GQA的分组数N等于1时,GQA等价于MQA,所有的查询Q共享一组K、V,当GQA的分组数N等于查询Q的数量时,GQA等价于MHA,每一个Q对应自己的K、V矩阵。
小结
MQA这种共享一个K,V,只有Q是独立的,大幅减小了参数数量,带来推理加速的同时会造成模型性能损失,且在训练过程使得模型变得不稳定,从而提出了GQA,它将Query进行分组,每个组内共享一组Key、Value。当GQA的分组数N等于1时,GQA等价于MQA,所有的查询Q共享一组K、V,当GQA的分组数N等于查询Q的数量时,GQA等价于MHA,每一个Q对应自己的K、V矩阵。当GQA的分组数N等于1时,GQA等价于MQA,所有的查询Q共享一组K、V,当GQA的分组数N等于查询Q的数量时,GQA等价于MHA,每一个Q对应自己的K、V矩阵。
特性 | MHA (Multi-Head Attention) | MQA (Multi-Query Attention) | GQA (Grouped Query Attention) | MLA( Multi-head Latent Attention) |
---|---|---|---|---|
参数共享程度 | Q、K、V矩阵都独立 | 共享K、V,Q矩阵独立 | 每组共享K、V,组内Q矩阵独立 | 压缩K、V,通过投影矩阵解压 |
推理效率 | 推理慢,计算量大 | 推理快,减少了计算量 | 推理接近MQA,同时保持模型性能 | 高效,减少了KV缓存占用 |
模型性能 | 最优 | 轻微下降 | 接近MHA | 优于MHA |
训练稳定性 | 强 | 低 | 中 | 中 |
特殊情况 | -- | GQA分组数为1等价于MQA | GQA分组数等于Q矩阵数量等价为MHA | GQA分组数为2.25时两者kv cache相当 |
应用模型 | bert、t5 | palm | llama2、chatglm2、baichuan、qwen | deepseekv2、deepseekv3 |
与MHA 相比,经过升级训练的 MQA 具有更优的权衡,其质量和速度都比 MHA-Large 更高,而 GQA 的性能甚至更好,速度增益与 MHA-XXL 相似,质量也相当。在所有任务上,T5-Large 和 T5-XXL 在所有任务上的平均性能与每个样本的平均推理时间有关,而 T5-XXL 在所有任务上的平均性能与 MQA 和 GQA-8 注意力有关。
GQA-XXL 的每个样本的时间是 GQA 组数的函数,输入长度为 2048,输出长度为 512。从 1 (MQA) 到 8
组会增加适度的推理开销,并且增加更多组的成本也会增加。在分组数量大于8以后推理开销显著变大,因此在llama中GQA的分组数设定为8。
MHA KV 缓存大小: KVCachePerToken = n_h × d_h MQA KV 缓存大小: KVCachePerToken = d_h GQA KV 缓存大小: KVCachePerToken = \frac{n_h}{n_g}×d_h MLA KV 缓存大小: KVCachePerToken = d_c dc 表示 MLA 中解耦查询和键的 KV 压缩维度。DeepSeek-V2 的 KV 缓存等效于 GQA 只有 2.25 组时,LA 的 KV 缓存大小与 GQA 相等。MLA 机制在减少 KV 缓存占用的同时,实现了比传统 MHA 更强的性能。
确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。
建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。
【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。
您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”
长按添加AI前沿速递小助理