​NeurIPS 2024 | 中科院自动化所提出MetaLA!线性模型架构的大一统

科技   2024-12-11 22:36   北京  
©PaperWeekly 原创 · 作者 | 李国齐课题组
单位 | 中国科学院自动化所

目前,各种线性复杂度模型来取代 Transformer 结构中的传统 Softmax 注意力被提出,例如线性 Transformer(LinFormer)[1][2],状态空间模型(SSM)[3][4] 和线性 RNN(LinRNN)[5][6][7]


然而,这些线性模型的最佳设计仍然是一个悬而未决的问题。在这项工作中,本研究试图从理论角度找到 Softmax 注意力的最佳线性近似来回答这个问题。


我们首先将现有的线性复杂度模型统一为线性注意力形式,然后确定最佳线性注意力设计的三个条件:i)动态记忆能力;ii)静态近似能力;iii)最小参数近似。


我们发现当前的线性模型都不能满足所有三个条件,导致性能不佳。相反,我们提出了元线性注意力(MetaLA)作为满足这些条件的解决方案。我们在多查询联想回忆 (MQAR) 任务、语言建模、图像分类和长距离依赖(LRA)基准上的实验表明,MetaLA 比现有的线性模型更有效。


论文链接:
https://arxiv.org/abs/2411.10741

代码链接:

https://github.com/BICLab/MetaLA


背景

Transformer 模型凭借高效的并行训练能力和卓越的性能,在深度学习应用中表现出色。然而传统的 Softmax 注意力机制在训练时,计算复杂度随输入长度呈二次增长;在推理时每个时间步和隐状态交互,时间和空间复杂度成线性增长。因此,Transformer 面临着计算成本过高的问题。


为此,当前研究主要致力于开发线性注意力模型,如 LinFormer、SSM(状态空间模型)和 LinRNN,试图达到训练时以线性复杂度替代 Softmax 注意力、推理时达到常数级别的时间和空间复杂度。然而这些模型在当前流行的功能和表现上仍与 Softmax 注意力存在差距。



主要贡献

近日,研究者提出了一种新型线性注意力模块——MetaLA,能够实现对 Softmax 注意力映射的最优线性逼近。MetaLA 的设计突破了传统线性模型的限制,统一了现有线性模型的结构,具有以下显著贡献:


a)统一框架下的线性模型解读


本课题组团队首次将 LinFormer、SSM 和 LinRNN 等线性模型抽象为统一的线性注意力形式,从模型的隐藏状态大小、隐藏状态维护方式以及参数映射策略等角度解析其关键设计。这种统一视角不仅帮助理解现有模型的功能差异,还从理论上揭示了它们在实现 Softmax 注意力功能方面的潜力和不足。


b)定义线性逼近的必要条件


为评估线性模型能否逼近 Softmax 注意力映射,研究者提出了两个必要条件:动态记忆和静态逼近。动态记忆要求线性注意力模型能存储最重要的信息并忘记无关信息,而静态逼近则要求模型能够拟合任意 Softmax 注意力映射。


基于这一理论分析,研究者指出,现有诸如 TransNormer、RetNet、RWKV-4、LRU、HGRN 等模型未能满足必要条件,而某些模型(如 Mamba 和 GLA)因使用多余的 Key 矩阵而非最优参数化方案。


c)最佳性能


实验表明,基于 MetaLA 的 Transformer 在关联记忆、语言建模、长序列建模和图像分类等任务上均取得了显著性能提升。同时,研究团队通过消融实验验证了 MetaLA 中各改进的有效性,并进一步探讨了如何提升线性注意力的逼近能力以及线性注意力的容量上限问题。



方法

3.1 一种通用的线性模型形式

观察现有的 LinFormer、SSM 和 LinRNN 模型,研究发现它们的推理过程可以统一为维护隐藏状态的递归形式。Softmax 注意力通过 KV 缓存实现无限隐藏状态,而线性模型通过限制隐藏状态实现对 Softmax 功能的逼近。


模型具有如下的串行形式,其中,不同线性模型的不同点主要在于  q,k,v,alpha 等向量信息等生成运算过程不同:

针对线性模型,其不仅有统一的串行形式,也存在统一的并行形式:

基于这样的串行形式和并行形式,我们可以绘制出如下的通用线性模型流程图:

▲ 图1:线性模型的通用形式的信息处理流程图。该图可视化了上述的统一公式在并行、串行视角下的信息处理流程


上述公式和流程图统一了 LinFormer [1][2]、SSM [3][4] 和 LinRNN [5][6][7] 等模型的递归形式,为理论分析奠定了基础。


在上文给出了统一的线性模型框架的基础上,我们指出,不同的线性模型,例如线性 Transformer (LinFormer),线性 RNN(LinRNN),状态空间模型(SSM)都是通用框架下的特例。而不同的具体模型,其差异体现在 q,k,v,alpha 等向量的生成过程,以及所维护的状态维度上。

▲ 表1:LinFormer [1][2],LinRNN [5][6],SSM [3][4] 均为线性模型的特例

3.2 逼近softmax的注意力图的最优线性模型的必要条件

研究者从理论上定义了实现 Softmax 注意力最优线性逼近的必要条件:


  • 线性复杂度:训练的时间和空间复杂度需为 O(n),推理复杂度为 O(1)。
  • 动态记忆能力:通过有限隐藏状态动态存储重要信息,忘记不重要信息。
  • 静态逼近能力:能逼近任意 Softmax 注意力映射。
  • 最优参数化:在满足上述条件的前提下,使用最少的参数。

▲ 表2:不同模型对最优理论的满足性分析,可以看,GLA [1],RWKV [5],Transformers [8],等模型都不能满足分析中的全部条件

理论分析表明,动态衰减和 Query 矩阵是实现上述条件的关键。而 Key 矩阵在理论上并非必要,可通过优化动态衰减机制替代。

3.3.基于通用形式的MetaLA架构设计

去除 Key 矩阵,用动态衰减 alpha 替代 Key 矩阵,减少参数冗余并增强动态记忆能力。一方面,这一机制有利于更好的参数调配。另一方面,这一机制保证了我们对最优逼近分析得来的必要条件(动态记忆能力和静态逼近能力)。


引入自增强机制和短卷积增强 Token 对自身的注意力,避免注意力稀释问题,提高当前 Token 的信息表达能力,强化局部特征建模能力。

▲ 图2:基于最优理论,为线性模型找到了最佳的设计方案



实验

我们在多查询联想回忆 (MQAR) 任务、语言建模、图像分类和长距离依赖 (LRA) 基准上的实验表明,MetaLA 比现有的线性模型更有效。


多查询联想回忆(MQAR)旨在测试模型在多查询场景下的联想记忆能力和信息检索效率。该实验的核心任务是让模型记住一系列键值对(Key-Value Pair),并在稍后根据给定的查询键返回正确的值。


通过这一实验,我们能够了解模型在处理动态记忆和高效查询中的表现,以及其是否能够成功应对多次查询的累积负担。

▲ 图3:MetaLA 模型中在多查询联想回忆任务中的性能,反映模型记忆能力


语言建模(CommonsenseReasoning)评估模型对日常生活中常识性知识的掌握程度以及基于常识进行推理的能力。


实验使用了如 Winograd Schema Challenge、HellaSwag 等常用基准数据集,要求模型推理隐含信息或基于有限背景知识做出决策。我们重点分析了模型在处理常识性推理能力。该实验有助于衡量模型在广泛实际应用场景中的泛化和推理能力。
▲ 表3:模型在 CommonSense Reasoning 上的性能对比,反映语言建模能力,模型性能显著高于 Pythia [8],Gated Linear Attention [1],Mamba [3] 等主流模型

图像分类(ImageNet-1k)实验 是经典的视觉分类评测任务,旨在验证模型对图像内容的识别能力。我们使用了包含 1000 个类别的大规模 ImageNet 数据集,测试模型的 Top-1 分类准确率。

▲ 表4:模型在ImageNet上的性能对比,反映图像建模能力


长序列任务 Long Range Arena(LRA)实验 旨在评估模型在处理长距离依赖关系和复杂结构数据方面的性能。LRA 基准任务包括文本分类、结构预测和图形匹配等,挑战模型在长文本或大规模图形数据上的捕捉能力。我们特别关注模型在长序列中的局部信息整合和全局依赖建模能力。

▲ 表5:模型在长距离依赖任务上的建模能力,反映模型对长序列关系的捕捉能力


总结

MetaLA 模块通过去除冗余的 Key 矩阵、引入自增强机制以及增强局部交互的短卷积设计,成功实现了对 Softmax 注意力的最优线性逼近。其创新性地统一了现有线性注意力模型的通用形式,并满足动态记忆和静态逼近的必要条件,同时有效降低了参数复杂度。


这一设计为线性注意力模型在长序列建模任务中的应用提供了全新思路,并显著提升了计算效率和模型性能。


参考文献

[1] Yang S, Wang B, Shen Y, et al. Gated linear attention transformers with hardware-efficient training[J]. arXiv preprint arXiv:2312.06635, 2023.
[2] Qin Z, Li D, Sun W, et al. Scaling transnormer to 175 billion parameters[J]. arXiv preprint arXiv:2307.14995, 2023.
[3] Gu A, Dao T. Mamba: Linear-time sequence modeling with selective state spaces[J]. arXiv preprint arXiv:2312.00752, 2023.
[4] Smith J T H, Warrington A, Linderman S W. Simplified state space layers for sequence modeling[J]. arXiv preprint arXiv:2208.04933, 2022.
[5] Qin Z, Yang S, Sun W, et al. Hgrn2: Gated linear rnns with state expansion[J]. arXiv preprint arXiv:2404.07904, 2024.
[6] Peng B, Alcaide E, Anthony Q, et al. Rwkv: Reinventing rnns for the transformer era[J]. arXiv preprint arXiv:2305.13048, 2023.
[7] Katharopoulos A, Vyas A, Pappas N, et al. Transformers are rnns: Fast autoregressive transformers with linear attention[C]//International conference on machine learning. PMLR, 2020: 5156-5165.

[8] Biderman S, Schoelkopf H, Anthony Q G, et al. Pythia: A suite for analyzing large language models across training and scaling[C]//International Conference on Machine Learning. PMLR, 2023: 2397-2430.


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·
·
·

PaperWeekly
PaperWeekly是一个推荐、解读、讨论和报道人工智能前沿论文成果的学术平台,致力于让国内外优秀科研工作得到更为广泛的传播和认可。社区:http://paperweek.ly | 微博:@PaperWeekly
 最新文章