Learning to (Learn at Test Time): RNNs with Expressive Hidden States
进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群
目录
0. 摘要
1. 简介
2. 方法
2.1 使用 TTT 更新隐藏状态
2.2 使用 TTT 层训练网络
2.3 为 TTT 学习自监督任务
2.4 通过 mini-batch TTT 进行并行化
2.5 对偶形式
2.6 理论等价性
2.7 实现细节
3. 实验
4. 相关工作
4.1 现代 RNN
4.2 测试时学习
4.2.1 测试时训练(test-time training)
4.2.2 快速权重(Fast Weights)
4.3 学习如何学习
5. 讨论
0. 摘要
自注意力机制在长上下文中表现出色,但具有二次复杂度。现有的 RNN 层具有线性复杂度,但它们在长上下文中的表现受限于其隐藏状态的表达能力。我们提出了一种新类型的序列建模层,具有线性复杂度和具有表现力的隐藏状态。关键思想是使隐藏状态本身成为一个机器学习模型,并将更新规则设为自监督学习的一步。由于隐藏状态在测试序列上也会进行训练更新,我们称这些层为测试时训练(Test-Time Training,TTT)层。我们考虑了两种实例化:TTT-Linear 和 TTT-MLP,其隐藏状态分别是线性模型和双层 MLP。我们在 125M 到 1.3B 参数规模上评估了这些实例化,比较了强大的 Transformer 和现代 RNN Mamba。TTT-Linear 和 TTT-MLP 都与基线相匹配或超越。与 Transformer 类似,它们可以通过以更多的 token 为条件来继续降低困惑度(perplexity),而 Mamba 在 16k 上下文后无法做到。经过初步的系统优化,TTT-Linear 在 8k 上下文时已经比 Transformer 更快,并在实际运行时间上匹配 Mamba。TTT-MLP 在内存 I/O 方面仍面临挑战,但在长上下文中显示出更大的潜力,指向未来研究的一个有前途的方向。
论文地址:https://arxiv.org/abs/2407.04620
项目页面:
JAX:https://github.com/test-time-training/ttt-lm-jax
Pytorch:https://github.com/test-time-training/ttt-lm-pytorch
1. 简介
长上下文的困难是 RNN 层的固有性质所致:与自注意力机制不同,RNN 层必须将上下文压缩成固定大小的隐藏状态。作为一种压缩启发式,更新规则需要发现数千甚至数百万个 token 之间的潜在结构和关系。在本文中,我们首先观察到自监督学习可以将大量训练集压缩到模型的权重中,例如 LLM,这通常展示了其对训练数据中语义连接的深刻理解——这正是我们从压缩启发式中所需要的。
TTT 层。基于这一观察,我们设计了一种新型的序列建模层,其中隐藏状态是一个模型,而更新规则是自监督学习的一步。因为在测试序列上更新隐藏状态的过程等同于在测试时训练一个模型,这种新型层被称为测试时训练(Test-Time Training,TTT)层。
2. 方法
如图 4 所示,所有序列建模层都可以从将历史上下文存储到隐藏状态的角度来看【我们将序列建模层定义为从一个序列到另一个序列的自回归映射】。
例如,RNN 层——如 LSTM [33]、RWKV [56] 和 Mamba [26] 层——在时间轴上将上下文压缩到固定大小的状态中。这种压缩有两个结果。
一方面,将输入 token xt 映射到输出 token zt 是高效的,因为更新规则和输出规则在每个 token 上花费恒定的时间。
另一方面,RNN 层在长上下文中的表现受其隐藏状态 st 的表达能力限制。
自注意力机制也可以从上述角度来看,除了它的隐藏状态通常称为 key-value(KV)缓存,它是一个随着 t 线性增长的列表。它的更新规则只是将当前的 KV 元组附加到此列表中,而输出规则扫描所有元组直到 t 以形成注意力矩阵。隐藏状态明确地存储所有历史上下文而不进行压缩,使得自注意力机制在长上下文中比 RNN 层更有表现力。然而,扫描这个线性增长的隐藏状态也需要每个 token 线性增长的时间。
为了在长上下文中既高效又具有表现力,我们需要一个更好的压缩启发式方法。具体来说,我们需要将数千甚至数百万个 token 压缩到一个能够有效捕捉它们潜在结构和关系的隐藏状态中。这听起来可能是个艰巨的任务,但实际上我们都已经熟悉这种启发式方法。
2.1 使用 TTT 更新隐藏状态
参数学习的过程可以看作是将大量训练集压缩到模型的权重中。具体来说,我们知道通过自监督训练的模型可以捕捉到其训练数据背后的潜在结构和关系 [48]——这正是我们从压缩启发式中所需要的。
LLM 本身就是一个很好的例子。通过自监督的下一个 token 预测任务进行训练,它们的权重可以看作是对互联网现有知识的一种压缩存储形式。通过查询 LLM,我们可以从它们的权重中提取知识。更重要的是,LLM 通常展示出对现有知识之间语义连接的深刻理解,以表达新的推理内容 [1]。
我们的关键思想是使用自监督学习将历史上下文 x1, ..., xt 压缩到一个隐藏状态 st 中,通过将上下文视为无标签数据集并将状态视为模型。具体来说,隐藏状态 st 现在等同于 Wt,即模型 f 的权重,可以是线性模型、小型神经网络或其他任何模型。输出规则很简单:
直观地说,输出 token 就是由 f 在更新后的权重 Wt 下对 xt 的预测。更新规则是对某个自监督损失 ℓ 进行梯度下降的一步:
其中 η 是学习率【现在,考虑 W0 = 0。我们将在 2.7 小节讨论更复杂的初始化 W 的技术。】。从压缩的角度来看,每个启发式方法都需要决定记住或忘记哪个输入。我们的 W 记住了那些产生大梯度的输入——直观上,这些输入使 W 学到了很多。
一种选择 ℓ 的方法是重建 xt 本身。为了使学习问题具有非平凡性,我们首先将 xt 处理成一个破坏的输入 ~x_t(详见 2.3 小节),然后优化
类似于去噪自编码器 [75],f 需要发现 xt 各维度之间的相关性,以便从部分信息 ~x_t 中重建它【在过去的实验中,我们也尝试在 f(编码器)之后添加另一个模型 g(解码器),使得重建由 g◦f 而不是仅由 f 本身完成。虽然这种更复杂的设计确实略微改善了结果,但它使整体训练变得不太稳定并增加了显著的计算成本。因此,我们专注于仅使用编码器的设计】。如图 5 所示,梯度下降能够减少 ℓ,但不能将其减少到零。我们在 2.3 小节中讨论了更复杂的自监督任务的公式。
与其他 RNN 层和自注意力机制一样,我们的算法将输入序列 x1,…,xT 映射到输出序列 z1,…,zT,可以使用上述隐藏状态、更新规则和输出规则将该算法编程到序列建模层的前向传递中。即使在测试时,我们的新层仍然会为每个输入序列训练一系列不同的权重 W1,…,WT。因此,我们称其为测试时训练(TTT)层。
2.2 使用 TTT 层训练网络
TTT 层的前向传递也有相应的反向传递。除了梯度算子 ∇,我们的前向传递只包含标准的可微算子。然而,∇ 只是将一个函数映射到另一个函数,在这种情况下是将 ℓ 映射到 ∇ℓ,并且 ∇ℓ 也是由可微算子组成的。从概念上讲,对 ∇ℓ 进行反向调用意味着获取梯度的梯度——这在元学习中是一种广泛探索的技术 [51]。
TTT 层与 RNN 层和自注意力机制具有相同的接口,因此可以在任何更大的网络架构中替换,这些网络通常包含许多这样的序列建模层。使用 TTT 层训练网络的方式与训练其他语言模型(如Transformer)相同。可以使用相同的数据、配方和目标(如下一个 token 预测)来优化网络其余部分的参数。
我们将训练更大网络称为外循环,而在每个 TTT 层内训练 W 称为内循环。这两个嵌套学习问题之间的重要区别在于,内循环梯度 ∇ℓ 是相对于 f 的参数 W 计算的,而外循环梯度是相对于网络其余部分的参数计算的,我们将这些参数表示为 θ_rest。在本文中,外循环参数总是用带各种下标的 θ 表示。
到目前为止,TTT 层没有外循环参数,这与其他 RNN 层和自注意力机制形成对比。在第 2.3 小节中,我们添加外循环参数到 TTT 层以改进其自监督任务。然后在第 2.4 和 2.5 小节中,我们讨论了改进 TTT 层实际运行时间的两种方法。
2.3 为 TTT 学习自监督任务
可以说,TTT 最重要的部分是自监督任务,因为它决定了 W 将从测试序列中学习到的特征类型。那么我们应该如何设计这个任务呢?TTT 的最终目标是使 zt=f(xt;Wt) 在语言建模中表现良好。我们没有采用从人为先验中手工设计自监督任务的方法,而是采取了一种更端到端的方法——直接为最终的下一个 token 预测目标优化自监督任务。
具体来说,我们将自监督任务作为外循环的一部分来学习。从方程 3 中朴素的重建任务开始,我们添加了一些外循环参数以使这个任务可学习。在 2.1 小节中,我们没有具体说明如何将 xt 转变为~xt。一种设计是将其设为低秩投影 ~xt = θ_K·xt,其中 θ_K 是一个可学习的矩阵【下标 K 暗示了与自注意力机制的联系,这一点我们将在2.6小节中建立】。按照多视图重建的术语,θ_K·xt 被称为训练视图。
此外,也许并非 xt 中的所有信息都值得记住,因此重建标签可以是另一个低秩投影 θ_V·xt 而不是xt。这里 θ_V·xt 被称为标签视图,其中 θ_V 也是可学习的。总之,我们新的自监督损失为:
由于 W 和各种 θ 一起出现在方程 4 中,我们再次强调它们在性质上的差异。在内循环中,只有 W 被优化,因此写作 ℓ 的参数;而 θ 是这个损失函数的 “超参数”。在外循环中,θK, θV, θQ 与 θrest 一起被优化,而 W 只是一个隐藏状态,不是一个参数。图 6 用代码说明了这一差异,其中 θK 和 θV 实现为 TTT 层的参数,类似于自注意力机制中的 Key 和 Value 参数。
最后,训练视图 θKxt 的维度比 xt 少,所以我们不能再使用方程 1 中的输出规则。最简单的解决方案是创建一个测试视图 θQxt,并将我们的输出规则改为:
这个解决方案有一个额外的好处。训练视图和标签视图指定了 xt 中被压缩到 Wt 并在时间上向前传播的信息。测试视图指定了可能不同的信息,这些信息被映射到当前输出 token zt 并在网络层中向前传播,因此为自监督任务增加了更多的灵活性。
总之,所有可能选择的 θK, θQ, θV 构成了一个多视图重建任务的家族,外循环可以解释为从这个家族中选择一个任务。为了简化,我们将所有视图设计为线性投影。未来的工作可能会尝试更灵活的变换,或更大和不同的自监督任务家族。
2.4 通过 mini-batch TTT 进行并行化
到目前为止,朴素的 TTT 层在浮点操作数(FLOPs)方面已经很高效。然而,其更新规则
不能并行化,因为 Wt 在两个地方依赖于 W_(t−1):减号前和 ∇ℓ 内。由于 ∇ℓ 包含了大部分计算量,我们专注于使这第二部分并行化。
我们通过 TTT 框架中的概念来解决这个系统挑战。梯度下降(GD)有很多变种。GD 的一般更新规则可以表示为:
其中,Gt 是下降方向。需要注意的是,一旦我们计算出 Gt 对于 t=1,…,T,我们可以通过方程 6 的后半部分中的累加来获得所有的 Wt。我们的朴素更新规则,即online梯度下降,使用
对于 t = 1, . . . , T,为了并行化 Gt ,我们可以获得所有对应于 W0 的值。这种使用 Gt = ∇ℓ(W0; xt) 的变体被称为 batch 梯度下降,因为
相当于将在 x1, . . . , xt 上计算的对应于 W0 的梯度作为一个 batch。然而,在 batch 梯度下降中,Wt 从 W0 实际上只有一个梯度步骤的距离,与online梯度下降不同,在online梯度下降中,Wt 与 W0 相距 t 步。因此,batch 梯度下降具有更小的有效搜索空间,这最终会影响语言建模的性能。
我们提出的解决方案—— mini-batch 梯度下降——如图 7 所示。将 batch 大小设为 b。我们使用 Gt = ∇ℓ(W_t'; xt),其中 t' = t - mod(t, b) 是前一个 mini-batch 的最后时间步(或对于第一个 mini-batch 为 0),因此我们可以一次并行化 b 个梯度计算。经验上,b 控制速度和质量之间的权衡,如图 8 所示。在本文的所有实验中,我们选择了 b = 16。
【5:理论上,b 可能过小而导致 mini-batch 之间的方差过高,从而影响优化效果。然而,在实践中我们并没有观察到这种效应。
6:在图 8 中,我们在 TTT-Linear 1.3B 中使用单个 TTT 层,纯 PyTorch 实现。我们的融合核显著提高了时间效率,但使得清晰地分解计算 Wb 和 z1, . . . , zb 的时间变得困难。】
总结一下,有两种潜在的渠道可以将信息从 Ws 传播到 Wt(其中 s < t):累加和(cumsum) 和梯度算子。累加和(cumsum) 总是活跃的,但是梯度通道只有在 Ws 是来自前一个 mini-batch 时才活跃。梯度下降的不同变体只影响梯度通道,即梯度 Gt 相对于哪个 W 取的方向。然而,由于更新规则的自回归性质,即使选择了 Gt ,梯度步骤 Wt = W_(t-1) - ηGt 总是从 W_(t-1) 开始。
2.5 对偶形式
上述介绍的并行化对于实现墙钟时间(wall-clock time)的效率是必要的但不足够的。现代加速器专门用于矩阵乘法,即 matmuls。例如,NVIDIA A100 GPU 包含高度优化的单元称为 TensorCores,每次只能执行一个操作 - 将两个大小为 16 × 16 的矩阵相乘。如果没有足够的这些 matmuls,TensorCores 将处于空闲状态,A100 的大部分潜力将无法实现。
不幸的是,即使使用 mini-batch,迄今开发的 TTT 层仍然具有非常少的 matmuls。考虑 ℓ 的最简单情况,其中 θK = θV = θQ = I,仅针对大小为 b 的第一个 mini-batch。此外,考虑 f 作为线性模型。复制方程 3,我们在时间 t 的损失是:
如在第 2.4 小节讨论的那样,对于 t=1,…,b,我们可以并行计算:
然而,我们不能通过单个矩阵乘法计算所有 b 个 G_t。相反,我们需要 b 个外积逐个计算它们。更糟糕的是,对于每个 xt∈R^d,G_t 是 d×d,这比大 d 值的 xt 的内存占用和 I/O 成本更高。
为了解决这两个问题,我们做出一个简单的观察:实际上,我们不需要把 G_1,…,G_b 物化,只要我们能在 mini-batch 结束时计算出 W_b 和输出 token z_1,…,z_b(参见图 7)。现在我们用上述简化的 TTT-Linear 情况演示这些计算。
设 X = [x_1,…,x_b],则:
因此,可以方便地用矩阵乘法计算 W_b。为了计算 Z = [z_1,…,z_b],我们知道:
设
矩阵 Δ=[δ1,…,δb],我们可以推导出:
其中 mask 是下三角形式的掩码(类似于注意力掩码,但将无穷替换为零),而项 W0·X − X 可以从 W_b 的计算中重复使用。现在,Δ 也可以通过矩阵乘法方便地计算。将 Δ 插入方程 7,我们得到 Z = W_0·X - 2ηΔ。
我们将此过程称为对偶(dual)形式,与本小节之前的原始形式形成对比,其中 Gs 和 Ws 明确地物化。正如讨论的那样,这两种形式在输出上是等效的。原始形式和对偶形式的术语遵循先前探索类似数学形式的工作的惯例。在附录 A 中,我们展示了当 f 是具有非线性层的神经网络时,对偶形式仍然适用,只是符号更加复杂。
在 TTT mini-batch 中的原始形式的时间复杂度为 O(b × d^2)。对偶形式的时间复杂度为仅计算 Wb 时的 O(b × d^2),然后计算 z1,…,zb 的额外时间复杂度 O(b^2 × d)。与原始形式相比,对偶形式在理论复杂性上有所牺牲,以换取硬件利用率。在实践中,d 通常是几百,而 b 只选择为 16。因此,计算 z1,…,zb 的墙钟时间相对较小,如图 8 右图所示。在我们的 JAX 实现中,使用对偶形式的训练比使用原始形式快 5 倍以上。
2.6 理论等价性
在第 2.1 小节中,我们提到 f 可以是线性模型或神经网络。在第 2.4 小节中,我们还讨论了更新规则的三种变体:online GD,batch GD 和 mini-batch GD。这 2×3 个组合的每种都会引发 TTT 层的不同实例化,如图 9 所示。现在我们证明,在这些引发的实例化中,具有线性模型和 batch GD 的 TTT 层等同于线性注意力 [41],这是一种广为人知的 RNN 层。
【简而言之,线性注意力 [41] 就是没有 softmax 的自注意力。回忆一下自注意力的定义:
去掉 softmax 后,这变为
这是线性注意力的最简单形式。类似于其他 RNN 层,它可以写成递归形式,其中
是隐藏状态。对于 t=1,…,T 由于
可以通过 累加和(cumsum) 计算,线性注意力相对于 T 也具有线性复杂度】
定理 1:考虑 TTT 层,其中内循环模型为 f(x) = Wx,更新规则为 batch 梯度下降且 η=1/2,且 W0=0。然后,给定相同的输入序列 x1,…,xT,方程 5 中定义的输出规则产生的输出序列 z1,…,zT 与线性注意力相同。
证明:根据方程 4 中 ℓ 的定义,
根据方程 6 中 batch GD 的定义:
将 Wt 代入方程 5 中的输出规则,我们得到输出 token:
这是线性注意力的定义。
在表 1 中,我们首先通过改进的线性注意力实现来经验性地验证上述等价性【[41] 中线性注意力的原始公式包含一个归一化器和对 xt 的特征扩展,这些内容仍然可以包含在等效的 TTT 层中。然而,之前的工作发现这两个添加会损害性能 [58],我们在自己的实验中也验证了这一点(表 1 的第一行与第二行)。因此,我们仅构建了一个与最简单的线性注意力公式等效的 TTT 层,而不包含这两个添加内容】。然后,为了说明我们每个组件的贡献(包括将在下一小节中介绍的一些组件),我们将它们逐行添加到等效于线性注意力的 TTT 层,最终获得我们提出的 TTT-Linear 实例。从 batch 梯度下降(GD)到 mini-batch 梯度下降的变化贡献了最大的改进。
尽管图 9 中的 模型×优化器 的空间已经很大,机器学习方法远比优化模型 f 的参数 Wt 的方法要丰富得多。还有非参数学习器(nonparametric learners),例如最近邻(nearest neighbors)、支持向量机(SVM)和核岭回归(kernel ridge regression)。根据定义,非参数学习器没有参数 Wt,而是直接使用训练数据 x1, ..., xt。因此,我们使用记号 f(x;x1, ..., xt)。我们现在展示,对于一个特定的非参数学习器,得到的 TTT 层等同于自注意力。
定理 2:考虑使用 Nadaraya-Watson 估计器 [7, 12] 定义的 TTT 层:
其中 ys=θV·xs 是第 2.3 小节中讨论的标签视图,且
是具有带宽超参数 θK 和 θQ 的核函数。然后,给定相同的输入序列 x1, ..., xT,方程 5 定义的输出规则产生的输出序列 z1, ..., zT 与自注意力相同。
证明:将上述 ys 和 κ 代入方程 8 中,得到自注意力的定义。
附录 B 包含对上述 Nadaraya-Watson 估计器和核 κ 的详细解释。与定理 1 不同,定理 2 没有产生与注意力不同的实现。对于上述 TTT 层,隐藏状态是 x1, ..., xt 或类似的处理训练数据列表,更新规则将 xt 添加到列表中,输出规则使用 κ 扫描列表。在前几小节中,我们的隐藏状态被定义为Wt,更新规则是梯度步骤,输出规则是调用 f。为了统一这两种构造,我们定义了一种新的抽象称为学习器,它唯一地诱导一个 TTT 层。
类似于标准机器学习包中的定义 [54],所有学习器都需要实现两个方法:训练和预测。现在我们将诱导的 TTT 层的隐藏状态重新定义为学习器的内部存储,更新和输出规则为训练和预测方法。
在这个新的 TTT 层定义下,可以包括定理 1 中的参数学习器和定理 2 中的非参数学习器。图 10 总结了在所有序列建模层的更广泛范围内的这种 TTT 层的通用定义。这种通用定义对参数学习器有一个额外的好处:在参数学习器的内部存储中,除了 W 之外,还可以有更多的对象,例如优化器状态,这也将包含在诱导的 TTT 层的隐藏状态中。此扩展允许 TTT 层在未来的工作中使用更复杂的优化器,如 Adam[42]。
2.7 实现细节
f 的实例化。我们提出了两种 TTT 层的变体——TTT-Linear 和 TTT-MLP,它们仅在 f 的实例化上有所不同。对于 TTT-Linear,f_lin(x)=W·x,其中 W 是方阵。对于 TTT-MLP,f_MLP 有两层,类似于 Transformer 中的 MLP。具体来说,隐藏维度是输入维度的 4 倍,接着是一个 GELU 激活 [31]。为了在 TTT 过程中更好地稳定性,f 总是包含层归一化(Layer Normalization, LN)和残差连接。也就是说,f(x) = x + LN(fres(x)),其中 fres 可以是 f_lin 或 f_MLP。
可学习的 W0。TTT 初始化 W0 在所有序列之间共享,尽管后续权重 W1, ..., WT 对于每个输入序列是不同的。我们可以将 W0 作为外循环的一部分进行学习,而不是设置 W0=0。由于外循环参数总是表示为 θs 而不是 Ws,我们为其分配一个别名 θ_init=W0。实际上,与重建视图 θK, θQ, θV 相比,θ_init 增加的参数量可以忽略不计,因为其输入和输出都是低维的。经验表明,学习 W0 显著提高了训练的稳定性。
可学习的 η。学习率通常是梯度下降最重要的超参数,因此我们尝试将方程 6 中内循环的学习率 η 作为外循环的一部分进行学习。为了增加灵活性,我们使 η 成为输入token的函数(因此在时间上不同)。具体来说,我们设计了 η(x)=η_base·σ(θ_lr ·x),其中可学习向量 θ_lr 是一个外循环参数,σ 是 sigmoid 函数,标量 η_base 是基础学习率,TTT-Linear 设为 1,TTT-MLP 设为 0.1。或者,η(x) 也可以解释为 ∇ℓ 的一个门控。
主干架构。将任何 RNN 层集成到更大架构中的最干净方式是直接替换 Transformer 中的自注意力,在这种情况下被称为主干。然而,现有的 RNN 如 Mamba [26] 和 Griffin [18] 都使用与 Transformer 不同的主干。最显著的是,他们的主干在 RNN 层之前包含时间卷积,这可能有助于跨时间收集局部信息。经过 Mamba 主干的实验,我们发现它也提高了 TTT 层的困惑度,因此我们将其纳入我们提出的方法中。详见图 16(附录)。
3. 实验
主干架构。正如在第 2.7 小节中讨论的,Transformer 和 Mamba 使用不同的主干架构,TTT-Linear 和 TTT-MLP 总是使用 Mamba 主干架构,除非另有说明。作为消融研究,图 11 和图 12 包含在 Transformer 主干中的 TTT 层。当一个图中同时包含 Transformer 主干和 Mamba 主干时,我们分别用 (T) 和 (M) 表示它们。
注:为什么 TTT-MLP有相比于 TTT-Linear 明显更高的时延?在文中并未给出解释。这可能源自其使用的隐藏层:TTT-Linear 和 TTT-MLP 的隐藏状态分别是线性模型和类似于 Transformer 的两层 MLP。
如图 14 所示,相比于 TTT-Linear,TTT-MLP 在短上下文中的表现略差,但在长上下文中的表现更好。这个观察结果符合我们的预期,即作为隐藏状态的 MLP 比线性模型更具表现力。 (符合摘要中的描述:TTT-MLP 在内存 I/O 方面仍面临挑战,但在长上下文中显示出更大的潜力)
4. 相关工作
4.1 现代 RNN
Mamba 是许多结构化状态空间模型之一 [27, 21, 57, 18]。这些模型中的隐藏状态是一个向量,类似于 LSTM。对于 TTT-Linear 或 TTT-MLP,隐藏状态是一个矩阵或两个矩阵,因此更大。在图 14 中,我们发现 TTT 层可以利用其更大的隐藏状态在长上下文中压缩更多信息,其中 TTT-MLP 优于 TTT-Linear,而后者优于 Mamba。
类似于 TTT-Linear,RWKV[55, 56]、xLSTM [5] 和门控线性注意力(GLA)[79] 也具有矩阵隐藏状态,这些状态继承自线性注意力 [41]。现代 RNN 如 GLA 使用块式并行性来提高硬件效率,因此块内的 token 可以通过矩阵乘法而不是累加和(cumsum)进行处理。然而,块式并行性不会改变模型的表达能力,因为所有时间依赖关系仍然等效于累加和。
相比之下, mini-batch TTT 允许跨 mini-batch 的更复杂的时间依赖关系。每个隐藏状态 Wt 仍然通过累加和依赖于其 mini-batch 内的前一个 Ws,但也通过梯度运算符依赖于前一个 mini-batch 的 Ws。如图 8 所示, mini-batch TTT 在表达能力和硬件效率之间实现了折衷,因为较小的 batch 大小 b 在提高困惑度的同时导致更高的延迟。这种折衷是 TTT 的一个独特且重要的特性。如表 1 所示,中间 batch 大小 b=16 显著优于完全累加和的 b=T。
(2020|ICML PMLR,线性 Transformer,核函数,RNN)Transformer 是 RNN
(2021,AFT,MHA,RWKV 基础,线性内存复杂度)无注意力的 Transformer
(2023|EMNLP,RWKV(RWKV-4),Transformer,RNN,AFT,时间依赖 Softmax,线性复杂度)
(2024,RWKV-5/6,RNN,矩阵值注意力状态,数据依赖线性插值,LoRA,多语言分词器)Eagle 和 Finch
(2023,SSM,门控 MLP,选择性输入,上下文压缩)Mamba:具有选择性状态空间的线性时间序列建模
(2024|ICML,Mamba2,SSD,SSM,SMA,矩阵变换,张量收缩,张量并行)Transformer 是 SSM
(2024,Attention-Mamba,MoE 替换 MLP)Jamba:混合 Transformer-Mamba 语言模型
(2024,FLOPs分配,MoD,MoDE,top-k 路由,块丢弃)在基于 Transformer 的语言模型中动态分配计算
(2024,Infini-T,Infini-A,压缩记忆,长期记忆)使用无限注意力的高效无限上下文 Transformer
(2024|ICML PMLR,强化学习,动态决策,网络跳过,质量和计算的权衡)可切换决策:动态神经生成网络
(2024,LSTM,Transformer,指数门控,归一化器状态,多头内存混合)xLSTM:扩展的 LSTM
4.2 测试时学习
测试时学习的想法在机器学习中有着悠久的历史。这一想法的最早版本之一被称为局部学习(local learning)(Bottou和Vapnik[10]):对于每个测试输入,在做出预测之前先对其邻域进行训练。这一过程已被有效地应用于从 SVM [81] 到现代大型语言模型(LLM)[29] 的各种模型。
测试时学习的另一早期版本被称为直推学习(transductive learning)[22]。Vladimir Vapnik [74] 阐述的直推原则是 “...获取你真正需要的答案,而不是更一般的答案。” 实际的直推学习实现使用测试数据为 SVM 的边缘添加约束 [39, 17]。然而,直推学习通常需要多个测试实例才能在经验上有效,这与许多测试时训练的实例不同,后者一次只需要一个测试实例(图像、视频或自然语言序列)。
在计算机视觉中,测试时学习的想法已经应用于面部检测 [38]、目标检测 [53]、图像超分辨率 [65]和 3D 重建 [50] 等应用中。最近,同样的想法也被应用于自然语言处理领域,被称为动态评估 [44, 45]。基本方法是直接在测试序列上微调语言模型,通常以提示的形式出现。
接下来,我们详细讨论两个相关的工作方向:测试时训练和快速权重。
4.2.1 测试时训练(test-time training)
测试时训练(TTT)的核心思想是每个测试实例定义其自身的学习问题,该测试实例本身就是泛化的目标 [69]。具体来说,对于每个测试实例 x,传统做法是使用一个为所有训练实例平均优化的预测器 f 来预测 f(x)。TTT 首先定义一个由 x 定义的学习问题,然后训练一个模型 f_x(通常以 f 作为初始化)并预测 f_x(x)。
由于测试实例没有标签,因此学习问题只能通过自监督任务来制定。先前的研究表明,使用重建的 TTT 显著提高了性能,特别是在异常值上 [23]。在流式到达的(arrive in a stream)视频帧上进行测试并且 TTT 是自回归时,性能提升更加明显 [76],因为 ft 是在过去的帧 x1, ..., xt 上进行训练的。自回归连接使得 [76] 与我们的论文最为相关。
从概念上讲,我们的论文与之前工作的最大区别在于我们的重建任务是在外循环中学习的,而不是通过人类先验手工设计的。TTT 的后续工作探索了诸如机器人操作 [28] 和运动 [68] 等应用,这些应用通常需要为自监督任务设计不同的方法。
4.2.2 快速权重(Fast Weights)
快速权重的普遍思想是只在最相关的数据上更新 “快” 模型的参数,而不是在所有数据上更新 “慢” 模型 [71]。这一思想自 20 世纪 80 年代以来就存在 [32]。最相关的数据可以是测试实例本身,因此 TTT 可以被视为快速权重的特例。
先前的快速权重研究通常避免形成明确的学习问题来优化数据上的某些目标。例如,给定每个输入 x,Hebbian 学习和 Hopfield 网络 [35] 的更新规则只是添加 x·x^T(或其某种变体)[4] 到快速权重上。相反,TTT 采用了明确制定学习问题的思想,其中测试实例是泛化的目标。我们的更新规则也是一个明确的优化步骤。
快速权重程序(fast weight programmers,FWPs)的思想是用 “慢” 模型更新快速权重 [62]。我们的内循环权重 W 可以被视为 “快” 权重,而外循环权重 θ 是 “慢” 权重。因此,包含 TTT 层的网络可以被视为 FWPs 的一种特例 [43],类似于 TTT 可以被视为快速权重的一种特例。上述具有 Hebbian 更新规则的 FWP 等效于线性注意力 [60],因此也等效于带批量梯度下降的简单 TTT-Linear。
FWPs 的定义非常广泛。事实上,所有具有某种门控机制的网络,例如带有 SwiGLU 块的Transformers [63],也可以被视为 FWPs 的一种特例。最近的工作正在实验使用 FWPs进行语言建模:Irie 等人 [37] 设计了 “快” 网络,其权重作为 “慢” 网络的输出生成。Clark 等人 [16] 为 Transformer 添加了一个快速权重的最终层,其初始化作为慢权重进行训练。与现有 FWP 工作相比,我们的贡献是明确制定了更新的学习问题,这使我们能够借用诸如小批量和 LN 等学习工具。
4.3 学习如何学习
几十年来,研究人员一直认为,学习如何学习(也称为元学习或双层优化)应该是智能的一个关键组成部分 [61, 6, 70, 47]。在先前的工作中,如 [2]、[20] 和 [52],内循环每次从整个数据集学习,而不是从一个序列学习,因此外循环需要收集数据集或任务。简而言之,外循环比常规训练 “高一个层级”。由于难以收集数百万个数据集,这种外循环难以扩展。
相反,对于 TTT,每个序列本身就是一个数据集,并定义了它自己的泛化问题。内循环比常规训练“低一个层级”,所以我们的外循环只是监督学习标准问题的另一种解决方案,而不是像跨数据集泛化那样的新问题设置。如表 2 所示,我们的外循环与常规训练 “在同一个层级”。这使得我们的外循环更容易扩展。
我们的论文将监督学习重新表述为学习如何学习(learning to learn),包含两个嵌套循环。外循环的突出部分与常规训练相同。外循环的参数成为内循环的超参数。直观地说,内循环,即 TTT,是在常规训练 “下一级” 的过程。
5. 讨论
我们已经将经典的监督学习问题重新表述为测试时学习(learn at test time)。我们的表述为构建传统上称为网络架构的事物提供了一个替代的概念框架。我们在表 2 中总结了我们目前的实例化。
在这个框架内,寻找有效实例化的搜索空间非常大,而我们的论文只是迈出了一小步。幸运的是,如果我们的观点成立,那么常规训练中的启发式方法可以转移到测试时训练中,搜索也可以高效。接下来我们概述一些特别有前景的未来研究方向。
外循环参数化。有很多其他方式可以参数化多视图重建任务的家族,或者更广泛的自监督任务家族。如果我们尝试的第一个方法恰好是最好的,那将是一个很大的巧合。
系统优化。我们在第 3.3 小节中的系统优化至多是初步的,还有很多改进的空间。此外,通过时间的流水线并行可能允许我们在多个设备上一起处理数百万个 token 的长序列。
更长的上下文和更大的模型。受限于我们的学术资源,我们尚未在数百万或数十亿的上下文长度下进行训练,根据图 19,这也需要更大的模型。在更长的上下文中,TTT 层的优势应该会更加明显。
更雄心勃勃的 f 的实例化。当上下文长度变长时,f 也需要变大。对于视频任务和具身代理(embodied agents),其上下文长度可以轻松扩展到数百万或数十亿,f 可以是卷积神经网络。
多级学习。如果 f 本身是一个自注意力层,那么根据定理 2,它可以解释为现有内循环内的另一个内循环。通过这种方式,我们可以潜在地构建多层嵌套的学习问题。
为什么我们研究 TTT?首先一个更基本的问题:为什么研究 AI?对于我们中的一些人来说,AI 是探索人类智能本质的游乐场。以往的研究经常尝试用机器学习来建模人类学习,其中训练是在一个带有 i.i.d. 实例的打乱的数据集上进行,推理是在一个单独的测试集上进行。然而,人类并不自然地用 i.i.d. 实例学习,也没有训练-测试划分。我们认为人类学习与 TTT,即我们的内循环,有着更有希望的联系,其数据是一个可能非常长的序列,具有强烈的时间依赖性,任何数据片段都可以用于训练和测试。这就是为什么我们研究 TTT。