01
引言
Multi-Query Attention 是一种注意力机制,可以加快解码器生成Token的速度,同时确保模型的性能。它在大型语言模型时代得到了广泛应用,许多 LLM 都采用了 MQA,如 Falcon、PaLM、StarCoder 等。
闲话少说,我们直接开始吧!
02
Multi-Head Attention
在介绍 MQA 之前,让我们先回顾一下Transformer中的默认注意力机制。多头注意力机制是Transformer模型中的默认注意力机制,如下图所示:
然而,基于Transformer的自回归语言模型在文本生成方面存在以下问题。
在训练过程中,我们可以访问真正的目标序列,并有效地实现并行性。不过,在推理过程中,每个位置的查询Query都会关注在该位置或之前生成的所有key-value对。换句话说,自注意力层在特定位置的输出会影响下一个标记Token的生成。由于无法进行并行计算,解码速度会变慢。
以下是基于Transformer解码器的自回归语言模型中自注意力层的解码过程:
def MHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
q = tf.einsum("bd, hdk−>bhk", x, P_q)
new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis = 2)], axis = 2)
new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis = 2)], axis = 2)
logits = tf.einsum("bhk, bhmk−>bhm", q, new_K)
weights = tf.softmax(logits)
O = tf.einsum("bhm, bhmv−>bhv", weights, new_V)
Y = tf.einsum("bhv, hdv−>bd", O, P_o)
return Y, new_K, new_V
上述代码中:
x为当前step的输入向量,维度为(b, d)
P_q、P_k:Query和Key投影张量,维度为 (h,d,k)
P_v:Value投影张量,维度为 (h,d,v)
P_o:学到的线性投影张量,维度为 (h,d,v)
Prev_K:上一步的Key张量,维度为 (b, h, m, k)
Prev_V:上一步的Value张量,维度为 (b, h, m, v)
new_K:添加了当前step的Key张量,维度为 (b, h, m+1, k)
new_V:添加了当前step的Value张量,维度为 (b, h, m+1, v)
03
Multi-Query Attention
Multi-Query Attention 是多头注意力机制的一种变体。如下图所示,这意味着所有 Query 头共享同一组 K 头和 V 头,因此被称为Multi-Query:
MQA 解码过程的代码与 MHA 的代码基本相同,只是在 K、V、P_k 和 P_v 的 tf.einsum 公式中去掉了代表heads维度的字母 h,代码如下:
def MQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
q = tf.einsum("bd, hdk−>bhk", x, P_q)
new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis = 2)], axis = 2)
new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis = 2)], axis = 2)
logits = tf.einsum("bhk, bmk−>bhm", q, new_K)
weights = tf.softmax(logits)
O = tf.einsum("bhm, bmv−>bhv", weights, new_V)
Y = tf.einsum("bhv, hdv−>bd", O, P_o)
return Y, new_K, new_V
04
性能对比
论文中还有关于性能的实验,结果表明 MQA 的性能仅略低于基线。更多详情,请参阅本文,链接在本文底部。
05
分析
更高的内存效率
更低的计算复杂度
一般来说,MQA 通过以下方法实现推理加速:
KV Cache的大小减少了 h 倍,这意味着需要存储在 GPU 内存中的张量也相应地减少了。这些节省下来的空间可以用来增加batchsize,从而提高推理效率。
减少了从内存读取的数据量,从而缩短了计算单元的等待时间,提高了内存计算利用率。 MQA 的 KV cache相对较小,可以放入高速缓存(SRAM)中。而 MHA 的 KV 缓存较大,无法完全存储在缓存中,需要从 GPU 内存(DRAM)中读取,非常耗时。
06
总结
点击上方小卡片关注我
添加个人微信,进专属粉丝群!