Files
Luo Mai 5ac39809e9 Fix AllGather. (#442)
* Update collective.md

* Update overview.md

---------

Co-authored-by: Tanzhipeng <Rudysheeppig@users.noreply.github.com>
2023-03-31 18:39:08 +08:00

30 KiB
Raw Permalink Blame History

集合通信

下面讨论如何利用集合通信在机器学习集群中实现分布式训练系统。作为并行计算的一个重要概念,集合通信经常被用来构建高性能的单程序流/多数据流Single Program-Multiple Data, SPMD程序。接下来首先会介绍集合通信中的常见算子。然后描述如何使用AllReduce算法解决分布式训练系统中网络瓶颈并且讨论AllReduce算法在不同网络拓扑结构下的差异性以及重要性能指标的计算方法。最后介绍现有机器学习系统对不同集合通信算法的支持

常见集合通信算子

下面首先定义一个简化的集合通信模型然后引入常见的集合通信算子Broadcast、Reduce、AllGather、Scatter和 AllReduce。需要指出的是在分布式机器学习的实际场景下人们还会使用许多其他的集合通信算子如ReduceScatter、Prefix Sum、Barrier、All-to-All等但由于篇幅限制便不再赘述。

通信模型

假定在一个分布式机器学习集群中,存在$p$个计算设备,并由一个网络来连接所有的设备。每个设备有自己的独立内存,并且所有设备间的通信都通过该网络传输。同时,每个设备都有一个编号$i$,其中$i$的范围从$1$到$p$。 设备之间的点对点Point-to-Point, P2P通信由全双工传输Full-Duplex Transmission)实现。该通信模型的基本行为可以定义如下:

  • 每次通信有且仅有一个发送者Sender和一个接收者Receiver。在某个特定时刻每个设备仅能至多发送或接收一个消息Message。每个设备可以同时发送一个消息和接收一个消息。一个网络中可以同时传输多个来自于不同设备的消息。
  • 传输一个长度为$l$个字节Byte的消息会花费$a+b \times l$的时间,其中$a$代表延迟Latency即一个字节通过网络从一个设备出发到达另一个设备所需的时间$b$代表传输延迟Transmission Delay即传输一个具有$l$个字节的消息所需的全部时间。前者取决于两个设备间的物理距离如跨设备、跨机器、跨集群等后者取决于通信网络的带宽。需要注意的是这里简化了传输延迟的定义其并不考虑在真实网络传输中会出现的丢失的消息Dropped Message和损坏的消息Corrupted Message的情况。

根据上述通信模型,我们可以定义集合通信算子,并且分析算子的通信性能。下面介绍一些常见的集合通信算子。

Broadcast

常用集合通信算子 :width:800px 🏷️ch10-collective-operators

一个分布式机器学习系统经常需要将一个设备$i$上的模型参数或者配置文件广播Broadcast给其余全部设备。因此可以把Broadcast算子定义为从编号为$i$的设备发送长度为$l$字节的消息给剩余的$p-1$个设备。 :numref:ch10-collective-operators展示了设备1在三个设备的集群里调用Broadcast的初始和结束状态。

一种简单实现Broadcast的算法是在设备$i$上实现一个循环,该循环使用$p-1$次Send/Receive操作来将数据传输给相应设备。然而该算法不能达到并行通信的目的该算法只有$(a+b \times l) \times (p-1)$的线性时间复杂度。为此可以利用分治思想对上述简单实现的Broadcast算法进行优化。假设所有的设备可以重新对编号进行排列使得Broadcast的发送者为编号为$1$的设备。同时,为了简化计算过程,假设对某个自然数$n$$p = 2^n$。 现在可以通过从1 向 p/2 发送一次信息把问题转换为两个大小为$p/2$的子问题编号为1的设备对编号1到编号p/2-1 的Broadcast以及编号为$p/2$的设备对编号$p/2$到编号$p$的Broadcast。我们便可以通过在这两个子问题上进行递归来完成这个算法并把临界条件定义为编号为$i$的设备在$[i,i]$这个区间中的Broadcast。此时由于$i$本身已经拥有该信息不需要做任何操作便可直接完成Broadcast。这个优化后的算法为(a+b \times l) \times \log p 时间复杂度,因为在算法的每一阶段(编号为$t$),有$2^t$个设备在并行运行Broadcast算子。同时算法一定会在\log p 步之内结束。

Reduce

在分布式机器学习系统中另一个常见的操作是将不同设备上的计算结果进行聚合Aggregation。例如将每个设备计算的本地梯度进行聚合计算梯度之和Summation。这些聚合函数表达为$f$往往符合结合律Associative Law和交换律Commutative Law。这些函数由全部设备共同发起最终聚合结果存在编号为$i$的设备上。常见聚合函数有加和、乘积、最大值和最小值。集合通信将这些函数表达为Reduce算子。 :numref:ch10-collective-operators展示了设备1调用Reduce来进行加和的初始和结束状态。

一个简易的Reduce的优化实现同样可以用分治思想来实现即把$1$到$p/2-1$的Reduce结果存到编号为$1$的设备中,然后把$p/2$到$p$的Reduce结果存到$p/2$上。最后,可以把$p/2$的结果发送至$1$,执行$f$,并把最后的结果存至$i$。假设$f$的运行时间复杂度为常数并且其输出信息的长度$l$不改变Reduce的时间复杂度仍然为$(a+b \times l) \times \log p$。

AllReduce

集合通信通过引入AllReduce算子从而将Reduce函数$f$的结果存至所有设备上。:numref:ch10-collective-operators展示了设备1设备2和设备3共同调用AllReduce来进行加和的初始和结束状态。

一种简单的AllReduce实现方法是首先调用Reduce算法并将聚合结果存到编号为$1$的设备上。然后再调用Broadcast算子将聚合结果广播到所有的设备。这种简单的AllReduce实现的时间复杂度为$(a+b \times l) \times \log p$。

Gather

Gather算子可以将全部设备的数据全部收集Gather到编号为$i$的设备上。 :numref:ch10-collective-operators展示了设备1调用Gather来收集全部设备的数据的初始和结束状态。

在收集函数Gather Function符合结合律和交换律的情况下可以通过将其设为Reduce算子中的$f$来实现Gather算子。但是在这种情况下无论是基于链表还是数组的实现在每一步的Reduce操作中$f$的时间复杂度和输出长度$l$都发生了改变。因此Gather的时间复杂度是$a \times \log p + (p-1) \times b \times l$。这是因为在算法的每一阶段$t$,传输的信息长度为$2^{t} \times l$。

AllGather

AllGather算子会把收集的结果分发到全部的设备上。 :numref:ch10-collective-operators展示了设备1设备2和设备3共同调用AllGather的初始和结束状态。

在这里一个简单的方法是使用Gather和Broadcast算子把聚合结果先存到编号为1的设备中再将其广播到剩余的设备上。这会产生一个$a \times \log p + (p-1) \times b \times l + (a+p \times l \times b) \times \log p$的时间复杂度,因为在广播时,如果忽略链表/数组实现所带来的额外空间开销,每次通信的长度为$pl$而不是$l$。简化后,得到了一个a \times \log p + p \times l \times b \times \log p 的时间复杂度。在一个基于超立方体的算法下可以将其进一步优化到和Gather算子一样的时间复杂度$a \times \log p + (p-1) \times b \times l$,由于篇幅问题此处便不再赘述。

Scatter

Scatter算子可以被视作Gather算子的逆运算把一个存在于编号为$i$的设备上,长度为$p$(信息长度为$p \times l$)的链式数据结构$L$中的值分散到每个设备上,使得编号为$i$的设备会得到$L[i]$的结果。 :numref:ch10-collective-operators展示了设备1调用Scatter的初始和结束状态。

可以通过模仿Gather算法设计一个简易的Scatter实现每一步的运算中我们把现在的子链继续对半切分并把前半段和后半段作为子问题进行递归。这时候在算法的每一阶段$t$,传输的信息长度为$l \times 2^{(m-t)}$,其中$m$是算法总共运行的步骤,不会超过\log p 见Broadcast算子的介绍。最终Scatter算子的简易实现和Gather算子一样都有a \times \log p + (p-1) \times b \times l 的时间复杂度。在机器学习系统中Scatter算子经常同时被用于链式数据结构和可切分的数据结构例如张量在一个维度上的$p$等分等。

基于AllReduce的梯度平均算法

下面讨论如何利用AllReduce算子实现大型集群中的高效梯度平均。首先参照前面的分析可以考虑一种简单的计算平均梯度的方法在集群中分配一个设备收集本地梯度并在计算平均梯度后再将其广播到全部的设备。这种做法易于实现但是引入了两个问题。首先多台设备同时给该聚合设备发送数据时聚合设备会因严重的带宽不足产生网络拥塞。其次单台设备需要负担大量的梯度平均计算而受限于单台设备上的有限算力这种计算往往会受限于算力瓶颈。

为了解决上述问题可以引入AllReduce算子的Reduce-Broadcast实现来优化算法其设计思路是通过让全部的节点参与到梯度的网络通信和平均计算中将巨大的网络和算力开销均摊给全部节点。这种做法可以解决先前单个梯度聚合节点的问题。假设有$M$个设备,每个设备存有一个模型副本,该模型由$N$个参数/梯度构成。那么按照AllReduce算子的要求需要先将全部的参数按照设备数量切分成$M$个分区Partition使得每个分区具有$N/M$个参数。首先给出这个算法的初始和结束状态。如 :numref:ch10-collective-operators的AllReduce的例子所示该例子含有3个设备。在每个设备有一个模型副本的情况下这个副本有3个参数。那么按照AllReduce的分区方法参数会被划分成3个分区3个设备而每一个分区则有1个参数$N/M$$N$代表3个参数$M$代表3个设备。在这个例子中假定设备1拥有参数2,4,6设备2拥有参数1,2,3设备3拥有参数4,8,12那么在使用一个AllReduce算子进行计算过后全部的设备都将拥有梯度相加后的结果7,14,21其中分区1的结果7是由3个设备中分区1的初始结果相加而成7 = 1 + 2 + 4。为了计算平均梯度每个设备只需要在最后将梯度之和除以设备数量即可分区1的最终结果为7除以3

AllReduce算法的过程 :width:800px 🏷️ch10-allreduce-process

AllReduce算子会把梯度的计算拆分成$M-1$个Reduce算子和$M-1$个Broadcast算子其中$M$是节点的数量。其中Reduce算子用于计算出梯度的加和Broadcast算子用于把梯度之和广播给全部的节点。 :numref:ch10-allreduce-process展示了一个AllReduce算子的执行过程。AllReduce算子由Reduce算子开始在第一个Reduce算子中AllReduce算子会对全部节点进行配对Pairing让它们共同完成梯度相加的操作。在 :numref:ch10-allreduce-process的第一个Reduce算子中设备1和设备2进行了配对共同对分区1的数据相加。其中设备2把本地的梯度数据1发送给设备1设备1将接收到的梯度数据1和本地的分区1内的梯度数据2进行相加计算出中间梯度相加的结果3。与此同时设备1和设备3进行配对共同完成对分区3的数据相加。而设备3和设备2进行配对共同完成对于分区2的数据相加。

上述的Reduce算子对梯度的分布式计算实现了以下的性能优化:

  • 网络优化: 全部设备都同时在接收和发送数据利用起了每个设备的入口Ingress和出口Egress带宽。因此在AllReduce算法的过程中可利用的带宽是$M * B$,其中$M$是节点数量,$B$是节点带宽,从而让系统实现网络带宽上的可扩展性。
  • 算力优化: 全部设备的处理器都参与了梯度相加的计算。因此在AllReduce算法的过程中可利用的处理器是$M * P$,其中$M$是节点数量,$P$是单个设备的处理器数量,从而让系统实现计算上的可扩展性。
  • 负载均衡: 由于数据分区是平均划分的,因此每次设备分摊到的通信和计算开销是相等的。

在接下来的Reduce算子中AllReduce算法会对不同数据分区选择另外的配对方法。例如在 :numref:ch10-allreduce-process 的第二个Reduce算子中AllReduce算法会将设备1和设备3进行配对负责分区1的数据相加。将设备1和设备2进行配对负责分区2。将设备2和设备3进行配对负责分区3。在一个3个节点的AllReduce集群里在2个Reduce算子完成后就计算出了每个分区的数据相加结果分区1的数据相加结果7此时在设备3上分区2的数据相加结果14此时在设备1上分区3的数据相加结果21此时在设备2上

接下来AllReduce算法将进入Broadcast阶段。这一阶段的过程和Reduce算子类似核心区别是节点进行配对后它们不再进行数据相加而是将Reduce的计算结果进行广播。在 :numref:ch10-allreduce-process 中的第一个Broadcast算子中设备1会将分区2的结果14直接写入设备3的分区2中。设备2会将分区3的结果21直接写入设备1中。设备3会将分区1的结果直接写入设备2中。在一个3个节点的AllReduce集群中我们会重复2次Broadcast算子将每个分区的Reduce结果告知全部的节点。

在本节中我们讨论了AllReduce的其中一种常用实现方法。根据集群网络拓扑的不同人们也会用以下的方法来实现AllReduce树形结构环形结构二维环面结构以及CollNet。在此我们不展开讨论。

集合通信算法性能分析

在讨论集合通信算子的性能时,人们经常会使用一些数值化指标量化不同的算法实现。在计算点对点通信所需的时间时,会在信息长度上乘以一个系数$b$。这个数值化指标就是算法带宽Algorithm Bandwidth泛指单位时间内执行操作通信和计算等的数量。一般计算公式为$b = s/t$,其中$s$代指操作的大小,$t$指操作指定的两个端点之间所经过的时间。以P2P通信举例可以通过衡量一个大小已知的信息$m$在执行Send函数时所花的时间来确定两个设备之间网络的带宽。

前文提到在计算点对点通信所需的时间是会在信息长度之上乘以一个系数b。这个系数就是算法带宽泛指单位时间内执行操作通信计算等的数量。一般计算公式为$b = s/t$,其中$s$代指操作的大小,$t$指操作指定的两个端点之间所经过的时间。以点到点通信举例,我们可以通过衡量一个大小已知的信息$m$在执行send函数时所花的时间来确定两个处理单元之间网络的带宽。

虽然算法带宽的计算方法既简单又高效但很难将其拓展至对于集合通信算子的带宽计算。这是因为取决于具体算子和算法实现的不同一个集合通信算子在执行过程中测得的算法带宽往往会远小于硬件本身的最高带宽。在实际运行相应的测试中经常能观测到随着设备增加算法带宽呈下降趋势。为了解决这一问题NCCL提出了总线带宽Bus Bandwidth这一数值化指标将根据每个集合通信算子的分析所测得的算法带宽乘以一个校正系数Correction Factor从而给出贴近实际硬件表现的带宽值。下面给出常见算子的校正系数

  • AllReduce对于在设备n_1, n_2, \cdots, n_p 上的值 v_1, v_2, \cdots, v_p 计算 $v_1 o v_2 o \cdots o v_p$(其中$o$为符合结合律的算子),再存回每个设备中。在不考虑实际实现算法和网络拓扑的情况下,这个操作在理论上只需要$2 \times (p-1)$次数据传输,其中包含在每个设备上分开进行的$p-1$次 $o$的运算,以及最后 p 次最终数据值的广播,再减去第一个设备的运算和最后一个设备的广播对运行时间的影响。假设每个设备对于外界所有信息处理的带宽为$B$,可以得出对于$S$个在不同设备上的数据运行AllReduce算子能得到最优情况下的运行时间$t = (2 \times S \times (p-1)) / (p*B)$,进行简化后可得 $B = (S/t) \times (2 \times (p-1)/p) = b (2 \times (p-1)/p)$。这里的 $2(p-1)/p$便是校正系数。
  • ReduceScatter对于每个设备来说可以把ReduceScatter理解为只执行AllReduce中的聚合部分。对此只需要考虑上面分析中的$n-1$次$op$的运算,整理后可得$B = (S/t) \times ((p-1)/p) = b \times ((p-1)/p)$。即校正系数为$b \times ((p-1)/p)$。
  • AllGather对于每个设备来说可以把AllGather理解为只执行AllReduce中的广播部分同理可得$B = (S/t) \times ((p-1)/p) = b \times ((p-1)/p)$。即校正系数为$b \times ((p-1)/p)$。
  • Broadcast与AllReduce不同的是Broadcast中所有数据需要从算子本身的发送者发出。即使在上面分治的情况下也需要等待所有子问题运行结束才能确保Broadcast算子本身的正确性。因此在计算带宽时瓶颈仍为发送者对于外界所有信息处理的带宽所以 $B = S/t$,即校正系数为$1$。
  • ReduceReduce需要将所有数据送往算子的接收者因此校正系数为$1$。

由于Gather和Scatter的带宽计算与实际聚合/分散时的数据结构相关性更高,故不给出特定的校正系数。

利用集合通信优化模型训练的实践

针对不同的集群机器学习系统往往会灵活组合不同集合通信算子来最大化通信效率。下面提供两个案例分析ZeRO和DALL-E。

ZeRO

ZeRO是微软提出的神经网络优化器在实践中成功训练了2020年世界上最大的语言模型高达1700亿参数。在训练这个级别的神经网络时优化器本身的参数反向计算时的梯度以及模型参数本身都会对加速器内存空间产生极大的压力。通过简易的计算不难得出1700亿参数的模型在32位浮点表示情况下会占用至少680GB的内存远超于现在内存最高的加速器A100 最高内存80GB。于是需要考虑如何高效地把模型切成数份存储在不同的加速器上以及如何高效地通过使用集合通信算子来进行模型训练和推理。这里介绍三个主要的关于集合通信的优化技术

  • 单一节点上的参数存储: 现代集群中节点内部加速器的带宽远大于节点之间的带宽。为此需要尽量减少节点间的通信并且保证大部分通信仅存在于节点内部的加速器之间。在观察模型切片时又可得模型本身前向和反向计算时需要在不同切片之间进行的通信远小于不同模型副本梯度平均的通信量。针对这一特性ZeRO选择了将单一模型的全部切片存储到同一节点内部从而大大提高了训练效率。
  • 基于AllGather算子的前向计算 假设模型中的参数在层级上呈线性便可按照参数在网络上的顺序从前到后将其分别存储到不同加速器中。在前向时可以注意到某一层的计算仅依赖于其相邻层的参数。对此可以对所有包含模型参数的加速器进行一次AllGather计算用来提取每一层的后一层的参数以及计算该层本身的激活值。为了节约内存在AllGather操作结束后需要立即丢弃除了该层以外其他层的参数。
  • 基于ReduceScatter算子的梯度平均 在反向计算时我们只需要前一层的参数来计算本层的激活值和梯度因此只需要再次使用AllGather来完成每个加速器上的梯度计算。同时在聚集梯度后对于每个加速器仅需要和加速器的编号相同的层数对应的梯度。对此可以使用ReduceScatter算子直接把相应的梯度存到编号为$i$的加速器上而不是通常情况下使用AllReduce算子。

DALL-E

DALL-E是OpenAI提出的一个基于文字的图片生成模型模型同样拥有高达120亿的参数。在训练时除了运用到ZeRO所使用的AllGather + ReduceScatter 技巧OpenAI团队在其他细节上做了进一步的优化。这里介绍两个主要的关于集合通信的优化技术

  • 矩阵分解: 集合通信算子的运行速度和信息本身的长度正相关。在模型训练中这代表了模型参数本身的大小。对此DALL-E 选择用矩阵分解Matrix Factorization的方法先把高维张量调整为一个二维矩阵通过分解后分开用集合通信算子进行传输从而大大减少了通信量。
  • 自定义数据类型: 一种减少通信量的方法在于修改数据类型本身。显然地可以使用16位的半精度浮点数相比正常的32位参数表示可以节省近一倍的通信量。但是在实践中发现低精度的数据类型会使得模型收敛不稳定导致最终训练效果大打折扣。为此OpenAI分析了DALL--E的模型结构并把其中的参数根据对数据类型精度的敏感性分为了三类。其中对精度最敏感的一类照常使用32位浮点表示并只通过AllReduce算子来同步而最不敏感的参数则照常通过矩阵分解进行压缩和传输。对于比较敏感的一类例如Adam优化器其中的动能Moments和方差Variance参数OpenAI 基于 IEEE 754 标准实现了两个全新的数据类型1-6-9和0-6-10其中第一表示正负所需的位数第二表示指数所需的位数第三表示有效数字所需的位数在节省空间的同时保证了训练的收敛。

集合通信在数据并行的实践

数据并行作为最广泛使用的分布式训练方法,是集合通信首先需要支持的范式。 对于数据并行的支持机器学习系统通常提供了两个级别的抽象在第一种级别的抽象里机器学习系统更与硬件耦合可以直接调用集合通信算子的库在另一种级别的抽象里机器学习系统更偏向神经网络实现通过内部调用集合通信算子实现分布式训练和推理的机器学习框架。作为算法工程师通常会接触到后者的抽象包括Horovod、KungFu、TensorFlow Distributed等而作为集群的维护者往往需要深入了解前者的运行原理和具体的调试方法。以 PyTorch 举例在torch.distributed 命名空间Namespace下实现了一系列方便开发者使用的分布式模型训练和推理函数。在其内部会根据实际运行的集群调用更底层的集合通信算子库例如MPINCCL前面已有介绍适用于GPU分布式训练Gloo适用于CPU分布式训练等。下面具体对比PyTorch Distributed和NCCL在AllReduce应用方面的差异 以下代码通过PyTorch自带的分布式数据并行Distributed Data ParallelDDP方法完成了一次简易的机器学习模型计算。

# 基于PyTorch DDP高层次封装实现AllReduce算法

def ddp_allreduce(rank, world_size):
    setup(rank, world_size)

    model = ToyModel().to(rank)
    # 通过调用DDP分布式数据并行方法将模型在每个处理器上完成初始化
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)

    # 在反向计算时框架内部会执行AllReduce算法
    loss_fn(outputs, labels).backward()
    optimizer.step()

下面代码通过Gloo的Python 接口pygloo和Ray完成了一个二维张量的AllReduce计算。

# 基于pygloo底层接口实现AllReduce算法

@ray.remote(num_cpus=1)
def gloo_allreduce(rank, world_size):
    context = pygloo.rendezvous.Context(rank, world_size)
    ...

    Sendbuf = np.array([[1,2,3],[1,2,3]], dtype=np.float32)
    recvbuf = np.zeros_like(Sendbuf, dtype=np.float32)
    Sendptr = Sendbuf.ctypes.data
    recvptr = recvbuf.ctypes.data

    # 标明发送者和接收者并直接调用AllReduce算法
    pygloo.allreduce(context, Sendptr, recvptr,
                    Sendbuf.size, pygloo.glooDataType_t.glooFloat32,
                    pygloo.ReduceOp.SUM, pygloo.allreduceAlgorithm.RING)

可以注意到PyTorch Distributed并没有显式地调用集合通信算子而是通过DistributedDataParallel方法将分布式训练和非分布式训练之间的不同隐藏了起来。如果需要在不同集群上运行这段代码只需要在setup 函数内对应地更改PyTorch使用的底层集合通信库即可。在backward函数被调用时才会真正地使用AllReduce算法。相比如果想要直接使用Gloo不仅需要一步一步地创建通信所需要的数据结构同时也很难和现有的模型训练框架无缝连接。

集合通信在混合并行的实践

随着深度学习的发展,模型和训练数据集的规模呈爆发式增长,单机的算力和存储能力已无法满足需求,因此,分布式训练技术成为行业发展趋势。

本章前几节已总结当前常用的分布式并行训练技术方案如数据并行、模型并行和流水线并行在复杂场景下往往需要不同技术点组合使用才能达到训练大模型的高性能。华为MindSpore开源框架提供混合并行的能力来支撑大模型分布式训练用户可以根据自己的需要进行灵活组合。以下通过简单代码示例来说明如何在MindSpore中组合使用数据并行、模型并行和流水线并行训练技术其他大模型训练技术的使用方法请参照官网教程。

以下代码利用set_auto_parallel_context接口设置并行模式和可用于训练的卡数同时利用该接口设置流水线并行中的stage数量。通过扩展nn.Cell, 定义了简单的神经网络模型其中self.matmul1和self.matmul2的两个矩阵乘操作调用shard接口来配置切分策略如matmul1将第一个输入按照行切成4份实则是在数据维度上切分是数据并行的样例而matmul2对第二个输入进行列切采用了模型并行的方式。为了实现流水线并行以下代码调用nn.PipelineCell接口来包装net_with_loss并指定流水线并行所需的微批次大小。最后通过model.train接口来对神经网络进行混合并行训练。

MindSpore提供了shard接口来允许用户配置切分策略。在这种切分的场景下需要在必要的时候插入集合通信算子来保证计算逻辑的正确性第一种是切分了单一算子的情况将算子切分到多卡进行计算为了保证计算结果和单卡计算结果一致需要集合通信算子来将多卡计算的部分结果同步聚合到每张卡上第二种是多算子情况下相邻算子的切分方式不同前继算子的计算结果排布在不同的卡上后续算子的计算需要用到非当前卡上的数据才能进行此时需要一个集合通信算子来重新排布前继算子的计算结果。

# 基于MindSpore对模型进行混合并行分布式训练

import mindspore.nn as nn
from mindspore import ops
import mindspore as ms

# 设置并行模式为半自动并行,同时设置训练的卡数
ms.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=4)
# 设置流水线并行的stage数量
ms.set_auto_parallel_context(pipeline_stages=stages)

class DenseMatMulNet(nn.Cell):
    def __init__(self):
        super(DenseMutMulNet, self).__init__()
        # 通过shard定义算子切分的方式matmul1是数据并行的样例matmul2是模型并行的样例
        self.matmul1 = ops.MatMul.shard(((4, 1), (1, 1)))
        self.matmul2 = ops.MatMul.shard(((1, 1), (1, 4)))
    def construct(self, x, w, v):
        y = self.matmul1(x, w)
        z = self.matmul2(y, v)
        return z

# 定义训练数据集
data_path = os.getenv('DATA_PATH')
dataset = create_dataset(data_path)
net = DenseMatMulNet()
loss = SoftmaxCrossEntropyExpand(sparse=True)
net_with_loss = nn.WithLossCell(net, loss)
# 用PipelineCell接口包装神经网络第二个参数指定MicroBatch Size
net_pipeline = nn.PipelineCell(net_with_loss, micro_size)
opt = Momentum(net.trainable_params(), 0.01, 0.9)
model = ms.Model(net_pipeline, optimizer=opt)
# 对模型进行迭代训练
model.train(epoch_size, dataset, dataset_sink_mode=True)

:numref:ch10-redistribution 展示了上述代码中matmul1和matmul2在调用shard函数后的数据排布情况。matmul1算子将输入$X$按照行切成4份后分别放置在4个计算设备上D1-D4$W$不切分则以复制的形式放置在4个计算设备上此时matmul1算子计算的结果$Y$以行切的形式被放置在不同设备上而matmul2算子在做计算时需要$Y$的全量数据因此两个计算算子之间需要插入AllGather集合通信算子来从4个不同的设备上收集到$Y$的全量数据。MindSpore能够自动识别不同切分方式的算子之间应该插入哪种集合通信算子并且将该逻辑对用户隐藏只暴露出shard接口供用户配置开发者可以通过合理的策略配置来减少算子间重排布通信算子在神经网络计算图中的占比以提升混合并行分布式训练的端到端速率。

相邻算子之间插入集合通信算子举例 :width:800px 🏷️ch10-redistribution