Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion
目录
0. 摘要
1. 简介
2. 扩散序列模型
3. 扩散强制
3.1 扩散强制采样及其结果能力
4. 实验
5. 讨论
0. 摘要
本文提出了扩散强制(Diffusion Forcing),这是一种新的训练范式,其中扩散模型被训练来去噪具有独立的每 token 噪声级别的一组 token。我们通过训练一个因果的下一 token 预测模型来生成一个或多个未来 token 而不完全扩散过去的 token,将扩散强制应用于序列生成建模。我们的方法被证明结合了下一 token 预测模型的优势,如可变长度生成,以及全序列扩散(Full-sequence diffusion)模型的优势,如引导采样到理想轨迹的能力。
我们的方法提供了一系列额外的功能,如
滚动生成连续 token 的序列(如视频)超过训练地平线(horizon)的长度,其中基线会发散,
新的采样和引导方案,从扩散强制的可变地平线和因果结构中独特地受益,并在决策和规划任务中带来显著的性能提升。
除了其经验上的成功,我们的方法还被证明优化了从真实联合分布中抽取的所有子序列 token 的似然的变分下界。
1. 简介
当前的下一个 token 预测模型通过教师强制(teacher forcing)[62] 进行训练,模型基于之前 token 的真实历史预测立即下一 token。这导致了两个限制:(1)没有机制可以引导序列的采样以最小化某个目标,(2)当前的下一 token 模型在处理连续数据时容易变得不稳定。例如,当尝试自回归生成视频(与文本 [6] 或矢量量化潜在变量 [33] 相对)超过训练地平线时,帧与帧之间预测的微小错误会累积,导致模型发散。
全序列扩散似乎提供了解决方案。常用于视频生成和长时间规划中,通过扩散 token 的串联直接建模固定数量 token 的联合分布 [31, 1],其中所有 token 的噪声水平是相同的。它们提供扩散引导 [30, 16],以引导采样到理想序列,这在决策(规划)应用中是无价的 [36, 34]。它们在生成连续信号(如视频)方面表现出色 [31]。然而,全序列扩散普遍通过非因果、未掩蔽(unmasked)的结构参数化。除了限制采样到全序列(相对于可变长度生成)外,我们还显示这限制了引导和子序列生成的可能性(图 1)。此外,我们展示了通过训练一个下一 token 预测模型来实现全序列扩散的天真尝试导致了糟糕的生成,直观地说,因为早期 token 中的小不确定性需要晚期 token 中高不确定性。
2. 扩散序列模型
扩散已被广泛应用于序列建模。
[42] 使用全序列扩散模型通过引导实现可控文本生成,例如生成符合指定词性要求的文本。
[31] 训练全序列扩散模型来合成短视频,并使用滑动窗口在先前生成的帧的基础上展开更长的视频。
[36] 在离线强化学习中使用全序列扩散模型作为规划器。通过在与环境交互的轨迹数据集上进行训练,并在采样时使用分类器引导,采样出朝向选定目标获得高奖励的轨迹。
[48] 修改自回归模型以去噪在先前 token 条件下的下一 token。它使用教师强制 [62] 进行训练,并对时间序列数据进行自回归地采样下一 token。
与我们工作最相似的是 AR-Diffusion [63],它使用因果结构沿时间轴训练具有线性依赖噪声级别的全序列文本扩散。在附录 C 中,我们提供了这种方法与我们方法的详细比较。
3. 扩散强制
加噪作为部分掩蔽。回顾一下,掩蔽(mask)是遮挡一部分数据的做法,例如图像的部分区域 [26] 或序列中的时间步长 [15, 47],并训练模型恢复未掩蔽的部分。一般来说,我们可以将任何 token 集合,无论是否连续,视为按 t 索引的有序集合。通过教师强制训练下一 token 预测可以解释为在时间 t 掩蔽每个 token xt,并根据过去的 x_{1 : t−1} 进行预测。限制在序列中,我们将所有这些做法称为沿时间轴的掩蔽。我们也可以将全序列前向扩散,即逐渐向数据
中添加噪声,视为部分掩蔽的一种形式,我们称之为沿噪声轴的掩蔽。实际上,经过 K 步加噪,x^K_{1 : T}(大致)是纯白噪声,没有关于原始数据的信息。
我们建立了沿两个掩蔽轴的统一视图(见图 2)。我们将 x_{1} 表示为 token 序列,其中下标表示时间轴。如上所述,x^{k_t}_t 表示在前向扩散过程中噪声级别为 k_t 的 x_t;x^0_t = x 是不含噪声的 token,而 x^K_t 是白噪声 N(0, I)。因此,(x^{k_t}_t)_{1≤t≤T} 表示一系列噪声观测,其中每个 token 具有不同的噪声级别 k_t,这可以看作是通过加噪对每个 token 应用的部分掩蔽程度。
扩散强制:不同 token 的不同噪声级别。扩散强制(Diffusion Forcing,DF)是一个用于训练和采样任意序列长度的噪声 token (x^{k_t}_t)_{1≤t≤T} 的框架,其中关键在于每个 token 的噪声级别 k_t 可以随时间步长而变化。本文中,我们专注于时间序列数据,因此采用因果结构实例化扩散强制(其中 x^{k_t}_t 仅依赖于过去的噪声 token),我们称之为因果扩散强制(Causal Diffusion Forcing,CDF)。为简单起见,我们专注于使用基础递归神经网络(RNN)[11] 的最小实现。
带权重 θ 的 RNN 保持潜在变量 zt,捕捉过去 token 的影响,并通过具有一个递归层的动态
随时间演化。在有新的噪声观测 x^{k_t}_t 时,隐藏状态以马尔可夫方式更新
给定潜在变量 zt,观测模型 pθ(x^0_t | zt) 预测 xt;该单元具有与标准条件扩散模型相同的输入输出行为,使用条件变量 z_{t−1} 和噪声 token x^{k_t}_t 作为输入预测无噪声的 xt = x^0_t,并通过仿射重参数化间接预测噪声 ϵ^{k_t} [29]。因此,我们可以直接使用传统的扩散训练目标训练(因果)扩散强制。我们以噪声预测 ϵ_θ(z_{t−1}, x^{k_t}_t, k_t) 参数化上述单元。然后通过最小化以下损失找到参数 θ:
其中 k_{1:T} 从 [K]^T 中均匀采样,x_{1:T} 从训练数据中采样,ϵ_t ∼ N(0, σ^2_{kt}·I) 按照前向扩散过程采样(见算法 1 的伪代码)。在附录 D.3 中,我们进一步重新推导了扩散模型训练中的常用技术以用于扩散强制。
3.1 扩散强制采样及其结果能力
采样如算法 2 所示,定义为在二维 M × T 网格 g ∈ [K]^M×T 上规定噪声计划;列对应时间步 t,行按 m 索引确定噪声级别。g_{m,t} 表示行 m 的时间步 t token 的期望噪声级别。要生成长度为 T 的整个序列,初始化 token x_{1:T} 为白噪声,对应噪声级别 k = K。我们按行逐行迭代,按列从左到右去噪到 g 规定的噪声级别。在最后一行 m = 0 时,token 是干净的,即其噪声级别 g_{0,t}≡ 0。
稳定的自回归生成。对于高维、连续的序列如视频,自回归架构已知会发散,特别是在采样超过训练范围时。相反,通过使用与稍微有噪 token(对于一些小噪声级别 0 < k ≪ K)相关的前一个潜在变量更新潜在变量,扩散强制可以稳定地展开长序列,即使超过训练序列长度。我们的实验(第 4.1 节)展示了在长时间生成能力上的显著改进;附录 B.2 提供了进一步的直觉。
保持未来的不确定性。
从一序列白噪声 token [x^K_1, x^K_2, x^K_3] 开始,
我们可以完全去噪第一个 token,部分去噪第二个 token,得到 [x^0_1, x^{K/2}_2, x^K_3],
然后 [x^0_1, x^0_2, x^{K/2}_3],
最后完全去噪所有 token 为 [x^0_1, x^0_2, x^0_3]
将噪声级别解释为不确定性,这种 “之字形” 采样方案直观地编码了比远期未来更确定的近期未来。第 3.2 节描述了这如何导致更有效的序列引导。
长程(Long-horizon)引导。在算法 2 的第 10 行,可以像第 2 节那样向部分扩散的轨迹 x_{1:T} 添加引导。由于未来 token 依赖于过去,引导梯度可以从未来 token 向过去传播。扩散强制的独特优势在于,因为我们可以在不完全扩散过去的情况下扩散未来 token,所以梯度引导过去 token 的采样,从而实现长程引导,同时尊重因果关系。我们在附录 B.1 中详细说明了实现细节。如第 4.2 节所示,以这种方式进行规划明显优于引导的全序列扩散模型。
4. 实验
表 1: 用于规划的扩散强制。
(顶部) 在采样过程中,扩散强制允许每个时间步在不同的噪声计划上去噪,使我们能够在引导规划期间考虑因果不确定性。扩散强制使远期未来比近期未来更不确定,而 Diffuser [36] 在采样过程中将它们放在相同的噪声水平上。
(底部) 定量分析显示,扩散强制在各次运行中获得了最高的平均奖励。Diffuser 在执行实际生成的动作时表现不佳,需要手工设计的 PD 控制器(以星号表示)并忽略生成的动作。
图 4: 在我们的真实机器人任务中,要求一个机器人手臂使用第三个槽来交换两个水果的位置。由于水果在开始时被随机放置在不同的槽中,因此在没有初始水果位置信息的情况下,无法从单一观测中确定下一步动作。如图所示,上方的观测相同,但下方所示的预期结果可能不同——因此,该任务需要记住初始配置。此外,生成动作的相同模型还可以从单个帧合成逼真的视频。
5. 讨论
局限性。我们当前的因果实现基于一个小型 RNN,应用于更高分辨率的视频或更复杂的分布可能需要大型的 transformer 模型。我们没有研究扩散强制在互联网规模的数据集和任务上的扩展行为。
结论。在本文中,我们介绍了扩散强制,这是一种新的训练范式,其中模型被训练以去噪具有独立每个 令牌 噪声水平的 令牌 集。应用于时间序列数据时,我们展示了如何通过扩散强制训练的下一 令牌 预测模型结合了下一 令牌 模型和全序列扩散模型的优点。我们引入了新的采样和引导方案,这些方案在序列决策任务中应用时带来了显著的性能提升。未来的工作可能会研究扩散强制在时间序列生成建模以外的领域的应用,并将扩散强制扩展到更大的数据集。
论文地址:https://arxiv.org/abs/2407.01392
项目页面:https://boyuan.space/diffusion-forcing
公和众与号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
加 VX 群请备注学校 / 单位 + 研究方向