FlashAttention算法之美:极简推导版

文摘   2024-11-07 18:30   中国香港  

知乎:方佳瑞(已授权)
地址:https://zhuanlan.zhihu.com/p/4264163756

FlashAttention(FA)是大模型训练和推理性能优化最重要的组件。从并行计算角度,FA算法设计是可以写进教科书的。通过利用简单数学知识,等价变化任务的计算流程,从而将算法并行执行起来,实现最佳的内存效率,这无疑是并行计算Phd心中最完美的idea。

FA的算法流程也以复杂著称,原始论文中公式包含纷繁变量,复杂计算流程图,让普通人很难理解。对于FA也有很多非官方解读版本,比如Zihao Ye的《From Online Softmax to FlashAttention》(2023年5月作为UW研究生作业发布)。Ye从FA的演化历史入手,囊括了完毕的前置知识,让读者读此一文即可搞懂FA。但是Ye版本力求完备,公式保持完整严谨,让普通人很那一遍理解,需要反复揣摩。Zhihao Ye也是FlashInfer的作者。

去年此时,FA v2刚刚更新,我写过文章分析过FA的历史和现状。最近工作需要用到FA的细节,于是又重新看了一下FA的论文[1]

这里post一下我自己的极简推导版本,平时作为小抄看代码时候使用。也希望能够帮助读者搞懂,理解FA算法之美。这里我避免使用任何复杂数学符号,需要基本线性代数知识就可以跟随下来。如有纰漏也希望大家指正。

一、前置知识

  1. 矩阵乘法分块:
  • O = QK^TV,三矩阵连续乘法,可以采用分块的方式避免materialize QK^T。但是如果用softmax之后,就不能那么简单的分块计算了。
  1. Softmax简化公式:
  • softmax(x) = exp(x - m) / sum(exp(x - m)),其中 m = max(x)
  1. exp和log运算性质:
  • exp(a + b) = exp(a) × exp(b),因此 exp(x - m2) = exp(x - m1) × exp(m1 - m2)
  • log(a x b) = log(a) + log(b)log(a / b) = log(a) - log(b)
  1. 矩阵乘法性质:
  • [A1, A2][V1, V2] = A1V1 + A2V2

二、目标

我们的目标是计算:O = softmax(Q × [K1, K2]^T) × [V1, V2]

我们接下来把小目标搞懂(如下图所示,我们从左向右计算有颜色的部分),就可以扩展到更大的问题规模。对于Q增加一个外层循环,KV增加内层循环长度,K1, K2, K3,… & V1, V2, V3, …。

三、FlashAttention计算流程

Step 1: 求Q x K^T

首先,每个设备独立计算两个矩阵的乘积:

  • X1 = Q × K1^T
  • X2 = Q × K2^T

Step 2: 寻找最大值

找到X矩阵中的最大值,算m1不准确,拿到X2时候需要回头看

m1 = max(X1)m2 = max(X1, X2)= max(m1, max(X2)

Step 3: Softmax分子登场

计算Softmax的分子部分,同样需要用X2的m2去矫正之前的计算结果,矫正因子alpha

  • alpha = exp(m1 - m2)
  • a1 = exp(X1 - m1)
  • a2 = exp(X2 - m2) & a1' = a1 × alpha 【前置知识3】

Step 4: Softmax分母来啦

  • d1 = sum(exp(X1 - m1))
  • d2 = sum(exp(X2 - m2))& d1' = sum(exp(X1 - m2)) = d1 × alpha

Step 5: 组装时刻 ️

最后,先规定

  • O1 = a1 × V1 / d1
  • O2 = a2 × V2 / d2

然后,进行最终组装:

  • d12 = d1 × alpha + d2
  • O = [a1', a2] / (d1' + d2) × [V1, V2]
  • 因式分解:O = a1' × V1 / (d1' + d2) + a2 × V2 / (d1' + d2) 【前置知识4】
  • 等价变化:O = a1 × alpha × V1 / d12 + a2 × V2 / d12
  • 最终:O = O1 × d1 / d12 × alpha + O2 × d2 / d12

仔细看最终公式展示了优雅的对称性,以此类推可以计算O3, O4, …。

FA实现中返回一个叫做LSE的变量(Log Sum Exp)。所谓LSE就是对X = QK^Tscale的值做log(sum(exp(X)))。这在我介绍的算法流程中没有出现。使用它可以增加数值稳定性,我们再对上述流程略微改进,实现LSE的版本。The Log-Sum-Exp TrickFA实现中返回一个叫做LSE的变量(Log Sum Exp)。所谓LSE就是对X = QK^Tscale的值做log(sum(exp(X)))。这在我介绍的算法流程中没有出现。使用它可以增加数值稳定性,我们再对上述流程略微改进,实现LSE的版本。

四、FlashAttention LSE版本计算流程

Step 1: 求Q x K^T

首先,每个设备独立计算两个矩阵的乘积:

  • X1 = Q × K1^T
  • X2 = Q × K2^T

Step 2: 寻找最大值

找到X矩阵中的最大值,算m1不准确,拿到X2时候需要回头看

  • m1 = max(X1)
  • m2 = max(X1, X2)= max(m1, max(X2)

Step 3: Softmax分子登场

计算Softmax的分子部分,同样需要用X2的m2去矫正之前的计算结果,矫正因子alpha

  • alpha = exp(m1 - m2)
  • a1 = exp(X1 - m1)
  • a2 = exp(X2 - m2) & a1' = a1 × alpha 【前置知识3】

Step 4: Softmax分母来啦

  • d1 = sum(exp(X1 - m1))
  • lse1 = log(d1) = lse(X1-m1)
  • d2 = sum(exp(X2 - m2))& d1' = sum(exp(X1 - m2)) = d1 × alpha
  • lse2 = log(d2) = log(d1) + log(alpha) = lse1 + m1 - m2

Step 5: 组装时刻 ️

最后,先规定

  • O1 = a1 × V1 / d1
  • O2 = a2 × V2 / d2

然后,进行最终组装:

  • d12 = d1 × alpha + d2
  • lse12 = log(d1 x (alpha + d2/d1)) = lse(d1) + lse(alpha + d2/d1) = lse1 + lse(alpha + exp(lse2 - lse1))
  • O = [a1', a2] / (d1' + d2) × [V1, V2]
  • 因式分解:O = a1' × V1 / (d1' + d2) + a2 × V2 / (d1' + d2) 【前置知识4】
  • 等价变化:O = a1 × alpha × V1 / d12 + a2 × V2 / d12
  • 最终:O = O1 × d1 / d12 × alpha + O2 × d2 / d12
  • 最终LSE版本:O1 x exp(lse1 - lse12) x alpha + O2 x exp(lse2 - lse12)【LSE更新方式】

这里一个技巧是我们计算第二个block的lse时候,再加上一个当前最大值max,也就是lse2_ = lse2 + m2; lse1_ = lse1 + m1,结果可以消除alpha,让公式的对称性更完美

  • 更新全局的lse12:lse12 = lse1 + log(alpha + exp(lse2_ + m2 - lse1_ - m1))
    = lse1_ - m1 + log(alpha + alpha * exp(lse2_ - lse1_))
    = lse1_ - m2 + log(1 + exp(lse2_ - lse1_))
  • 更新O:O = O1 x exp(lse1_ - m1 - (lse12_ - m2) + m1 - m2) + O2 x exp(lse2_ - m2 - (lse12_ - m2))
  • 更新O:O = O1 x exp(lse1_ - lse12_) + O2 x exp(lse2_ - lse12_)
参考资料
[1]

大模型训练加速之FlashAttention系列:爆款工作背后的产品观: https://zhuanlan.zhihu.com/p/664061672


备注:进群,进入大模型技术群

添加好友:baobaogpt,记得备注呦


包包算法笔记
大模型技术和行业认知
 最新文章