跳转至

ZeRO优化器设计理念解析

标签: 并行训练LLM

1. 背景与动机

随着深度学习模型参数规模不断攀升,传统的数据并行 (Data Parallel, DP) 训练面临严重的显存瓶颈。在标准DP中,每块GPU都维护模型的完整拷贝,包括模型参数、各层梯度和优化器状态。这种状态冗余导致显存需求随GPU数线性增加,却无法通过增加GPU来降低单卡内存占用。事实上,在当前32GB显存的GPU上,基础DP在模型超过约14亿参数时就会内存耗尽 (OOM)。以GPT-2 (15亿参数) 为例,仅Adam优化器的状态 (权重的FP32副本、一阶梯度m和二阶梯度v) 即占用约18GB显存,在单卡32GB环境下无法训练。即使使用8张GPU做数据并行,传统DP仍会在每卡复制这18GB优化器状态,总显存需求远超单卡限制。因此,大模型训练亟需新的并行策略来突破显存壁垒。

Zero Redundancy Optimizer (ZeRO) 由微软DeepSpeed团队提出,正是为解决上述内存冗余问题而生。ZeRO的核心思想是利用数据并行集群的整体内存资源,通过分区存储模型状态来消除各GPU间的重复占用。简而言之,不再让每张卡都保存完整的参数、梯度和优化器状态,而是按需将这些数据切分后分布到不同GPU,只在计算时临时汇集需要的部分。这种“零冗余”数据并行方案(ZeRO-DP)在保持与数据并行相当的计算和通信效率前提下,大幅提高了内存利用率。ZeRO的出现,使得模型规模几乎可以随着GPU数量线性扩展:例如理论上1024张GPU即可训练包含1万亿参数的模型,每卡分担约16GB显存。研究者利用ZeRO的这一突破成功训练了当时规模最大的17B参数语言模型。总的来说,ZeRO的背景动机在于消除数据并行中的内存冗余,从而在无需修改模型代码的情况下突破单机显存限制,支持百亿乃至万亿参数级模型的高效训练。

2. ZeRO 三阶段设计

ZeRO通过逐步增大的分片范围来实现内存优化,共划分为三个阶段(Stage 1/2/3),每个阶段在前一阶段基础上进一步减少显存冗余。各阶段主要区别在于分片对象的不同:

  • ZeRO-1:优化器状态分片 – 只对优化器状态进行分区,每块GPU仅维护全部优化器状态中的一部分。这样可以显著削减Adam等优化器带来的内存占用,但模型参数和梯度仍完整复制在每卡上。
  • ZeRO-2:优化器状态 + 梯度分片 – 在ZeRO-1基础上,进一步将反向传播产生的梯度也分片存储,各GPU只保留对应部分的梯度。这进一步降低显存需求,但模型参数本身仍未分片。
  • ZeRO-3:参数 + 梯度 + 优化器状态三重分片 – 对模型训练中的所有三类主要状态都做分片。也即参数、梯度和优化器状态均打散分配到所有GPU,真正实现“零冗余”。这是内存优化幅度最大但实现也最复杂的阶段。

下面分别介绍每个阶段的原理、通信开销和优缺点。

2.1 ZeRO-1:优化器状态分片

原理: 将优化器的所有状态按参数划分为N份,分配给N个GPU进程中的各一个。例如,对于Adam优化器,每个参数对应的FP32权重副本、一阶动量m、二阶动量v这三项会拆分成N块,令第i个GPU仅保存第i块并更新相关参数。这样每张卡只需维护原本1/N的优化器状态,内存占用下降约N倍。在一个训练step结束时,各GPU需要同步它们更新的参数片段,以确保所有卡上的参数一致。这通常通过一次全局All-Gather操作收集更新后的参数来完成。

通信开销: ZeRO-1的通信模式与常规数据并行非常接近,因为正向和反向过程中各卡仍拥有完整参数参与计算。主要的通信仅发生在反向后的梯度归并(All-Reduce或等价实现)以及优化器更新后的参数同步上。相较于标准DP,ZeRO-1几乎没有额外通信量开销,通信成本与传统数据并行类似。因此ZeRO-1的计算速度与DP基本持平,但显存开销显著降低。

优缺点: 优点是实现简单、对现有训练流程改动小,通信代价低,能够直接将优化器状态内存减少约4倍。针对优化器状态占比较高的场景(例如Adam在FP16训练下状态可达参数4倍以上)效果突出。缺点是在模型参数本身很大的情况下,仍需每卡保存完整参数和梯度,无法彻底解决超大模型的显存瓶颈。因此ZeRO-1适合模型规模中等但优化器内存占用巨大的情况,例如参数数亿、使用Adam优化器的模型。

2.2 ZeRO-2:优化器状态 + 梯度分片

原理: 在ZeRO-1的基础上,ZeRO-2进一步将反向传播的梯度也按参数分区。每个GPU在反向中只计算并保留自己负责参数的那部分梯度,其余梯度通过分布式规约后就地丢弃。具体实现上,当各层反向计算产生梯度时,系统将对应梯度根据参数所属分区执行Reduce-Scatter操作:将不同参数的梯度划分并规约到负责该分区的GPU上。这样每张卡最终只留下了1/N的梯度(即与其参数分区对应的部分),不用再存储完整梯度矩阵。优化器更新时,各GPU仍各自使用本地的梯度分片和优化器状态分片来更新对应的参数分片,本质上实现了数据并行下梯度同步和参数更新的融合。

通信开销: ZeRO-2相比ZeRO-1增加了梯度归并的通信步骤。传统DP使用All-Reduce汇总梯度,而ZeRO-2采用等价的Reduce-Scatter(规约并散播)来完成梯度同步。总通信量与All-Reduce相当,但通过规约到不同GPU避免了全副本梯度的存储。因此通信开销相对ZeRO-1略有提高,但仍保持在与DP同阶的量级。此外,DeepSpeed在ZeRO-2中常用“通信桶 (bucket)”策略,将多个小梯度打包后再做Reduce-Scatter,减少通信次数并可以与反向计算重叠。综上,ZeRO-2的通信代价中等,可通过异步化和分桶优化来减小影响。

优缺点: 相较ZeRO-1,ZeRO-2进一步释放了存储梯度的内存,令每卡显存占用最多可缩减到原来的1/8(理论8倍降低)。这使得在相同硬件上可以训练更大模型或使用更大batch。其缺点是实现复杂度和通信量有所增加,尤其是在梯度大量、通信带宽有限的场景下可能影响训练吞吐。但总体而言,ZeRO-2在显存换取少量通信的折衷上性价比较高,适合GPU数量有限但希望训练接近十亿级模型的情况。

2.3 ZeRO-3:全参数状态分片

原理: ZeRO-3是ZeRO系列的最终阶段,对模型参数、梯度和优化器状态进行三重细粒度分片。在该模式下,没有任何一个GPU保存整个模型的参数:每个GPU只持有全部参数中的1/N(假设N张卡)。具体来说,模型初始化时将各层参数张量按行或列切分成N块,分别放到不同GPU上。前向传播时,遇到某层需要的权重,框架会从其它GPU拉取其余碎片,临时拼接出完整权重再计算。这种拉取通常通过高效的多对多All-Gather实现,并可在计算前预取下一层参数以隐藏延迟。反向传播时,每张卡根据局部输出误差只计算它拥有的参数片段的梯度,如果梯度需要跨分片汇总则使用通信规约。梯度计算后立即执行Reduce-Scatter,将完整梯度平均分发回各负责分片的GPU上。这样每块GPU同样只保留了对应参数分区的梯度。参数更新阶段,各GPU仅使用本地分片的梯度和优化器状态来更新本地参数碎片,无需全局同步,其余GPU的对应参数将由它们各自更新。下一次前向需要该参数时再重新按需收集。通过上述机制,ZeRO-3将模型三类状态的冗余拷贝全部去除,每张卡仅存储原来的1/N大小的数据碎片,大幅突破了单卡内存限制。

通信开销: ZeRO-3为了动态汇集所需的参数和梯度,不可避免地引入了频繁的通信。在前向过程中,每层计算前都要进行一次参数All-Gather,将该层分散在各GPU的权重碎片收集起来;在反向过程中,也可能需要对某些梯度执行All-Gather以计算参数更新(视实现策略而定),然后对梯度执行Reduce-Scatter完成跨卡梯度聚合。因此,每层的前向和反向都伴随通信操作。总体来看,ZeRO-3的通信量约为传统数据并行的1.5倍(因为参数需要额外在各卡传输),通信频次也显著更高。这对GPU互联带宽提出了更高要求。不过,DeepSpeed通过一系列通信优化来降低开销,包括:在计算时并行进行通信(隐藏All-Gather/Reduce的延迟)、提前预取下一批参数碎片以及动态调整通信粒度等。实践表明,即便在千亿参数训练中,经过优化的ZeRO-3通信耗时也可控制在总时间的15%以内。简而言之,ZeRO-3以增加一定通信换取线性级的内存节省,其通信成本虽最高但通过重叠和压缩技术可被显著缓解。

优缺点: 优点是不言而喻的:ZeRO-3最大化利用集群总显存,使模型规模不再受单卡限制。每增加一块GPU,就等效增加一份模型容量。例如64卡集群下理论上可训练比单卡大64倍的模型。这使得像GPT-3(1750亿参数)这类超大模型在有限资源下成为可能。另外,ZeRO-3还支持结合CPU内存和NVMe磁盘的分片异构存储(即ZeRO-Infinity),可以将参数/优化器状态进一步跨设备分片和Offload,突破GPU显存总量限制。缺点是实现复杂度和调度开销最高,对通信带宽和延迟较为敏感。尤其在GPU很多或网络较慢时,如果没有良好通信优化,可能出现通信瓶颈影响吞吐。此外,ZeRO-3对框架和模型有更多约束,例如模型参数需要在使用时可动态收集,某些不规则的模型结构可能需要特殊处理。总的来说,ZeRO-3非常适合百亿到万亿级模型的训练,已成为当前超大模型分布式训练的核心基石技术。在资源足够且通信优化得当的前提下,它能够以较小代价换取巨大的内存扩展能力。

阶段比较: 综合来看,ZeRO各阶段在内存节省通信开销上呈递增关系。ZeRO-1最简单,约4倍内存节省,几乎无额外通信;ZeRO-2可达8倍左右节省,通信略增但可接受;ZeRO-3节省率与GPU数近似线性 (1/N),但通信最复杂。具体选择上,小规模模型可用ZeRO-1/2以减少优化器开销,而真正的超大模型训练几乎只能依赖ZeRO-3 来彻底打破单卡内存天花板。实际工程中常常直接采用ZeRO-3,并通过配置调整在需要时关掉某些功能(如offload)来退化为ZeRO-2或ZeRO-1,以平衡性能和内存需求。

以下是ZeRO-3一次训练迭代中的核心交换:

时间点 主要操作 所需通信 参与张量 备注
模型初始化 按参数维度/行列把 FP16 权重 切成 N 块并散布到 N 张 GPU (直接分配) 参数碎片 deepspeed.zero.Init
完成
前向(FWD) All‑Gather 当前层的权重碎片到每张卡 FWD‑AllGather 参数碎片 → 完整层权重 预取下一层,计算后立即 release
反向(BWD) (可选)再次 All‑Gather 激活所需权重;随后对局部梯度做 Reduce‑Scatter BWD‑AllGather + BWD‑ScatterReduce 完整层权重 / 梯度碎片 梯度钩子在 bucket 满时触发通信
优化器 step 每张卡仅用本地梯度 & 本地状态更新 本地参数碎片 无(完全本地) m、v、FP32 主权重等 单卡持有 1/N 状态

3. 内存优化机制

本节从显存占用角度分析ZeRO各阶段如何优化参数、梯度、优化器状态三类内存,并讨论激活重计算等配合策略。

3.1 模型三类状态的显存占用

在大模型训练中,显存主要被以下三部分消耗:

  • 模型参数 (Parameters): 即模型的权重矩阵和偏置等,可视为_静态存储_,在训练过程中需要常驻显存。其占用随模型规模线性增长。例如百亿级参数模型光权重就需要几十GB显存。
  • 激活值 (Activations): 前向计算过程中每层的中间激活需要暂存以供反向使用,占用随batch大小和序列长度增长,一般是显存主要消耗之一。深层Transformer中,激活占用甚至可能超过参数本身。
  • 优化器状态 (Optimizer States): 优化算法为加速收敛而维护的附加变量,比如Adam的一阶、二阶动量,还有FP32精度的主权重拷贝等。这些状态每个参数对应多份值,因此总大小往往是参数的数倍。以Adam在混合精度训练下为例,m和v通常以32位浮点存储,相当于每个参数有额外2×Param大小,再加上维护的FP32权重副本,优化器状态可达到参数4~6倍内存,占据显存大头。

在传统数据并行中,这三部分在每块GPU上各保存一份,总体占用非常高,常导致“显存爆炸”。例如某大型模型在单GPU上训练时,参数+梯度+优化器状态总共可能需要其参数量5~8倍的显存,这对硬件提出了极高要求。

ZeRO针对上述参数、梯度、优化器状态引入分片机制来分别优化。其效果可以用一个7.5B参数模型、64卡并行的案例来说明:

  • 不使用ZeRO (标准DP):每卡需要完整存储参数、梯度、优化器状态。例如7.5B参数模型混合精度下,每卡大约需要120GB显存,显然远超单卡容量。
  • ZeRO Stage 1 (Pos):仅分片优化器状态,参数和梯度仍全量复制。对于64卡并行,优化器状态内存从原来的12ψ(这里ψ表示参数大小常数)降低到12ψ/64,约减少4倍。上述7.5B模型每卡占用从120GB降至约31.4GB。
  • ZeRO Stage 2 (Pos+g):分片优化器+梯度。梯度占用从2ψ降低到2ψ/64,再结合优化器状态分片,总体内存相比DP减少8倍左右。7.5B模型例子中每卡需16.6GB,比未分片的120GB降低了86%。
  • ZeRO Stage 3 (Pos+g+p):分片优化器+梯度+参数。此时模型状态三者均摊到64卡上,单卡所需内存约为原来的1/64。7.5B模型每卡只需约1.9GB,内存缩减达98%。换言之,只要GPU数量足够,ZeRO-3理论上可以支持任意大小的模型。

上述是理论最大节省值,实际中还会有一些额外开销(例如通信缓存、内存碎片等)使得节省率略低于理想值。但总体趋势符合4倍 (ZeRO-1)、8倍 (ZeRO-2)、N倍 (ZeRO-3) 的级数增长。在真实实验中,ZeRO-3相较ZeRO-2通常能进一步节省约一半左右的显存。例如某GPT-3规模模型在8卡下,ZeRO-2时单卡占用74.8GB,而ZeRO-3降至41.5GB,减少了44%。虽然没达到理论1/8=12.5%(87.5%降低)的极限,但已经大幅缓解了显存压力。可见ZeRO-3结合大规模并行GPU,能将模型状态占用从“无法训练”降至“轻松容纳”的水平。

3.2 各阶段显存节省总结

综合以上分析,ZeRO各阶段对三类状态的内存占用影响如下:

  • 参数 (Weights): Stage1/2未分片参数,每卡保留全量参数,占用=100%。Stage3将参数按GPU数均分,每卡仅存原来的1/N,大幅减少了模型静态内存。
  • 梯度 (Gradients): Stage1未分片梯度,反向后每卡存完整梯度=100%。Stage2/3对梯度做Reduce-Scatter分片,每卡只保留1/N梯度,显存下降为原来的1/N。
  • 优化器状态 (Opt State): Stage1开始分片优化器状态,每卡保留1/N,显存降为1/N。Stage2/3同理保持优化器状态分片。对于Adam这意味着从原本动辄4×参数量的占用降为原来的1/N,大模型训练成为可能。

简单来说,ZeRO-1主要节省优化器状态内存(约4倍减小),ZeRO-2在此基础上再节省梯度内存(约8倍减小),ZeRO-3则连参数也按N等比例缩减(N倍减小)。这三类状态分片相辅相成,使得模型总状态内存随并行度增加而大幅降低。正因为如此,在实际使用中我们往往将ZeRO-3与激活重计算等手段结合,充分压榨硬件显存潜力,做到“低资源跑大模型”

3.3 激活重计算与 ZeRO 的协同

激活重计算 (Activation Checkpointing,又称Recompute) 是应对激活内存占用的常用策略,即在前向计算时有选择地不保存中间激活,等反向需要时再重新计算获取,从而节省内存。Recompute本质上用额外的计算换取内存:一般会略微降低训练吞吐,但对非常深的模型能节约40%以上的激活占用。

ZeRO优化的是模型状态(参数、梯度、优化器)的存储,与Recompute针对激活内存的优化是正交互补的。二者结合可以覆盖显存开销的两个主要方面:ZeRO削减模型静态状态,Recompute压缩动态激活占用,从而最大化整体节省。DeepSpeed官方也强烈建议在使用ZeRO-2/3时开启激活重计算,以换取更大batch或更深的模型。例如在训练50层以上的Transformer时,开启Recompute可将激活内存占用降低一半左右,是ZeRO不可或缺的伙伴策略。

值得一提的是,ZeRO本身也引入了一些针对激活的高级优化,被称为ZeRO-R(Residual Optimizations)。其中包括:(1) 分区激活重计算 (Partitioned Activations),将模型并行环境下需要重复的激活也进行跨GPU分片,在需要用到时再All-Gather到一起。这可以进一步减少激活冗余拷贝,占用随并行度线性降低(甚至可选择将激活碎片暂存CPU)。(2) 恒定大小内存池,将反复申请释放的临时张量用固定缓冲区复用,避免碎片化。(3) 动态内存整理,将生命周期不同的张量放入预先分配的连续区,减少CUDA allocator碎片。这些技术可以视作ZeRO对激活和临时空间的“软优化”。在实际工程中,用户主要需要关注是否开启activation checkpointing。DeepSpeed提供了便捷选项:例如配置文件中"activation_checkpointing": {"partition_activations": true, ...}可以让ZeRO-3自动对重计算的激活进行分片存储。总之,Recompute + ZeRO 的组合是大模型训练的常规操作:前者解决深度网络激活显存占用高的问题,后者解决宽大网络参数冗余的问题,两者结合能让有限显存支持远超以往规模的模型。

4. 通信开销与分布式原语

大规模分布式训练不可避免地引入通信开销,而ZeRO各阶段的通信模式和优化策略是其性能关键。下面结合分布式原语AllReduce/AllGather/ReduceScatter,分析ZeRO的通信开销与并行调度。

4.1 各阶段的通信模式

在ZeRO方案下,不同阶段需要的通信操作有所区别:

  • 数据并行 (Baseline DDP):每个step主要通信是梯度All-Reduce,将各卡计算的梯度求和后同步到所有GPU。参数更新后通常不需要额外通信,因为每卡参数保持一致副本。
  • ZeRO-1:与DDP类似,只需对梯度做All-Reduce同步。另外在每次优化器更新后,用Broadcast或All-Gather发送更新后的参数给其他卡,以替换它们本地的旧参数(因为只有拥有对应优化器状态的GPU真正更新了该参数)。该参数同步与All-Reduce等价(只是发生在更新后),通信量很小。
  • ZeRO-2:使用Reduce-Scatter + All-Gather来替代All-Reduce完成梯度归并。反向中,每张卡将本地梯度分块后与对应GPU Reduce-Scatter累加,这相当于All-Reduce拆解为规约和散发两个步骤。在实现上,DeepSpeed直接对梯度进行Reduce-Scatter汇聚,各GPU得到各自分片的总梯度,然后各自更新参数。由于参数仍全量复制,各卡最终需要一致参数,所以还需一次参数广播。但如果所有GPU都应用了同样的全局梯度更新,参数本就是一致的,无需再同步。因此ZeRO-2通常避免了额外的参数通信,只在梯度归并上比DP多用了分散操作。其总通信量略高于All-Reduce但同阶,主要特点是将一次大的All-Reduce拆成两次更小操作,更易与计算重叠
  • ZeRO-3:通信最为频繁,涵盖前向的参数All-Gather后向的参数All-Gather + 梯度Reduce-Scatter。具体而言,在前向每层计算前,各GPU通过All-Gather从别的卡获取该层的参数碎片,组装成完整权重用于计算;在反向计算该层梯度时,同样需要确保参与计算的参数是完整的,可通过再次All-Gather参数或在前向时缓存使用;梯度产生后,立即对其执行Reduce-Scatter,将梯度片段规约并分发回各负责更新的GPU。因此,每层都有两个通信过程(前向All-Gather + 后向Reduce-Scatter,必要时还有后向All-Gather)。这些操作通过NCCL等库高效实现,但相对于ZeRO-2/DP而言通信总量增加显著。正如前文所述,ZeRO-3整体通信量约为原来的1.5倍,并会随着层数线性增加通信频次。它适合在高带宽互联环境(如NVLink、InfiniBand)下使用,在网络较慢时需要额外优化才能保持吞吐。

4.2 通信优化与异步重叠

面对ZeRO-3高频通信带来的挑战,DeepSpeed和PyTorch FSDP等实现都投入了大量优化手段,以尽可能降低通信开销对训练性能的影响。以下介绍几项关键技术:

  • 计算与通信重叠 (Overlap Compute/Communication): 通过多CUDA流并行执行计算和通信,使GPU不必等待数据传输完成才继续计算。例如在前向传播中,当某层参数All-Gather尚未完全完成时,可以先行利用已到达的部分数据开始计算,或者提前启动下一层的预取;在反向中,可一边进行当前层的反向计算,一边在后台异步执行前一层梯度的Reduce-Scatter。DeepSpeed提供了配置项"overlap_comm": true来开启这一特性,使通信耗时几乎全部被掩盖在计算过程中。Amazon团队曾发现原始ZeRO-3由于同步过于粗放,GPU常处于等待通信完成的空转状态。他们通过细粒度同步和解耦通信依赖,将通信与计算的重叠程度大大提高,显著减少了等待浪费。优化后,All-Gather和Reduce-Scatter可以几乎pipelin式地穿插在前向/后向各层的计算之间,最大化GPU利用率。
  • 通信分桶与分片 (Bucketization & Chunking): 考虑到模型中存在大量小张量通信的情况,将多个小tensor拼接成一个大buffer再进行一次All-Gather或Reduce-Scatter可以显著提升带宽利用率。DeepSpeed允许用户调节如reduce_bucket_sizeallgather_bucket_size参数,将梯度或参数分片累积到一定大小(如50MB)再发起通信。这样做可以摊平通信初始化开销,减少过多小报文带来的效率损失。同时,分桶传输也利于与计算重叠(每处理完一桶梯度就立刻通信,计算不必等待所有梯度算完才同步)。需要注意分桶大小过大可能增加一次通信时延,实际配置时通常在50~100MB范围调优。
  • 预取下一步数据 (Prefetching): 针对ZeRO-3中层与层之间串行通信的瓶颈,引入参数预取机制:在当前层前向/后向计算时并行请求下一层所需的参数碎片。这样当进入下一层时,所需参数可能已经在路上甚至接收完毕,缩短了等待时间。DeepSpeed的配置"stage3_prefetch_bucket_size"等即控制预取的粒度。合理的预取可以将通信延迟隐藏在计算中,但预取过多数据可能造成无谓传输或内存占用,需要动态权衡。实际结果表明,预取对于小batch情况下提升明显——因为每层计算更快,通信来不及完全隐藏,需要预先调度。
  • 压缩与量化通信 (通信压缩): 针对某些通信瓶颈,研究者也探索了梯度压缩、量化传输等方法。例如ZeRO++中提出对All-Gather的权重块进行8-bit量化再传输,以降低通信字节数。这些方法需配合额外的还原开销,属于前沿探索,在此不展开。

通过以上策略,ZeRO各阶段的通信开销被大大优化。在理想环境下,ZeRO-3可近乎线性扩展:微软报告在400 GPU上训练1000亿参数模型时获得了超过10倍加速,通信并未成为主要瓶颈。而在通信较慢环境(如多机弱网络)下,充分的通信重叠和分桶也能使ZeRO-3保持较好的扩展效率。例如某实验对比显示,8卡扩展到32卡时,ZeRO-3的并行效率达到了76%,明显优于ZeRO-2的54%。总之,通信优化使得ZeRO在极大降低内存的同时,尽可能保持了训练吞吐率。在工程实践中应根据硬件和模型规模,调整如overlap_comm, bucket_size等参数,以获得通信和计算的最佳平衡。