Attention 的计算公式中为什么要除以根号k ?这么回答惊艳面试官!

文摘   2024-05-31 09:02   上海  

这个题目可以说是 NLP 面试中一个高频出现的问题,基本上问到 Attention 或者 Transformers 的时候都会问。

这是个好题目,因为很快能了解到面试同学的数学功底怎么样。

如果你是 NLP 学生或者从业者,不妨先试着回答一下。如果有更好的答案欢迎交流。

最基本的答案
这个问题在《Attention is All You Need》的原始论文中是给出了一个粗略的答案的。

While for small values of   the two mechanisms perform similarly, additive attention outperforms dot product attention without scaling for larger values of   [3]. We suspect that for large values of  , the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients. To counteract this effect, we scale the dot products by .

作者说,当 的值变大的时候,softmax 函数会造成梯度消失问题,所以设置了一个 softmax 的 temperature 来缓解这个问题。这里 temperature 被设置为了 , 也就是乘上

这个回答当然没什么问题,但是接下来就会再问两个问题:

  1. 为什么会导致梯度消失?

  2. 为什么是 , 有更好的值么?

下面来回答一下这两个衍生的问题。

变大为什么会导致梯度消失?

先说结论:

  1. 如果 变大, 方差会变大。

  2. 方差变大会导致向量之间元素的差值变大。

  3. 元素的差值变大会导致 softmax 退化为 argmax, 也就是最大值 softmax 后的值为 1, 其他值则为 0。

  4. softmax 只有一个值为 1 的元素,其他都为 0 的话,反向传播的梯度会变为 0, 也就是所谓的梯度消失。

下面分别证明这 4 点。

第一点: 变大,QK 方差会变大。

假设 Q和 K的向量长度为 , 均值为0, 方差为 1。则 Q和 K的点积的方差为:

所以,当 变大时,方差变大。证毕。

第二点:方差变大会导致向量之间元素的差值变大。

这似乎是一个显而易见的结论,因为方差变大就是代表了数据之间的差异性变大。

第三点:softmax 退化为 argmax

当输入向量的方差变得非常大时,softmax 函数将会趋近于将最大的元素赋值为 1,而其他元素赋值为 0,也就是是 argmax 函数。用公式表示的话:

第四点:softmax 什么情况下会梯度消失

这一块有点复杂,直接看以下的实验,一目了然。

梯度实验

我们同样做个实验,看看梯度到底为多少。

import numpy as np

n = 10

x1 = np.random.normal(loc=0, scale=1, size=n)
x2 = np.random.normal(loc=0, scale=np.sqrt(512), size=n)
print('x1最大值和最小值的差值:', max(x1) - min(x1))
print('x1最大值和最小值的差值:', max(x2) - min(x2))

def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), keepdims=True)

def softmax_grad(y):
    return np.diag(y) - np.outer(y, y)

ex1 = softmax(x1)
ex2 = softmax(x2)
print('softmax(x1) =', ex1)
print('max of gradiant of softmax(x1) =', np.max(softmax_grad(ex1)))
print('softmax(x2) =', ex2)
print('max gradiant of softmax(x2) =', np.max(softmax_grad(ex2)))

其结果为:

x1最大值和最小值的差值: 1.8973472870218264
x1最大值和最小值的差值: 66.62254341144866
softmax(x1) = [0.16704083 0.21684976 0.0579299  0.05408421 0.16109133 0.14433417
 0.03252007 0.05499126 0.04213939 0.06901908]
max of gradiant of softmax(x1) = 0.1698259433168865
softmax(x2) = [4.51671361e-19 2.88815837e-21 9.99999972e-01 3.02351231e-17
 3.73439970e-25 8.18066523e-13 2.78385563e-08 1.16465424e-29
 7.25661271e-20 3.21813750e-21]

可以看出,在方差为 的时候,长度仅仅为10的向量x2,其梯度就已经快没有了,最大值为2.78e-8。

而如果将方差控制在1,则最大的梯度为0.1698

scale 有更好的值么?

从上一节的第一步的证明,可以发现,scale 的值为 其实是把 归一化成了一个 均值为 0, 方差为 1 的向量。

至于是不是最好呢?不好说,因为参数的分布我们不太清楚。苏神曾经试图求解了一些常用分布的最佳 scale 值,感兴趣的可以看下:https://spaces.ac.cn/archives/9812


不摸鱼的小律
互联网大厂算法工程师一枚,分享各种技术、职场热点和感悟。不做每日打卡的路人。
 最新文章