🤖 roboto_origin_03 Wiki
首页 / RSL-RL / 多 GPU 分布式训练机制

rsl_rl 的分布式训练采用**数据并行(Data Parallelism)**范式:每个 GPU 进程独立运行一组向量化环境、收集本地转移数据,并在策略更新阶段通过梯度归约(Gradient Reduction)与参数广播(Parameter Broadcast)保证所有进程模型一致性。与 PyTorch 的 DistributedDataParallel 不同,框架选择手动调用 torch.distributed 原语,以精确控制同步时机,适应 on-policy RL 中“先完整 rollout、后多 epoch 更新”的特殊节奏。

分布式训练架构概览

在分布式模式下,系统由单个 OnPolicyRunner(或其子类)实例驱动,每个进程通过环境变量感知自身全局/局部 rank,并在构造函数中初始化 NCCL 进程组。算法层(PPOPPOAMPDistillation)接收 multi_gpu_cfg 字典,据此决定是否执行跨 GPU 通信。日志系统仅允许 rank 0 写入,同时将所有进程的采样量计入总吞吐量计算。

flowchart TD
    A[环境变量 WORLD_SIZE/RANK/LOCAL_RANK] --> B[OnPolicyRunner._configure_multi_gpu]
    B --> C[torch.distributed.init_process_group backend=nccl]
    C --> D[创建 multi_gpu_cfg]
    D --> E[算法实例化 PPO/PPOAMP/Distillation]
    E --> F[训练循环 learn]
    F --> G{broadcast_parameters<br/>初始参数同步}
    F --> H[各进程独立 Rollout]
    H --> I[本地数据 compute_returns]
    I --> J[update 循环]
    J --> K[backward]
    K --> L[reduce_parameters<br/>梯度 all_reduce + avg]
    L --> M[optimizer.step]
    J --> N[KL all_reduce + LR broadcast]
    M --> O[下一 mini-batch]
    F --> P[Logger rank0 聚合日志<br/>总步数 × world_size]

该架构的核心假设是:各进程的环境实例在初始化时具有不同随机种子,因此 rollout 数据天然异构,无需像监督学习那样使用 DistributedSampler 划分数据集。

Sources: on_policy_runner.py, ppo.py

Runner 层初始化与进程组配置

分布式能力的入口位于 OnPolicyRunner._configure_multi_gpu()。该方法在构造器中优先执行,负责从环境变量读取分布式元数据、校验设备映射、并初始化进程组。

框架读取三个标准环境变量:WORLD_SIZE(总进程数)、LOCAL_RANK(当前节点内 GPU 序号)、RANK(全局进程序号)。当 WORLD_SIZE > 1 时,框架判定为分布式模式,随后强制校验 device 必须等于 cuda:{local_rank},以避免用户误将多个进程绑定到同一张显卡。校验通过后,调用 torch.distributed.init_process_group 以 NCCL 为后端创建进程组,并执行 torch.cuda.set_device(local_rank) 设置当前 CUDA 设备。所有这些信息被封装为 multi_gpu_cfg 字典,后续传递给算法与日志器。

Sources: on_policy_runner.py

算法层分布式状态管理

算法类通过构造函数中的 multi_gpu_cfg: dict | None 参数接入分布式上下文。以 PPO 为例,当传入非空配置时,算法保存 gpu_global_rankgpu_world_size,并置 is_multi_gpu = True;若配置为 None,则所有 rank 相关标识退化为单 GPU 默认值(rank 0、world size 1)。这种显式分支设计使得同一套算法代码在单卡与多卡场景下均可直接运行,无需外部封装器。PPOAMPDistillation 遵循同样的配置契约:子类通过 super().__init__(..., multi_gpu_cfg=multi_gpu_cfg) 继承基础分布式状态,再根据各自模块特性扩展训练逻辑。

Sources: ppo.py, ppo_amp.py, distillation.py

参数同步原语

rsl_rl 在算法层暴露两个核心同步原语:broadcast_parametersreduce_parameters。两者均直接调用 torch.distributed 底层 API,而非依赖 PyTorch 的高级封装。

参数广播(broadcast_parameters) 在训练开始前由 Runner 调用,目的是确保所有进程从完全一致的模型权重出发。该方法将 self.policy.state_dict()(以及 RND 预测器的状态,若启用)打包为列表,通过 torch.distributed.broadcast_object_list(..., src=0) 从 rank 0 广播到所有进程,随后在各进程本地执行 load_state_dictDistillation 算法也实现了同名方法,仅广播 StudentTeacher 的策略状态。

梯度归约(reduce_parameters) 在每个 mini-batch 的 backward() 之后调用,负责将各进程独立计算的梯度做全局平均。实现上,算法先将所有待同步参数的 grad 展平并拼接为一条长向量 all_grads,执行 torch.distributed.all_reduce(all_grads, op=SUM) 后除以 gpu_world_size,最后按原始形状写回各参数的 grad.data。这种“展平-归约-回填”模式避免了为每个参数单独发起通信,有效降低了延迟。

Sources: ppo.py, distillation.py

训练循环中的分布式协作

OnPolicyRunner.learn() 的主循环中,分布式协作体现在三个精确的时间点:

  1. 训练启动时的参数广播:进入迭代前,若 is_distributed 为真,Runner 调用 self.alg.broadcast_parameters(),保证所有 GPU 上的策略网络初始值一致。
  2. KL 散度归约与学习率广播:在 PPO.update() 的自适应学习率分支中,首先对各进程本地计算的 KL 均值做 all_reduce 求平均;随后仅由 rank 0 根据全局 KL 调整学习率,再通过 torch.distributed.broadcast(lr_tensor, src=0) 将新的学习率下发到全部进程,确保各进程优化器步长一致。
  3. 每步梯度归约:在每个 mini-batch 的 loss.backward() 之后,调用 reduce_parameters() 对策略网络(及 RND 模块)的梯度做全局平均,然后才执行 optimizer.step()
sequenceDiagram
    participant R0 as Rank 0
    participant R1 as Rank 1
    participant RN as Rank N
    Note over R0,RN: 训练开始前
    R0->>R1: broadcast_object_list(state_dict)
    R0->>RN: broadcast_object_list(state_dict)
    Note over R0,RN: 每次 update 迭代
    loop 每个 mini-batch
        R0->>R0: backward()
        R1->>R1: backward()
        RN->>RN: backward()
        R0->>R0: all_reduce(grads, SUM) / world_size
        R1->>R1: all_reduce(grads, SUM) / world_size
        RN->>RN: all_reduce(grads, SUM) / world_size
    end
    opt 自适应学习率
        R0->>R0: 计算全局 KL 均值
        R0->>R0: 调整 learning_rate
        R0->>R1: broadcast(lr_tensor, src=0)
        R0->>RN: broadcast(lr_tensor, src=0)
    end

上述同步点的设计体现了 on-policy RL 的刚性约束:策略网络在更新中途不能出现参数分叉,否则重要性采样比(importance ratio)将失去跨进程可比性。

Sources: on_policy_runner.py, ppo.py, ppo.py

日志系统的分布式适配

分布式训练下,若所有进程同时写入同一日志目录或同一控制台,会导致指标混乱与文件竞争。LoggerLoggerAMP 通过 disable_logs 属性解决该问题:当 is_distributed and gpu_global_rank != 0 时,非主进程的日志写入与终端输出被完全静默。这意味着只有 rank 0 负责向 TensorBoard、WandB、Neptune 等后端上报标量与模型文件。

与此同时,为了正确反映集群级吞吐量,Logger.log() 在计算每轮采样量时显式乘以 gpu_world_size

collection_size = num_steps_per_env * num_envs * gpu_world_size

FPS(每秒帧数)同样基于聚合后的 collection_size 计算,从而让用户直观看到多卡叠加后的数据生成速度。总时间步 tot_timesteps 也按此聚合值累进。

Sources: logger.py, logger.py, logger.py

各算法分布式支持对照

算法 继承关系 参数广播范围 梯度归约范围 KL/LR 同步 特殊说明
PPO 基类 Policy + RND predictor Policy + RND all_reduce + broadcast 完整支持
PPOAMP 继承 PPO 同 PPO(Policy + RND) 同 PPO(Policy + RND) all_reduce + broadcast AMP 判别器参数与梯度未纳入同步原语
Distillation 独立实现 StudentTeacher StudentTeacher 无自适应 LR 拥有独立 broadcast_parametersreduce_parameters 实现

从表中可见,当前所有 on-policy 算法共享同一套“参数广播 + 梯度归约”的通信契约,但各算法需自行确保新增模块(如 AMP 判别器、蒸馏中的教师网络)被纳入同步范围,否则可能出现多卡间模型分叉。

Sources: ppo.py, ppo_amp.py, distillation.py

启动配置与使用要点

由于框架内部已集成进程组初始化,用户侧无需修改训练脚本的核心逻辑,但需遵循以下外部约束:

典型的多卡启动命令形如:

torchrun --standalone --nnodes=1 --nproc_per_node=4 train.py --device cuda

其中 train.py 负责读取 LOCAL_RANK 并实例化 OnPolicyRunner(env, train_cfg, log_dir, device=f"cuda:{local_rank}")

Sources: on_policy_runner.py

延伸阅读与上下游关联

理解多 GPU 机制后,建议继续阅读以下关联章节,以完整把握训练系统的数据流与控制流: