1.3 分布式训练的必要性
在大数据和互联网时代,机器学习又遇到了新的挑战,具体如下。
• 样本数据量大:训练数据越来越多,在大型互联网场景下,每天的样本数据量是百亿级别。
• 特征维度多:由于样本数据量巨大而导致机器学习模型参数越来越多,特征维度可以达到千亿甚至万亿级别。
• 训练性能要求高:虽然样本数据量和模型参数量巨大,但是业务需要我们在短期内训练出一个优秀的模型来验证。
• 模型上线实时化:对于推荐类和资讯类的应用,往往要求根据用户实时行为及时调整模型,以对用户行为进行预测。
传统机器学习算法存在如下问题:单机的计算能力和拓展性能始终有限,迭代计算只能利用当前进程所在主机的所有硬件资源,无法将海量数据和超大模型加载到有限的内存之中。而串行执行需要花费大量时间,从而导致计算代价和延迟性都较高,所以大数据和大模型最终将出现以下几个问题。
• 内存墙:单个GPU无法容纳模型,导致模型无法训练,目前最大GPU的主内存也不可能完全容纳某些超大模型的参数。
• 计算墙:大数据和大模型都代表计算量巨大,将导致模型难以在可接受的时间内完成训练。比如,即使我们能够把模型放进单个GPU中,模型所需的大量计算操作也会导致漫长的训练时间。
• 通信墙:有存储和计算的地方,就一定有数据搬运。内存墙和计算墙必然会导致出现通信瓶颈,这也会极大地影响训练速度。下面针对这些问题做具体分析。
1.内存墙
模型是否能够训练和运行的最大挑战是内存墙。一般来说,训练AI模型所需的内存比模型参数量还要多几倍,为了理解此问题,我们需要梳理一下内存增长的机理。显存占用分为静态内存(模型权重、优化器状态等)和动态内存(激活、临时变量等),静态内存比较固定,而动态内存在单次迭代之中有如下特点。
• 因为反向计算需要使用前向传播的中间结果,所以在前向传播时需要保存神经网络中间层的激活值。又因为每一层的激活值都需要保存下来给反向传播使用,所以在前向传播开始之后,显存占用不断增加,并且在前向传播结束之后,显存占用会最终累积达到一个峰值。
• 在反向传播开始之后,由于激活值在计算完梯度之后就可以被逐渐释放掉,所以显存占用将逐渐下降。
• 在反向传播结束之后,显存占用最终会下降到一个较小的数值,这部分显存是参数、梯度等状态信息,就是常说的模型状态。
削峰是处理内存墙的关键手段,只有当削峰无法解决问题时,才能考虑其他处理方法。此外,内存墙问题不仅仅与内存容量尺寸相关,也和包括内存在内的传输带宽相关,这涉及跨越多个级别的内存数据传输。
2.计算墙
因为数据量和模型巨大,所以我们面临巨大的算力需求,需要思考如何提高计算能力和效率。针对强大的敌人有两种策略:壮大自己和找帮手,这对应了两种优化途径:单机优化和多机并行优化。其中,单机优化主要包括:
• 数据加载效率优化,比如使用高性能存储介质或者缓存来加速。
• 算子级别优化,包含如何实现高效算子、如何提高内存利用率、如何把计算与调度分离等。
• 计算图级别优化,包含常量折叠、常量传播、算子融合、死代码消除、表达式简化、表达式替换、如何搜索出更高效的计算图等。
然而,面对巨大的算力需求,单机依然无能为力,所以有必要通过增加计算单元并行度来提高计算能力,即把模型或者数据切分成多个分片,在不同机器上借助其硬件资源对训练进行加速,这就是多级并行优化。根据前面训练迭代的特点,我们可以对并行梯度下降进行计算切分,基本思想是将训练模型并行分布到多个节点之上再进行加速:
• 每个节点都获取最新模型参数,同时将数据平均分配到每个节点之上。
• 每个节点分别利用自己分配到的数据在本地计算梯度。
• 通过聚集(Gather)或者其他方式把每个节点计算出的梯度统一起来,以此更新模型参数。
3.通信墙
为了解决内存墙和计算墙问题,人们尝试采取分布式策略将训练拓展到多个硬件(GPU)之上,希望以此突破单个硬件的内存容量/计算能力的限制,既然多个硬件要同时参与一个任务的计算,这就涉及如何让它们彼此之间协调合作,整体上作为一个巨大的加速器来运行。这使得通信方面的挑战随之而来。虽然我们可以对神经网络进行各种切分以实现分布式训练,但模型训练是一个整体任务,这就意味着必须在前面的切分操作后面添加一个对应的聚集操作,这样才能实现整体任务。于是此聚集操作就是通信瓶颈所在。
神经网络具有如下特点。
• 通信量大。因为模型规模巨大,所以每次更新的梯度都可能是大矩阵,由此导致剧增的通信量很容易就把网络带宽给占满。
• 通信次数多。因为是迭代训练,所以需要频繁更新模型。
• 通信量在短期内达到峰值。神经网络运算在完成一轮迭代之后才更新参数,因此通信量会在短时间内暴增,而在其他时间网络是空闲的。
• 内存墙问题。在通信上也会遇到内存墙问题。
因此,我们需要减少机器之间的通信代价,进而提高并行效率,解决内存墙问题。优化是一个整体方案,可以从两方面入手,一方面提升通信速度,比如优化网络协议层,使用高效通信库,进行通信拓扑和架构的改进,通信步调和频率的优化。另一方面也可以减少通信内容和次数,比如梯度压缩和梯度融合技术等;也可以通过代码优化,减少I/O的阻塞,尽量使得I/O与计算可以做重叠(Overlap)。
4.问题总结
综上所述,大数据和互联网时代机器学习的各个瓶颈并不是孤立的,无法用单一的技术解决,需要一个整体解决方案。该方案既需要考虑庞大的节点数目和计算资源,也要考虑具体框架的运行效率和分布式架构,以达到良好的扩展性和加速比,还要考虑合理的网络拓扑和通信策略。此方案是显存优化和速度优化的整体权衡结果,也是统计准确性(泛化)和硬件效率(利用率)的折中结果。而且,对于不同计算问题来说,计算模式和对计算资源的需求都不一样,因此没有解决所有问题的最好的架构方案,只有针对具体实际问题最合适的架构。我们只有针对机器学习具体任务的特性进行系统设计,才能更加有效地解决大规模机器学习模型训练的问题。因此,这就引出了下一个问题:分布式机器学习究竟在研究什么?