1-bit大模型还能再突破!新一代BitNet架构启用4位激活值

科技   2024-12-06 10:04   吉林  


MLNLP社区是国内外知名的机器学习与自然语言处理社区,受众覆盖国内外NLP硕博生、高校老师以及企业研究人员。
社区的愿景是促进国内外自然语言处理,机器学习学术界、产业界和广大爱好者之间的交流和进步,特别是初学者同学们的进步。
转载自 | 新智元
编辑 | alan

量化到1 bit的LLM还能再突破?
这次,他们对激活值下手了!
近日,BitNet系列的原班人马推出了新一代架构:BitNet a4.8,为1 bit大模型启用了4位激活值:
论文地址:https://arxiv.org/pdf/2411.04965
众所周知,激活值量化通常是比较难办的。
本次的BitNet a4.8采用混合量化和稀疏化策略,来减轻异常通道引入的量化误差。
简单来说就是,对注意力层和FFN层的输入采用4位量化,同时用8位整数稀疏化中间状态。
大量实验表明,BitNet a4.8在相同的训练成本下,实现了与前代BitNet b1.58相当的性能,同时因为可以吃到4位(INT4/FP4)内核的计算红利,实现了更快的推理速度。
BitNet a4.8仅激活55%的参数,并支持3 bit KV cache,进一步提升了大规模LLM部署和推理的效率。

BitNet a4.8

模型架构

模型的整体架构如图1所示,BitNet a4.8采用了与BitNet b1.58相同的布局。
作者使用BitLinear替换注意力(MHA)和前馈网络(FFN)中的线性投影,以从头开始学习1.58 bit权重。对于激活值,采用混合量化和稀疏化策略来减轻异常值维度引入的误差。
图2说明了模型大小为7B的BitNet b1.58中,每个模块输入的分布。
注意力层和FFN层的输入通常类似高斯分布,而在FFN下采样之前的激活值和注意力中的输出投影中,发现了很多异常值通道和大量接近零的条目(全精度LLM也有类似观察结果)。
如图3所示,直接将低位量化应用于这些中间状态会引入很大的量化误差。
因此,作者使用Q-Sparse的稀疏化方法,将这些中间状态保持在8位(同时消除了计算瓶颈)。
对于自注意层的输出投影,使用sparsify-then-quantize函数:
两个Q分别表示权重W和激活X的量化函数,M是掩码,根据激活X的绝对值取topK,⊙是元素乘法。
具体来说,权重量化和激活值量化函数可以表述为:
对于FFN,这里采用squared ReLU和门控线性单元(GLU)来进一步提高激活的稀疏性:
根据初步实验的结果,使用squared ReLU时,下采样输入的稀疏性超过了80%,且对性能的影响最小。
此外,作者还观察到gate + squared ReLU的输出也表现出高激活稀疏性(7B模型为67.5%)。通过首先计算gate projection,然后仅在非零通道上执行up projection,可以进一步减少推理的计算量。
相比之下,attention和FFN的输入中包含的异常值特征要少得多,可以使用absmean函数将激活值量化为4位整数:

模型训练

初始化
BitNet a4.8使用BitNet b1.58的权重开始训练,分为W1.58A8与W1.58A4两阶段。
第一阶段使用8位激活和GLU + squared ReLU训练模型;第二阶段采用上面介绍过的混合量化和稀疏化。
BitNet a4.8只需少量训练,即可快速适应4bit位宽和稀疏激活,同时性能损失可以忽略不计。
梯度近似
作者使用直通估计器(STE)对BitNet a4.8进行梯度逼近,使用混合精度训练来更新参数。
这里直接绕过了不可微函数,包括反向传播过程中的量化函数和topK稀疏函数。对于混合精度训练,保持全精度latent weight来累积参数更新。

模型量化

浮点量化提供了比基于整数的量化更宽的动态范围,这对于处理激活值的长尾分布至关重要。
研究人员将FFN下采样层的输入保留为8位整数,其他激活值使用MinMax量化器量化为FP4:
公式中E和M分别表示指数和尾数部分的位宽。这里采用E2M1格式,因为它的动态范围更大。

实验

本文将BitNet a4.8、BitNet b1.58,以及各种参数量大小的FP16精度LLaMA进行了比较。
其中的1.58 bit模型,遵循BitNet b1.58的训练方案,采用了两阶段权重衰减和学习率调度。
所有模型都使用RedPajama数据集中的100B token进行训练,以确保公平比较。
对于BitNet a4.8,作者首先使用95B token来训练8位激活值的模型。然后重用优化器状态,并使用5B token进行混合量化和稀疏化的训练。实验将topK设置为50%(attention的输出投影位置)。
作者使用lm-evaluation-harness工具包,评估模型在一系列语言任务上的zero-shot准确性,包括ARC-Easy(ARCe)、ARCChallenge(ARCc)、Hellaswag(HS)、Winogrande(WGe)和PIQA(PQ)。另外还测试了在C4数据集(测试集)上的困惑度。
主要结果
表1总结了BitNet a4.8、BitNet b1.58和FP16 LLaMA的详细测试结果。
全精度(FP16)LLaMA和BitNet b1.58之间的性能差距,随着模型大小的增长而缩小。对于7B模型,BitNet b1.58在语言模型困惑度和任务的平均准确性方面与LLaMA相当。
此外,相比于BitNet b1.58,BitNet a4.8的平均精度几乎没有损失。
表2展示了各种大小的BitNet a4.8、BitNet b1.58 和 FP16 LLaMA中每个模块的详细稀疏性(使用C4验证集上的非嵌入参数计算)。
值得注意的是,BitNet a4.8的稀疏性明显高于BitNet b1.58和LLaMA。
比如在7B模型中,BitNet a4.8的整体稀疏性达到了44.5%,只有3.4B的活跃参数。down projection层的输入显示出特别高的稀疏性,且中间状态分布以零为中心。
此外,gate projection的输出非常稀疏,导致了up projection的高稀疏性(因为只需要在从Gate中选择非零通道来执行投影)。
具体来说,对于7B BitNet a4.8,Gate和up projection的稀疏率分别为67.5%和12.0%。
表3显示了BitNet a4.8在3B和7B模型大小下,low-bit attention的详细情况。模型使用4位KV或QKV头,精度损失可忽略不计,同时KV cache可以量化为3位整数。
low-bit attention对于高效的长序列建模至关重要,它减少了KV cache的内存占用和IO,并加速了注意力计算。
在本文的实验中,作者采用RoPE后量化。使用absmax函数将QKV头直接量化为无符号整数,无需任何校准数据集。
对于3 bit KV量化,研究人员将bos token的头保留为4 bit,因为它包含更多的异常值特征。
消融实验
图4显示了700M BitNet a4.8的训练损耗曲线,比较了使用完整的INT4/FP4量化,以及本文的混合量化和稀疏化。
完整的INT4量化会导致发散,而混合架构在训练困惑度方面明显优于完整的FP4架构。
使用RedPajama数据集中25B token,来进行模型的第一阶段训练,采用absmean和MinMax量化器分别进行完整的INT4和FP4量化。
对于完整的INT4量化,由于其输入具有更大的异常值,这里设置β = 2*mean(|X|)。
接下来为1.3B BitNet a4.8的down projection层输入,设置不同的量化或激活函数。
所有模型都使用RedPajama数据集中的50B token进行第一阶段训练。为了确保公平比较,其他激活值都保留在8位。
图5显示了这些模型的训练损失曲线。Squared ReLU的训练困惑度比Swish略好,同时实现了更高的稀疏性。
此外,对down projection的输入应用FP4量化会导致性能显著下降,而将INT4激活与STE一起使用会导致发散。
参考资料:
https://arxiv.org/abs/2411.04965
https://venturebeat.com/ai/how-microsofts-next-gen-bitnet-architecture-is-turbocharging-llm-efficiency/
技术交流群邀请函

△长按添加小助手

扫描二维码添加小助手微信

请备注:姓名-学校/公司-研究方向
(如:小张-哈工大-对话系统)
即可申请加入自然语言处理/Pytorch等技术交流群

关于我们

MLNLP 社区是由国内外机器学习与自然语言处理学者联合构建的民间学术社区,目前已经发展为国内外知名的机器学习与自然语言处理社区,旨在促进机器学习,自然语言处理学术界、产业界和广大爱好者之间的进步。
社区可以为相关从业者的深造、就业及研究等方面提供开放交流平台。欢迎大家关注和加入我们。


机器学习算法与自然语言处理
关注AI前沿技术,助力AI学者进步
 最新文章