知乎:方佳瑞(已授权)
地址:https://zhuanlan.zhihu.com/p/4264163756
FlashAttention(FA)是大模型训练和推理性能优化最重要的组件。从并行计算角度,FA算法设计是可以写进教科书的。通过利用简单数学知识,等价变化任务的计算流程,从而将算法并行执行起来,实现最佳的内存效率,这无疑是并行计算Phd心中最完美的idea。
FA的算法流程也以复杂著称,原始论文中公式包含纷繁变量,复杂计算流程图,让普通人很难理解。对于FA也有很多非官方解读版本,比如Zihao Ye的《From Online Softmax to FlashAttention》(2023年5月作为UW研究生作业发布)。Ye从FA的演化历史入手,囊括了完毕的前置知识,让读者读此一文即可搞懂FA。但是Ye版本力求完备,公式保持完整严谨,让普通人很那一遍理解,需要反复揣摩。Zhihao Ye也是FlashInfer的作者。
去年此时,FA v2刚刚更新,我写过文章分析过FA的历史和现状。最近工作需要用到FA的细节,于是又重新看了一下FA的论文[1]。
这里post一下我自己的极简推导版本,平时作为小抄看代码时候使用。也希望能够帮助读者搞懂,理解FA算法之美。这里我避免使用任何复杂数学符号,需要基本线性代数知识就可以跟随下来。如有纰漏也希望大家指正。
一、前置知识
矩阵乘法分块:
O = QK^TV
,三矩阵连续乘法,可以采用分块的方式避免materializeQK^T
。但是如果用softmax之后,就不能那么简单的分块计算了。
Softmax简化公式:
softmax(x) = exp(x - m) / sum(exp(x - m))
,其中m = max(x)
。
exp和log运算性质:
exp(a + b) = exp(a) × exp(b)
,因此exp(x - m2) = exp(x - m1) × exp(m1 - m2)
。log(a x b) = log(a) + log(b)
,log(a / b) = log(a) - log(b)
矩阵乘法性质:
[A1, A2][V1, V2] = A1V1 + A2V2
。
二、目标
我们的目标是计算:O = softmax(Q × [K1, K2]^T) × [V1, V2]
我们接下来把小目标搞懂(如下图所示,我们从左向右计算有颜色的部分),就可以扩展到更大的问题规模。对于Q增加一个外层循环,KV增加内层循环长度,K1, K2, K3,… & V1, V2, V3, …。
三、FlashAttention计算流程
Step 1: 求Q x K^T
首先,每个设备独立计算两个矩阵的乘积:
X1 = Q × K1^T
X2 = Q × K2^T
Step 2: 寻找最大值
找到X矩阵中的最大值,算m1不准确,拿到X2时候需要回头看
m1 = max(X1)
m2 = max(X1, X2)= max(m1, max(X2)
Step 3: Softmax分子登场
计算Softmax的分子部分,同样需要用X2的m2去矫正之前的计算结果,矫正因子alpha
设 alpha = exp(m1 - m2)
a1 = exp(X1 - m1)
a2 = exp(X2 - m2) & a1' = a1 × alpha
【前置知识3】
Step 4: Softmax分母来啦
d1 = sum(exp(X1 - m1))
d2 = sum(exp(X2 - m2))& d1' = sum(exp(X1 - m2)) = d1 × alpha
Step 5: 组装时刻 ️
最后,先规定
O1 = a1 × V1 / d1
O2 = a2 × V2 / d2
然后,进行最终组装:
设 d12 = d1 × alpha + d2
O = [a1', a2] / (d1' + d2) × [V1, V2]
因式分解: O = a1' × V1 / (d1' + d2) + a2 × V2 / (d1' + d2)
【前置知识4】等价变化: O = a1 × alpha × V1 / d12 + a2 × V2 / d12
最终: O = O1 × d1 / d12 × alpha + O2 × d2 / d12
仔细看最终公式展示了优雅的对称性,以此类推可以计算O3, O4, …。
FA实现中返回一个叫做LSE的变量(Log Sum Exp)。所谓LSE就是对X = QK^Tscale的值做log(sum(exp(X)))。这在我介绍的算法流程中没有出现。使用它可以增加数值稳定性,我们再对上述流程略微改进,实现LSE的版本。The Log-Sum-Exp TrickFA实现中返回一个叫做LSE的变量(Log Sum Exp)。所谓LSE就是对X = QK^Tscale的值做log(sum(exp(X)))。这在我介绍的算法流程中没有出现。使用它可以增加数值稳定性,我们再对上述流程略微改进,实现LSE的版本。
四、FlashAttention LSE版本计算流程
Step 1: 求Q x K^T
首先,每个设备独立计算两个矩阵的乘积:
X1 = Q × K1^T
X2 = Q × K2^T
Step 2: 寻找最大值
找到X矩阵中的最大值,算m1不准确,拿到X2时候需要回头看
m1 = max(X1)
m2 = max(X1, X2)= max(m1, max(X2)
Step 3: Softmax分子登场
计算Softmax的分子部分,同样需要用X2的m2去矫正之前的计算结果,矫正因子alpha
设 alpha = exp(m1 - m2)
a1 = exp(X1 - m1)
a2 = exp(X2 - m2) & a1' = a1 × alpha
【前置知识3】
Step 4: Softmax分母来啦
d1 = sum(exp(X1 - m1))
lse1 = log(d1) = lse(X1-m1)
d2 = sum(exp(X2 - m2))& d1' = sum(exp(X1 - m2)) = d1 × alpha
lse2 = log(d2) = log(d1) + log(alpha) = lse1 + m1 - m2
Step 5: 组装时刻 ️
最后,先规定
O1 = a1 × V1 / d1
O2 = a2 × V2 / d2
然后,进行最终组装:
设 d12 = d1 × alpha + d2
设 lse12 = log(d1 x (alpha + d2/d1)) = lse(d1) + lse(alpha + d2/d1) = lse1 + lse(alpha + exp(lse2 - lse1))
O = [a1', a2] / (d1' + d2) × [V1, V2]
因式分解: O = a1' × V1 / (d1' + d2) + a2 × V2 / (d1' + d2)
【前置知识4】等价变化: O = a1 × alpha × V1 / d12 + a2 × V2 / d12
最终: O = O1 × d1 / d12 × alpha + O2 × d2 / d12
最终LSE版本: O1 x exp(lse1 - lse12) x alpha + O2 x exp(lse2 - lse12)
【LSE更新方式】
这里一个技巧是我们计算第二个block的lse时候,再加上一个当前最大值max,也就是lse2_ = lse2 + m2; lse1_ = lse1 + m1,结果可以消除alpha,让公式的对称性更完美
更新全局的 lse12:lse12 = lse1 + log(alpha + exp(lse2_ + m2 - lse1_ - m1))
= lse1_ - m1 + log(alpha + alpha * exp(lse2_ - lse1_))
= lse1_ - m2 + log(1 + exp(lse2_ - lse1_))
更新O: O = O1 x exp(lse1_ - m1 - (lse12_ - m2) + m1 - m2) + O2 x exp(lse2_ - m2 - (lse12_ - m2))
更新O: O = O1 x exp(lse1_ - lse12_) + O2 x exp(lse2_ - lse12_)
大模型训练加速之FlashAttention系列:爆款工作背后的产品观: https://zhuanlan.zhihu.com/p/664061672