大模型轻量化解读系列 (五):QuaRot:基于 Rotation 的 4-bit LLM 量化

科技   2024-12-30 22:00   广东  
↑ 点击蓝字 关注极市平台
作者丨科技猛兽
编辑丨极市平台

极市导读

 

4-bit 量化 LLaMA2-70B 模型的 WikiText-2 困惑度损失最多为 0.47,并保留 99% 的 Zero-Shot 性能。QuaRot 可以使用 Round-To-Nearest (RTN) 量化提供无损 6-bit 和 8-bit LLaMA-2 模型,且无需任何校准数据。>>加入极市CV技术交流群,走在计算机视觉的最前沿

太长不看版

采用旋转矩阵解决 4-bit LLM 量化困难。

量化方案:

Weight: Per-channel Symmetric,Activation:Per-token Symmetric

QuaRot 是一种基于旋转 (Rotation) 的新量化方案,它能够以 4-bit 端到端量化 LLM,包括所有的 weight、activation 和 KV cache。QuaRot 通过旋转的方式,在不改变输出的情况下从隐藏状态中去除异常值,进而使量化更容易。这种模式应用于 LLM 的 hidden state,FFN 的激活值,attention 和 KV cache。QuaRot 量化之后,所有的矩阵乘法都使用 4-bit 执行,没有任何 channel 保持更高的精度。

4-bit 量化 LLaMA2-70B 模型的 WikiText-2 困惑度损失最多为 0.47,并保留 99% 的 Zero-Shot 性能。QuaRot 可以使用 Round-To-Nearest (RTN) 量化提供无损 6-bit 和 8-bit LLaMA-2 模型,且无需任何校准数据。

下面是对本文的详细介绍。

本文目录

1 QuaRot:基于 Rotation 的 4-bit LLM 量化
(来自 ETH Zurich,EPFL,微软)
1 QuaRot 论文解读
1.1 QuaRot 研究背景
1.2 正交矩阵,旋转矩阵和 Hadamard 矩阵简单介绍
1.3 QuaRot 阶段 1:Hadamard 变换
1.4 QuaRot 阶段 2:权重和激活值的量化
1.5 实验设置
1.6 精度结果

1 QuaRot:基于 Rotation 的 4-bit LLM 量化

论文名称:QuaRot: Outlier-Free 4-Bit Inference in Rotated LLMs (NeurIPS 2024)

论文地址:

http://arxiv.org/pdf/2404.00456

代码链接:

http://github.com/spcl/QuaRot

1.1 QuaRot 研究背景

大型语言模型 (LLM) 变得越来越重要。然而,实际使用,即推理时,需要大量的计算、显存和能量,特别是在 LLM 的预填充 (Prefill) 阶段,模型应该处理大的 Prompt 并在每一层 cache 它们。量化是改进内存和计算问题的最重要技术之一,通过在前向传递期间以较低的精度保持数据类型。

由于预填充阶段是 Compute Bound 的,联合量化旨在减少参数和 KV cache 的精度 (使得显存的使用量较低) 以及 activation 为的精度。但是量化 activation 很困难,因为它们具有较大的异常值 (示例见图1),使得激活 activation 比量化 weight 更难,尤其是对于 4-bit 的情况。之前的工作[1][2]依赖于使用校准集来表征异常值特征,并将它们保持在更高的精度进行推理。

本文通过使用随机 Hadamard 变换旋转模型的输入来解决异常值特征的问题。作者借助计算不变性思想[3]做到这一点,并将 Hadamard 变换融合到权重矩阵中,从而得到一个没有异常值特征的等价网络。这使得 weight、activation 和 KV cache 能够在精度损失最小的情况下量化为 4-bit。主要贡献是:

  • 本文表明,随机 Hadamard 变换可以应用于权重矩阵,而无需额外的模型修改。反过来,这完全消除了异常值特征,并使激活易于量化,而无需更改模型的输出。这可以看作是在结构化修剪的背景下在 SliceGPT 中提出的计算不变性思想的扩展。
  • 本文扩展了这种方法,将在线 Hadamard 变换应用于 Attention 以消除 Key 和 Value 中的异常值特征,实现 KV cache 的量化。
  • 使用上述修改,QuaRot 实现了所有 weight、activation 和 KV cache 的 4-bit 量化。作者为 QuaRot 提供了有效的 Kernel 支持:在 LLaMA2-70B 模型上,QuaRot 实现了高达 3.33× 的 prefill 加速,解码阶段节省 3.89 倍的显存,WikiText-2 困惑度损失 0.47。QuaRot 保留了 Zero-Shot 任务的 99% 的精度,本文也实现了 6-bit 和 8-bit 量化通过简单的 RTN 量化无损。

1.2 正交矩阵,旋转矩阵和 Hadamard 矩阵简单介绍

正交矩阵   是一个方阵,使得 。在这项工作中,作者只考虑正交矩阵。

旋转矩阵是一个正交矩阵。

Hadamard 矩阵是一个正交矩阵,其内部值为

Walsh-Hadamard 矩阵是大小为 的方阵,其中:

遵循[4]的做法,作者利用随机 Hadamard 矩阵, 方便。设 是一个包含从 中随机抽取的向量。很容易看出 也是一个正交矩阵。

矩阵的不一致性

QuIP[5]在 Weight-only LLM 量化中引入了不一致性处理。定义 -incoherent 的权重矩阵 W 为:

其中, 是矩阵的元素最大值, 是元素的数量。具有高不一致性的权重矩阵很难量化:最大元素相对于平均元素的大小是异常值。QuIP 表明,将权重矩阵左乘,右乘正交矩阵可以减少不一致性,使矩阵更容易量化。

QuaRot 用了类似的技术,将权重矩阵乘以正交矩阵,来改善不一致性。还对 activation 应用了针对不一致性的处理,从而改善了 weight 和 activation 的量化。图 1 中显示了将不一致性处理应用于 LLaMA-2 对于 activation 的影响。

图1:第 10 层中,LLaMA2--7B 模型输入到 FFN 块的激活分布。左:使用从 Hugging Face 下载的默认配置。右图:在使用 QuaRot 处理后。处理后的分布没有异常值,更适合量化

Transformer 架构

图 2 和图 3 (更详细的版本) 介绍了本文针对的 LLM 架构,网络是 "pre-norm" 的,其中每个 Block 前面都有一个 LayerNorm 或 RMSNorm 操作。尽管本文的方法可以直接应用于 MLP 架构,作者还是假设 FFN 使用门控架构,如 LLaMA-2。

图2:LM 中使用的门控前馈网络,包括 RMSNorm
图3:LM 中使用的 Self-attention 的流程,包括 RMSNorm。实线箭头表示训练期间的流量、每个 token 的填充和推理。虚线箭头显示访问和从生成时使用的 KV cache

计算不变性

SliceGPT[3]的计算不变性定理 (Theorem1) 表明,可以使用正交矩阵转换 Transformer 中的 weight 和 Block 之间的 activation,而不更改模型输出。

主要思想是:如果 是一个权重矩阵,出现在 Block 的左侧 (即图1中的 或图2中的 ), 那么可以将左侧乘以正交矩阵 , 并通过将输出矩阵 ( )乘以 来消除这种影响。

上面这个计算不变性的思想尽管两个 Block 之间有 RMSNorm 也不影响。这是因为从概念上讲, RMSNorm 将 activation 除以其范数, 并将正交矩阵 应用于 activation 则不会影响范数。有:

这里假设的是 RMSNorm 会作用于输入 的每一行:

那么 RMS 的这个性质就意味着一件事:给输出矩阵乘以 使得线性层的输出本该是 但变为了 。这个 被归一化之后送入下一个 Block。此 Block 的输入权重现在是 。因此, 这个线性层可以输出原始的 activation。

QuaRot 包括 2 个阶段。

  1. 在第 1 阶段,对模型 weight 进行操作 (以全精度),并将两个额外的 Hadamard 操作插入到模型的前向传递中。
  2. 在第 2 阶段,使用一些现有方法对 weight 进行量化,并将量化操作添加到前向传递中,以实现 activation (和cache) 的在线量化。默认情况下,使用 GPTQ[6]来量化权重,而 activation 使用简单的 RTN 方案即时量化。图 4 和图 5 显示了带有 QuaRot 修改的前向传递的更新框图。

1.3 QuaRot 阶段 1:Hadamard 变换

阶段 1a:权重调节:利用计算不变性将每个权重矩阵乘以正交矩阵

作者首先利用计算不变性将每个权重矩阵乘以正交矩阵。为了实现这一点, LayerNorm 或 RMSNorm 的线性部分被融合到相邻的权重矩阵中。图4显示了如何通过从 RMSNorm( ) 中移除缩放操作并吸收到后续权重矩阵。作者选择了一个大小与模型隐藏维度相匹配的随机 Hadamard 矩阵, 并乘在每个权重矩阵上面。在图 4 和图 5 中, 该矩阵表示为 。例如, Key 投影权重矩阵 修改为:

对于其他权重矩阵也是如此。出现在 Block 输出端的矩阵右乘Q 。

根据计算不变性定理[3],这种权重修改不会影响模型的输出。作者注意到修改后的权重类似于 QuIP#中使用的修改,减少了权重的不一致性,尽管本文的修改在运行时不需要任何额外的处理。此外,在 Transformer Block 之间传递的激活矩阵变成 。图 1 显示了这种处理的结果:处理后的激活不再包含任何异常值。

阶段 1b:FFN 输出旋转:在 FFN 中插入在线 Hadamard 操作

对矩阵权重做了左乘 的修改之后, 作者再在 FFN 中插入在线 Hadamard 操作, 如图4中的下采样投影之前的 hadamard 所示。然后这个操作也需要补偿,方式就是将另一个 Hadamard 矩阵融合到下采样投影矩阵中: 。再加上阶段 1 a 中为每个 Block 输出端右乘 , 因此下采样投影矩阵就变为:

阶段 1c:注意力值投影:对每个注意块应用额外的 Hadamard

作者对每个注意块应用额外的 Hadamard 操作。这种修改一部分在线计算,一部分融合到权重矩阵中。首先,在计算注意力时, 矩阵在每个头部内隐式相乘。

其中, 是方阵, 维度为序列长度。 是每个 head 的 Value 矩阵。上式提供了一个使用 Hadamard matrix 对 矩阵进行处理的机会。

将式 6 的修改代入式 5,计算的结果保持不变。式 6 可以等效为 对 矩阵执行单个 Kronecker 结构化乘法:

作者利用下面的特性:

当头数 和每个 head 的维数 都是 2 的幂时成立。

因为式 7 ,所以:

  • 图5 中的 先右乘 之后, 再进行了一次 hadamard heads, 即: , 相当于是给 右乘了

  • 图5 中的 先左乘 之后, 再左乘 , 相当于是左乘了 , 即

图4:QuaRot 应用于 LLaMa 风格的 FFN。RMSNorm 缩放 (α) 已被吸收到权重矩阵 ((α) 是具有 RMSNorm 参数的对角矩阵)
图5:QuaRot 应用于注意力。RMSNorm 缩放 α 被吸收到输入权重矩阵中,隐藏状态按照与 FFN 相同的方式旋转

阶段 1d:Key 的旋转:对注意块的 Key 应用额外的 Hadamard

注意力模块中的 Key 向量也被认为会受到异常值的影响[7]。因此,本文再使用 Hadamard 旋转来缓解这个问题,允许完全量化 KV cache。注意力分数计算如下:

其中, 是 Softmax 尺度通常设置为 是 Mask (比如 Causal Mask), Pos 表示位置编码。位置嵌入通常仅在第 1 层到输入之前添加, 在这种情况下 Pos 是恒等函数。然而, 最近的方法, 比如 RoPE, 直接向 Key 向量和 Query 向量添加位置信息。

可以观察到 之间的相同交互,这点和 矩阵之间的计算过程很类似。然而, Pos 的存在阻止了将 Hadamard 矩阵直接融合到 中。因此, 作者使用在线头部 Hadamard 旋转来旋转 Query 和Key。因此, Query 和 Key 矩阵的计算更改如下:

由于 Query 和 Key 都被旋转,最终的注意力分数 保持不变。

总体而言,本文前向传播的修改,包括插入特殊的 Hadamard 变换并对权重进行调整不会改变模型的前向传播过程。效果是 Block 之间的 activation 乘以 Hadamard 矩阵,Block 内的 activation 使用 Hadamard 变换在线处理,通过相应的权重矩阵修改完成。现在可以开始量化 weight 和 activation 了。

1.4 QuaRot 阶段 2:权重和激活值的量化

在做好 Hadamard 变换之后,就可以开始量化 weight 和 activation 了。

阶段 2a:Weight 的量化

作者应用 GPTQ[6]来量化网络的权重。在上述前向传递修改之后,可以应用任何量化方法。在随后的部分中,作者也展示了可以以牺牲一些准确性为代价应用简单的 RTN 方案。

阶段 2b:Activation 的量化

在 activation 的量化期间,保留 RMSNorm (without scaling) 的精度为 FP32,使用对称 per-token 量化线性层的输入。

阶段 2c:Attention 的量化

对于更长的序列和更大的 Batch Size,注意力存在显著的显存限制。在旋转 Key 和 Value 后,可以成功地将 cache 量化为低位宽。这减少了所需的 IO 操作的数量。作者将 Query 保留为 FP16,并使用类似于 Flash Attention[8]的在线 softmax 计算。在从内存中加载一段 KV 向量后,以 FP16 精度解量化并计算点积。

1.5 实验设置

在 PyTorch 上使用 Hugging Face 实现 QuaRot。

Input 量化: 所有实验中使用 per-token 对称量化,固定 clipping ratio 为 0.9。

KV cache 量化: 使用 group size 为 128 的非对称量化对 KV 缓存进行量化,clipping ratio 为 0.95。

Weight 量化: 使用 RTN 和 GPTQ,per-channel 对称量化。

使用来自 WikiText-2 训练集的 128 个样本,序列长度为 2048,作为 GPTQ 量化期间的校准集。在单个 NVIDIA A100 GPU 上,使用 QuaRot 修改 LLAMA2-70B 需要 5 分钟,并使用 GPTQ 量化模型需要 2 小时。

模型、任务和 GPU

作者在语言生成和零样本任务,在 LLaMA-2 系列上评估 QuaRot。作者实现了 CUDA kernel,使用 CUTLASS 库执行 4-bit 矩阵乘法。使用 FlashInfer 库来实现 KV cache 量化。由于针对消费者类型的 GPU,作者在 NVIDIA RTX 3090 GPU 上对所有实验进行评估。

1.6 精度结果

语言生成任务

作者首先评估了 QuaRot 在语言生成任务上的准确性。图 6 显示了使用 GPTQ 量化权重时 WikiText-2 上 LLaMA-2 模型的困惑度。作者与 4-bit SmoothQuant 和 OmniQuant 进行比较。QuaRot 在不需要任何重新训练 (比如 OmniQuant) 或更高的精度异常值特征和非对称量化 (比如 QUIK) 的情况下,最多优于所有以前的工作 0.63 的 Perplexity。作者还应用分组量化,对 weight 和 activation 应用相同数量的 group,与 Atom 进行比较。在这种情况下,QuaRot 不需要保留任何更高的精度特征和相关的操作 (比如重新排序)。QuaRot 在 7B 模型中优于 Atom 0.1 的困惑度,在 13B 模型中得到与 Atom 相同的困惑度。

图6:LLaMA-2 模型的 4-bit 量化的 WikiText-2 困惑度结果 ,序列长度为 2048

Zero-Shot 任务

作者接下来专注于在 6 个重要的零样本任务上评估 QuaRot:PIQA、WinoGrande、HellaSwag、LAMBADA (OpenAI) 和 Arc (Easy and Challenge)。使用默认参数的 LM Evaluation Harness 进行实验。图 7 显示了本文的方案在上述任务和平均分数上的精度。在 LLaMA-2 模型家族上,QuaRot 以最多 4.18% 的平均分数损失来保持精度。

图7:PIQA (PQ)、Winogrande (WG)、HellaSwag (HS)、Arc-Easy (A-e)、Arc-Challenge (A-c) 和 LAMBADA (LA) 上的 4 -bit (A4W4KV4) QuaRot 的 LLAMA-2 模型的 Zero-Shot 精度

1.7 性能分析

作者在 PyTorch 之上使用 CUDA/12.1 实现 QuaRot,并使用 CUTLASS 在 TensorCore 上执行 INT-4 矩阵乘法 (结果将保存在 INT32 累加器中)。本节中作者评估了本文 Kernel 在 NVIDIA RTX 3090 GPU 上预填充和解码步骤的性能。作者在单个 Transformer Block 上提供了所有的实验,因为用的 GPU 集群装不下整个模型。

Prefill 阶段性能增提升

对于 compute-bound 的 Prefill 阶段,作者在图 8 中展示了在 2048 的序列长度上使用 QuaRot 的加速。在 LLaMA2-7B 模型上,使用 QuaRot Kernel 获得了比 FP16 实现的版本 1.97-2.16 倍的加速。在 LLaMA2-70B 模型上获得了高达 3.33 倍的加速。注意可以通过优化 Kernel (例如将量化操作融合到 MatMul) 来提高性能结果。

Decoding 阶段显存节约

最后,作者评估了节约的显存,这是 Decoding 阶段的主要瓶颈。图 8 显示了 LLaMA-2 模型的峰值显存节省。作者提供了 LLAMA2-7B 和 LLaMA2-70B 模型的结果。在这两个模型中,在 Decoding 过程中获得了与 FP16 相比至少 3.63 倍的峰值显存节约。注意 LLaMA2--7B 模型中 KV cache 更大,因为 LLaMA2-70B 使用分组查询注意力。在 LLAMA2--7B 模型中,显存节约随着序列长度的增加而增加,显存节约高达 3.75 倍。在 LLaMA2-70B 模型上,几乎在所有情况下都节约了 3.89 倍。作者认为这些值对于整个模型来说更大,因为随着层数的增加,恒定大小对象在显存中的影响变得不那么显著。

图8:使用 NVIDIA RTX 3090 GPU 在 LLaMA-2 模型的单个 Transformer Block 上的性能。左:对于加速结果,使用不同 Batch Size,序列长度 2048 进行评估。右:使用 Batch Size 16 对不同预填充序列长度的 50 个 token 的解码期间的峰值显存节约

参考

  1. ^Atom: Low-bit quantization for efficient and accurate llm serving
  2. ^Towards end-to-end 4-bit inference on generative large language models
  3. ^abcSliceGPT: Compress Large Language Models by Deleting Rows and Columns
  4. ^QuIP#: Even Better LLM Quantization with Hadamard Incoherence and Lattice Codebooks
  5. ^Quip: 2-bit quantization of large language models with guarantees
  6. ^abGPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers
  7. ^KVQuant: Towards 10 Million Context Length LLM Inference with KV Cache Quantization
  8. ^FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness



公众号后台回复“极市直播”获取100+期极市技术直播回放+PPT

极市干货

技术专栏:多模态大模型超详细解读专栏搞懂Tranformer系列大视觉模型 (LVM) 解读扩散模型系列极市直播
技术综述:小目标检测那点事大模型面试八股含答案万字长文!人体姿态估计(HPE)入门教程

极市平台签约作者#


科技猛兽

知乎:科技猛兽


清华大学自动化系19级硕士

研究领域:AI边缘计算 (Efficient AI with Tiny Resource):专注模型压缩,搜索,量化,加速,加法网络,以及它们与其他任务的结合,更好地服务于端侧设备。


作品精选

搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
用Pytorch轻松实现28个视觉Transformer,开源库 timm 了解一下!(附代码解读)
轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur



投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿
△长按添加极市平台小编

觉得有用麻烦给个在看啦~  

极市平台
为计算机视觉开发者提供全流程算法开发训练平台,以及大咖技术分享、社区交流、竞赛实践等丰富的内容与服务。
 最新文章