图解Megatron TP中的计算通信overlap

学术   2025-01-04 20:09   江苏  

这篇文章想来探索Megatron中实现计算通信overlap的方法。

具体来说,Megatron的dp、tp和pp部分,都有可以做overlap的地方,本文探索的是tp部分(更准确地说是megatron sp-tp)。做这个探索的主要目的是:了解在哪些位置有做overlap的潜能,以及当前一些可行的实现思路。

最后,特别感谢overlap大师,megatron特级学者,大众点评美食优惠券killer:https://www.zhihu.com/people/yu-huo-er-wang 为本文提供的各类参考资料。

一、TP中哪些地方做了overlap

我们说的tp,是指“开启megatron sp做了activation显存优化”的tp,下图绘制了在megatron sp中单卡上Attn + MLP的运作流程(更多详细解读请参见这里

由此我们知道,在megatron sp中,tp部分的通讯被拆成若干个all-gatherreduce-scatter,在下文中我们会用AGRS来简称它。现在我们对tp中的fwd和bwd过程再做一个重新绘制,更清晰地展示通信步骤(绿色)和计算步骤(蓝色)

图中的红/黄框则分别展示了计算和通信之间的依赖关系,具体来说:

  • 红色:通信和相关的计算有依赖关系,需要串行。但是可以通过优化使得计算通信能overlap。

  • 黄色:通信和相关的计算没有依赖关系,可以并行。dgrad表示算的是input grad;wgrad表示算的是weight grad。

在Megatron-LM中,以下参数将控制是否开启红/黄框中的计算通信overlap:

  • tp_comm_overlap_ag:开启红框中ag相关的overlap

  • tp_comm_overlap_rs:开启红框中rs相关的overlap

  • tp_comm_bulk_dgrad:开启黄框中dgrad + ag的overlap

  • tp_comm_bulk_wgrad:开启黄框中wgrad + ag的overlap

  • tp_comm_overlap_rs_dgrad:黄框中的dgrad计算出来后会做rs,这里控制的就是这两者间的overlap。需要注意的是,如果此项为True,则会关闭 tp_comm_bulk_dgrad 和 tp_comm_bulk_wgrad(参见代码),猜测可能是因为同时开启时,存在对缓冲区资源的竞争及复杂管理等问题,会造成整体性能下降。

  • tp_comm_overlap:应该是一个总开关。只有当它为True时,才可以根据需要自动开关以上5项。否则是不开启tp部分的计算通讯overlap的(参考这份代码: https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/pytorch/module/layernorm_mlp.py#L265)

我们在Megatron-LM中设置这些参数,进而更改Transformer Engine(以下简称TE)的相关配置,最终的overlap是在TE中实现的。下面我们就来详细介绍这几个overlap技术。

二、tp_comm_overlap_ag

我们以下图圈出来的部分为例:

2.1 朴素all-gather 

假设我们采取的是最朴素的,没有任何overlap的策略,那么红框中的计算流程应该是下图这样的,这里假设tp_size = 2:

如上图所示,我们有2张gpu(tp_size = 2):

  • 在all-gather开始前,gpu0上存储着输入A0和模型分块B0,gpu1上存储着输入A1和模型分块B1。这里的B就对应着上图中的fc1。
  • 在朴素的all-gather中,我们先对输入A矩阵做all-gather,之后两张卡上的数据都变成[A0, A1]
  • 然后再各自个和B矩阵(fc1)相乘,得到最终的结果。不难发现,这里我们需要先等输入数据A到齐,然后才可以开始计算,也就是没有实现任何的计算通信overlap。

针对这张图,我们额外说明一点:例如[A0, A1]这样的形式,不代表A一定就是按照列切割的,只代表我们以分块的视角看待A。而Enisum可理解为一种自适应式的矩阵乘。因此我们要根据实际应用的场景来理解这张图,后文同理。

2.2 all-gather overlap p2p 

现在我们引入计算通信overlap,流程如下图所示:

  • 在最开始阶段,gpu0上存放着输入A0和模型分块B0,gpu1上存放着输入A1和模型分块B1。
  • 现在开始操作:
    • 在gpu0上,我们先把A0发送到gpu1,于此同时开始做gemm(A0, B0),以便得到C00,实现计算通讯overlap
    • 在gpu1上,我们先把A1发送到gpu0,于此同时开始做gemm(A1, B1),以便得到C11,实现计算通讯overlap
    • 等gpu0计算完C00,并收到A1后,它就可以继续gemm(A1, B0),以便得到C10;gpu1也是同理
  • 在overlap下,我们无需等到输入数据all-gather到齐后再进行计算,这样就可以减少整体的运行时间。

以上展示了2卡情况下的all-gather overlap,在多卡情况下也是同理,整体流程如下图所示:

  • partition即为卡,iteration则为每轮迭代,每轮迭代里包含了计算-通信的overlap。partition中的Di表示目前正在使用哪块输入做计算。

  • 从图中我们可以发现,这里采取的是p2p ring exchange的方式,也就是每张卡只和自己相邻的2张卡做数据的收-发。

  • 例如,在iteration0上时,每张卡做计算时,都用自己维护的那份数据做计算,所以这里Di和partition_i的下标是一一对应的。同时,每张卡会和相邻的2张卡做数据收发。例如partition2会把自己的数据D2发送给partition1,并从partition3上接受D3。

  • 再如,在iteration1上时,partition2就用自己收到的D3做计算了,同时它准备把D3发送给partition1,并从partition3上接收D0。以此类推。

相关的代码实践在TE仓库的CommOverlapP2PBase类下,大家可以自行阅读。注意代码里的A=weight, B=input,后文也是同理。

https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp#L561

三、tp_comm_overlap_rs

我们以下图圈出来的部分为例:

(备注:tp_comm_overlap_rs_dgrad,也就是右侧bwd中fc1_dgrad和下一个黄框中的RS做overlap的本质也是如此,所以后文不会再单独介绍它了)

3.1 朴素reduce-scatter 

假设我们有2张gpu(tp_size = 2)

  • B0和B1即为fc2,也就是按行切割的模型权重
  • A0和A1理解成fc2的输入。这里A0 = [A00, A10],A1 = [A10, A11]
  • 我们需要对B矩阵(fc2)的输出结果做reduce-scatter,而两张卡上的这个输出结果分别为C0 = [C00, C10], C1 = [C01, C11]。
  • 不难知道,做完reduce-scatter后:
    • gpu0上,C0 = C00 + C01
    • gpu1上,C1 = C10 + C11
  • 同样,在朴素reduce-scatter中,我们也需要等到[C00, C10]和[C01, C11]这个结果计算出来后,再做reduce-scatter,即计算通信没有overlap

针对这张图,我们额外说明一点:之所以要修改原始图片中的[A0, A1],是因为在tp mlp的fc2中,每张卡上的输入是不一样的,所以这里特别针对这个场景做了修改

3.2 reduce-scatter overlap p2p 

上图直接理解起来可能会有点头晕,我们不妨从一个更形象的视角理解一下:

  • 还是和all-gather overlap一样,这里采用的是p2p ring exchange的通信方式。

  • 在初始阶段,每个 gpu_i 都会发送出一个“碗C_i”,这个“碗C_i”的意思是,请把和我(gpu_i)相关的计算结果装在这个碗里。

  • 那么接下来,哪个gpu接收到这个碗C_i,它就要负责计算和这个 C_i 相关的结果,并把结果更新在 C_i 里

  • 假设我们共有 n 块gpu,那么 n-1 轮后,C_i 又流转回 gpu_i 手里,这时轮到 gpu_i 做和 C_i 相关的计算,这次计算结束后,就得到了最终的 C_i

基于此解读以上2卡的情况就不难了。那我们顺水推舟到多卡的情况:

  • 在iteration0里,由于p2p ring exchange机制的影响:
    • 对于partition0,它接收到来自partition1的碗C1,所以它只能做和C1相关的计算,也就是利用D1进行计算,然后把计算结果更新到C1中。
    • 其余partition也是同理
  • 在iteration1里:
    • 对于partition0,此时它接收的是来自partition1的碗C2(因为在上一次迭代中,partition1就是接收的partition2的C2,所以现在继续击鼓传花式地传递),因此partition0只能做C2相关的计算,也就是利用D2来计算。
    • 其余partition也是同理。
  • 在iteration3里:
    • 对于partition0,它终于接到了在iteration0里它传出去的碗C0,此时C0已经装满了其余partition上和partiton0相关的计算结果了,现在只要partition0把自己的这份结果更新进去,就大功告成了。

相关的代码实践可以参考下文链接。注意代码里的A=weight, B=input。这个过程理解起来有点绕,大家可以多体会下:

https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp#L924

3.3 reduce-scatter overlap pipeline chunk 

对于像fc2这种需要对输出结果做reduce-scatter的情况,除了p2p形式的overlap,megatron还提供了另外一种overlap的方法:pipeline chunk。

Pipeline chunk的思想是:假设原来是做完gemm(A, B)后再对结果reduce-scatter,那么现在我可以把矩阵(比如A)拆分成若干chunk,每次等gemm(chunk_i, B)的结果出来后,把这个结果发出去做reduce-scatter的同时,再继续做下一个chunk的计算,以此实现overlap。当然chunk的数量也不能太多(也就是不能把矩阵切得太小),否则反而会降低整体性能。在代码中默认chunk数量 = 4 (_num_splits = 4)。

代码详情可以参见下面链接。(我觉得这个代码写得可能有点问题。目前看来它能奏效是因为fc2的一个维度是4h,而num_splits = 4,切分后刚好是个方形矩阵(h, h),所以不管怎么转置做gemm都没有尺寸匹配问题,但是计算逻辑就不对了,另外这个函数似乎也只在fc2_fprop里用)

https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp#L324

四、 tp_comm_bulk_ag 和 tp_comm_bulk_rs

我们先来看下面框中 fc1_dgrad 和 AG 的overlap

这个过程对应到megatron sp的的架构图里如下:

之所以 fc1_dgrad和 AG 可以并行操作,是因为当前进程上做 fc1_dgrad只依赖上层传导过来的链式结果和fc1_weight。但是计算 fc1_wgrad却要依赖 AG 后的完整data(input activation)。fc1_dgrad 和 fc1_wgrad 计算完毕之后,前者做 RS 后继续向下层传导,后者用于更新 fc1_weight。

在代码中,管黄框中的overlap叫bulk overlap,并通过设置主流(stream_main)和通信流(stream_comm)来实现这个overlap,我们直接来看代码细节:

/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/

void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb,
                                   TensorWrapper &D, TensorWrapper &bias,
                                   TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
                                   bool accumulate, bool use_split_accumulator,
                                   CommOverlapType comm_type, TensorWrapper &rs_output,
                                   cudaStream_t stream_main)
 
{
// 设置通信的上下文参数_ub_comm
int ori_sms = _ub_comm->sms;
  _ub_comm->use_ce = _use_ce;
  _ub_comm->sms = _num_comm_sm;
  _ub_comm->cga_size = _cga_size;

// Catch up the default torch stream
// 同步主流(用以计算,stream_main)和通信流(用以通信,_stream_comm)
  NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); // 在主流中记录事件_start_comm
  NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));// 让通信流等待该时间完成,这样可以确保通信流在正确时间启动

// Communication: AG and RS
// 通信流执行通信:根据入参comm_type选择不同的通信类型(AG或RS)
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size();  // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
    allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
                                (cudaEvent_t)_comm_launch_event);
  } else {
    if (_ubuf.element_size() == 1) {
      assert(_ubuf_scale_inv_initialized);
      comm_elements *= 2;
      assert(rs_output.numel() == _ubuf.numel() / _tp_size);
      assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
      assert(rs_output.element_size() == 2);
      char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
      reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0,
                                                 comm_elements, _ub_comm, _stream_comm,
                                                 (cudaEvent_t)_comm_launch_event);
    } else {
      reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
                                      (cudaEvent_t)_comm_launch_event);
    }
  }

//  主流执行gemm计算:
  assert(pre_gelu_out.numel() == 0);
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
if (_comm_launch_event)
    NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _comm_launch_event, 0));
  nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
                   grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
                   stream_main);

// 让主流等待通信流完成,这样接下来才可以继续做后续的计算流程
  _ub_comm->sms = ori_sms;
  NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
  NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
}  // CommOverlapBase::bulk_overlap

fc1_wgrad 和 RS 的overlap也是用这个函数实现的,这里不再赘述。

https://github.com/NVIDIA/TransformerEngine/blob/c9ea6be92948e1ec553037f1a04900617b9f7f6b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp#L171

五、小结

本文第一节中,展示了在megatron sp-tp中,一个decoder layer做fwd和bwd时需要做的计算与通信,其中:

对于红框部分,理论上计算和通信是有串行依赖的关系,但是可以通过一些优化办法做成overlap。具体来说TE实现了以下2种办法,它们本质上都是通过把计算拆分成更小的若干算子,从而实现边算边通信的目的:

  • 串行overlap方法一:p2p ring exchange,参见2.1(2), 2.2(2)
  • 串行overlap方法二:pipeline chunk,参见2.2(3)

对于黄框部分,理论上计算和通信没有依赖关系,所以天然可以做成overlap。TE提供了一种bulk overlap的方法,通过设置计算流和通信流,完成两者间的交叠:

  • 并行overlap方法:bulk overlap,参见第四节

六、参考

1、https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/communication_overlap.html

2、https://dl.acm.org/doi/10.1145/3567955.3567959

3、https://github.com/NVIDIA/TransformerEngine

4、https://github.com/NVIDIA/Megatron-LM

进技术交流群请添加AINLP小助手微信(id: ainlp2)

请备注具体方向+所用到的相关技术点

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括LLM、预训练模型、自动生成、文本摘要、智能问答、聊天机器人、机器翻译、知识图谱、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLP小助手微信(id:ainlp2),备注工作/研究方向+加群目的。


AINLP
一个有趣有AI的自然语言处理公众号:关注AI、NLP、大模型LLM、机器学习、推荐系统、计算广告等相关技术。公众号可直接对话双语聊天机器人,尝试对对联、作诗机、藏头诗生成器、自动写作等,查询相似词,测试NLP相关工具包。
 最新文章