图解KV Cache:加速大模型推理的幕后功臣

科技   2024-10-14 15:34   日本  

在开始之前,给大家出几个“高频面试题”,看看你能答上来吗?

1. 举例说明 KV Cache 的计算过程

2.为什么要用 KV Cache?它能解决什么问题,代价又是什么?

3. vLLM 里 KV Cache 形影不离的搭档是谁?

还记得之前那篇大语言模型推理,用动画一看就懂!的文章吗?是的!我们再次用动画来演示大语言模型的推理过程!几乎所有的大语言模型(LLM)都基于 Transformer 架构,它依赖于之前生成的 token 来预测下一个字符。而自注意力机制(self-attention)则是模型推理的核心:它不仅需要当前 token,还要每次“回顾”之前的所有 token。

动画演示 KV Cache

为了更加形象理解上面提到的自注意力机制的“回顾机制”,下面我画了一张图。它是大语言模型推理,用动画一看就懂!中那个文本生成步骤的第四步,其中计算 self-attention 时所需的 Key 和 Value 的示意图。

注意:Prompt 是 "The future of AI is" 有五个 token,第一步推理时模型输入的是整个 prompt,会计算出每个 prompt token 对应的 key 值和 value 值,为了清晰起见图里仅用 K1 和 V1 来代表它们。

接下来的动画演示了每一步计算自注意力的过程,清晰起见去掉了其他算子。

从图里看到每一步计算时,当前的 Qi 都需要和之前的 Kj 进行矩阵乘法计算,然后再和之前的 Vj 进行矩阵乘法。那么为了节省算力,我们可以把之前的 Kj、Vj 的结果“缓存”起来,这样每次只需要做增量计算。这个缓存机制就是 KV Cache ,简单却非常有效!来看看加上 KV Cache,推理过程变得多轻松吧!

从上面的动画可以看到除了第一步,其他步骤都可以通过缓存复用之前步骤产生的 Ki 和 Vi。这些步骤在计算 self attention 时只有一个 query,因此叫做 single query attention。

KV Cache 有多大?

一条文本所需的 KV Cache 计算公式如下:

KV Cache Bytes = 2 * 2 * Sequence Length *   Number of Layers * Hidden Size

举个实际的例子,Qwen2 7B 这个国产大模型,在 4 K 序列长度下,KV Cache 大小是 1.6 GB!这是什么概念呢?要知道很多人的显卡也就 8GB 或者 16GB。

KV Cache 的代价

KV Cache 虽然能节省计算,但是显存开销也很显著,随着模型变大(Hidden Size 和 Layer Num 会增大)、序列长度变长,占用的显存迅速膨胀。

假设你实现 KV Cache 时,每次都是预留了一个超大的仓库来存放它,但每次只用了一小部分,这会导致资源浪费,你服务不了太多用户,而且容易出现“撑爆”显存(OOM)的现象。因为你的用户,他们每次推理时的文本长度是变化的!

那么如何解决这些问题呢?vLLM 提出的 PagedAttention 就是聪明地按需分配空间,像是“分隔储物柜”,需要多少就分配多少,避免浪费。

在下一篇文章中,我将继续用动画的方式,深入拆解 KV Cache 的好基友 PagedAttention 的工作原理,带你从源码层面剖析 vLLM 如何用这一技术解决显存瓶颈。敬请期待!

参考资料:

Transformers KV Caching Explained

游凯超(vLLM 核心开发者)知乎上的《一文读懂 KV Cache》

EFFICIENTLY SCALING TRANSFORMER INFERENCE

GiantPandaCV
专注于大语言模型,CUDA,编译器,工程部署和优化等多个方向技术分享。我们不仅坚持原创,也规范转载知乎大佬们的高质量博文。希望在传播知识、分享知识的同时能够启发你,在人类通往AGI的道路上互相帮助(・ω\x26lt;)☆
 最新文章