🤖 roboto_origin_03 Wiki
首页 / RSL-RL / 训练运行器生命周期管理

训练运行器(Runner)是 rsl_rl 框架的训练编排中枢,负责将向量化环境、策略算法、经验存储与日志系统组装为完整的训练流水线。本文档从生命周期视角解析 OnPolicyRunner 及其派生类的初始化、训练循环、状态持久化与推理部署四个核心阶段,帮助开发者理解各组件之间的调用时序与状态迁移边界。若你尚未熟悉 PPO 算法细节或向量化环境接口,建议先阅读 PPO 算法实现与训练流程向量化环境抽象接口

运行器类结构与职责边界

框架在 runners/ 目录下提供了三个运行器实现,形成一条以 OnPolicyRunner 为基类的继承链。基类封装了同策(on-policy)训练的标准生命周期,而派生类仅对特定算法所需的组件构造、日志记录和状态管理进行扩展。

运行器类 对应算法 核心扩展点 典型使用场景
OnPolicyRunner PPO (+ RND / Symmetry) 标准生命周期模板 通用强化学习训练
AMPRunner PPOAMP 替换 Logger 为 LoggerAMP;增加 AMP 判别器与循环缓冲区的构造与状态持久化 对抗动作先验模仿学习
DistillationRunner Distillation 重写 _construct_algorithm 以构造师生策略;校验教师模型已加载 策略蒸馏与模型压缩

三者的关键差异体现在初始化阶段的组件装配与检查点格式的扩展上,而训练循环的主干时序(rollout → compute_returns → update → log)完全复用基类实现。这种设计使得新增算法时,开发者通常只需重写 _construct_algorithmsave / load 方法,无需改动训练循环逻辑。

Sources: on_policy_runner.py, amp_runner.py, distillation_runner.py

生命周期总览

运行器的生命周期可划分为四个连续阶段:构造与配置训练循环状态持久化推理部署。以下时序图展示了各阶段内部的关键方法调用及跨组件交互关系。

flowchart TB
    subgraph Init["阶段 1: 构造与配置"]
        A[__init__] --> B[_configure_multi_gpu]
        B --> C[get_observations]
        C --> D[resolve_obs_groups]
        D --> E[_construct_algorithm]
        E --> F[Logger]
    end

    subgraph Train["阶段 2: 训练循环 learn()"]
        G[train_mode] --> H[for it in iterations]
        H --> I[Rollout: act / step / process_env_step]
        I --> J[compute_returns]
        J --> K[update]
        K --> L[log]
        L --> M{save_interval?}
        M -->|Yes| N[save]
        M -->|No| H
    end

    subgraph Checkpoint["阶段 3: 状态持久化"]
        O[save] --> P[model_state_dict + optimizer_state_dict]
        Q[load] --> R[restore state_dict & iter]
    end

    subgraph Inference["阶段 4: 推理部署"]
        S[get_inference_policy] --> T[eval_mode]
        T --> U[return act_inference]
    end

    Init --> Train
    Train --> Checkpoint
    Train --> Inference

在典型训练脚本中,开发者仅显式调用 runner.learn(num_learning_iterations),其余阶段由运行器内部自动推进。这种高层封装隐藏了分布式同步、设备迁移与缓冲区管理的细节,同时通过暴露 save / load / get_inference_policy 等钩子,支持断点续训与模型导出。

Sources: on_policy_runner.py

阶段一:构造与配置

运行器构造函数接收 VecEnv 实例、训练配置字典、log_dirdevice 四个参数,按严格顺序完成以下装配任务。

多 GPU 分布式配置

_configure_multi_gpu 方法首先检测环境变量 WORLD_SIZELOCAL_RANKRANK,判断是否启用分布式训练。若启用,则初始化 torch.distributed 的 NCCL 进程组,并校验 device 与当前进程的 local rank 匹配。配置结果以字典形式存入 self.multi_gpu_cfg,供后续算法构造时使用。该设计将分布式复杂性收敛在运行器层,算法侧仅通过 multi_gpu_cfg 感知自身所处进程上下文。

Sources: on_policy_runner.py

观测组解析与算法装配

运行器通过 env.get_observations() 获取环境的初始观测字典(TensorDict),随后调用 resolve_obs_groups 将用户配置的观测组映射解析为算法所需的观测集合(如 criticrnd_stateteacher)。解析完成后,_construct_algorithm 方法按以下顺序实例化核心组件:

  1. 策略网络:根据 policy_cfg.class_name 动态解析类(如 ActorCriticActorCriticRecurrent),并传入观测字典、观测组、动作维度与策略超参;
  2. 经验存储:创建 RolloutStorage,指定训练类型为 "rl",并预分配 (num_steps_per_env, num_envs) 尺度的张量缓冲区;
  3. 算法实例:解析 alg_cfg.class_name(通常为 PPO),将策略网络、存储对象与分布式配置传入构造器。

AMPRunner 在此阶段额外构造两个 CircularBuffer 实例,分别用于存放策略生成的判别器观测与专家示范观测,随后将它们连同标准组件一并注入 PPOAMPDistillationRunner 则将训练类型改为 "distillation",并构造 StudentTeacher 策略与 Distillation 算法。

Sources: on_policy_runner.py, amp_runner.py, distillation_runner.py

日志系统初始化

运行器在算法装配完成后实例化 Logger(或 LoggerAMP),传入运行器配置、环境配置、进程数与分布式标识。日志器负责管理标量写入器(TensorBoard / WandB / Neptune)、奖励滑动窗口、Git 仓库状态记录以及训练步数统计。值得注意的是,日志器在分布式场景下会自动禁用非主进程的日志输出,避免多进程竞争写入。

Sources: on_policy_runner.py, logger.py

阶段二:训练循环

learn 方法是运行器的核心编排逻辑,其内部按迭代次数展开为 Rollout 采集回报计算策略更新指标记录模型保存 五个子阶段。以下流程图展示了单轮迭代内部的精确调用链。

sequenceDiagram
    participant Runner as OnPolicyRunner
    participant Env as VecEnv
    participant Alg as PPO
    participant Storage as RolloutStorage
    participant Logger as Logger

    Runner->>Runner: train_mode()
    loop num_steps_per_env
        Runner->>Alg: act(obs)
        Alg->>Storage: transition = {...}
        Runner->>Env: step(actions)
        Env-->>Runner: obs, rewards, dones, extras
        Runner->>Alg: process_env_step(obs, rewards, dones, extras)
        Alg->>Storage: add_transition(transition)
        Alg->>Alg: update_normalization(obs)
        Runner->>Logger: process_env_step(rewards, dones, extras)
    end
    Runner->>Alg: compute_returns(last_obs)
    Alg->>Storage: returns / advantages
    Runner->>Alg: update()
    Alg->>Storage: mini_batch_generator()
    Alg-->>Runner: loss_dict
    Runner->>Logger: log(loss_dict, ...)
    opt Logger->>Logger: save_model()
    opt it % save_interval == 0
        Runner->>Runner: save(path)
    end

Rollout 采集

每一轮迭代开始时,运行器在 torch.inference_mode() 上下文中执行固定步数(num_steps_per_env)的环境交互。每一步的时序为:算法根据当前观测采样动作 → 运行器将动作发往环境执行 step → 算法处理环境反馈并将转移存入 RolloutStorageprocess_env_step 内部还负责更新观测归一化统计量、计算 RND 内在奖励、处理 time_outs 引导(bootstrapping),并在回合结束时重置策略的隐藏状态。运行器同时将奖励与终止信号传递给日志器,用于维护滑动窗口统计。

Sources: on_policy_runner.py, ppo.py

回报计算与策略更新

Rollout 结束后,运行器调用 alg.compute_returns(last_obs),利用 GAE(Generalized Advantage Estimation)从最后一步向前递推计算回报值与优势值。随后进入 alg.update(),算法从 RolloutStorage 中抽取小批量数据,执行多次 PPO 裁剪目标优化。若启用了对称性增强、RND 探索或辅助损失,这些模块的梯度计算与反向传播也在 update 内部完成。对于分布式训练,update 还会调用 reduce_parameters 对所有 GPU 的梯度进行 all-reduce 平均,保证参数同步。

Sources: ppo.py, ppo.py

日志记录与定期保存

策略更新完成后,运行器将损失字典、学习率、动作标准差、采集耗时与更新耗时提交给日志器。日志器把标量写入外部服务,并在控制台打印训练摘要。若当前迭代满足 save_interval 的整除条件,运行器立即执行 save 将当前模型写出到磁盘。整个训练流程结束后,无论是否满足间隔条件,运行器都会强制保存最终模型,防止训练结果丢失。

Sources: on_policy_runner.py, logger.py

阶段三:状态持久化

运行器提供了 saveload 两个对称接口,用于训练状态的全量持久化与恢复。标准检查点字典包含 model_state_dictoptimizer_state_dict 与当前迭代号 iter。当启用 RND 时,检查点额外追加 rnd_state_dictrnd_optimizer_state_dictAMPRunner 进一步扩展了检查点格式,纳入 AMP 判别器网络、其观测归一化器与判别器优化器的状态字典。

load 方法在恢复模型参数后,通过 policy.load_state_dict 的返回值 resumed_training 判断网络结构是否兼容。仅当兼容时,才恢复优化器状态与学习迭代号,避免从异构检查点恢复时发生状态污染。map_location 参数支持跨设备加载(例如从 GPU 训练环境加载到 CPU 推理环境)。

Sources: on_policy_runner.py, amp_runner.py

阶段四:推理部署

训练完成后,运行器通过 get_inference_policy 提供推理入口。该方法首先调用 eval_mode() 将策略(以及 RND、AMP 判别器等辅助模块)切换至评估模式,关闭 Dropout 与 BatchNorm 的训练统计更新;随后若指定了目标设备,则将策略网络迁移至该设备;最终返回 alg.policy.act_inference 方法句柄,供外部脚本直接调用。act_inference 与训练阶段的 act 不同,它通常返回确定性的均值动作而非采样动作,且不再记录值函数与 log-prob 等训练专用张量。

Sources: on_policy_runner.py

训练模式与评估模式

运行器显式维护 train_modeeval_mode 两个状态切换方法,以确保 PyTorch 模块在不同阶段处于正确的计算模式。基类中,这两个方法分别对策略网络与 RND 模块调用 .train().eval()AMPRunner 重写了这对方法,额外将 AMP 判别器及其观测归一化器纳入模式切换范围。这种显式管理避免了在 rollout 采集或推理阶段因忘记切换模式而导致的归一化统计污染或 Dropout 噪声问题。

Sources: on_policy_runner.py, amp_runner.py

下一步阅读建议

训练运行器生命周期管理处于 rsl_rl 基础设施层的核心位置,向上衔接算法实现,向下对接环境与日志。根据你的关注点,推荐继续深入以下主题: