(2024,RNN,梯度消失和爆炸,记忆诅咒,重参数化和动态学习率,权重矩阵对角化,复值 RNN)梯度消失和爆炸并不是故事的结局

文摘   2024-06-27 21:08   新加坡  

Recurrent neural networks: vanishing and exploding gradients are not the end of the story

进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群

目录

0. 摘要

1. 梯度消失和梯度爆炸

2. 记忆的诅咒

2.1 直觉

2.2 线性对角 RNN 中的信号传播

2.3 将分析扩展到非对角情况

3. 减轻记忆诅咒

3.1 一个解决方案:归一化和重新参数化

3.2 多种 RNN 架构隐式缓解了记忆的诅咒

4. 线性教师-学生分析

4.1 一维情况

4.2 对角连接简化优化

4.3 自适应学习率的重要性

5. 初始化时深度 RNN 中的信号传播

6. 结论



0. 摘要

循环神经网络(Recurrent neural networks,RNNs)在学习长期记忆方面一向表现不佳,主要原因是梯度消失和梯度爆炸。最近,状态空间模型(SSMs)作为 RNNs 的一个子类,在克服这些困难方面取得了成功,这对我们的理论理解提出了挑战。在本文中,我们深入探讨了 RNNs 的优化挑战,发现随着网络记忆的增加,其参数的变化会导致输出变化显著增大,从而使基于梯度的学习变得高度敏感,即使没有梯度爆炸。我们的分析进一步揭示了元素级递归设计图样(pattern)的重要性,结合谨慎的参数化,可以减轻这种影响。这一特性存在于 SSMs 以及其他架构中,如LSTMs。总体而言,我们的见解为 RNNs 基于梯度学习的一些困难提供了新的解释,并解释了为什么某些架构比其他架构表现更好。

论文地址:https://arxiv.org/abs/2405.21064

1. 梯度消失和梯度爆炸

首先,我们介绍将在本文中使用的符号。我们考虑一个具有隐藏状态 h_t、由参数 θ 参数化的更新函数 f_θ 和输入序列 (x_t)_t 的循环神经网络。网络的平均性能通过损失 L 来衡量。我们有

瞬时损失 L_t 相对于参数 θ 的梯度等于

在上述方程中,我们使用 ∂ 表示偏导数,使用 d 表示全导数。使用这种符号可以区分 ∂_(ht) Lt,它对应于通过读取(readout)函数从当前损失项反向传播到隐藏状态的误差,以及 d_(ht) L,它积累了通过未来隐藏状态值反向传播的误差。特别地,

当在多层循环层上堆叠时,∂_(ht) Lt 对应于通过网络层次结构反向传播到隐藏状态 ht 的当前误差,而 d_(ht) L 对应于通过递归反向传播的未来误差信号。

早期工作指出,梯度下降难以使循环神经网络记住过去的输入,这些输入在以后会有助于产生期望的行为。这是因为时间上向后传播的误差信号往往会爆炸或消失。关键量是

可以注意到,当雅可比矩阵 ∂_h f_θ 的谱半径上界是一个严格小于 1 的常数时,这个量会指数级收敛到 0;如果存在某个分量大于1,则会指数级爆炸。时间 t 处的误差信号反向传播到时间 t′ 的行为类似,因为

因此,基于梯度的长期记忆学习是困难的:随着考虑的时间跨度增加,过去隐藏状态对当前损失的贡献要么变得微不足道,要么变得占主导地位。

自那时起,分析得到了改进,循环架构的开发主要是为了解决这一病态问题。最著名的是,LSTM单元,以及后来出现的 GRU,通过使用便于直接信息存储和检索的记忆神经元来解决这个问题,并通过同样的方式进行误差反向传播。其他解决该问题的方法包括梯度剪裁(gradient clipping)、活动归一化(activity normalization)、谨慎的权重初始化或实施架构约束,如分层处理、正交权重矩阵和振荡等。

2. 记忆的诅咒

根据常见的深度学习经验,通常认为解决梯度消失和梯度爆炸问题可以使循环神经网络学习长期依赖。我们对此观点提出质疑:解决这些问题是否足以确保良好的损失景观(loss landscapes)?我们否定地回答这个问题,显示出即使在网络动态保持稳定的情况下,随着网络记忆的增加,梯度也可能会爆炸

2.1 直觉

循环神经网络有一些特殊之处:相同的更新函数 fθ 一次又一次地被应用。因此,修改参数 θ 不只影响一次更新,就像在前馈神经网络中改变给定层的权重一样,而是会影响所有更新。随着网络记忆的增加,隐藏状态会保留更多更新的影响。隐藏状态因此对参数变化变得越来越敏感。这就是记忆的诅咒。我们借用了 [38, 39] 中的术语,并指出 Martens 和 Sutskever[40] 假设这种现象可能在 RNNs 中出现并阻碍它们的优化。

让我们形式化我们的直觉,并考虑隐藏状态 ht 对参数 θ 的敏感性:

当信息在网络的记忆中停留时间更长时,不可忽略的雅可比项 d_(ht′) ht 的数量增加。因此,当网络编码更长期的依赖时,这种敏感性的幅度增加,学习 θ 变得更加困难。关键是要注意,即使通过将递归雅可比矩阵的特征值约束在小于 1 的范围内,并确保网络动态保持稳定从而消除梯度爆炸,这种现象仍然会出现。本节的其余部分将专注于定量研究这种行为。

2.2 线性对角 RNN 中的信号传播

我们研究了隐藏状态和梯度幅度如何随着网络编码更长期依赖而演变。理想情况下,我们希望这些量不消失或爆炸。这一特性改善了损失景观的条件性,从而简化了优化。我们做出以下假设:

  • 线性对角循环神经网络。我们限制自己使用形式为 fθ(ht,x_(t+1)) = λ⊙ht + x_(t+1) 的更新函数,其中 λ 是隐藏状态大小的向量,⊙ 是元素级乘积。为了便于说明,我们在这里展示实值 λ 的结果;关于复值情况,请参见附录 A.2。虽然这一假设较强,但它允许我们识别一些关键机制,并且对于某些模型(如 S4 和 LRUs)是满足的。我们稍后会展示我们的分析可以模拟更复杂网络的一些特征。

  • 无限时间视界。我们考虑无限序列,并在 t_0 = −∞ 初始化网络动态。这简化了我们的计算,同时当所考虑的序列长于我们想要学习的依赖关系的特征时间尺度时,这是一个合理的假设。

  • 广义平稳性。我们假设网络接收的不同量,包括输入 xt,都是广义平稳(wide-sense stationary,WSS)的。一个随机过程 Xt 被称为 WSS,如果它的自相关函数与时间无关,即对于所有 t∈ℤ 和 Δ∈ℤ 有 E_X [Xt+ΔXt]=:R_X (Δ),其中 E_X 表示对数据的期望。这相当于假设数据中不同序列的统计量对于时间移位是不变的。

我们现在有能力分析一个循环层中的信号传播,无论是前向传播还是反向传播。我们展示了当 ∣λ∣→1 时,隐藏状态和反向传播的误差都会爆炸。

正向传播。在这里,我们感兴趣的是理解隐藏状态方差 E[h^2_t] 如何作为网络中特征时间尺度(编码在 λ 中)以及输入自相关函数 Rx 的函数的演变。经过计算(详见附录 A.2),我们得到:

重要的是,当网络中编码的长期依赖增加时,即 ∣λ∣→1,隐藏状态的方差趋向于无穷大。此外,发散速度取决于输入数据分布:当输入分布中连续时间步长的相关性增加时(即,较少的 Rx(Δ) 可忽略),它会增加。这一行为已经突显出包含线性递归层的深度神经网络基于梯度学习的潜在困难,因为神经活动的方差可能变得任意大,从而阻碍更深层的学习能力。

反向传播。让我们首先导出损失相对于 λ 的梯度。使用链式法则我们有

因此,我们需要理解 d_λ ht 的行为。我们注意到

因此 d_λ ht 是隐藏状态的低通滤波版本,而隐藏状态本身是输入的低通滤波版本。因此,当 ∣λ∣→1 时,d_λ ht 的方差比 ht 的方差发散得更快。更准确地说,我们得到:

我们在图 1 中绘制了当 Rx(Δ)=ρ^∣Δ∣ 满足时该量的精确行为,并请感兴趣的读者参阅附录中的方程 6 的推导。更一般地说,当网络达到动态稳定的边缘(即 ∣λ∣→1)时,网络的隐藏状态及其最终输出对递归参数的变化变得越来越敏感。

最后,我们需要考虑的是反向传播到递归层输入 x 的误差。可以观察到,反向传播是正向传播的对偶过程,因为它是一个接收反向传播误差 d_λ ht 并在时间上反向运行的递归过程:

其中,我们使用了 ∂_(xt) ht = 1。因此,我们对正向传播的分析在这里也适用。关键是,这意味着爆炸行为对于递归参数而非潜在的输入或读取权重最为显著。

2.3 将分析扩展到非对角情况

现在我们将结果推广到形式为 h_(t+1) = Aht + xt 的全连接线性递归神经网络。为了分析,我们假设 A 是复数对角化的,即存在一个复数值矩阵 P 和复数值向量 λ,使得 A=Pdiag(λ)P^(−1)。注意,在随机初始化 A 的情况下,这种情况发生的概率为 1。在这种情况下,

从前面的分析中,我们知道在 ∣λ∣→1 的极限中 ∂_P ht、d_λ ht 和 d_(P^(−1)) ht 之间占主导地位的项是 d_λ ht,因为 P 和 P^(−1) 的作用是读取和输入权重。鉴于所有其他项不直接依赖于 λ 的幅度,我们有

详见附录 A.2.3 的正式声明。这有两个后果:首先,随着编码更长记忆,ht 对 A 的敏感性会爆炸,这直接来自于 A 的特征值。其次,由于 A 的每个条目通常会影响矩阵的所有特征值,爆炸行为将分布在所有条目上,而在对角情况中则集中在特征值上。我们稍后会观察到,这在实践中有显著的后果,并部分解释了为什么全连接线性 RNNs 难以训练。作为旁注(side note),我们注意到强制矩阵 A 为正交矩阵解决了梯度消失和梯度爆炸问题,但由于记忆的诅咒,这些权重可能仍然敏感。

3. 减轻记忆诅咒

我们已经讨论了递归网络对参数更新的敏感性。鉴于这个问题,如何才能缓解它呢?具有对角连接的递归网络特别适合这个目的。除了能够控制雅可比矩阵并避免梯度爆炸之外,它们还促进了记忆诅咒的缓解。在这种情况下,我们证明了状态空间模型和门控 RNN 固有地包含了这种机制。

3.1 一个解决方案:归一化和重新参数化

随着网络编码更长的记忆,正向和反向传递都会爆炸。当 h_(t+1) = λht + x_(t+1) 时,我们认为减轻这种效果相对容易。我们旨在使 E[h^2_t]、E[(d_λ ht)^2] 和 E[(d_(xt) ht)^2] 与 λ 无关,类似于确保神经活动幅度在深层网络中保持恒定的初始化方案,并且与层宽无关。

输入归一化。确保 E[h^2_t] 保持恒定的一种简单方法是引入一个缩放因子 γ(λ),应用于神经元接收的输入,满足

考虑到输出误差的反向传播到输入的过程与前向传递是对偶的,γ 在反向传递中的作用将类似。因此,γ 的值需要同时依赖于输入分布以归一化前向传递,以及输出误差分布以归一化反向传递。完美的归一化可能不现实,但如图 2.A 所示,某种程度的归一化是有帮助的。

特征值重新参数化。现在我们需要控制损失相对于 λ 的梯度。输入归一化在一定程度上减少了记忆引起的爆炸效应,但并不能完全消除,因为 d_λ ht 的方差远大于 ht 的方差(见图 1.A)。重新参数化可以弥补这一差距。实际上,如果 λ 通过 ω 参数化,我们有

当 λ 趋近于1时,选择一个越来越细化的参数化有助于保持 d_ω ht 的恒定。假设为了简化,γ 与 λ 无关,实现

需要求解微分方程

虽然由于依赖于输入分布,得出通用的最优参数化是不现实的,但如图 2.B 所示,重新参数化确实有帮助。图 6 展示了它如何影响损失景观。

复数的情况如何?我们还没有讨论 λ∈C 的情况,这对于像 S4 这样的 SSMs 是相关的。我们在附录 A.3.2 中扩展了对复数 λ 的分析,并强调它们很难正确参数化。简言之,我们的分析表明,如果 λ 被参数化为 ν·exp⁡(iθ),θ 的参数化必须依赖于 ν 的参数化,但反之则不必。然而,这样做不会影响学习,如我们在附录 A.3.2 中所示。

3.2 多种 RNN 架构隐式缓解了记忆的诅咒

状态空间模型以及门控RNN都具有某种形式的归一化和重新参数化,这有助于信号传播。我们在下面讨论如何实现这一点。

状态空间模型。SSMs 最初是作为连续时间微分方程 ˙h = Ah + Bx 的离散化动机而提出的。对微分方程的简单离散化得到

当 dt 较小时,这已经在某种程度上充当了输入归一化。更复杂的离散化方案,如零阶保持(zero-order hold),实际上重新参数化了 A 矩阵,例如用 exp⁡(dtA)。其中,对角化是出于计算效率和简便性的原因【18】。虽然这种模型可以近似任何平滑映射【47,48】,但它们的表达能力仍然有限【49】。这些模型的下一代,如 Mamba【21】,结合了输入依赖的门控,根据输入 xt 调节 dt。我们上面开发的理论不严格适用于这种情况,因为 dt 不是恒定的。然而,由于模型结构的其余部分保持不变,我们预计这种行为及其补救措施也会保持不变。

门控 RNN。虽然 LSTM【3】或 GRU【23】等门控 RNN 的最初动机与 SSMs 的动机大不相同,但它们具有相似的机制。在这些网络中,存储在隐藏神经元中的记忆内容可以通过遗忘门擦除,而传入的输入可以通过输入门选择性地写入记忆。数学上,这对应于形式为 h_(t+1) = f_(t+1) ⊙ ht + i_(t+1) ⊙ x_(t+1) 的隐藏状态更新,其中遗忘门 f_(t+1) 和输入门 i_(t+1) 是 x_(t+1) 和 ht 的独立非线性函数。遗忘门类似于 λ,通常涉及一个 sigmoid 非线性函数,这在反向传递中有类似重新参数化 λ 的效果。输入门可根据网络的初始化或与遗忘门(如 GRU【29】中的 f_t = 1 − i_t)耦合时,充当输入归一化的重要角色。重要的是,这些门依赖隐藏状态,从而使雅可比矩阵 ∂_(ht) h_(t+1) 非对角化。然而,我们认为这些架构仍然有对角化的偏向。实际上,通过遗忘门和输入门的隐藏状态的贡献是间接的,当连接隐藏状态到门的权重较小时,可以忽略不计。因此,我们回到了前一段中讨论的设置;我们在第 5 节中证实了这一直觉。在这种近似不成立的情况下,研究信号传播需要比我们在这里所做的更复杂的分析【50】。

4. 线性教师-学生分析

我们考虑具有线性递归网络的教师-学生任务【51】。这是可以训练递归网络的最简单设置之一,然而,正如我们将看到的那样,它极其复杂。我们首先转向一维设置,以直观地说明记忆诅咒和梯度消失如何相互作用。然后,我们处理一般设置,观察到线性网络确实受到记忆诅咒的影响,并且我们在上一节研究的补救措施是有效的。此外,我们发现对角性极大地改变了损失景观的结构,并有助于自适应学习率的优化器补偿最终增加的敏感性。

4.1 一维情况

我们首先考虑遵循一维动态  h_(t+1) = λht + x_(t+1) 的学生和教师,学生的复值参数 λ,教师的复值参数 λ*。为了简单起见,我们从均值为 0,标准差为 1 的正态分布中抽取 x_(t+1),并注意到其他输入分布不会在质上改变结果。学生的表现通过损失 L 来衡量,该损失平均了整个序列的每个时间步损失

这个简单的模型已经捕捉到了基于梯度学习递归神经网络的两个关键难点。在图 1 中,我们绘制了不同 λ* 值下的损失景观,当 λ 在实轴的正部分演化(图 1.B)和当它在复平面上 ∣λ*∣ 半径的圆上演化(图 1.C)时。我们限制 λ 的绝对值小于1:梯度爆炸不在考虑范围内。然而,基于梯度的学习在这里出现了两个困难。一方面,梯度消失导致难以逃脱的平坦损失区域。另一方面,由于记忆诅咒,随着学生编码更长的记忆,损失变得更加尖锐。因此,基于梯度的优化变得极其繁琐,即使在这个简单的例子中也是如此。

4.2 对角连接简化优化

现在我们转向教师按以下公式演化的一般情况:

其中 ht∈R^n,xt∈R 从 N(0,1) 中独立同分布抽取,A∈R^(n×n),B∈R^(n×1),C∈R^(1×n), D∈R^(1×1)。其中,输入和输出都是标量。

根据我们迄今所发展的直觉,我们预计完全连接的线性递归神经网络在处理教师编码更长记忆的任务时会遇到困难,不仅因为梯度爆炸,还因为记忆的负面影响。相反,对角化有助于特征值的重新参数化,避免梯度爆炸问题,并使其更加稳定。我们进行了以下实验来验证这一直觉。我们随机选择隐藏维度为 n = 10 的教师,并将递归矩阵 A 的复特征值变换为其模值接近我们控制的值 ν。ν 越大,教师编码的记忆越长。我们在这个任务上训练了一个线性 RNN,以及一个具有隐藏维度为 64 的 LRU [20]。因此,学生们在很大程度上是过度参数化的。我们选择 LRU 架构来代表 SSM,因为它简单明了。该架构使用输入规范化和特征值的指数重新参数化,与我们在第 3 节中分析的类似。这两个网络均使用 Adam 优化器 [52] 和余弦退火调度进行了 10,000 步的训练,在批次大小为 128 的情况下。序列包含 300 个时间步长。学习率针对每种方法和训练分布分别调整。我们在图 3.A 中绘制的结果验证了我们的直觉:在需要学习长期记忆时,LRU 明显优于线性 RNN,尽管参数少了 10 倍。

接下来,我们想知道 LRU 架构背后的设计选择对性能改进的关键性是什么。为此,我们通过以下方式在线性 RNN 和 LRU 之间进行插值:首先,我们将线性 RNN 的权重矩阵限制为大小为 2 的块对角线。每个这样的块可以表示一个复数,总共 32 个复数。我们还将隐藏神经元的数量加倍。其次,我们将这些 2×2 的块(及其输入和输出权重)改为复数。最后,我们添加 γ 输入规范化和指数参数化,得到最终的 LRU 架构。我们在图 3.B 中报告了这个实验的结果。我们发现,大部分差距来自于引入复数,并且可以通过使权重矩阵块对角线化来部分减少。有趣的是,这两个改变减少了模型的参数数量,并略微降低了模型的表达能力,因此这种行为的解释可能与这些模型的优化特性有关。我们在下一节中确认了这一假设。

4.3 自适应学习率的重要性

到目前为止,我们的结果突显了直接参数化递归连接矩阵的复杂特征值的重要性。这种参数化并不缓解任何爆炸性行为,但修改了损失景观,使得具有自适应学习率的优化器能够补偿这些行为。为了证明这一点,我们研究了损失的 Hessian 矩阵:

如果网络可以完全拟合目标数据,这里的第二项在最优情况下会消失。

我们在图 4.A 和 B 中绘制了在标准线性递归网络和具有复值对角参数化的情况下(均为 4 个隐藏神经元,ν = 0.99)的 Hessian 矩阵。我们观察到,两种架构的特征值谱类似,都表现出大量有记忆诅咒的项,这使得使用随机梯度下降几乎不可能。然而,它们的结构不同。对于完全连接的线性 RNN,顶部特征向量分布在许多坐标上,而对于复值对角线的情况,则集中在少数坐标上。这个特性有助于自适应优化器(如 Adam):对于大曲率的适应更容易,因为病态方向(pathological directions)与规范基(canonical basis)几乎对齐。这正是我们在实践中观察到的。在图 4.C 和 D 中,我们比较了 Adam 使用的有效学习率,通过向优化器提供一个全为 1 的向量来计算。对于密集线性 RNN,自适应学习率无法补偿组件间复杂耦合,导致非常小的学习率。相反,复值对角线 RNN 的敏感性集中在少数参数上,自适应学习率可以补偿,导致更快的学习速度和更大的整体学习率。顺便提一下,教师的复特征值成对出现。然而,在训练过程中,复 RNN 的复数值并不互为共轭,从而增加了 Hessian 矩阵的对角性。最后,对于 LRU 进行这种分析,我们发现其 Hessian 谱与对角设置相似,而 Hessian 的爆炸维度几乎完全由角参数引起,这与我们的理论分析一致;参见图 9。

在结束本节之前,我们调查是否存在能够打破 Hessian 矩阵对角结构的特征值分布,使得优化变得更加困难,并增加特征值重新参数化的压力。我们在附录 B.2 中从理论上证明了这种直觉结果:特征值越集中,Hessian 矩阵的对角性越低。因此,复值对角网络和 LRU 之间的差距加大,但前者仍然大大优于其全连接的对应物;参见图 10。

5. 初始化时深度 RNN 中的信号传播

我们理论研究的最终目标是获得有关 RNN 训练的实际见解。具体来说,我们旨在通过研究初始化时的信号传播来验证理论和受控实验中建立的趋势是否在实际中成立。

我们提供 512 个文本标记的序列作为输入,给包含四个块(每个块有 256 个隐藏神经元)的深度递归网络,并使用下一个标记预测损失来衡量其性能。每个块由一个递归层和一个前馈门控线性单元 [57] 组成。默认情况下,此架构中没有归一化层。更多细节可在附录 C.1 中找到。我们从经验上研究了当递归层的记忆(由 ν 控制)增加时,E[h^2_t] 和 E[d_θ h^2_t] 如何演变。我们比较了三种不同的递归层:复数对角 RNN (cRNN),LRU 和使用 chrono 初始化的 LSTM [29]。

结果与我们的理论一致。复数 RNN 遭受记忆的负面影响。LRU几乎完美地减轻了这种效果,无论是在前向传播(图 5.A)还是反向传播(图 5.B),除了角参数 θ,这是预料之中的。我们还想知道层归一化是否可以取代 LRU 的输入归一化和重新参数化。我们发现,它在宏观水平上缓解了由记忆引起的梯度爆炸(图 5.C),但可能会杀死最小特征值的任何学习信号。最后,LSTM 成功地保持了不同记忆水平下的梯度范数,这与我们在第 3.2 节中发展起来的直觉一致,尽管 LSTM 特定的参数表现出比前馈参数更小的梯度。

6. 结论

梯度消失和爆炸使 RNN 的学习变得复杂,但解决这些问题还不够。我们发现了训练这种网络的另一个难点,这根源于其迭代性质,并在动态稳定性边缘出现。重新参数化和自适应学习率在实践中可以有效地缓解这种行为,而对递归进行对角化则简化了这两者。我们的分析还揭示了学习复数特征值角度的复杂性(见附录 B.1.4),这可能解释了为什么在最近的大多数状态空间模型架构中,复数未被发现有用 [21, 22]。

我们的研究还发现了独立模块(这指的是神经元,更一般地说是小头(small heads))与线性递归网络中的自适应学习率优化器之间的共生关系。这种设计模式具有有希望的特性:它促进了在线学习 [58] 和组合泛化 [59],允许高水平的并行化 [22],并在高层次上匹配皮质柱中皮层的模块化组织 [60]。理解如何在保持其出色优化特性的同时增加小线性模块的表达能力,构成了未来研究的一个有希望的方向。

EDPJ
CV 博士在读。文章搜索:公众号主页右上角放大镜搜关键词。
 最新文章