01
引言
自回归模型解码过程的标准做法是缓存序列中前一个Token的Key和Value,以加快注意力计算速度。然而,随着上下文窗口或BatchSize的增加,多头注意力模型中与kv 缓存大小相关的内存成本也会显著增加。
Multi-Query attention(MQA)是一种只使用一个Key-Value head来进行多次Query的机制,它可以节省内存并大大加快解码器的推理速度。然而,MQA 可能会导致模型推理性能下降。事实上,我们不仅希望快速推理,还希望推理性能与 MHA 相当,因此分组查询注意力 Grouped-query attention(GQA)开始发挥作用。
GQA是MQA和MHA的一种互补。它既能达到与多头注意力MHA相似的性能,又能保持与多查询注意力MQA相当的速度。
闲话少说,我们直接开始吧!
02
Grouped-Query Attention
GQA 可以看作是 MQA 和 MHA 的通用形式:
如果 GQA 中只有一组,则称为 MQA
当 GQA 中的分组数目等于注意力头数时,称为 MHA。
下图清晰地展示了这种关系。
从上图中可以看出,GQA 的策略是通过提供改进版的 MQA 来提高推理质量。这是通过使用多个Key和Value头数目但少于查询Query头数目来实现的。
03
GQA的实现
下面,让我们来看看 Llama 2 中GQA的实现。代码如下:
class Attention(nn.Module):
"""Multi-head attention module."""
def __init__(self, args: ModelArgs):
"""
Initialize the Attention module.
Args:
args (ModelArgs): Model configuration parameters.
Attributes:
n_kv_heads (int): Number of key and value heads.
n_local_heads (int): Number of local query heads.
n_local_kv_heads (int): Number of local key and value heads.
n_rep (int): Number of repetitions for local heads.
head_dim (int): Dimension size of each attention head.
wq (ColumnParallelLinear): Linear transformation for queries.
wk (ColumnParallelLinear): Linear transformation for keys.
wv (ColumnParallelLinear): Linear transformation for values.
wo (RowParallelLinear): Linear transformation for output.
cache_k (torch.Tensor): Cached keys for attention.
cache_v (torch.Tensor): Cached values for attention.
"""
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
# ColumnParallelLinear and RowParallelLinear are two common strategies for implementing model parallelism.
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
# The dimension of wk and wv has changed.
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
# kv cache, used for caching keys and values
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Forward pass of the attention module.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position for caching.
freqs_cis (torch.Tensor): Precomputed frequency tensor.
mask (torch.Tensor, optional): Attention mask tensor.
Returns:
torch.Tensor: Output tensor after attention.
"""
bsz, seqlen, _ = x.shape
# The dimension of k and v has changed.
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# Incorporate rotary position embedding
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
# Cache current token's kv
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# Retrieve the previously cached keys and values
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
# make the number of heads in kv and q the same
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
# Self-attention
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
我已经对这段代码的关键过程进行了注释。有以下几点需要注意:
self.n_local_heads 表示原始多头注意力中的head数目,也指查询的head数。
self.n_local_kv_heads 表示GQA中key和value的head数目。这意味着KV-Cache的大小可以减少 self.n_rep= self.n_local_heads // self.n_local_kv_heads 倍
由于 GQA 减少了 KV-Cache的大小,但在实际计算(矩阵乘法(GEMM)子程序)中,它需要与查询头的数量相匹配。因此,需要将它们扩展回原来的大小。repeat_kv 函数用于复制键/值,并使其与查询头数保持一致。
代码如下:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x # MHA
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
) # GQA or MQA
让我们看一个例子,了解 repeat_kv 的作用:
>>> x = torch.rand(1, 1, 4, 6)
>>> x
tensor([[[[0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],
[0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],
[0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],
[0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833]]]])
>>> n_rep = 2
>>> bs, slen, n_kv_heads, head_dim = x.shape
>>> print(x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim))
tensor([[[[0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],
[0.1898, 0.5731, 0.4586, 0.5906, 0.2105, 0.7735],
[0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],
[0.6219, 0.3407, 0.3804, 0.7781, 0.3234, 0.8874],
[0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],
[0.7422, 0.8980, 0.7574, 0.5109, 0.6943, 0.9066],
[0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833],
[0.3850, 0.7974, 0.1791, 0.6012, 0.5239, 0.4833]]]])
>>>
为了进一步理解 GQA 的原理,我画了两张图。如下图显示,Key和Value的头数为 self.n_local_kv_heads = 4 而Query的头数为 self.n_local_heads = 8:
04
总结
点击上方小卡片关注我
添加个人微信,进专属粉丝群!