01
引言
在现代深度学习领域,注意力机制是最强大的进步之一。具体来说,基于点积的注意力是Transformer模型的支柱,它给自然语言处理(NLP)等领域带来了革命性的变化。然而,这种技术会耗费大量内存和计算资源,尤其是对于长序列。FlashAttention 是一种优化注意力计算的技术,它将核心操作分解成易于管理的块,在保持精度的同时减少内存的使用。
我们将通过本文的一个简单易懂的示例,以分块矩阵乘法为重点,探讨 FlashAttention 背后的数学原理。
02
基于点积的注意力
基于点积的注意力的核心计算公式如下:
Q为Query矩阵
K为Key矩阵
V为Value矩阵
d_k为Key向量的维度
简单来说,点积 Q*K^T 反映了每个Key和Value之间的相似度。然后用 sqrt(d_k) 对结果进行缩放,以防止出现极大的梯度。最后,softmax 函数将这些相似度得分归一化为概率分布。得出的权重用于计算 V 中各值的加权和。
03
FlashAttention 中的分块计算
在处理大型矩阵时,直接计算注意力权重需要生成大量中间矩阵,这会占用大量内存。FlashAttention 通过逐块处理矩阵运算解决了这一问题。
输入矩阵 Q、K 和 V 被分成较小的块,分别进行处理,而不是一次性计算完整的矩阵乘积 Q*K^T。这样,矩阵运算就能在现代 GPU 的内存限制范围内进行。
04
举个栗子
设置矩阵
将矩阵划分为块
逐块计算注意力
逐行计算softmax,如下:
与对应的V进行相乘,结果如下:
同样,对于Block 2,我们重复同样的过程,可以得到序列中第二部分的结果。
05
分块计算的优势
FlashAttention 中的分块方法每次只将矩阵中的一小部分保存在内存中,从而大大减少了内存使用量。此外,这种方法还能在 GPU 等硬件上实现更好的并行化,使大规模模型能在不牺牲精度的情况下更高效地计算注意力。
06
总结
点击上方小卡片关注我
添加个人微信,进专属粉丝群!