Towards a theory of learning dynamics in deep state space models
进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群
目录
0. 摘要
1. 简介
2. 傅里叶域中的状态空间模型
3. 简化的学习动态
4. 较大潜在状态大小的学习动态
5. 结论
0. 摘要
状态空间模型(State space models,SSM)在许多长序列建模任务中表现出了显著的经验性能,但对这些模型的理论理解仍然缺乏。在这项工作中,我们研究了线性 SSM 的学习动态,以理解数据的协方差结构、潜在状态大小和初始化如何影响参数在梯度下降学习过程中的演变。我们表明,关注频域中的学习动态可以在温和的假设下提供解析解(analytical solutions),并且我们建立了一维 SSM 与深度线性前馈网络动态之间的联系。最后,我们分析了潜在状态过参数化(over-parameterization)如何影响收敛时间,并描述了将我们的结果扩展到具有非线性连接的深度 SSM 的研究工作。此项工作是迈向深度状态空间模型学习动态理论的一步。
1. 简介
深度状态空间模型(SSM)已经成为长序列任务中具有竞争力和高效的构建模块 [1-13]。尽管相关工作不断增多 [14-16],但对于这些模型的学习动态的理论理解仍然很少。在这里,我们将深度线性前馈网络中的学习动态理论 [17-19] 扩展到线性 SSM 的情况。通过研究这一设置,我们展示了数据的协方差结构和模型参数化如何影响学习。正如 Saxe 等人 [18] 所示,通过研究深度线性模型获得的见解可以在温和假设下转化为非线性模型。
2. 傅里叶域中的状态空间模型
我们考虑线性时不变系统。我们从单输入、单输出离散时间状态空间模型开始,
设 u_t ∈ R 表示时间 t 的输入,x_t ∈ R 表示潜在状态,y_t ∈ R 是输出信号。SSM 由状态转移矩阵 A ∈ R^(N×N)、输入向量 B ∈ R^(N×1) 和输出向量 C ∈ R^(1×N) 参数化。
时间域中的梯度下降动态因时间递归而变得复杂,这通常需要通过时间进行反向传播。然而,在频域中,SSM 允许更简单的表示,即元素乘法。此外,时间域和频域中的学习动态在乘法常数下是等价的。以下命题总结了这些性质。图 1 A-D 展示了在离散傅里叶变换下从潜在状态递归到标量乘法的转变。
图 1:频域中 SSM 的学习动态。
A. 根据方程(1)定义的线性 SSM 展开为长度为 L 的序列。
B. 应用离散傅里叶变换,SSM 通过其频率响应 H_k 完全描述,将时间域中的递归转化为频域中的调制标量乘法。
C. 时间域中的示例输入信号。
D. 输入信号在频域中的离散傅里叶变换。
E. 即使在强假设下,方程(5)中的解析学习动态也能近似 SSM 在简单输入输出模式下的经验演变。每个子图显示单个输入输出对的频率响应演变。
F. 将理论扩展到 N 维单层 SSM,我们展示了潜在状态的过参数化如何导致更快的收敛。实线表示自动微分(automatic differentiation)产生的轨迹,虚线轨迹是从学习动态的解析解的数值模拟中获得的。
命题 1。设 U_k ∈ C 和 Y_k ∈ C 分别表示输入 u_(1:T) 和输出 y_(1:T) 的离散傅里叶变换 (DFT),其中 k = 1,…,L。对于对角动态矩阵 A=diag(a_1,…,a_N),其中对所有 i = 1,…,N,∣a_i∣<1 以确保稳定性,方程 (1) 中的 SSM 完全由其频率响应 Y_k = H_k·U_k 描述,其中 H_k ∈ C 由下式给出:
命题 2。在频域中平方损失下的梯度下降学习动态
与时间域中的动态通过序列长度的比例常数相关。
3. 简化的学习动态
我们推导了一层 SSM 在频域中的连续时间学习动态。有关完整的推导和扩展到两层和 K 层的情况,请参见附录 C。对于参数 θ ∈ {A,B,C} 在平方误差损失函数下的连续时间动态方程如下:
这些连续时间常微分方程表示一层线性 SSM 的学习动态的一般形式,但它们太复杂,无法直观理解模型在梯度下降下的收敛方式。挑战在于 A、B 和 C 的动态是非线性耦合的,因为 Hk 依赖于所有三个参数。为了对这些动态有一些直观的理解,让我们在一些简化假设下考虑梯度下降的表现。
首先,考虑一层、一维 (N=1) SSM 的情况,其中 A∈R 固定,因此唯一可学习的参数是 B,C∈R(实际上,之前的工作发现即使 A 固定也能得到合理的结果 [20])。在这些简化假设下,B 和 C 的动态为:
其中 σ 和 η 是在频域中求和输入输出协方差的充分统计量(sufficient statistics)。因此,我们得到一个二维非线性系统,其中每个坐标的动态在给定另一个坐标的条件下是线性的。
遵循 Saxe 等人 [18] 的方法,我们通过进一步假设 C=B 得到了一个闭形式解。(对于 C≠B 的解,可以通过双曲坐标变换获得,参见 Saxe 等人 [18, App. A])。约束系统的动态由乘积 Λ=CB 的动态特征化。在假设 C=B 下,从方程(4)可以得出:
从(5)可以看出,Λ(t) 以时间常数 τ / (2σ) 收敛到其极限值 σ/η。直观上,在这个简化的情况下,学习的时间常数与频域中的总输入输出协方差成反比。换句话说,更强的输入输出协方差将导致更快的收敛。此外,这一结果表明,对于任何给定的序列数据,其最强的协方差结构将首先被学习到。这一结论恢复了 Saxe 等人 [18] 对两层前馈神经网络分析的一个主要结果,建立了在特定简化假设下频域中的 SSM 与前馈神经网络之间的联系。
图 1E 显示了方程(5)的数值模拟,忠实再现了自动微分的动态。现在我们继续放宽早期对潜在状态大小的假设,以考虑高维模型的动态。
4. 较大潜在状态大小的学习动态
随着潜在状态大小 N 的增加,误差极小值的数量趋向于无限,使得每个极限值对参数初始化敏感。为了从分析上解决这个问题,让我们考虑潜在状态维度上所有参数的对称初始化,即
现在,假设 b=c 并将 a 视为固定,如第 3 节所述,我们得到了乘积 Λ=cb 的动态,
与方程(5)中描述的一维情况相比,方程(6)在 O(τ / (Nσ)) 时间内收敛到其固定点值。因此,我们看到,通过使用更多潜在状态维度来参数化 SSM 的组件可以加速学习收敛。图 1F 展示了对于一个固定输入输出对的合成数据,在一系列潜在状态大小 N 中的这种效果。此外,学习时间仍然与充分统计量 σ 成反比。放宽 B 和 C 之间的平衡假设,我们发现了一些情况下学习的时间过程对潜在状态大小呈逆二次依赖关系。我们在附录 B 中研究了这些其他情况。这里和第 3 节中做出的假设揭示了 N 维单层线性 SSM 的学习动态与深度前馈线性网络的动态之间的密切相似性。我们预计这些结果对于研究具有非线性连接的堆叠线性 SSM 层具有重要意义,类似于 Saxe 等人 [18] 的深度线性前馈网络分析展示了与非线性前馈网络的联系。一个主要的挑战是找到在梯度下降下状态转移矩阵 A 的动态提供解析或近似解释的情况,涵盖方程(3)中的所有非线性交互源。我们将这一分析留待未来的工作。
5. 结论
在这项工作中,我们推导了在线性状态空间模型中基于平方损失进行梯度下降的动态解析形式。在温和假设下,我们发现了一个解,展示了数据的充分统计量如何影响收敛时间,并将这一结果与现有的深度前馈神经网络学习动态理论联系起来。随后,我们扩展了分析,描述了 N 维线性 SSM 中的动态,得出结论:在约束条件下,过参数化可以导致更快的学习。
在未来的工作中,我们计划扩展我们的分析,以理解数据和参数化在多层 SSM 的学习动态中的作用,不论层之间是线性还是非线性连接。遵循深度前馈网络理论的先前工作,我们相信这里获得的结果将在考虑这些更复杂的设置时提供一个有用的起点。
论文地址:https://arxiv.org/abs/2407.07279