一文弄懂Grouped-Query Attention

文摘   科技   2024-10-15 07:50   江苏  
点击蓝字
 
关注我们










01


引言



自回归模型解码过程的标准做法是缓存序列中前一个Token的Key和Value,以加快注意力计算速度。然而,随着上下文窗口或BatchSize的增加,多头注意力模型中与kv 缓存大小相关的内存成本也会显著增加。


Multi-Query attentionMQA)是一种只使用一个Key-Value head来进行多次Query的机制,它可以节省内存并大大加快解码器的推理速度。然而,MQA 可能会导致模型推理性能下降。事实上,我们不仅希望快速推理,还希望推理性能与 MHA 相当,因此分组查询注意力 Grouped-query attentionGQA)开始发挥作用。


GQAMQAMHA的一种互补。它既能达到与多头注意力MHA相似的性能,又能保持与多查询注意力MQA相当的速度。


闲话少说,我们直接开始吧!






02


Grouped-Query Attention


GQA 可以看作是 MQA MHA 的通用形式:

  • 如果 GQA 中只有一组,则称为 MQA

  • GQA 中的分组数目等于注意力头数时,称为 MHA

下图清晰地展示了这种关系。


从上图中可以看出,GQA 的策略是通过提供改进版的 MQA 来提高推理质量。这是通过使用多个Key和Value头数目但少于查询Query头数目来实现的。





03


 GQA的实现

下面,让我们来看看 Llama 2 中GQA的实现。代码如下:

class Attention(nn.Module):    """Multi-head attention module."""    def __init__(self, args: ModelArgs):        """        Initialize the Attention module.
Args: args (ModelArgs): Model configuration parameters.
Attributes: n_kv_heads (int): Number of key and value heads. n_local_heads (int): Number of local query heads. n_local_kv_heads (int): Number of local key and value heads. n_rep (int): Number of repetitions for local heads. head_dim (int): Dimension size of each attention head. wq (ColumnParallelLinear): Linear transformation for queries. wk (ColumnParallelLinear): Linear transformation for keys. wv (ColumnParallelLinear): Linear transformation for values. wo (RowParallelLinear): Linear transformation for output. cache_k (torch.Tensor): Cached keys for attention. cache_v (torch.Tensor): Cached values for attention.
""" super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads model_parallel_size = fs_init.get_model_parallel_world_size() self.n_local_heads = args.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads
# ColumnParallelLinear and RowParallelLinear are two common strategies for implementing model parallelism. self.wq = ColumnParallelLinear( args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) # The dimension of wk and wv has changed. self.wk = ColumnParallelLinear( args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wv = ColumnParallelLinear( args.dim, self.n_kv_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wo = RowParallelLinear( args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x, )
# kv cache, used for caching keys and values self.cache_k = torch.zeros( ( args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim, ) ).cuda() self.cache_v = torch.zeros( ( args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim, ) ).cuda()
def forward( self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], ): """ Forward pass of the attention module.
Args: x (torch.Tensor): Input tensor. start_pos (int): Starting position for caching. freqs_cis (torch.Tensor): Precomputed frequency tensor. mask (torch.Tensor, optional): Attention mask tensor.
Returns: torch.Tensor: Output tensor after attention.
""" bsz, seqlen, _ = x.shape # The dimension of k and v has changed. xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # Incorporate rotary position embedding xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq) self.cache_v = self.cache_v.to(xq)
# Cache current token's kv self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# Retrieve the previously cached keys and values keys = self.cache_k[:bsz, : start_pos + seqlen] values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads # make the number of heads in kv and q the same keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
# Self-attention xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) return self.wo(output)

我已经对这段代码的关键过程进行了注释。有以下几点需要注意

  • self.n_local_heads 表示原始多头注意力中的head数目,也指查询的head数。

  • self.n_local_kv_heads 表示GQAkeyvaluehead数目。这意味着KV-Cache的大小可以减少 self.n_rep= self.n_local_heads // self.n_local_kv_heads

  • 由于 GQA 减少了 KV-Cache的大小,但在实际计算(矩阵乘法(GEMM)子程序)中,它需要与查询头的数量相匹配。因此,需要将它们扩展回原来的大小。repeat_kv 函数用于复制键/值,并使其与查询头数保持一致。

代码如下:

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""    bs, slen, n_kv_heads, head_dim = x.shape    if n_rep == 1:        return x       # MHA    return (        x[:, :, :, None, :]        .expand(bs, slen, n_kv_heads, n_rep, head_dim)        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)    )                  # GQA or MQA

让我们看一个例子,了解 repeat_kv 的作用:

>>> x = torch.rand(1, 1, 4, 6)>>> xtensor([[[[0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],          [0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],          [0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],          [0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833]]]])>>> n_rep = 2>>> bs, slen, n_kv_heads, head_dim = x.shape>>> print(x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim))tensor([[[[0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],          [0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],          [0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],          [0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],          [0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],          [0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],          [0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833],          [0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833]]]])>>>


为了进一步理解 GQA 的原理,我画了两张图。如下图显示,Key和Value的头数为 self.n_local_kv_heads = 4 Query的头数为 self.n_local_heads = 8

如下图所示,在repeat_kv 之后,Key和Value的head数目与Query次数相匹配(每种颜色代表一组,每组中的头数已重复展开),并可执行矩阵乘法(GEMM)子程序。







04


  总结

无论是 GQA 还是 MQA,它们都无法显著降低计算负荷。它们的主要目的是减少存储大量 kv-cache的需要。这样,kv-cache占用的内存就会变小,从而使我们的 LLM模型可以处理更多请求,从而允许更大的batchsize和更高的吞吐量。
最后,如果本文有任何错误或遗漏,请随时指正。






点击上方小卡片关注我




添加个人微信,进专属粉丝群!



AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
 最新文章