一文弄懂Multi-Query Attention

文摘   科技   2024-10-14 07:21   江苏  
点击蓝字
 
关注我们










01


引言



Multi-Query Attention 是一种注意力机制,可以加快解码器生成Token的速度,同时确保模型的性能。它在大型语言模型时代得到了广泛应用,许多 LLM 都采用了 MQA,如 Falcon、PaLM、StarCoder 等。


闲话少说,我们直接开始吧!







02


Multi-Head Attention


在介绍 MQA 之前,让我们先回顾一下Transformer中的默认注意力机制。多头注意力机制是Transformer模型中的默认注意力机制,如下图所示:

然而,基于Transformer的自回归语言模型在文本生成方面存在以下问题。

在训练过程中,我们可以访问真正的目标序列,并有效地实现并行性。不过,在推理过程中,每个位置的查询Query都会关注在该位置或之前生成的所有key-value对。换句话说,自注意力层在特定位置的输出会影响下一个标记Token的生成。由于无法进行并行计算,解码速度会变慢。

以下是基于Transformer解码器的自回归语言模型中自注意力层的解码过程:

def MHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):    q = tf.einsum("bd, hdk−>bhk", x, P_q)    new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis = 2)], axis = 2)    new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis = 2)], axis = 2)    logits = tf.einsum("bhk, bhmk−>bhm", q, new_K)    weights = tf.softmax(logits)    O = tf.einsum("bhm, bhmv−>bhv", weights, new_V)    Y = tf.einsum("bhv, hdv−>bd", O, P_o)    return Y, new_K, new_V

上述代码中:

  • x为当前step的输入向量,维度为b, d)

  • P_q、P_k:Query和Key投影张量,维度为 (h,d,k)

  • P_v:Value投影张量,维度为 (h,d,v)

  • P_o:学到的线性投影张量,维度为 (h,d,v)

  • Prev_K:上一步的Key张量,维度为 (b, h, m, k)

  • Prev_V:上一步的Value张量,维度为 (b, h, m, v)

  • new_K:添加了当前step的Key张量,维度为 (b, h, m+1, k)

  • new_V:添加了当前step的Value张量,维度为 (b, h, m+1, v)





03


 Multi-Query Attention

Multi-Query Attention 是多头注意力机制的一种变体如下图所示,这意味着所有 Query 头共享同一组 K 头和 V 头,因此被称为Multi-Query

MQA 解码过程的代码与 MHA 的代码基本相同,只是在 K、V、P_k 和 P_v 的 tf.einsum 公式中去掉了代表heads维度的字母 h,代码如下:

def MQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):    q = tf.einsum("bd, hdk−>bhk", x, P_q)    new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis = 2)], axis = 2)    new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis = 2)], axis = 2)    logits = tf.einsum("bhk, bmk−>bhm", q, new_K)    weights = tf.softmax(logits)    O = tf.einsum("bhm, bmv−>bhv", weights, new_V)    Y = tf.einsum("bhv, hdv−>bd", O, P_o)    return Y, new_K, new_V






04


  性能对比


MQA 究竟能提高多少速度?让我们来看看原始论文中提供的结果图表:

从上表可以看出,MQA 在编码器上的速度提升并不明显,但在解码器上的速度提升却相当明显。

论文中还有关于性能的实验,结果表明 MQA 的性能仅略低于基线。更多详情,请参阅本文,链接在本文底部。




05


  分析


为什么MQA可以提升推理性能呢?:
  • 更高的内存效率

MQA 中,Key和Value张量的大小分别为 b * k 和 b * v,而在 MHA 中,Key和Value的大小分别为 b * h * k 和 b * h * v,其中 h 代表heads的数目
  • 更低的计算复杂度

通过使用 KV Cache,在 MQA 的每个步骤中计算张量 Key 和 Value 的计算成本是 MHA 的 1/h,其中 h 代表heads的数目。

一般来说,MQA 通过以下方法实现推理加速:

  • KV Cache的大小减少了 h 倍,这意味着需要存储在 GPU 内存中的张量也相应地减少了。这些节省下来的空间可以用来增加batchsize,从而提高推理效率。

  • 减少了从内存读取的数据量,从而缩短了计算单元的等待时间,提高了内存计算利用率。
  • MQA 的 KV cache相对较小,可以放入高速缓存(SRAM)中。而 MHA 的 KV 缓存较大,无法完全存储在缓存中,需要从 GPU 内存(DRAM)中读取,非常耗时。





06


  总结

值得一提的是,MQA 是在 2019 年提出的,当时它的应用还没有这么广泛。这是因为以前的模型不需要关注这些方面,例如 LSTM 只需要保持一种状态,不需要保留任何缓存。
Transformer模型最初被提出时,主要用于 Seq2Seq 任务,特别是编码器-解码器模型。然而,这些模型的规模并不大,实际需求也不多,因此 MQA 并没有引起太多关注。随后,同样基于Transformer编码器结构的代表模型BERT规模也不够大。
直到最近基于transformer解码器的大型语言模型(如 GPT)得到广泛应用,人们才发现推理的瓶颈所在。因此,人们重新审视了几年前的技巧,发现它们非常有用。换句话说,这主要是由于对大规模 GPT 类型生成模型的实际需求。
最后,如果本文有任何错误或遗漏,请随时指正。






点击上方小卡片关注我




添加个人微信,进专属粉丝群!



AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
 最新文章