掌握Transformer之KV Cahce

文摘   科技   2024-09-22 13:29   江苏  
点击蓝字
 
关注我们










01


引言



ChatGPT等平台的成功在很大程度上归功于许多研究人员和工程师为提高大型语言模型(LLM)推理速度所做的努力。用户需要实时的人工智能交互--模仿自然对话的快速反应。


ChatGPT采用的是Transformer结构中的Decoder部分,该结构在推理阶段用KV Cache技术来加速推理经有一段时间了,但也许大家需要了解它到底是什么,以及该技术所带来的巨大推理速度的提升


闲话少说,我们直接开始吧!






02


基于点积的注意力


如下图所示,Key和Value在Transformer中主要用于计算基于点积的注意力得分

基于点积的注意力在Transformer中的应用


这里需要注意的一点为:

KV Cache发生在多个Token生成的步骤中,并且只发生在解码器中(即在 GPT等仅解码器模型中,或在 T5 等编码器-解码器模型的解码器部分)。BERT 等模型不是生成式模型,因此没有 KV 缓存。

解码器以自回归方式工作,如以下 GPT-2 文本生成示例所示。

在解码器的自回归生成过程中,给定输入后,模型会预测下一个标记,然后在下一步中综合输入和上一步输出进行下一步的预测。

这种自回归行为会重复一些计算,我们可以通过放大解码器中基于掩码点积注意力的计算过程来更好的理解这一点。

解码器中基于点积注意力计算的逐步可视化


03


KV-Cache的引入


由于解码器是具有因果关系的(即一个Token的注意力只取决于其前面的Token),因此在每一次的生成中,我们都要重新计算前面相同Token的注意力,而实际上我们只想计算新Token的注意力。
这就是KV-Cache发挥作用的地方。通过缓存之前的Key和Value,我们可以只计算新Token的Attention。

有 KV Cache和无 KV Cache的点积注意力对比

为什么这一优化非常重要?如上图所示,使用 KV Cache获得的矩阵更小,从而加快了矩阵乘法的速度。唯一的缺点是需要更多的 GPU VRAM(或 CPU RAM,如果不使用 GPU)来缓存Key和Value。





04


 性能对比


让我们使用Transformer来比较有 KV Cache和无 KV Cache的 GPT-2 的推理速度。

import numpy as npimport timeimport torchfrom transformers import AutoModelForCausalLM, AutoTokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"tokenizer = AutoTokenizer.from_pretrained("gpt2")model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
for use_cache in (True, False): times = [] for _ in range(10): # measuring 10 generations start = time.time() model.generate(**tokenizer("What is KV caching?", return_tensors="pt").to(device), use_cache=use_cache, max_new_tokens=1000) times.append(time.time() - start)  print(f"{'with' if use_cache else 'without'}\n KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")


在使用 Tesla T4 GPU 的谷歌 Colab 笔记本上,报告了生成 1000 个新Token的平均耗时:

with KV caching: 11.885 +- 0.272 seconds without KV caching: 56.197 +- 1.855 seconds
推理速度的差异非常大,而 GPU VRAM 的使用量却可以忽略不计,正如以下链接所报告的那样

网址:https://discuss.huggingface.co/t/generate-using-k-v-cache-is-faster-but-no-difference-to-memory-usage/31272

因此请确保在Transformer Decoder模型变种中使用 KV 缓存!







点击上方小卡片关注我




添加个人微信,进专属粉丝群!



AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
 最新文章