Pytorch distributed overview

发布时间 2024-01-04 19:07:52作者: snake-pupil

torch.distributed
v1.6.0后包括三个主要的部分:
1.Distributed Data-Parallel Training(DDP):单程序多数据训练范式。模型被复制到每个进程中,每个模型副本被提供一组不同的输入数据,并将其梯度计算累加以加快训练速度。(collective communications)
2.RPC-Based Distributed Training(RPC):支持不适合data-parallel training的一般训练结构,eg pipeline parallelism,parameter server paradigm,和DDP与其他训练范式的组合。(P2P)
3.Collective Communication(c10d):支持在一个组内不同进程间发送张量,即提供collective communication APIs (e.g., all_reduce and all_gather)和P2P communication APIs (e.g., send and isend)。DDP和RPC构建在c10d之上。特殊情况使用该API,分布式参数平均,即希望在反向传播之后计算所有模型副本的参数的平均值,而不是用DDP来传播梯度。


数据并行:
-单机多GPU,DataParallel可以在最小化代码修改的情况下加速训练;
-单机多GPU,DistributedDataParallel可以更进一步加快训练;
-多机DistributedDataParallel和启动脚本,如果程序要跨越机器边界延伸;
-使用torch.distributed.elastic来启动分布式训练,如果发生错误(out of memory)

data-parallel 也可以 Automatic Mixed Precision (AMP).