Megatron SFT 长上下文训练显存优化指南
2026-02-10
ms-swift
00

目录

Megatron SFT 长上下文训练显存优化指南
一、激活值重计算(Activation Recomputation)
1.1 核心参数
1.2 重计算的两种模式详解
1.3 当前配置问题
二、Context Parallel(长序列最重要的参数)
2.1 参数
2.2 关键澄清:SP 和 CP 可以同时使用
2.3 当前配置问题
2.4 Padding 对齐规则
三、优化器显存优化
3.1 核心参数
3.2 低精度优化器状态参数
3.3 官方 Offload 示例
四、并行策略
4.1 并行参数总览
4.2 DP 计算公式
4.3 约束条件
4.4 Virtual Pipeline Parallelism(VPP)
4.5 64 卡并行方案对比
五、其他显存优化参数
5.1 注意力后端
5.2 融合优化
5.3 数据效率
5.4 FP8 训练(如果硬件支持)
5.5 MoE 专用
5.6 其他
5.7 环境变量
六、针对当前配置的具体修改建议
6.1 当前值 vs 推荐值对比
6.2 推荐的 YAML 配置片段
6.3 如果 CP=2 仍然 OOM
七、显存优化优先级排序
附录:参考示例脚本

Megatron SFT 长上下文训练显存优化指南

针对 Qwen3-Omni-30B-A3B-Instruct 模型,60K tokens 长上下文 Megatron SFT 训练场景。 基于 ms-swift 源码分析,所有参数引用均标注源码位置。


一、激活值重计算(Activation Recomputation)

激活值重计算是长上下文训练中最关键的显存优化手段,用计算时间换取显存空间。

1.1 核心参数

参数类型默认值说明源码位置
recompute_granularity'selective' / 'full''selective'重计算粒度。full = 整层重计算(最省显存),selective = 只重计算选定模块megatron_args.py:402
recompute_method'uniform' / 'block'None重计算方法。仅在 recompute_granularity='full' 时生效uniform = 均匀分布重计算层,block = 块状分布megatron_args.py:403
recompute_num_layersintNone每次重计算的层数。仅在 recompute_granularity='full' 时生效。The larger the recompute_num_layers, the smaller the memory usage but higher computation cost. Default is None.megatron_args.py:404
recompute_modulesList[str]['core_attn']仅在 recompute_granularity='selective' 时生效。可选模块:core_attn, moe_act, layernorm, mla_up_proj, mlp, moemegatron_args.py:405
moe_layer_recomputeboolFalseMoE 层额外重计算,对 MoE 模型(如 30B-A3B)有显著效果megatron_args.py:551
vit_gradient_checkpointingboolTrue对视觉编码器(ViT)启用梯度检查点(HuggingFace 风格)megatron_args.py:362
gradient_checkpointing_kwargsdictNone传递给 torch.utils.checkpoint 的额外参数,如 {"use_reentrant": false}megatron_args.py:365

1.2 重计算的两种模式详解

Full 模式(推荐用于长上下文):

前向传播时不保存中间激活值,反向传播时从 checkpoint 点重新计算。设 recompute_num_layers=1 意味着每一层都重新计算,最大化节省显存。

yaml
展开代码
recompute_granularity: full recompute_method: uniform recompute_num_layers: 1 # 最省显存,每层都重计算

实现位置:swift/megatron/model/mm_gpt/qwen3_vl.py:251-280_checkpointed_forward 方法),使用 tensor_parallel.checkpoint() 对 transformer 层做分段 checkpoint。

Selective 模式(默认):

只对 recompute_modules 中指定的模块做 checkpoint,其余模块正常保存激活值。core_attnmlpmoe 使用标准 checkpoint;moe_actlayernormmla_up_proj 使用 output-discarding checkpoint(CheckpointWithoutOutput)。

1.3 当前配置问题

你当前 recompute_num_layers: 2,表示每 2 层为一组做 checkpoint。


二、Context Parallel(长序列最重要的参数)

Context Parallel (CP) 将序列按长度维度切分到多个 GPU,是处理超长序列的核心手段。

2.1 参数

参数类型默认值说明源码位置
context_parallel_sizeint1CP 大小。将序列切分到 N 个 GPU,每 GPU 处理 seq_length / N 个 tokensmegatron_args.py:481
sequence_parallelboolFalse序列并行,在 TP 组内切分 LayerNorm/Dropout 的激活值。需要 tensor_model_parallel_size > 1megatron_args.py:480

2.2 关键澄清:SP 和 CP 可以同时使用

源码中唯一的互斥约束是 mlp_padding_free 与 SP/CP(megatron_args.py:608-609):

python
展开代码
if self.mlp_padding_free and (self.sequence_parallel or self.context_parallel_size > 1): raise ValueError('mlp_padding_free is not compatible with sequence parallel or context parallel.')

这说明:

  • 不兼容mlp_padding_free + SP 或 CP
  • 兼容:SP + CP(可以同时开启)
  • 兼容padding_free(默认 True)+ SP + CP

SP 和 CP 是互补关系:SP 切分 TP 组内的激活值序列维度,CP 在更大范围切分整个序列。

2.3 当前配置问题

你当前 context_parallel_size: 1,意味着 60K tokens 全部放在单 GPU 上处理。这是 OOM 的最主要原因之一。

建议:

CP 值每 GPU tokensDP 数(64卡, TP=4, PP=2)说明
1(当前)60,0008极易 OOM
230,0004推荐首选
415,0002如果 CP=2 仍 OOM

2.4 Padding 对齐规则

开启 SP/CP 后,序列长度会被自动 padding 到特定倍数(源码 swift/megatron/trainers/utils.py:318-332):

条件Padding 对齐到
SP 开启tensor_model_parallel_size 的倍数
CP 开启TP × CP 的倍数
FP8 blockwiseTP × CP × 128 的倍数
FP8 其他max(TP × CP × 8, 16) 的倍数

三、优化器显存优化

全参数训练时,优化器状态(Adam 的一阶矩、二阶矩、fp32 主参数副本)通常占总显存的 50% 以上。

3.1 核心参数

参数类型默认值说明源码位置
use_distributed_optimizerboolTrueZeRO-1:将优化器状态分片到各 DP rank,显存减少为 1/DP_sizemegatron_args.py:472
optimizer_cpu_offloadboolFalse将优化器状态卸载到 CPU,大幅释放 GPU 显存megatron_args.py:422
use_precision_aware_optimizerboolFalse使用 TransformerEngine 的精度感知优化器,支持低精度优化器状态megatron_args.py:424
optimizer_offload_fractionfloat1.0卸载到 CPU 的优化器状态比例(0.0 ~ 1.0),配合 optimizer_cpu_offload 使用megatron_args.py:423

3.2 低精度优化器状态参数

参数类型默认值可选值说明源码位置
main_grads_dtypestr'fp32''fp32', 'bf16'主梯度精度,bf16 节省一半梯度显存megatron_args.py:425
main_params_dtypestr'fp32''fp32', 'fp16'主参数副本精度megatron_args.py:426
exp_avg_dtypestr'fp32''fp32', 'fp16', 'bf16', 'fp8'Adam 一阶矩精度,fp8 最省megatron_args.py:427
exp_avg_sq_dtypestr'fp32''fp32', 'fp16', 'bf16', 'fp8'Adam 二阶矩精度,fp8 最省megatron_args.py:428

3.3 官方 Offload 示例

examples/megatron/moe/qwen3_moe_offload.sh 针对同款 Qwen3-30B-A3B 模型的配置:

bash
展开代码
--optimizer_cpu_offload true \ --use_precision_aware_optimizer true \ --optimizer_offload_fraction 1

该示例在 4 × A100 上,每卡 75GiB 显存。


四、并行策略

4.1 并行参数总览

参数类型默认值说明源码位置
tensor_model_parallel_sizeint1TP:将权重矩阵按列切分到多个 GPUmegatron_args.py:473
pipeline_model_parallel_sizeint1PP:将模型按层切分到多个 GPU 组megatron_args.py:474
context_parallel_sizeint1CP:将序列按长度切分到多个 GPUmegatron_args.py:481
sequence_parallelboolFalseSP:TP 组内切分序列维度的激活值megatron_args.py:480
expert_model_parallel_sizeint1EP:将不同 expert 分配到不同 GPUmegatron_args.py:542
expert_tensor_parallel_sizeint1Expert TP:单个 expert 内部的张量并行megatron_args.py:543

4.2 DP 计算公式

展开代码
DP = total_GPUs / (TP × PP × CP)

源码位置:megatron_args.py:227-228

注意:EP 不参与 DP 公式计算(EP 在 TP 组内部再分),但会影响每个 GPU 上的 expert 数量。

4.3 约束条件

约束说明源码位置
num_query_groups 必须是 TP 的倍数Qwen3-Omni-30B-A3B 的 num_query_groups=4,所以 TP 只能是 1, 2, 4Megatron 内部校验
mlp_padding_free 与 SP/CP 互斥不能同时开启 mlp_padding_free 和 SP 或 CPmegatron_args.py:608-609
freeze_parameters_ratio(0~1) 与 PP>1 互斥部分冻结不能与流水线并行同时使用megatron_args.py:308-309
decoder_first/last_pipeline_num_layers 需要 PP>1PP 首尾层数自定义仅在 PP>1 时生效megatron_args.py:713-716
global_batch_size 必须被 micro_batch_size × DP 整除框架自动计算梯度累积步数Megatron 内部校验

4.4 Virtual Pipeline Parallelism(VPP)

VPP 通过增加 virtual stages 减少 PP bubble,适合 PP 较大的场景。

参数类型默认值说明源码位置
num_layers_per_virtual_pipeline_stageintNone每个虚拟 PP stage 的层数megatron_args.py:486
num_virtual_stages_per_pipeline_rankintNone每个 PP rank 的虚拟 stage 数megatron_args.py:487
decoder_first_pipeline_num_layersintNone第一个 PP stage 的层数(平衡 embedding 层的显存)megatron_args.py:475
decoder_last_pipeline_num_layersintNone最后一个 PP stage 的层数(平衡 LM head 的显存)megatron_args.py:476

4.5 64 卡并行方案对比

方案TPEPPPSPCPDP每 GPU tokens特点
当前配置4441160,000DP=1 吞吐极低,单卡 60K tokens 易 OOM
推荐方案 A4422430,000平衡显存和吞吐
推荐方案 B4424215,000最省激活显存
极限方案4442130,000PP=4 减少模型权重显存,但 bubble 大

五、其他显存优化参数

5.1 注意力后端

参数类型默认值说明源码位置
attention_backendstr'flash''flash' = Flash Attention(O(n) 显存),'fused', 'unfused', 'local', 'auto'megatron_args.py:420

Flash Attention 对长序列至关重要,将注意力显存从 O(n²) 降为 O(n)。

5.2 融合优化

参数类型默认值说明源码位置
cross_entropy_loss_fusionboolFalse融合交叉熵损失计算,减少 logits 张量的显存占用megatron_args.py:416
no_gradient_accumulation_fusionboolFalse禁用梯度累积融合(需要 apex)megatron_args.py:415
moe_permute_fusionboolFalseMoE token permute 操作融合megatron_args.py:547
moe_grouped_gemmboolTrueMoE grouped GEMM 优化megatron_args.py:546
moe_shared_expert_overlapboolFalse共享专家计算与通信重叠megatron_args.py:550

5.3 数据效率

参数类型默认值说明源码位置
padding_freeboolTrue去除 batch 内 padding,减少无效计算和显存浪费megatron_args.py:321
packingboolFalse将多个短序列打包到一个 batch,减少 padding继承自 BaseArguments

5.4 FP8 训练(如果硬件支持)

参数类型默认值说明源码位置
fp8_formatstrNoneFP8 格式:'e4m3''hybrid',大幅减少权重和激活显存megatron_args.py:568
fp8_recipestr'delayed'FP8 算法:'tensorwise', 'delayed', 'mxfp8', 'blockwise'megatron_args.py:569
fp8_param_gatherboolFalseall-gather 时保持 FP8 参数,节省通信和显存megatron_args.py:572

5.5 MoE 专用

参数类型默认值说明源码位置
moe_expert_capacity_factorfloatNone专家容量因子,限制每个 expert 处理的 token 数,防止显存峰值megatron_args.py:552
moe_token_dispatcher_typestr'alltoall'Token 分发方式:'allgather', 'alltoall', 'flex', 'alltoall_seq',不同方式显存特性不同megatron_args.py:544

5.6 其他

参数类型默认值说明源码位置
micro_batch_sizeint1每 GPU 微批次大小,长上下文必须设为 1megatron_args.py:400
use_cpu_initializationboolFalseCPU 初始化模型权重,节省初始化阶段 GPU 显存megatron_args.py:406
manual_gcboolFalse手动控制垃圾回收,减少显存碎片megatron_args.py:430
manual_gc_intervalint0手动 GC 间隔megatron_args.py:431
no_save_optimboolFalse不保存优化器状态到 checkpoint(节省存储,不影响训练显存)megatron_args.py:454

5.7 环境变量

bash
展开代码
# 必须设置,启用 PyTorch 可扩展内存段,减少显存碎片 PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True'

所有官方 Megatron 示例脚本均使用此环境变量。


六、针对当前配置的具体修改建议

6.1 当前值 vs 推荐值对比

参数当前值推荐值改动原因
recompute_num_layers21每层都重计算,最大化节省激活显存
moe_layer_recompute未设trueMoE 层额外重计算,对 30B-A3B 有效
context_parallel_size12(或 4核心改动:将 60K 序列切分,每 GPU 处理 30K(或 15K)
pipeline_model_parallel_size42PP=4 时 DP=1 吞吐太低,降为 2 配合 CP=2 更平衡
optimizer_cpu_offload未设true优化器状态卸载到 CPU
use_precision_aware_optimizer未设true配合 CPU offload
optimizer_offload_fraction未设1.0全部卸载
moe_shared_expert_overlap未设true共享专家计算重叠,提升性能

6.2 推荐的 YAML 配置片段

yaml
展开代码
# ============ Memory Optimization(关键改动)============ recompute_granularity: full recompute_method: uniform recompute_num_layers: 1 # 从 2 改为 1,最省显存 moe_layer_recompute: true # 新增:MoE 层重计算 vit_gradient_checkpointing: true # ============ 优化器 CPU Offload(新增)============ optimizer_cpu_offload: true use_precision_aware_optimizer: true optimizer_offload_fraction: 1.0 # ============ 并行策略(调整)============ tensor_model_parallel_size: 4 # TP=4(受 num_query_groups=4 约束) expert_model_parallel_size: 4 # EP=4 pipeline_model_parallel_size: 2 # PP: 从 4 降为 2,减少 pipeline bubble sequence_parallel: true # SP: TP 组内序列并行 context_parallel_size: 2 # CP=2: 60K ÷ 2 = 每 GPU 30K tokens # DP = 64 / (TP=4 × PP=2 × CP=2) = 4 # ============ 注意力 & 融合优化 ============ attention_backend: flash cross_entropy_loss_fusion: true moe_permute_fusion: true moe_grouped_gemm: true moe_shared_expert_overlap: true # ============ Batch ============ micro_batch_size: 1 global_batch_size: 64 # 需确保能被 micro_batch_size × DP=4 整除

6.3 如果 CP=2 仍然 OOM

逐步升级策略:

展开代码
第 1 步:CP=2, PP=2 → DP=4, 每 GPU 30K tokens 第 2 步:CP=4, PP=2 → DP=2, 每 GPU 15K tokens 第 3 步:CP=4, PP=2 + optimizer_cpu_offload → 进一步释放优化器显存 第 4 步:CP=4, PP=4 → DP=1, 吞吐最低但最省显存 第 5 步:在以上基础上加 FP8 → fp8_format: e4m3(需要硬件支持)

七、显存优化优先级排序

从最有效到最不有效:

优先级手段预期效果代价
⭐⭐⭐⭐⭐context_parallel_size 增大激活显存按比例线性切分增加通信开销
⭐⭐⭐⭐⭐optimizer_cpu_offload: true优化器状态全部卸载到 CPU,释放大量 GPU 显存训练速度略降(CPU-GPU 数据传输)
⭐⭐⭐⭐recompute_num_layers: 1最大化激活重计算节省约增加 30% 计算时间
⭐⭐⭐⭐sequence_parallel: trueTP 组内激活值序列维度切分几乎无额外代价(需 TP>1)
⭐⭐⭐moe_layer_recompute: trueMoE 层额外重计算略增加计算时间
⭐⭐⭐attention_backend: flash注意力显存 O(n²) → O(n)无(纯优化)
⭐⭐⭐cross_entropy_loss_fusion: true减少 logits 显存无(纯优化)
⭐⭐调整 PP/TP 比例平衡权重和激活显存分布PP 增大 → pipeline bubble 增大
⭐⭐fp8_format: e4m3权重和激活显存减半需要 H100/H800 等硬件支持
use_cpu_initialization: true节省初始化阶段显存初始化慢一些
manual_gc: true减少显存碎片可忽略

附录:参考示例脚本

脚本路径场景关键配置
examples/megatron/long_text.sh32K 长文本,4×A100TP=4, SP, recompute full/uniform/1
examples/megatron/moe/qwen3_moe_offload.sh30B-A3B + CPU offload,4×A100EP=4, optimizer_cpu_offload, recompute full/uniform/1
examples/megatron/moe/qwen3_moe.sh30B-A3B 多节点,16卡PP=2, EP=8, recompute full/uniform/1
examples/megatron/multimodal/omni/moe.shQwen3-Omni-30B LoRA,2卡EP=2, LoRA, vit_gradient_checkpointing
如果对你有用的话,可以打赏哦
打赏
ali pay
wechat pay

本文作者:Dong

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 CC BY-NC。本作品采用《知识共享署名-非商业性使用 4.0 国际许可协议》进行许可。您可以在非商业用途下自由转载和修改,但必须注明出处并提供原作者链接。 许可协议。转载请注明出处!