下游训练任务起飞!FlashAttention终于高性能地支持多样的attention mask!

文摘   2024-11-19 16:59   新加坡  

News1: 最近很多朋友秋招斩获多个offer,前来报喜以及关于选哪个offer来问我意见,在此恭喜,应他们的要求,为了不被定位和求稳,就不放截图了。

News2: 三个月前发布的面试冲刺课程中,新增了6道关于FA v1 v2和flashdecoding的题目和答案,参加了的朋友可移至飞书查看

FlashAttention(后面统称FA)早在去年我就写了三篇文章分别讲解了v1 v2和推理版本的flash decoding,我个人认为我这三篇文章基本涵盖了关于FA的思想,有的读者可能关注我比较晚,未曾看到过,这里给个链接,感兴趣的朋友可以转至v1,v2,flashdecoding

本文主分享一个最新的研究进展,来自PaddlePaddle,通过将attention mask稀疏化以降低FA支持各种各样mask的访存成本,从而支持了相比官方FA更高性能的对loss和精度无损的FA with attention mask kernel

背景

FA至今只支持特定的几种mask类型,tri dao的那个github上已经有一个关于此的issue#352挂了一年了,有很多人提出了需求:让tri dao支持一种API,允许用户自定义attention mask,但是直到今日,也还没有实现,所以就有人自发地去给FA支持各种各样的复杂的attention mask,以使得不同的下游任务的训练(ep.SFT、reward models、DPO、in-context learning.etc)也可以无损地、高性能地利用FA的性能,否则无法利用FA的性能那么训练起来非常慢,且上下文句子长度有限

各种任务的attention mask

各种奇形怪状的都有,怪不得FA的issue里面对attention mask的feature支持这么迫切

动机

FLASHMASK可以理解为是对FA的一个扩展。FA旨在解决传统注意力机制在处理长句子时面临的计算和内存需求呈平方阶(O(N^2))增长的问题。这种增长对于 Transformer 模型在任意一个硬件上来说都是一个重大挑战,尤其是长句子的LLM训练中。具体点讲,FA通过 IO 感知的内存优化减少了注意力延迟,并消除了对 O(N^2)的内存依赖。然而,在上述训练场景下,FA的不足有二:

  • 对某些attention mask类型的原生支持有限,并不天然地适应更复杂的mask需求,如上图,目前FA repo只支持(1)到(4)的causal mask、bidirectional mask、SWA和causal doc mask。

  • 以往的方法使用稠密mask矩阵,这会导致 O(N^2) 的访存增长,从而效率不高,导致支持的最大上下文长度有限。

为了解决这些问题,paddlepaddle提出了 FLASHMASK,核心idea是引入了一种列式稀疏表示的attention mask,有效地表示了广泛的mask类型,并有利FA with mask kernel的优化。通过采用这种新颖的mask稀疏表示方法,FLASHMASK 实现了对attention mask的线性访存复杂度 O(N),使其更适合于长上下文的训练。此外,这种表示方法还使得kernel能够利用attention mask中的稀疏性来消除不必要的计算,从而在不牺牲计算精度的情况下提高计算效率。

方法

提出一种column wise的稀疏表示方法来表示传统的稠密attention mask

先抛一个问题:为什么可以用稀疏表示方法表示稠密attention mask?

答:熟悉稀疏矩阵的朋友都知道,表示稀疏矩阵通常用几个一维数组或向量就可以表示,无需用二维tensor,这也是稀疏化的重要收益来源,也是稀疏属于压缩范畴的原因,因为实实在在的减少了内存占用、访存量、计算量。同理,这里也是相同的思想,用4个向量表示k矩阵每一个token在左下角和右上角对应的哪些q token被mask了

具体,FLASHMASK是怎么做的呢?如下图

右边针对左图的12个mask,给出了它所对应的稀疏向量表示,paddle把mask分为两个区域,一个左下角,一个右上角,LT开头的描述左下角的masked情况,UT表示右上角的masked情况,(6)举例,如下,q有10个token,k也有10个token,针对每个k维度的token,我们来计算对应q维度token的masked情况,比如对于5号token,灰色部分有下图红圈部分,所以[LTS,LTE)=[7,10),[UTS,UTE)=[2,4)

问题又来了:这4个向量是针对每一个token的,但是FA是tile based算法,如何把这几个向量也tiling一下?

答:小事情,我把这几个向量的值也tile一下不久好了,如下为tile size=4的情况,并且,计算出每个block的LTSmin LTSmax LTEmin LTEmax,同理UTS UTE这里不再列举

然后,paddle再制定了一个规则,用以决定哪些mask对应的block是fully masked、partial masked以及unmasked,这里具体我不细讲规则了,感兴趣的可以直接去看看FLASHMASK的论文。这里我讲区分出哪些block是fully masked、partial masked以及unmasked后如何处理:unmasked不用就经过mask矩阵的计算,直接算softmax,fully masked直接skip掉后面的计算,都mask掉了还算啥对吧,只有block里面一部分mask掉一部分unmask的才apply mask,那么由此就省略了一部分计算量。

至于访存,我们用了8个向量来表示稀疏mask,每个向量的大小是seqlen/blocksize, 很明显这是O(N)的复杂度                                        

具体算法


本质上是扩展了FA,即新增了紫色部分,对于熟悉FA算法的话,应该不难理解,在此不赘述了。


实验效果

training吞吐:绿色的flashmask最好

loss收敛:证明flashmask的idea无损loss

kernel性能:遥遥领先pytorch compiler based flexAttention


AI不止算法
AI-HPC/AI工程/AI推理加速/AI算子开发的技术分享和入门转行学习的全套解决方案提供