知乎:手抓饼熊
地址:https://zhuanlan.zhihu.com/p/12302566679
编辑:「深度学习自然语言处理 公众号」,转载授权请联系作者
大模型推理并行方向只有2个难点(特指并行而非分布式,个人觉得分离式架构、分布式调度等均属于大模型推理分布式领域),一个是序列并行、一个是张量并行的通信计算重叠优化。序列并行笔者介绍的比较多了(手抓饼熊:大模型推理序列并行),本文介绍一下张量并行通信计算重叠的一些优化方案。
背景
张量并行目前已经是大模型推理的一个必备的技术,笔者之前分析过vLLM的张量并行(手抓饼熊:vLLM源码之模型并行)。然而张量并行一个缺点是通信开销,当推理采用PCIE类卡的时候,该缺点更加明显。
针对通信开销的缺点,训练框架已经有了通信计算重叠优化(手抓饼熊:Megatron-LM Tensor并行计算通信重叠),而目前开源的推理引擎如vLLM和SGLang均没有实现该功能。最近知乎开源的大模型推理引擎ZhiLight支持张量并行通信计算重叠(如何评价知乎刚刚宣布开源的大模型推理框架 ZhiLight?)。笔者认为,在2025年张量并行通信计算重叠将会是所有开源框架的必备功能。本文结合当前最新的论文,介绍张量并行通信计算重叠的做法。
张量并行的几种实现
2.1 朴素版张量并行
如上图所示,标准的Transformer张量并行结构,从图中我们可以看到,每次Transformer前向需要进行2次AllReduce,这会导致模型前向执行AllReduce的时候,计算的GPU比较空闲。
2.2 Gemm版本通信计算重叠
当我们说到张量并行计算通信重叠,一个最直观的实现是分布式Gemm + AllReduce的overlap,目前TransformerEngine、torch([Distributed w/ TorchTitan] Introducing Async Tensor Parallelism in PyTorch)和字节Flux都是采取类似的实现,Flux在优化方面做的更细一点,如上图所示,o_proj其实是一个分布式矩阵乘法,其后面跟着一个Allreduce算子。放大分布式Gemm + AllReduce如上图所示。对上述图示说明如下:
上图上半部分展示的是分布式Gemm,A @ B,其中A是列切,B是行切,A0 @ B0 得到蓝色的 C00 C10部分,A @ B1 得到蓝色的 C01 C11部分。蓝色的 C00 C11和黄色的C01 C11进行ReducesScatter得到 C0 C1。 上图的下半部分是overlap的版本,我们看到原来的A是按照列切分的,计算的时候再按照行分块计算。分2个step计算,在step0时候,A00 @ B0会得到 C00,同样 A01 @ B得到C01,在step1时候,A10 @ B0会得到 C00,同样 A11 @ B得到C11,此时可以同时进行step0计算结果C的规约。
Flux的思想和这个类似,但是还有很多优化,以后有时间再探讨。
2.3 请求间通信计算重叠
上图是张量并行通信与计算重叠的另一种实现(Liger: Interleaving Intra- and Inter-Operator Parallelism for Distributed Large Model Inference)。从图中可以看出该方法有如下特点:
会有多个请求,不同的请求会有不同的stream。 执行请求1的计算的时候,请求2正在进行通信操作,反之依然。 类似的这种做法还有Nanoflow,如下图所示。 这种做法,按理说不需要重写一个计算通信的kernel(右图是计算和通信对应的SM分配情况),但是整体调度实现会很复杂,同样后续有机会再深入分析。
2.4 请求内通信计算重叠
第3种通信计算重叠方式如上图所示(ISO: Overlap of Computation and Communication within Seqenence For LLM Inference),看起来和Flux、Liger均不一样。但凡通过这张图,要是能看懂这个方法的思路,那么大模型并行基础就不错,至少对分布式attention实现还是有很深了解的。
这张图看起来比较疑惑,因为他是单张卡不同的流的图示,把这张图变成多卡。我用红色、黑色、黄色框画了几个step,后续会用到,可以结合着看。如上图所示,上面的是rank0,下面的是rank1 。可以看到一个transformer对于单卡而言,通信和计算是重叠的。从图中可以看出,attention采用了分块attention实现的。关于MLP的计算通信重叠和2.2节应该类似,我们重点看attention如何实现通信计算重叠的。上图是整体流程,可以结合上面的流程图一起看,核心思想是,每张卡,在序列维度分块执行,分块0执行的时候没有通信,分块1执行的时候,执行分块0的通信,具体细节可以看图。
总结
听君一席话,如听君一席话。