作者:菽陌松囿
原文:https://zhuanlan.zhihu.com/p/15666403147
关于训练,14.8T的token,花费不到600m,大概是llama3.1(15T+的token)成本的1/10,认为加速10倍的建议不要看,大概率MFU不怎么会算,属于没入门。
没有贬低之意, 作为软件公司能对cuda软件硬件有如此了解,没有深厚计算机功底玩不转,以我的认知,在软件公司属于国内外业界top1的存在,大模型infra最缺懂模型+强工程能力,懂硬件/网络更稀缺。
不过话说回来,因为论文没有披露训练MFU,不大好评估优化最终效果,如果按照gemm(6ND) + attention,N=14.8T,D=37B, seq=4k,按照2.788m gpu hours去算MFU, 经过多轮确认参数,gemm约327T,attention计算量约118T(这块, 计算方法: 3 x 61 x 14.8T x 4k x 2 x(128 x 192 + 128 x 128) / (2.788 x 3600 x 1000^2), 一般来说虽然是casual按照full attention去算了, 如果按casual计算就是60T ),总计算量445T, 所以 mfu 45%(445/989), 关于6ND误差,我仔细对比多,主要embedding/访存算子引入的误差, 对于gpt/llama llm模型gap 3%以内,参数越大误差越小,不需要手动去算,用forwad hook捕获gemm的shape累加),各位看官也可以自己算下。
ps:moe架构相比于dense多了两个问题:all2all成本,expert的均衡,所以看论文要带着这两个问题看,就能比较好理解其实现意图。
所以对训练加速点进行一些推测:
以下数据理论乐观估计猜测
fp8加速且如此稳定, 训练速度可接近一倍,32k seq乐估80%(32k seq),如果128k attention占比会超过gemm,可能40%-50%,attention用的bf16。所以H800上,因为fp8算力是bf16两倍,但是由于量化scale、累积精度等影响,不会直接double,贵乎有人说30-40%加速,低估多了怕被喷,另外fp8混合精度可以节省模型参数/激活产生的显存,访存型算子也快了,有了足够的显存,可以调整tp/pp/ep,因为v2用tp=1容易oom,因为dag或者激活没有切充分,另外tp=1, gemm shape很大, gpu降频咋办?对于gpu自保策略不是特别清楚。
dualpipe调度:同样的pp size和 acc数量,bubble约减少50%(F&B的时间需要细品,绝对不是3W,详见评论区讨论),这个影响看单dp的batchsize,因为论文里渐进式增加(3072-15360)按照15360/128,大约5%加速
all2all通信overlap:当然以seq而定,128k,all2all成本估计10%不到,所以这个优化省10-20%,所有论文infra都没有贴热点,无法有旳放失。把all2all跨node通信收敛到4个node,每个node2个expert,ib和nvlink可以互相overlap(其实他这个配比完全按理论带宽算的,不对,实际上ib、nvlink实际带宽,特别IB大概20GB/s,这就是搞ai普遍通病,脑袋聪明工程能力不行,无法观测系统实际情况,所谓的可观测性),因为dualpipe同时再与计算overlap,一般我们把通信/计算扔给两个stream就完了,但其实这两种负载会在sm 资源比如cache、core产生竞争,所以用warp spec技术划分了20个sm,控制计算通信比例,来实现更好的overlap,这个很难得,但计算可用单元少了20%,在大seq/shape是否考虑gemm的降频影响?
moe token dispatch 均衡:这个调整dispatch辅助loss,如果均衡,会有加速效果,但是不大好评估,我就没算, 估计不会太大。
综上训练加速相比于bf16,因为fp8算力更高,综合加速约一倍左右(更新:fp8算力有人说加速40-50%,所以相比bf16实际加速50%), 低估了不要喷我, 理论懂的人很多,实际有操作的并且能观测到瓶颈的欢迎来纠正。
关于推理的加速
api上,相比于v2 提升3倍+吞吐
mtp提升1.8倍,论文里说的
推理通过mb pipeline编排实现all2all overlap,这个蛮重要
pd分离,而且prefill架构和decode完全不一样,对于moe架构必不可少,他们早就做了
更重要分布式推理,特别是decode集群,ep是320,256普通expert和64个冗余expert,每张卡只放一个expert,显存尽可能多的留给query的kvcache,同时也考虑到热点expert问题,以便batchsize做大,由于是集群,多个dp的token可以比较均衡dispatch到320张卡,尽可能把tensor core用起来,这样以达到比较高的吞吐性能……但是推理集群如果有rank挂了,容灾呢?
总之成倍提升训练速度纯粹是不懂,训练加速很有限的,不能突破gemm/attention MFU,我一直走的路是把mb/seq做大,gemm/attention优化到占比90%, 降低卡间/node通信占比,其实是否overlap不是特别重要。