大模型推理张量并行的4种模式

教育   2024-12-22 17:16   江苏  

知乎:手抓饼熊
地址:https://zhuanlan.zhihu.com/p/12302566679
编辑:「深度学习自然语言处理 公众号」,转载授权请联系作者

大模型推理并行方向只有2个难点(特指并行而非分布式,个人觉得分离式架构、分布式调度等均属于大模型推理分布式领域),一个是序列并行、一个是张量并行的通信计算重叠优化。序列并行笔者介绍的比较多了(手抓饼熊:大模型推理序列并行),本文介绍一下张量并行通信计算重叠的一些优化方案。

LLM所有细分方向群+ACL25/ICML25/NAACL25投稿群->LLM所有细分领域群、投稿群从这里进入!

背景

张量并行目前已经是大模型推理的一个必备的技术,笔者之前分析过vLLM的张量并行(手抓饼熊:vLLM源码之模型并行)。然而张量并行一个缺点是通信开销,当推理采用PCIE类卡的时候,该缺点更加明显。

针对通信开销的缺点,训练框架已经有了通信计算重叠优化(手抓饼熊:Megatron-LM Tensor并行计算通信重叠),而目前开源的推理引擎如vLLM和SGLang均没有实现该功能。最近知乎开源的大模型推理引擎ZhiLight支持张量并行通信计算重叠(如何评价知乎刚刚宣布开源的大模型推理框架 ZhiLight?)。笔者认为,在2025年张量并行通信计算重叠将会是所有开源框架的必备功能。本文结合当前最新的论文,介绍张量并行通信计算重叠的做法。

张量并行的几种实现

2.1 朴素版张量并行

如上图所示,标准的Transformer张量并行结构,从图中我们可以看到,每次Transformer前向需要进行2次AllReduce,这会导致模型前向执行AllReduce的时候,计算的GPU比较空闲。

2.2 Gemm版本通信计算重叠

当我们说到张量并行计算通信重叠,一个最直观的实现是分布式Gemm + AllReduce的overlap,目前TransformerEngine、torch([Distributed w/ TorchTitan] Introducing Async Tensor Parallelism in PyTorch)和字节Flux都是采取类似的实现,Flux在优化方面做的更细一点,如上图所示,o_proj其实是一个分布式矩阵乘法,其后面跟着一个Allreduce算子。放大分布式Gemm + AllReduce如上图所示。对上述图示说明如下:

  • 上图上半部分展示的是分布式Gemm,A @ B,其中A是列切,B是行切,A0 @ B0 得到蓝色的 C00 C10部分,A @ B1 得到蓝色的 C01 C11部分。蓝色的 C00 C11和黄色的C01 C11进行ReducesScatter得到 C0 C1。
  • 上图的下半部分是overlap的版本,我们看到原来的A是按照列切分的,计算的时候再按照行分块计算。分2个step计算,在step0时候,A00 @ B0会得到 C00,同样 A01 @ B得到C01,在step1时候,A10 @ B0会得到 C00,同样 A11 @ B得到C11,此时可以同时进行step0计算结果C的规约。

Flux的思想和这个类似,但是还有很多优化,以后有时间再探讨。

2.3 请求间通信计算重叠

上图是张量并行通信与计算重叠的另一种实现(Liger: Interleaving Intra- and Inter-Operator Parallelism for Distributed Large Model Inference)。从图中可以看出该方法有如下特点:

  • 会有多个请求,不同的请求会有不同的stream。
  • 执行请求1的计算的时候,请求2正在进行通信操作,反之依然。
  • 类似的这种做法还有Nanoflow,如下图所示。
  • 这种做法,按理说不需要重写一个计算通信的kernel(右图是计算和通信对应的SM分配情况),但是整体调度实现会很复杂,同样后续有机会再深入分析。

2.4 请求内通信计算重叠

第3种通信计算重叠方式如上图所示(ISO: Overlap of Computation and Communication within Seqenence For LLM Inference),看起来和Flux、Liger均不一样。但凡通过这张图,要是能看懂这个方法的思路,那么大模型并行基础就不错,至少对分布式attention实现还是有很深了解的。

这张图看起来比较疑惑,因为他是单张卡不同的流的图示,把这张图变成多卡。我用红色、黑色、黄色框画了几个step,后续会用到,可以结合着看。如上图所示,上面的是rank0,下面的是rank1 。可以看到一个transformer对于单卡而言,通信和计算是重叠的。从图中可以看出,attention采用了分块attention实现的。关于MLP的计算通信重叠和2.2节应该类似,我们重点看attention如何实现通信计算重叠的。上图是整体流程,可以结合上面的流程图一起看,核心思想是,每张卡,在序列维度分块执行,分块0执行的时候没有通信,分块1执行的时候,执行分块0的通信,具体细节可以看图。

总结

听君一席话,如听君一席话。



备注:昵称-学校/公司-方向/会议(eg.ACL),进入技术/投稿群


id:DLNLPer,记得备注呦

深度学习自然语言处理
一个热衷于深度学习与NLP前沿技术的平台,期待在知识的殿堂与你相遇~
 最新文章