本文主要是关于mamba论文的详解~
论文:Mamba: Linear-Time Sequence Modeling with Selective State Spaces
论文地址:https://arxiv.org/ftp/arxiv/papers/2312/2312.00752.pdf
代码:state-spaces/mamba (github.com)
Demo:state-spaces (State Space Models) (huggingface.co)
概述
Mamba 是一种新的状态空间模型架构,适用于信息密集型数据,例如语言建模。它基于结构化的状态空间模型,具有高效的硬件感知设计和实现。
Mamba是对长数据序列进行建模的新型神经网络.这些是新的选择性状态空间模型(SSM),旨在克服传统序列模型(尤其是Transformers)的局限性。该模型是循环神经网络 (RNN) 和卷积神经网络 (CNN) 的组合,灵感来自经典状态空间模型。
1.介绍
Mamba 根据输入专注于或忽略特定信息。它根据输入参数化选择性状态空间模型 (SSM) 权重,允许模型过滤掉不相关的信息并无限期地保留相关数据。
Mamba 还使用硬件感知算法以递归方式而不是卷积来计算模型。这种方法比传统方法更快、更高效,因为它不会实现拉伸状态,并避免了 GPU 内存层之间的 I/O 访问。
能够处理长序列
传统的转换器模型存在计算复杂度随着序列长度的增加而以平方形式增加的问题。在处理长序列时,这是低效且资源密集型的。Mamba 解决了这个问题,在序列的长度上线性缩放。因此,曼巴蛇可以有效地处理长序列,并具有重要的应用潜力,特别是在语言、音频和基因组学等领域。
计算效率和速度
与 Transformer 相比,Mamba 具有更快的推理速率和更低的内存要求。这意味着 Mamba 在实际应用中效率更高,并节省了训练和推理大规模模型所需的计算资源。
选择性状态空间
Mamba 根据输入对 SSM 参数进行参数化。这允许模型过滤掉不相关的信息,并无限期地保留它需要的信息。这种选择机制允许曼巴只关注相关数据,从而提高数据处理效率。
硬件感知算法
Mamba 使用一种硬件感知算法,该算法以递归方式而不是卷积进行计算。这会阻止 GPU 内存层之间的 IO 访问,并且不会实现扩展状态。因此,无论是在理论上(与序列长度线性缩放)还是在现代硬件上(例如,在 A100 GPU 上速度提高 3 倍),这种实现都比以前的方法更快。
简化架构
Mamba 将之前的 SSM 架构与 Transformer 的 MLP 模块组合成一个模块,提供更简单、更高效的架构。这使得 Mamba 更易于实现和扩展,适用于广泛的应用。
2.状态空间模型
S4基本结构,如下图。
3.选择性状态空间
SSM 是一种模型,旨在对序列数据(例如,随时间变化的数据)进行建模。这些模型结合了传统递归神经网络 (RNN) 和卷积神经网络 (CNN) 的特征,并受到经典状态空间模型的启发。
SSM能够处理长序列,并可应用于各种类型的序列数据。这些功能可以与各种架构相结合,以应用于新形式的序列建模任务。
Mamba的构架
论文提出了2种结构S4和S6,ABCD是可学习的参数,相比与传统的S4,S6中允许A,B,C矩阵具有选择性,也就是依赖于上下文,mamba为O(1),而Transformer为O(n),
变量x和y是具有维度 B(表示批处理大小)、L(表示序列长度)和 D(表示维度大小)的张量。N 被选为任意值,用于确定后续张量的大小。Δ 是我们用于从当前状态 h 过渡到下一个状态 h' 的张量。Sb 、Sc和SΔ 是激活函数:
Sb(x)=LinearN(x)
Sc(x)=LinearN(x)
Sd(x)=BroadcastD(Linear1(x)) Td = Softplus,
其中LinearN到N维的线性映射,离散化是将矩阵或张量从连续时间更改为离散时间。
4.实验评估
作者在不同的章节里对于不同任务进行实验比对,其中包含
- 对于2项合成任务的实验
- 语言模型预训练(缩放规律)和零点下游评估,
- DNA 序列预训练 DNA 序列预训练和长序列分类任务的微调。
- 音频波形预训练和自回归生成语音片段的质量。
- Mamba 在训练和推理时的计算效率。
5.结论
论文为结构化状态空间模型引入了一种选择机制,使其能够执行上下文相关的推理,同时按序列长度线性扩展。当把 Mamba 纳入一个简单的无注意力架构时在不同的领域都取得了最先进的结果,其性能达到或超过了Transformer模型的性能。作者认为选择性状态空间模型可以得到广泛应用,它可以为不同领域建立基础模型,尤其是在需要长语境的新兴模式中,如基因学、音频和视频等等。
疑问与解答(Q&A)
Q1:为什么对于每个输入的token有不同矩阵,就实现了内容感知(SSM的选择性)?
理解为矩阵的乘法本质上是一种对信息加权,当我们从输入中学习到某种模式的时候,矩阵就会变得有意义,这种意义在计算时通过加权就实现了对信息的感知。
Q2:为什么较小的步长会忽略特定单词,关注先前的上下文;较大的步长会更多的关注输入单词而不是上下文?
原文中对于步长的解释,其中有一个很重要的矩阵A,该矩阵捕获先前状态构建新状态。离散化这一节中的公式表明,离散化后的A与步长为幂乘关系,步长较小可以理解为矩阵中的零少,上下文信息保留的更多,因此缺少对当前信息的关注。反之亦然。
Q3:Δ是在学习中自动学习到类似于Gate的操作吗?完全是由模型自己学习到的?
Delta 可以被认为是如何将 A、B、C 分成离散的部分,也是可以学习的参数。
参考文献
[1]Mamba Explained (thegradient.pub)
[2]Mamba: The Easy Way (jackcook.com)
[3]Mamba architecture : A Leap Forward in Sequence Modeling | by Puneet Hegde | Medium
[4]Mamba Simplified - Part 2 - S4 and Mamba (premai.io)
[5]The Annotated S4 (srush.github.io)
[6]Structured State Spaces: Combining Continuous-Time, Recurrent, and Convolutional Models · Hazy Research (stanford.edu)
[7]mamba-notes/Mamba_Slides.pdf at main · hkproj/mamba-notes (github.com)
[8]https://newsletter.maartengrootendorst.com/p/a-visual-guide-to-mamba-and-state
附录详解
A1:RNN中的离散方法
Euler method
Zero-Order Hold(ZOH)
A2:HiPPO
u是原信号,x是压缩后的信号。给定一个持续增长的u,HiPPO允许online update压缩的x。如果使用一个64unit的polynomial压缩器.
HiPPO在低阶信号上work后,我们希望将它扩展到高阶信号上。阶数越高,与LLM越相似,工作的价值就越大。
想到了非常精妙的一个方法:不考虑input x 到state u ,而是直接从state x 到output u 。这样一来,这里的 Cx就是state x 的线性组合,
而 D就是skip connection,是绕开state x,直接从input u到输出 y 的一个连接:
前定义了 x′ (下一时刻的 x )来将input u记忆成state,现在又定义了 y 来将state x 线性组合成一个输出。这就是S4。
A3:SSM详解
SSM 是一种模型,旨在对序列数据(例如,随时间变化的数据)进行建模。这些模型结合了传统递归神经网络 (RNN) 和卷积神经网络 (CNN) 的特征,并受到经典状态空间模型的启发。
状态空间
举例来说:在下图中
You表示当前位置,也就是当前状态
紫色表示下一步去哪里,也就是可能的未来状态
蓝色表示因为什么变化会带You进入下一状态
使用状态空间模型,通过使用方程和矩阵来描述这种行为。根据描述状态的变量,画出入口到出口的距离,也可以称为状态向量。
状态空间模型
SSM 是用于描述这些状态表示的模型,并根据某些输入预测它们的下一个状态。
传统上,在时间t,SSM:
映射输入序列x(t) — (例如,在矩阵中向左和向下移动)
潜在状态表示h(t) — (例如,出口距离和 x/y 坐标)
并得出预测输出序列y(t) —(例如,再次向左移动以更快到达出口)
但是,它不使用离散序列(例如向左移动一次),而是以连续序列作为输入并预测输出序列。
SSM 假设动态系统(例如在 3D 空间中移动的物体)可以通过两个方程根据其在时间t的状态进行预测。
通过求解这些方程,我们假设可以发现统计原理,根据观察到的数据(输入序列和先前状态)预测系统的状态。
其目标是找到这个状态表示h(t),以便我们可以从输入转到输出序列。
这两个方程旨在根据观察到的数据预测系统的状态。由于输入预计是连续的,因此 SSM 的主要表示是连续时间表示。
别忘了点赞👍+关注✨哟~~