🤖 roboto_origin_03 Wiki
首页 / RSL-RL / 模型保存、加载与推理部署

本文档聚焦 rsl_rl 框架中神经网络策略从训练到部署的完整生命周期管理,涵盖 Checkpoint 的自动保存策略、跨场景加载恢复机制,以及面向生产环境的推理接口设计。理解这些机制对于实现训练断点续训、教师模型蒸馏初始化、以及将策略网络无缝迁移到真实机器人部署至关重要。

Checkpoint 保存机制与文件结构

框架通过运行器(Runner)层统一封装模型持久化逻辑。OnPolicyRunnerAMPRunner 均实现了 save() 方法,在训练循环中按配置间隔自动触发,并在训练结束时保存最终模型。保存文件采用 PyTorch 原生 torch.save 序列化,扩展名为 .pt,默认存放于日志目录下,命名格式为 model_<iteration>.pt。基础 OnPolicyRunner 保存的核心字段包括策略网络的 state_dict、优化器状态、当前迭代计数与附加元信息;若启用 RND 探索奖励,则额外追加 RND 网络及其优化器状态。AMPRunner 作为子类进一步扩展了 AMP 判别器、其观测归一化层以及判别器优化器的状态。保存完成后,运行器会调用日志系统的 save_model() 方法,当使用 Weights & Biases 或 Neptune 时自动将模型文件上传至云端实验追踪服务。

保存字段 OnPolicyRunner AMPRunner 说明
model_state_dict 策略网络(Actor-Critic)参数
optimizer_state_dict PPO 策略优化器状态
iter 当前训练迭代数
infos 附加元信息字典
rnd_state_dict 条件 条件 RND 网络参数(若启用)
rnd_optimizer_state_dict 条件 条件 RND 优化器状态(若启用)
amp_discriminator_state_dict AMP 判别器参数
amp_discriminator_normalizer_state_dict 判别器观测归一化统计量
amp_discriminator_optimizer_state_dict 判别器优化器状态

在分布式多 GPU 训练中,框架通过 Logger.disable_logs 标志确保仅全局 rank 0 进程执行磁盘写入与云端上传,避免冗余 I/O 与存储冲突。训练循环内的保存触发条件为 it % self.cfg["save_interval"] == 0,其中 save_interval 由顶层训练配置决定。

Sources: on_policy_runner.py, amp_runner.py, logger.py

模型加载与训练恢复

load() 方法提供了从 Checkpoint 恢复训练状态的完整能力。该方法首先通过 torch.load 反序列化文件字典,随后将 model_state_dict 注入策略网络。这里存在一个关键的设计模式:load() 并非直接调用 PyTorch 的 load_state_dict,而是委托给策略模块自身的 load_state_dict() 方法,并捕获其返回值 resumed_training。该布尔标志决定了后续是否恢复优化器状态与学习迭代计数,这一设计主要服务于蒸馏场景——当加载的是预训练教师模型(仅用于初始化,不恢复训练)时,返回 False 以阻止优化器覆盖。

对于标准的 ActorCriticActorCriticRecurrentActorCriticCNNActorCriticAttnEncload_state_dict() 直接代理父类实现并返回 True,表示可以继续恢复训练。而 StudentTeacher 模块则实现了更复杂的参数路由逻辑:若 Checkpoint 键名中包含 "actor."(表明来源为 RL 训练),则将 Actor 参数重映射后加载至 teacher 网络,并返回 False,同时将 loaded_teacher 标志置为真;若键名中包含 "student."(来源为蒸馏训练),则按常规模块加载并返回 True。这种设计使得同一个 load() 接口既能支持蒸馏前的教师预热加载,也能支持蒸馏过程本身的断点续训。

load() 还支持 map_location 参数,允许 Checkpoint 在不同硬件设备间迁移(例如从 GPU 训练节点加载到 CPU 推理节点),以及 load_optimizer 开关以控制是否恢复优化器动量状态。

Sources: on_policy_runner.py, actor_critic.py, student_teacher.py

推理部署接口

框架将推理部署抽象为获取一个无状态可调用对象的过程。OnPolicyRunner.get_inference_policy() 完成三项准备工作:将策略切换至 eval() 模式(禁用 Dropout 等训练专属行为)、可选地将模型迁移到目标设备,最后返回策略模块的 act_inference 方法句柄。调用方可直接将该句柄传入环境交互循环,无需再感知运行器或算法细节。

policy = runner.get_inference_policy(device="cuda:0")
actions = policy(observations)  # 直接输出确定性动作

推理与训练采样的核心差异体现在 act_inference()act() 的实现上。训练阶段的 act() 会更新动作分布并从中采样,引入策略噪声以保障探索;而 act_inference() 直接输出分布的均值(确定性动作)。以基础 ActorCritic 为例,act_inference 先提取 Actor 观测、经过归一化,再前向传播 Actor MLP,若采用状态依赖型标准差(state_dependent_std),则截取输出张量的均值通道返回。对于循环网络变体 ActorCriticRecurrentact_inference 会自动驱动 memory_a 的隐状态更新,无需调用方手动管理 hidden_statemasks,这极大简化了循环策略的部署复杂度。CNN 变体在推理时同样完成 2D 观测的编码拼接,注意力编码器变体 ActorCriticAttnEnc 还额外支持 return_attention 参数以返回注意力权重图,便于可视化分析。

训练阶段的 Rollout 环节本身也大量复用了推理路径:代码通过 torch.inference_mode() 上下文禁用梯度计算,在环境中执行 self.alg.act(obs) 收集经验,这与部署阶段的推理模式在性能特征上高度一致,仅存在“采样 vs 取均值”的语义差异。

Sources: on_policy_runner.py, actor_critic.py, actor_critic_recurrent.py, actor_critic_cnn.py, actor_critic_attn_enc.py

多 GPU 参数同步与云端备份

在多 GPU 分布式训练中,框架不仅涉及保存加载,还依赖运行时参数同步机制保障各进程模型一致性。PPO.broadcast_parameters() 在训练开始前由 rank 0 将策略 state_dict(以及 RND 预测器参数)广播至所有进程,随后各进程加载该字典,确保初始权重完全一致。梯度同步则由 reduce_parameters() 负责,在各进程反向传播后执行 all-reduce 求平均,保证多卡等价于大单卡 batch 训练。

云端备份方面,Logger 通过策略模式支持 TensorBoard、W&B、Neptune 三种后端。仅当后端为 wandbneptune 且当前进程为主进程时,save_model() 才会将 .pt 文件上传,实现模型版本与实验指标的自动关联。

Sources: ppo.py, logger.py, wandb_utils.py, neptune_utils.py

典型使用场景

以下流程图展示了从训练到部署的典型 checkpoint 生命周期:

flowchart TD
    A[训练启动] --> B{是否加载 checkpoint?}
    B -->|是| C[runner.load(path)]
    C --> D[根据 resumed_training 恢复优化器与迭代数]
    B -->|否| E[初始化新模型]
    E --> F[训练循环]
    D --> F
    F --> G{it % save_interval == 0?}
    G -->|是| H[runner.save(path)]
    H --> I[本地磁盘 .pt]
    H --> J[可选: W&B/Neptune 上传]
    G -->|否| F
    F --> K[训练结束]
    K --> L[保存最终模型]
    L --> M[推理部署]
    M --> N[runner.get_inference_policy]
    N --> O[env.step(policy(obs))]

断点续训模式:调用 runner.load("model_1000.pt"),框架自动恢复网络权重、优化器状态并将 current_learning_iteration 设为 1000,下一次 learn() 将从该迭代继续计数。

蒸馏初始化模式:调用 runner.load("teacher_model.pt") 加载一个标准 PPO 训练产出的 checkpoint 到 DistillationRunner。由于 StudentTeacher.load_state_dict 检测到 "actor." 前缀,参数被路由至教师网络,返回 False,因此优化器与迭代计数不会被覆盖,随后即可开始蒸馏训练。

纯推理部署模式:无需构造完整运行器,只需实例化策略类并直接加载 state_dict 中的 "model_state_dict",或复用 get_inference_policy() 获取闭包。务必确保推理时观测经过与训练时相同的归一化预处理。

Sources: on_policy_runner.py, distillation_runner.py

下一步阅读

掌握模型生命周期管理后,建议继续阅读以下相关主题以完善对整个训练基础设施的理解: