DeMo介绍:用于高效分布式LLM训练的解耦动量优化

本文介绍了使用大量参数训练大型语言模型(LLM)时面临的计算密集和通信量大的问题,并重点介绍了Nous Research提出的DeMo方法,该方法通过减少通信成本,降低训练成本,并允许使用较差的连接和较低成本的硬件进行训练。DeMo 算法基于梯度的高度可压缩性,通过离散余弦变换(DCT)提取动量中的快速分量,从而显著降低了数据传输量,同时保持了模型的收敛性。

摘要

使用数十亿参数训练大型语言模型 (LLM) 需要大量的计算,并且需要在专门的数据中心进行大量的通信。Nous Research 发布了 DeMo,展示了如何将这些通信成本降低几个数量级,从而降低成本,并能够使用较差的连接和较不昂贵的硬件进行训练。本文介绍了基本概念,并讨论了该论文

介绍

机器学习的问题在于找到一个从输入集合 XX 到输出集合 YY 的函数(或映射)。这种关系可能非常复杂,我们希望通过拥有样本 (x,y)(x,y) 的信息来近似它。例如,我们可能对悬挂弹簧的长度对增加重量的响应感兴趣;为此,我们将测量我们正在增加的重量 ww,并记录长度的变化 ΔxΔx。另一个例子可以是基于心率、体重、身高和骨骼肌质量等信息来关联一个人的能量消耗。我们可能还想训练一个智能体来识别图像。虽然底层关系和目标可能非常不同,但它们可以通过一些数学方法族来处理。在深入研究大型语言模型 (LLM) 和人工智能 (AI) 的细节之前,让我们先关注更简单的问题,比如测量弹簧的伸长量与重量的关系,或者由于施加给定的电压而在导线中循环的电流。

对于弹簧的情况,我们得到一些重量(例如,25 克、50 克、100 克、200 克)。我们测量弹簧运动结束后的最终伸长量,比如 1、2、4 和 8 厘米。利用物理学的经验知识,只要我们处于弹性状态,胡克定律 就成立:重量(施加的力)与伸长量成正比,kΔx=wkΔx=w,其中 kk 是弹簧的刚度。这种关系并非总是如此,因为如果我们增加太多的重量,弹簧会发生显著变形并失去其行为。因此,我们想要解决的问题是,

找到 kk 使得 kΔxi=wikΔxi=wi 对于 i=0,1,2,...ni=0,1,2,...n。这是一个线性方程组,如果没有测量误差并且这种关系是真实的映射,那么 k=wi/Δxik=wi/Δxi。

我们面临的一些问题是:

  1. 我们使用的关系/映射可能是真实关系的近似值。
  2. 存在与测量相关的误差(我们可以暂时假设这些误差是随机的,而不是由观察者系统性地引入的)。
  3. 我们没有大量的测量数据 (Δx,w)(Δx,w)。

这使得事情变得相当困难。首先,方程组 kΔxj=wjkΔxj=wj 可能不再有有效的解。例如,我们可能有 (1,25)(1,25) 和 (2.01,49.9)(2.01,49.9),这转化为:

k.1=25
k.2.01=49.9

第一个方程得出 k=25k=25,而第二个方程得出 k=24.82k=24.82。这个方程组没有解,但我们可能仍然有兴趣从现有信息中估计 kk(这两个值相差不太远,所以也许我们可以做些什么)。我们可以定义一个新的函数来衡量观察到的输出 wjwj 和预测的输出 ΔxjΔxj 之间的差异。我们称这个函数为 损失函数。例如,

$L(k)=(k\Delta x_0-w_0)^2+(k\Delta x_1-w_1)^2=(\hat{w}_0-w_0)^2+(\hat{w}_1-w_1)^2$

该函数衡量了胡克定律预测的重量与我们的测量值之间的二次方误差。我们的目标是找到 kk,使得损失函数最小,

$min_{k \in K} L(k)$

微积分告诉我们,如果关于 kk 的导数为零,则该函数(假设它是“好的”)达到极值,

$dL/dk=0$

使用导数的链式法则,

$dL/dk=2(k\Delta x_0-w_0)\Delta x_0+2(k\Delta x_1-w_1)\Delta x_1=0$

这个方程是线性的,我们可以直接求解它。让我们把问题复杂化一点,假设

  1. 我们有几个参数,k0,k1,...,kmk0,k1,...,km
  2. 找到参数的方程是非线性的。

如果我们有几个参数,可以使用多元微积分来推广这个过程。我们要求关于每个参数的偏导数为零:

$\partial L/\partial k_0=0$

$\partial L/\partial k_1=0$

$\partial L/\partial k_2=0$

$\vdots$

$\partial L/\partial k_m=0$

包含所有这些偏导数的向量是 LL 的梯度。我们有一个包含多个变量的方程组需要求解。

如果上面的方程不容易求解会发生什么?我们有两个事实:

  1. 梯度应该在最小值处为零。
  2. 梯度的方向给出函数中最大增长的方向(因此,沿着相反的方向应该给出最陡的下降)。

这是最速下降搜索的工作原理。从一组参数 k0k0 开始,我们递归地设置

kn+1=kn−γ∇Lkn+1=kn−γ∇L

其中 γγ 是一个参数(称为学习率)。γγ 的高值会产生不稳定性和收敛问题,而 γγ 的低值意味着我们缓慢地朝着最小值移动。

我们现在面临一些之前没有解决的进一步问题:

  1. 一个函数可以有几个(局部)最小值,那么我们如何确保找到真正的(全局)最小值?
  2. 有没有办法调整学习率 γγ,以便我们可以更快地实现收敛?
  3. 如果观察次数 (xi,yi)(xi,yi) 非常多,并且损失函数具有复杂或难以评估的表达式,会发生什么?

我们将首先解决第三个问题,然后尝试解决其他问题。我们有一个如下形式的表达式:

$L(k)=\sum_j E_j(x_j,y_j,k)$

例如,Ej=(f(xj,k)−yj)2Ej=(f(xj,k)−yj)2 可能是每个观察值的二次方误差,并且 ff 是给出输入和输出之间关系的函数。计算整个梯度涉及每个 Ej(xj,yj)Ej(xj,yj) 的(偏)导数,并对 jj 的所有值求和,这使得梯度的评估成本很高。我们可以尝试仅通过选择一个观察值来减少项的数量,并通过该值来近似真实梯度:

∇L≈∇Ej∇L≈∇Ej

这牺牲了准确性来降低计算负担。我们也可以尝试使用观察值或小批量的一个子集来估计梯度。这就是随机梯度下降的思想。

由于我们正在处理近似值,因此可能需要重新调整学习率,并以特定的速率降低它,使其变为 γnγn。

我们可以通过引入动量来改进该方法,动量会在下一次迭代更新时跟踪先前的梯度。基本上,

Δkn=αΔkn−1−γ(∇L)nΔkn=αΔkn−1−γ(∇L)n

kn+1=kn+Δknkn+1=kn+Δkn

我们可以看到,如果 α=0α=0,我们会恢复原始的梯度下降。如果 αα 不同于零,我们会累积先前的梯度,并考虑早期步骤给出的方向。这将确保如果我们在一段时间内朝着给定的方向前进,我们将继续朝着该方向前进,从而避免方向的突然变化。

由于梯度可能具有值非常不同的分量,因此我们可以像 Adam 优化器 的情况一样,调整每个变量的学习率。

局部最小值的问题可以通过这种动量方法(这将防止我们陷入浅最小值)、尝试不同的起始点以及 退火方法 来解决。

我们可以通过使用神经网络来创建或近似更复杂的行为。给定输入变量 x1,...xmx1,...xm,我们可以使用权重 wj0wj0 形成线性组合,并应用激活函数 ff,获得新值 z11,...z1mz11,...z1m,如下所示:

a1j=∑lwjlxl+wj0a1j=∑lwjlxl+wj0

z1j=f(∑lwjlxl+wj0)z1j=f(∑lwjlxl+wj0)

我们可以通过执行线性组合并应用激活函数来添加一个新层,使用上面的输出

z2j=f(∑lw(2)jlz1l+w(2)j0)z2j=f(∑lwjl(2)z1l+wj0(2))

我们可以类似地添加其他层,直到我们得到神经网络的输出,

z3j=f(∑lw(3)jlz2l+w(3)j0)z3j=f(∑lwjl(3)z2l+wj0(3))

可以使用反向传播有效地计算梯度。我们将再次以我们的损失函数作为项的总和开始,每个项对应于一个样本,

$L(k)=\sum_j E_j(x_j,y_j,k)$

我们将专注于计算一个 EjEj 关于每个参数的导数,

∂Ej∂wji=∂Ej∂aj∂aj∂wij∂Ej∂wji=∂Ej∂aj∂aj∂wij

右侧的第二个偏导数很简单,因为 ajaj 是 wijwij 的线性组合,

∂aj∂wij=zi∂aj∂wij=zi

对于另一个导数,我们将其简称为

∂Ej∂aj=δj∂Ej∂aj=δj

这样

∂Ej∂wji=ziδj∂Ej∂wji=ziδj

可以通过评估 δjδj 并使用提供的公式来计算每层的导数。对于隐藏层,

δj=∑m∂Ej∂am∂am∂akδj=∑m∂Ej∂am∂am∂ak

我们最终可以得出 δjδj 的反向传播公式,

δj=f′(aj)∑mwmjδmδj=f′(aj)∑mwmjδm

评估导数的基本过程是首先计算所有层和输出的 ajaj,评估输出的 δjδj,并使用最后一个公式使用反向传播来获得每个内层的每个 δjδj。

许多大型语言模型 (LLM) 都基于神经网络。它们在不同的领域表现出良好的性能,例如翻译和对话式人工智能。这些参数的数量级可能达到数万亿。因此,为了获得合理的训练时间,我们需要加速器,例如 GPU 和 TPU。我们经常在 GPU 集群中遇到异构性,并且互连被划分为每个机器中的高带宽岛和跨机器的低带宽,从而限制了训练速度和次优的硬件利用率。这也会影响内存规划,并且频繁的内存碎片整理会显着降低训练速度。这也会转化为资本和运营成本。

诸如分布式数据并行和完全分片数据并行等策略使加速器拆分权重并同步梯度,通信量与模型的大小成正比(例如,在 4 台机器上使用 10B 个 token 训练 GPT-J-6B 将需要传输 915 TB 的数据!使用 70 亿参数的 LlaMa 预训练需要超过 58 GB 的内存来存储参数、激活和梯度)。这使得梯度同步需要昂贵的高速互连,迫使所有设备位于同一物理空间中。将通信成本降低一个数量级以上不仅可以降低成本或缩短训练时间,还可以允许使用更分布式的硬件。

用于减少内存占用和通信成本的一些技术是:

在这篇博文中,我们将讨论 DeMo,它由 Nous Research 最近发布,它在通信和内存使用方面提供了显着的节省,从而允许使用较差的连接和功能较弱的硬件来训练 LLM。

Nous Research

Nous Research 致力于研究以人为本的语言模型和模拟器,重点关注模型架构、数据合成、微调和推理等领域,所有这些都旨在使 AI 系统与真实世界的用户体验保持一致。四个月前,他们发布了一份关于 DisTro 的初步报告,DisTro 是一系列与架构无关且与网络无关的优化器,可将通信成本显着降低几个数量级,从而实现 AI 的高效分布式训练。

工作假设

该论文表明,非常大的 LLM 的梯度表现出冗余性和高可压缩性。这是支持 DeMo 的核心见解。它基于以下三个观察结果:

  1. 动量的快速移动分量表现出高的空间自相关性,并且具有少量主成分。
  2. 快速移动的动量分量显示出低的时间方差,应立即用于更新参数。缓慢移动的分量表现出高的时间方差,并受益于时间平滑。
  3. 缓慢移动的动量分量对于长期收敛至关重要,应保留而不是过滤掉。

使用这些推测,作者修改了带有动量的 SGD 方法,以解耦不同加速器之间的动量。更新动量后,使用离散余弦变换 (DCT) 提取动量的快速分量 qq,并以最小的通信量共享这些分量。

DeMo 是如何工作的?

起点是带有动量的随机梯度下降 (SGD) 算法。我们将计算局部梯度并使用它们来更新(解耦的)动量,而不是计算总梯度。然后,我们将为每个动量提取 kk 个最快的分量,并从解耦的动量中减去它们。最后,我们将通信并同步所有快速分量,并使用这个同步的梯度来更新参数。这是 论文 中描述的算法:

Screenshot 2024-12-04 at 2.35.32 PM

快速分量的提取对于算法的性能至关重要。虽然 Kosambi-Karhunen-Loève 变换提供了一种实现去相关、分离和提取主要成分的方法,但 DCT 在上述假设下提供了极好的近似。DCT 的优势在于其高效的计算和高度的并行化。此外,它是在固定的正交基上计算的,这使我们能够有效地解码 DCT 编码的信号,而无需额外的信息。

我们可以将每个动量张量作为 d 维自相关信号进行处理,将它们分块,并将 DCT 应用于每个分块,提取最高的 kk 个值及其频率。这将创建两个张量,一个包含频率(使用索引),另一个保持幅度(使用浮点数)。在 DCT 中,频率由 2πi/N2πi/N 给出,因此给出 ii 足以指定频率,因此我们将得到对 (i,A)(i,A),指示最快分量的频率和幅度。然后,我们可以使用这些张量执行逆 DCT 以恢复分量的值 qtqt,并从动量中删除这些值(算法的第四步)。

在收集完所有最快的局部组件后,我们就可以同步它们了。第一步是对重复频率的幅度进行平均(如果索引 11 给出的频率,对应于 2π11/N2π11/N,在局部梯度的最快分量中重复)。在第二步中,我们执行逆 DCT 以恢复全局梯度的最快分量 QtQt 的值。优点在于,如果我们适当地选择参数,则我们必须共享的最快分量的数量远小于梯度。

实验结果表明,与 AdamW 相比,DeMo 可以将通信成本降低至少一个数量级,而收敛性没有明显变化。

总结

本文介绍了与机器学习和 LLM 相关的基本概念,解释了训练非常大的模型时出现的目标、策略和挑战。在多个加速器之间拆分参数和计算的需求引入了对专用连接的需求,从而使所有设备位于同一物理位置。利用来自训练 LLM 的经验观察,Nous Research 提出了 DeMo,利用 DCT 提取最快的分量并减少加速器必须共享的数据量。实验结果表明,相对于 AdamW,减少了至少一个数量级(具体取决于参数的选择,可能会更高),从而允许使用带宽较差的网络和异构硬件来训练 LLM,从而降低了资本和运营成本。

  • 原文链接: blog.lambdaclass.com/int...
  • 登链社区 AI 助手,为大家转译优秀英文文章,如有翻译不通的地方,还请包涵~
点赞 0
收藏 0
分享
本文参与登链社区写作激励计划 ,好文好收益,欢迎正在阅读的你也加入。

0 条评论

请先 登录 后评论
lambdaclass
lambdaclass
LambdaClass是一家风险投资工作室,致力于解决与分布式系统、机器学习、编译器和密码学相关的难题。