🤖 roboto_origin_03 Wiki
首页 / RSL-RL / 配置文件与参数体系

在强化学习实验中,超参数的选择往往决定了训练成败。rsl_rl 采用了一套层次清晰、扩展友好的配置体系:你只需要向 Runner 提供一个 Python 字典,框架便会自动完成类名解析、观测维度推断、网络构建和算法初始化。本文将带你从顶层到底层,系统性地理解这套参数体系的结构、语义与工作机制。

配置的整体架构

rsl_rl 的配置以单个字典 train_cfg 作为唯一入口,由 Runner 在初始化时将其拆解为三个逻辑层级:运行控制层(Runner 自身使用)、策略网络层policy)和算法层algorithm)。此外,一个名为 obs_groups 的子字典负责告诉框架:环境中的哪些观测应该送给 Actor,哪些应该送给 Critic,以及高级模块(如 RND、AMP)需要订阅哪部分状态。

以下 Mermaid 图展示了配置从用户代码到运行时组件的流向。整个过程中,框架会调用多个 resolve_* 函数对原始配置进行校验、补全和类型解析。

flowchart TD
    A[用户提供的 train_cfg 字典] --> B[OnPolicyRunner]
    B --> C[抽取 policy_cfg]
    B --> D[抽取 alg_cfg]
    B --> E[抽取 obs_groups]
    E --> F[resolve_obs_groups<br/>校验与填充默认值]
    F --> G[策略网络初始化<br/>ActorCritic / CNN / Recurrent]
    C --> G
    D --> H[resolve_rnd_config<br/>resolve_symmetry_config<br/>resolve_amp_config]
    H --> I[算法初始化<br/>PPO / PPOAMP / Distillation]
    G --> I
    I --> J[开始训练循环]

Sources: on_policy_runner.py, utils.py

顶层配置项一览

在根级字典中,除了 policyalgorithm 两个核心命名空间外,还有若干直接控制训练流程的字段。下表总结了初学者最需要关注的顶层键:

配置键 类型 默认值/典型值 说明
num_steps_per_env int 24 每次策略更新前,每个向量化环境收集的转移步数
save_interval int 100 每隔多少迭代保存一次模型权重
obs_groups dict 见下文 观测分组映射,定义策略、Critic 等模块看到的观测子集
policy dict 策略网络配置命名空间
algorithm dict 算法超参数命名空间
empirical_normalization bool 已废弃 旧版参数,现应分别写入 policy.actor_obs_normalizationpolicy.critic_obs_normalization

Sources: on_policy_runner.py, on_policy_runner.py

观测分组配置 obs_groups

向量化环境返回的观测通常以 TensorDict 形式组织,内部包含多个语义不同的观测组(例如 base_lin_velprojected_gravitycommands 等)。obs_groups 的作用就是建立观测组 -> 观测集合的映射,让不同网络只接收它们需要的信息。

框架要求至少定义 policy 集合;对于 criticrnd_stateteacher 等保留名称,若未显式定义,resolve_obs_groups 会按以下规则自动补全:先检查环境中是否存在同名观测组,若存在则直接采用,否则复用 policy 的观测列表。

flowchart LR
    subgraph 环境输出
        A1[base_lin_vel]
        A2[projected_gravity]
        A3[joint_pos]
        A4[height_scan]
    end
    subgraph obs_groups 映射
        B1[policy] --> A1 & A2 & A3
        B2[critic] --> A1 & A2 & A3 & A4
        B3[rnd_state] --> A2 & A3
    end

一个典型的 obs_groups 配置如下:

obs_groups = {
    "policy": ["base_lin_vel", "projected_gravity", "joint_pos"],
    "critic": ["base_lin_vel", "projected_gravity", "joint_pos", "height_scan"],
    "rnd_state": ["projected_gravity", "joint_pos"],
}

Sources: utils.py, vec_env.py

策略网络配置 policy

policy 字典决定使用哪类策略网络及其结构尺寸。rsl_rl 支持通过字符串 class_name 动态解析类,这意味着你无需在代码中显式 import,只需写出类名或完整模块路径即可。

配置键 类型 典型值 说明
class_name str "ActorCritic" 策略类名,支持 "ActorCriticRecurrent""ActorCriticCNN""ActorCriticAttnEnc"
actor_hidden_dims list[int] [256, 256, 256] Actor(策略)MLP 的隐藏层维度
critic_hidden_dims list[int] [256, 256, 256] Critic(价值)MLP 的隐藏层维度
activation str "elu" 激活函数,可选 relutanhselugelu
actor_obs_normalization bool False 是否对 Actor 输入进行在线经验归一化
critic_obs_normalization bool False 是否对 Critic 输入进行在线经验归一化
init_noise_std float 1.0 动作分布初始标准差
noise_std_type str "scalar" 标准差参数化方式:scalar 直接学习标准差,log 学习对数标准差
state_dependent_std bool False 是否让标准差依赖于状态(网络输出两组头)

框架通过 resolve_callableclass_name 字符串解析为实际的 Python 类。你可以直接写简单名(如 "ActorCritic"),也可以写完整模块路径(如 "my_module.policies:CustomActorCritic"),这为用户自定义策略提供了极大的灵活性。

Sources: actor_critic.py, on_policy_runner.py, utils.py

算法配置 algorithm

algorithm 字典存放与训练更新直接相关的超参数。对于标准的 PPO 训练,核心字段如下:

配置键 类型 典型值 说明
class_name str "PPO" 算法类名,AMP 场景下使用 "PPOAMP",蒸馏场景使用 "Distillation"
num_learning_epochs int 5 每次收集数据后,策略更新的轮数(epoch)
num_mini_batches int 4 每次更新将数据切分成多少个小批量
clip_param float 0.2 PPO 裁剪阈值 $\epsilon$
gamma float 0.99 折扣因子
lam float 0.95 GAE 广义优势估计参数 $\lambda$
value_loss_coef float 1.0 价值函数损失权重
entropy_coef float 0.01 策略熵正则化系数
learning_rate float 3e-4 Adam 优化器初始学习率
max_grad_norm float 1.0 梯度裁剪阈值
use_clipped_value_loss bool True 是否对价值函数也使用裁剪损失
schedule str "adaptive" 学习率调度方式:adaptive 根据 KL 散度自动调整,fixed 保持固定
desired_kl float 0.01 自适应学习率调度期望的 KL 散度目标值
normalize_advantage_per_mini_batch bool False 是否在每轮小批量内部重新归一化优势
enable_aux_loss bool False 是否启用策略网络辅助损失
aux_loss_coef float 0.0 辅助损失系数

Sources: ppo.py, ppo.py

高级特性子配置

除了基础 PPO 参数外,algorithm 字典还可以包含若干可选子字典,用于启用探索增强、对称先验或动作风格模仿。

RND 随机网络蒸馏配置 rnd_cfg

rnd_cfg 不为 None 时,框架会自动实例化 RandomNetworkDistillation 模块,将内在奖励叠加到外在奖励上。

配置键 类型 说明
predictor_hidden_dims list[int] 预测网络隐藏层
target_hidden_dims list[int] 目标网络隐藏层
num_outputs int 嵌入维度
activation str 网络激活函数
weight float 内在奖励缩放权重(会被自动乘以 step_dt
state_normalization bool 是否对 RND 输入状态做归一化
reward_normalization bool 是否对内在奖励做归一化
weight_schedule dict 权重调度器,支持 constantsteplinear 模式

框架在 resolve_rnd_config 中会根据 obs_groups["rnd_state"] 自动计算输入维度并注入配置,因此你无需手动填写 num_states

Sources: rnd.py, rnd.py

对称性增强配置 symmetry_cfg

用于利用运动对称性进行数据增强或镜像损失约束。

配置键 类型 说明
use_data_augmentation bool 是否在训练时对批量数据做左右镜像增强
use_mirror_loss bool 是否添加镜像一致性损失
data_augmentation_func str/callable 镜像变换函数,支持字符串名或直接传入 callable
mirror_loss_coeff float 镜像损失权重

Sources: symmetry.py, ppo.py

AMP 对抗动作先验配置 amp_cfg

AMP 算法需要额外的判别器配置。由于 AMP 观测通常是带历史帧的三维张量,resolve_amp_config 会自动推断 disc_obs_stepsdisc_obs_dim

配置键 类型 说明
disc_learning_rate float 判别器 Adam 学习率
disc_trunk_weight_decay float 判别器主干 L2 正则化
disc_linear_weight_decay float 判别器输出层 L2 正则化
disc_max_grad_norm float 判别器梯度裁剪阈值
loss_type str 判别器损失类型:GANLSGANWGAN
style_reward_scale float 风格奖励缩放系数
task_style_lerp float 任务奖励与风格奖励的线性插值权重
amp_discriminator dict 判别器网络结构子配置(hidden_dimsactivation

Sources: amp.py, amp.py, ppo_amp.py

配置解析的生命周期

理解配置何时被修改、何时被消费,有助于你在遇到错误时快速定位问题。整个生命周期可分为四个阶段:

sequenceDiagram
    participant U as 用户代码
    participant R as OnPolicyRunner
    participant Res as resolve_* 函数
    participant Alg as 算法实例

    U->>R: train_cfg (原始字典)
    R->>R: 抽取 policy / algorithm / obs_groups
    R->>Res: resolve_obs_groups()
    Res-->>R: 补全 critic 等默认观测集
    R->>Res: resolve_rnd_config()
    Res-->>R: 注入 num_states, weight*dt
    R->>Res: resolve_symmetry_config()
    Res-->>R: 注入 _env 对象
    R->>Res: resolve_amp_config()
    Res-->>R: 注入 disc_obs_dim, disc_obs_steps
    R->>Res: resolve_callable(class_name)
    Res-->>R: 返回实际 Python 类
    R->>Alg: 实例化 policy + algorithm

需要特别注意的是,resolve_callable 会在解析类名的同时弹出(pop) class_name 字段,因此原始字典会被修改。如果你需要在训练结束后复用同一个配置字典构建其他对象,建议先进行深拷贝。

Sources: on_policy_runner.py, utils.py

完整配置示例

下面给出一份面向足式机器人 locomotion 任务的典型 PPO 配置。你可以将其保存为 YAML 或直接在 Python 代码中以字典形式书写:

train_cfg = {
    "num_steps_per_env": 24,
    "save_interval": 500,
    "obs_groups": {
        "policy": ["base_lin_vel", "projected_gravity", "joint_pos", "joint_vel", "commands"],
        "critic": ["base_lin_vel", "projected_gravity", "joint_pos", "joint_vel", "commands", "height_scan"],
    },
    "policy": {
        "class_name": "ActorCritic",
        "actor_hidden_dims": [512, 256, 128],
        "critic_hidden_dims": [512, 256, 128],
        "activation": "elu",
        "actor_obs_normalization": True,
        "critic_obs_normalization": True,
        "init_noise_std": 1.0,
        "noise_std_type": "scalar",
        "state_dependent_std": False,
    },
    "algorithm": {
        "class_name": "PPO",
        "num_learning_epochs": 5,
        "num_mini_batches": 4,
        "clip_param": 0.2,
        "gamma": 0.99,
        "lam": 0.95,
        "value_loss_coef": 1.0,
        "entropy_coef": 0.01,
        "learning_rate": 3e-4,
        "max_grad_norm": 1.0,
        "use_clipped_value_loss": True,
        "schedule": "adaptive",
        "desired_kl": 0.01,
        "normalize_advantage_per_mini_batch": False,
        # 高级特性(可选)
        "rnd_cfg": None,
        "symmetry_cfg": None,
    },
}

Sources: on_policy_runner.py

常见配置错误与排查

初学者在调整配置时,经常会遇到以下三类错误:

错误现象 根因 解决方法
ValueError: 'policy' key must be in obs_groups 忘记定义 policy 观测集,或拼写错误 确保 obs_groups 包含 "policy"
ValueError: Observation 'xxx' not found obs_groups 中引用的观测组不存在于环境返回的 TensorDict 打印 env.get_observations().keys() 核对可用观测名
Could not resolve 'class_name' class_name 字符串拼写错误,或自定义类未安装到 Python 路径 检查类名拼写;若是自定义类,使用完整模块路径 "module.path:ClassName"
The RND module only supports 1D observations RND 的 rnd_state 观测集中包含了多维观测(如图像、历史帧) 为 RND 单独选择一维向量观测组
Symmetry augmentation is not supported for recurrent policies ActorCriticRecurrent 启用了对称数据增强 对称增强仅支持前馈策略,循环策略需关闭 use_data_augmentation

Sources: utils.py, ppo.py, rnd.py

推荐阅读顺序

掌握配置体系后,你可以根据实验需求继续深入以下专题: