一文弄懂Flash-Attention

文摘   科技   2024-10-01 15:14   河南  
点击蓝字
 
关注我们










01


引言



在现代深度学习领域,注意力机制是最强大的进步之一。具体来说,基于点积的注意力是Transformer模型的支柱,它给自然语言处理(NLP)等领域带来了革命性的变化。然而,这种技术会耗费大量内存和计算资源,尤其是对于长序列。FlashAttention 是一种优化注意力计算的技术,它将核心操作分解成易于管理的块,在保持精度的同时减少内存的使用。


我们将通过本文的一个简单易懂的示例,以分块矩阵乘法为重点,探讨 FlashAttention 背后的数学原理。







02


基于点积的注意力


基于点积的注意力的核心计算公式如下:

上述公式中:
  • Q为Query矩阵

  • K为Key矩阵

  • V为Value矩阵

  • d_k为Key向量的维度

简单来说,点积 Q*K^T 反映了每个Key和Value之间的相似度。然后用 sqrt(d_k) 对结果进行缩放,以防止出现极大的梯度。最后,softmax 函数将这些相似度得分归一化为概率分布。得出的权重用于计算 V 中各值的加权和。

对于小矩阵来说,这很简单,但对于大序列(常见于文本生成等 NLP 任务)来说,这可能会耗费大量内存,因为必须存储整个 Q K^T 后的矩阵。




03


 FlashAttention 中的分块计算

在处理大型矩阵时,直接计算注意力权重需要生成大量中间矩阵,这会占用大量内存。FlashAttention 通过逐块处理矩阵运算解决了这一问题。

输入矩阵 Q、K 和 V 被分成较小的块,分别进行处理,而不是一次性计算完整的矩阵乘积 Q*K^T。这样,矩阵运算就能在现代 GPU 的内存限制范围内进行。






04


  举个栗子


让我们以小矩阵为例,使用分块式 FlashAttention 计算它们的注意力权重。
  • 设置矩阵
假设有以下 Q、K 和V 的矩阵:

  • 将矩阵划分为块

为了简单起见,我们可以分别把 Q、K 和 V分成两个较小的block块。如下所示:

  • 逐块计算注意力

该我们首先来看第 1 组(Q、K 和 V中的第 1 行),计算点积,如下:

逐行计算softmax,如下:

与对应的V进行相乘,结果如下:

同样,对于Block 2,我们重复同样的过程,可以得到序列中第二部分的结果。




05


  分块计算的优势


FlashAttention 中的分块方法每次只将矩阵中的一小部分保存在内存中,从而大大减少了内存使用量。此外,这种方法还能在 GPU 等硬件上实现更好的并行化,使大规模模型能在不牺牲精度的情况下更高效地计算注意力。

FlashAttention 中使用的分块方法具有足够的通用性,除了 softmax 之外,它还可以与其他函数配合使用,例如 ReLU、sigmoid 或 max-pooling,具体取决于应用情况。其核心原理是将操作分解成更小、更易于管理的块,从而减少内存负荷,使计算更加灵活。





06


  总结

FlashAttention 中的分块计算优化了大型Transformer模型中传统注意力机制的性能。通过分块处理矩阵并融合softmax和矩阵乘法等运算,FlashAttention 显著提高了内存效率,同时保持了相同的精度水平。这项技术是处理规模不断扩大的现代 NLP 和深度学习任务的关键。
您学废了吗?






点击上方小卡片关注我




添加个人微信,进专属粉丝群!



AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
 最新文章