什么是Gradient Checkpointing?

文摘   2024-11-16 10:02   辽宁  

神经网络主要有两种内存使用方式:

1.模型权重(这是固定的内存使用)。

2.训练过程(这是动态的内存使用)。

训练过程又有两种方式:

a. 在前向传播过程中,计算并存储所有层的激活值。

b. 在反向传播过程中,计算每一层的梯度。

后者,即动态内存使用,通常会限制我们训练更大模型和使用更大批次的能力。

这是因为内存使用量是随着批次大小的增加而成比例增长的。

然而,有一个相当了不起的技术可以让我们在保持整体内存使用不变的情况下增加批次大小。

这个技术叫做梯度检查点(Gradient Checkpointing),根据我的经验,它是一个严重被低估的减少内存开销的技术。

让我们更详细地理解一下这个技术。

Gradient Checkpointing如何工作的?

梯度检查点基于两个关于神经网络工作方式的观察:

1.一个特定层的激活值可以仅通过前一层的激活值来计算。例如,在下面的图像中,“层B”的激活值可以仅通过“层A”的激活值来计算。

2.更新某一层的权重只依赖于两件事:

a. 该层的激活值。

b. 在下一层(右边层)计算出的梯度(或者更确切地说,是累积梯度)。

梯度检查点利用这两条观察来优化内存使用。

它的工作原理如下:

步骤 1) 在前向传播之前,将网络划分为多个段:

步骤 2) 在前向传播过程中,仅存储每个段中第一层的激活值。其余的激活值在被用来计算下一层的激活值后立即丢弃。

步骤 3) 在反向传播过程中。为了更新某一层的权重,使用该段中的第一层激活值重新计算该层的激活值。

人工智能大讲堂
专注人工智能数学原理和应用
 最新文章