作者:你的真实姓名
链接:https://www.zhihu.com/question/650979052
最近看到知乎一个回答,把千卡训练的难度吹上天了。但其实真正用过千卡就会发现也就那么几个点。于是想写一篇文章简单讲讲。
本文将包括几个部分:
首先我们将讨论千卡训练的难题,以及应该在什么时候使用千卡训练; 接着,我们将讨论如何在一千张卡上开始训练,如何让他达到近乎线性的性能提升; 然后我们将展开讨论一些千卡训练当中仍然悬而未决(至少对于开源社区来说)的问题
为什么千卡训练是困难的?
其实那篇回答在这部分说的没错。千卡训练和八卡训练的区别是—显卡多了一百多倍。这意味着什么呢?
通信时间增加 故障概率增加
这俩问题都很好理解。时间上,PyTorch内部支持NCCL/Gloo/MPI三个通信后端(请务必使用NCCL。其中AllReduce操作会根据具体硬件配置走Ring AllReduce和Tree AllReduce。Ring的时间复杂度是O(p n),Tree的时间复杂度是O(\log p n)。就算是理论上128节点也比单节点慢至少七倍,实践当中跨节点通讯要远比单节点慢得多。
故障上,一个节点出问题的概率是p,128个节点就是1-(1-p)^128。也就是说如果一个操作在一个训练当中的出错概率是1%,那么在128节点当中的出错概率就是72.37%。此外,随着规模的增大,许多问题都会变得难以忍受。比如数据增强要花0.1s,一亿条数据就是278个小时(当然这只是胡拆的一个数字,实际有各种机制所以不会有这么大影响)。
因此,钱多烧手并不是使用千卡训练的理由。闲得蛋疼可能是,但你得多蛋疼才能想出这么折磨自己的idea?千卡训练解决的问题是大模型&大数据问题。如果你的训练时间没有超过8192GPU日,那么你绝对不需要一千张显卡。看到这里,绝大多数人已经可以关掉这篇文章了。除非你的模型和数据都以B(十亿)来作为计量单位。当然如果你正在厕所里手机没电想看点儿东西解闷儿的话(虽然我很怀疑是否会有人把他打出来……那么可以继续往下看。
如何使用一千张卡训练?
如何提高计算效率?
这件事情其实是一个case by case的事情。因为通信、计算速度啥的受硬件影响更多。而每一个集群的硬件拓扑都是不一样的。同样是A100集群,我全DGX节点,每一张A100都是SXM接口并配一块儿专属的IB网卡。你一个小破普惠服务器插8张PCI-E A100,IB卡一个节点只给一张。那咱俩遇到的问题就完全不是一个问题。
因此,要讨论如何提高训练效率、减少训练耗时,我们首先要了解训练耗时在哪里。那么,一个训练步的耗时在哪里呢?需要谨记,没有profile的优化是没有意义的。
你可能会说,forward backward sync。很好,这说明你了解PyTorch的基本流程。不过现实当中要复杂得多。
dataset读取数据,构建输出 dataloader collate数据,进行数据预处理 模型forward计算输出 loss compute 模型backward计算梯度 模型sync梯度 优化器step更新权重 打印log
当然这是可以无限细分下去的,但一般这些就够了。需要注意的是,除了4-7的耗时是真耗时,其他都需要通过异步操作来盖掉。这也是我们的优化目标。异步执行在PyTorch的dataloader、CUDA和分布式当中都存在。前者可以通过设置num_workers
和prefetch_count
为0来关闭,后两者可以通过cuda.synchronize
和dist.barrier
来执行手动同步。在profile时,我们需要首先需要测整个step的时长。然后再在每次测量前执行手动同步来计算每个部分的时长。如果前者的总耗时等于后者4-7的耗时之和,那么通常不需要执行任何操作。但这种情况在千卡操作中几乎不可能发生。
第6步通信往往需要耗费大量时间。因此,我们还需要进一步优化通信。以下内容是对论文的概括,有感兴趣的同学建议通读并背诵全文。
https://arxiv.org/pdf/2006.15704
计算-通信重叠
在PyTorch当中,梯度的通信和反向传播是交叠进行的。也就是说,每完成一层的梯度计算,都会立即触发当前层的同步。实现起来也很简单,每个进程在完成自己第k层的梯度计算后都会触发一个钩子来给计数器+1。当计数器达到进程数是开火进行梯度通信。有很多同学在计算梯度过程中遇到过RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one.
错误,这就是因为有的模块没有参与计算loss,导致梯度同步卡住了。需要注意,当find_unused_parameters=True
时,PyTorch分布式使用nn.Module.__init__
当中定义sub-module的反向顺序来作为梯度桶的构建顺序。因此,确保模块定义和调用的顺序一致对于高效训练来说很重要。
梯度合桶
尽管理论上来说,同步发生的越及时,重合度越高,性能越好。但实际上每次发起通信都是有上头的。因此,现实当中梯度同步并不是越多越好越快越好。为此,PyTorch引入了梯度合桶机制,通过把多个Tensor装在一个桶里再通信桶来减少通信次数从而减少总耗时。合桶的Buffer Size等等参数往往需要针对硬件和模型来调整从而取得最好的通信效果。PyTorch的默认参数是从0.x时代祖传下来的,这一参数通常都需要调节。
梯度累加
当你做完所有操作之后,惊喜的发现TMD怎么同步时间还是单节点的好几倍。这其实是正常情况……实际上超过256卡的训练想要把通信盖掉就是一件不可能的事情。你说老师我看FB论文说他们256卡就是线性提升啊…那这里不得不提的一个策略就是梯度累加了。梯度累加会执行k次forward+backward之后再执行优化器步进。这有很多好处,首先对于大模型batch size通常不能开多大,梯度累加可以提升等效batch size。其次累加期间的backward不需要通信梯度,加快了训练速度。
少即是快
Python是一种很慢的代码。当然你说JIT trace+torch.compile有提升我也不反对,但对于最高效率来说,只有必须要存在的代码和不存在的代码两种。抱抱脸的Transformers就是一个反例。两个sub-Module就能写完的TransformerLayer他们硬是能写出来一堆…偏偏他们还信奉Single Model File Policy……我寻思你这完全不考虑继承的封这么多层是要搞鸡毛啊?正例反而是PyTorch……(笑死,我竟然会夸脸书代码写得好。具体来说就是nn.functional
当中的各种实现。你会发现他们第一行往往是handle_torch_func
。熟悉Python装饰器的小伙汁通常要问了,为啥这里不用个装饰器统一一下?因为装饰器会引入额外的函数调用,额外的函数调用就是额外的上头。因此,如果你想确保最高的效率,写一个简单的训练代码和模型代码非常重要。毕竟,1%的效率提升,节省的可能是数百个GPU日。
如何平稳训练
这一段当中中咱们只讨论你能控制的问题。
捕捉不致命的异常
故障率高的问题其实很好解决。在训练当中,大部分异常都是非致命异常,捉住他们就好了。
https://danling.org/utils/decorators/#danling.utils.decorators.catch
是我之前写的一个装饰器,它的作用就是catch异常,然后调回调函数(默认当然就是把错误打印到log里)。所有你需要做的只是使用它来装饰非fatal的操作。
咳咳,说点儿正经的。任何联网操作都是需要catch的,常见的联网操作主要包括从ceph读取数据和…写log到远程(逃。其他就没啥了吧,我见过有大哥尝试恢复OOM的,但效果似乎不是很好,至少我自己没用过。简单来说,唯一不应捕捉的错误是集群炸了。那有的大兄弟就说了,集群没爆炸,但是有两张卡突然掉了咋办。这个咱第三部分再讨论。
过程也很重要
有用过丹灵的同学可能比较熟悉。丹灵其他地方都很轻量,唯独实验管理这里写的很复杂。现代丹灵会将创建一个三个级别的实验目录,project/experiment-run/timestamp
。其中project
是用户给出的,experiment
和run
分别是通过代码版本和配置计算出来的,timestamp
就是运行开始的时间。也就是说,如果代码和配置是完全一样的,丹灵就会认为这是同一个运行。在设置中打开auto_resum
就会自动找最新的一个检查点(这就是为啥最后一级要用时间戳)来加载。其实微软用的amlt更好用,他甚至还会创建一个代码的diff文件夹来帮助你回忆当初代码修改了些啥。
收敛,收敛,收敛
模型训着训着发散了几乎是每个训大模型的人都会遇到的问题。输出和loss只要有nan
果断丢掉。梯度先clip by value再clip by norm都是常规操作。哦对了,还有初始化……关于大模型收敛性的论文有一堆,此处不再赘述。
比更大,还更大,再更大
弹性训练
实际上当你的训练超过2048个GPU日时,在整个训练过程当中发生单个GPU甚至单个节点下线是再正常不过的事情了。PyTorch在1.10就引入了torchelastic
弹性训练机制,用过的都骂娘。等下,让我先骂一遍,呸。ok咱们继续吧。
我印象当中在微软的最后一轮面试当中被问到了这个问题:如何设计一个弹性分布式系统。我的回答很教科书。每k
分钟,系统会做一次AllReduce来统计存活进程数,然后选举出一个主进程。主进程会计算好每个进程的rank和local rank进行broadcast。所有进程每次forward开始时向主进程发送一个心跳包来汇报状态。主进程会根据心跳包来确定这一个step参与同步的机器有多少。但很可惜,2024年了。还是没人去写。他妈的。
大小梯度同步
我一直认为梯度同步不应该以GPU/进程为单位。而应该分为大同步(节点间同步)和小同步(节点内同步)。小同步可以更高频的进行,大同步则可以更慢的执行。这样不仅能提高实际的梯度同步频率,降低同步总耗时,并且还能天然的去结合小batch和大batch训练的优点—节点内小batch关注个体,节点间大batch关注整体。
延伸阅读
https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html https://github.com/NVIDIA/nccl/issues/256 https://arxiv.org/abs/2312.16903 https://proceedings.mlr.press/v9/glorot10a.html
有没有发现所有东西都很简单?是这样的,千卡训练是任何一个普通CS本科生花三个月就能学会的东西。没有任何复杂的地方。