在足式机器人的策略部署中,一个经典的矛盾是:训练时可以依赖全量 privileged 观测(如接触力、全局地形高度、机器人本体完整状态)以获得最优策略,但真机部署时传感器只能获取局部 onboard 观测(如 IMU、关节编码器、深度相机)。学生-教师蒸馏框架(Student-Teacher Distillation)正是为解决这一"感知鸿沟"而设计:先以 PPO 等算法训练一个使用全观测的教师策略,再通过行为克隆(Behavior Cloning)将教师的知识迁移到一个仅使用 onboard 观测的学生网络。与端到端的强化学习不同,蒸馏阶段不涉及奖励函数和价值网络,仅最小化学生与教师输出动作之间的差异,从而显著降低学生策略的输入维度与推理开销。
Sources: student_teacher.py, distillation.py
整体架构与模块关系
蒸馏框架在 RSL-RL 中由三个纵向层级构成:网络模块层定义学生与教师的网络结构;算法层实现行为克隆的损失计算与参数更新;运行层编排环境交互与训练循环。三者通过 DistillationRunner 统一调度,而底层数据流转则复用 RolloutStorage 的存储机制,仅针对蒸馏任务扩展了 privileged_actions 字段。
graph TB
subgraph 网络模块层
ST[StudentTeacher<br/>前馈网络]
STR[StudentTeacherRecurrent<br/>循环网络]
end
subgraph 算法层
D[Distillation<br/>行为克隆算法]
end
subgraph 运行层
DR[DistillationRunner<br/>训练循环]
end
subgraph 数据存储层
RS[RolloutStorage<br/>distillation模式]
end
Env[IsaacLab / MuJoCo<br/>向量环境] -->|obs: TensorDict| DR
DR -->|act / evaluate| D
D -->|调用 student / teacher| ST
D -->|调用 student / teacher| STR
D -->|add_transition| RS
RS -->|generator| D
style ST fill:#e1f5e1
style STR fill:#e1f5e1
style D fill:#fff2cc
style DR fill:#dae8fc
与 PPO 训练流程相比,蒸馏框架最大的架构差异在于消除了 Critic 网络和价值估计。算法不再依赖 GAE 优势估计或策略梯度,而是纯粹基于监督学习范式。这一简化使得蒸馏阶段的超参数空间大幅缩小,训练稳定性也显著提高。
Sources: distillation_runner.py, rollout_storage.py
观测分组:不对称感知的核心机制
蒸馏框架的前提是学生与教师看到不同的观测。RSL-RL 通过 obs_groups 配置字典将环境中的多个观测组(observation groups)映射到不同的观测集(observation sets)。对于蒸馏任务,obs_groups 至少包含两个键:"policy" 定义学生使用的观测组列表,"teacher" 定义教师使用的观测组列表。resolve_obs_groups 函数会在 runner 初始化阶段校验这些配置:若 teacher 键缺失,默认会回退到 policy 的观测组,但在蒸馏场景下通常需要显式配置以体现感知差异。
例如,一个典型的蒸馏配置将 onboard 传感器映射给学生,将 privileged 状态映射给教师:
obs_groups:
policy: ["policy"] # 学生:关节角度、IMU、速度指令
teacher: ["teacher"] # 教师:接触力、地形高度、全局状态
StudentTeacher 模块在初始化时会遍历这些列表,拼接对应观测组的张量维度,从而为 student 和 teacher 分别构建输入层。这种设计允许一个环境同时输出多组观测,而 runner 无需修改环境代码即可灵活重组输入。
Sources: utils.py, student_teacher.py
网络模块详解
StudentTeacher(前馈版本)
StudentTeacher 是蒸馏框架的基础网络模块,继承自 nn.Module。其内部同时维护两个独立的 MLP:student 接收 obs_groups["policy"] 的拼接观测,teacher 接收 obs_groups["teacher"] 的拼接观测。两者输出维度均为 num_actions,但仅有 student 的参数会在训练中被优化。模块还包含独立的学生/教师观测归一化器(EmpiricalNormalization),因为两组观测的统计分布通常差异极大,共享归一化器会导致一方失真。
值得特别关注的是 train() 方法的重载:当调用 policy.train() 时,父类会将所有子模块切换到训练模式,但该方法显式将 teacher 和 teacher_obs_normalizer 重新设为 eval(),确保教师网络在前向传播时保持推理统计量(如 BatchNorm 的运行均值,尽管当前 MLP 未使用 BatchNorm,但这是一种防御性设计)。
Sources: student_teacher.py
StudentTeacherRecurrent(循环版本)
对于需要历史记忆的策略(如处理延迟观测或时间序列特征),StudentTeacherRecurrent 在 student 侧引入了 Memory 模块(支持 LSTM 或 GRU)。教师侧可选择是否为循环网络,通过 teacher_recurrent 参数控制。当教师也为循环网络时,模块会额外实例化 memory_t;否则教师保持纯前馈结构。
循环版本对 hidden states 的管理更为精细:reset() 在环境 done 标志触发时重置记忆,detach_hidden_states() 在反向传播前截断梯度流以防止 BPTT 跨越 episode 边界。get_hidden_states() 返回二元组 (memory_s.hidden_state, memory_t.hidden_state),供 runner 在迭代边界保存和恢复记忆上下文。
Sources: student_teacher_recurrent.py
前馈与循环模块对比
| 特性 | StudentTeacher | StudentTeacherRecurrent |
|---|---|---|
| 循环记忆 | 无 | Student 必带 Memory;Teacher 可选 |
| RNN 类型 | — | lstm / gru |
| 隐藏状态维度 | — | rnn_hidden_dim(默认 256) |
| 层数 | — | rnn_num_layers(默认 1) |
is_recurrent |
False |
True |
reset() |
空操作 | 重置 student / teacher 记忆 |
get_hidden_states() |
(None, None) |
返回实际 hidden state 元组 |
Sources: student_teacher.py, student_teacher_recurrent.py
蒸馏算法:Distillation
Distillation 类是算法的核心,其实现极为简洁:一个优化器、一个损失函数、一个梯度累积机制。与 PPO 的复杂 surrogate loss 不同,蒸馏仅计算行为克隆损失:
$$\mathcal{L}{\text{behavior}} = \frac{1}{N} \sum{i=1}^{N} \text{loss_fn}\left( \pi_{\text{student}}(o_i), \pi_{\text{teacher}}(o_i) \right)$$
其中 loss_fn 支持 MSE 或 Huber 损失,通过 loss_type 配置选择。gradient_length 参数控制梯度累积步数:每累积 gradient_length 个 transition 的损失后才执行一次 optimizer.step(),这在 batch size 较小时有助于稳定梯度估计。
训练阶段的数据流
sequenceDiagram
participant Env as 向量环境
participant Alg as Distillation
participant ST as StudentTeacher
participant Storage as RolloutStorage
loop Rollout 阶段(num_steps_per_env)
Alg->>ST: act(obs) → student采样动作
Alg->>ST: evaluate(obs) → teacher推理动作(no_grad)
Alg->>Storage: add_transition(actions, privileged_actions)
Alg->>Env: step(actions)
end
loop Update 阶段(num_learning_epochs)
Storage->>Alg: generator() 遍历 transitions
Alg->>ST: act_inference(obs) → 学生确定性输出
Note over Alg: loss = MSE(act_inference, privileged_actions)
Alg->>Alg: 累积 gradient_length 步后 backward()
Alg->>Alg: clip_grad_norm_(student.parameters())
Alg->>Alg: optimizer.step()
end
注意一个关键细节:rollout 阶段调用 act() 时,学生网络会采样带有探索噪声的动作(通过 Normal 分布),并将该采样动作送入环境;而 evaluate() 直接返回教师的确定性输出(无探索噪声),作为监督标签存入 privileged_actions。在 update 阶段,算法调用的是 act_inference(),即学生的确定性输出,与教师标签计算损失。这种设计确保环境交互仍具有一定的随机性,而监督信号保持确定。
Sources: distillation.py
数据存储:RolloutStorage 的蒸馏模式
RolloutStorage 通过 training_type 参数区分 RL 与蒸馏两种存储模式。当 training_type == "distillation" 时,存储器额外分配 privileged_actions 张量,其形状与 actions 一致,用于保存教师输出。蒸馏模式不存储 value、log_prob、returns、advantages 等 RL 专属字段,从而节省显存。
蒸馏模式提供专用的 generator() 方法,按时间步顺序 yield (observations, actions, privileged_actions, dones)。与 RL 模式的 mini_batch_generator(随机打乱并生成 mini-batch)不同,蒸馏的 generator 保持时序顺序,这对于循环网络至关重要——RNN 的 hidden state 依赖时间连续性,随机打乱会破坏 episode 内部的状态传播逻辑。
Sources: rollout_storage.py
教师模型加载与状态恢复
蒸馏训练的第一个前提是将预训练的教师参数注入 teacher 网络。StudentTeacher.load_state_dict() 实现了智能的键名映射,支持从两种来源加载:
- 从 RL 训练 checkpoint 加载:PPO 训练保存的模型中,actor 网络的键名前缀为
actor.。load_state_dict检测到键名包含"actor"时,会自动将actor.替换为student_teacher中teacher的前缀,并将actor_obs_normalizer映射到teacher_obs_normalizer。加载完成后设置loaded_teacher = True,返回False表示这不是一次训练恢复,而是教师初始化。 - 从蒸馏 checkpoint 恢复:若键名包含
"student",说明这是一个之前蒸馏训练的断点,直接调用父类的load_state_dict加载 student 和 teacher 的全部参数,返回True表示训练恢复。
DistillationRunner.learn() 会在训练循环开始前强制检查 loaded_teacher 标志,若教师未加载则抛出 ValueError,防止无意义的随机初始化训练。
Sources: student_teacher.py, distillation_runner.py
DistillationRunner 与训练生命周期
DistillationRunner 继承自 OnPolicyRunner,复用了后者的环境交互循环、日志记录、模型保存与分布式训练基础设施。其定制化主要体现在三个覆写点:
首先,_get_default_obs_sets() 返回 ["teacher"],告知 resolve_obs_groups 必须确保 teacher 观测集已正确配置。其次,_construct_algorithm() 解析配置中的 StudentTeacher 或 StudentTeacherRecurrent 类,并实例化 RolloutStorage(..., training_type="distillation")。最后,learn() 增加了教师加载的前置校验。
在训练入口 train.py 中,当配置 agent_cfg.class_name == "DistillationRunner" 时,系统实例化蒸馏 runner,并且无论是否设置 resume 都会加载 checkpoint——因为蒸馏必须从一个预训练的教师模型开始。加载路径通过 get_checkpoint_path 解析,通常是之前 PPO 训练保存的最佳模型。
Sources: distillation_runner.py, train.py
关键配置参数
以下 YAML 片段展示了蒸馏任务的核心配置结构,参数含义见表格:
runner:
class_name: DistillationRunner
num_steps_per_env: 24
obs_groups:
policy: ["policy"]
teacher: ["teacher"]
policy:
class_name: StudentTeacher
student_hidden_dims: [256, 256, 256]
teacher_hidden_dims: [256, 256, 256]
student_obs_normalization: true
teacher_obs_normalization: true
algorithm:
class_name: Distillation
learning_rate: 1.0e-3
loss_type: "mse" # 或 "huber"
gradient_length: 15
max_grad_norm: 1.0
| 参数 | 所属层级 | 说明 |
|---|---|---|
class_name |
runner | 必须为 DistillationRunner |
obs_groups |
runner | 定义 "policy" 与 "teacher" 观测组映射 |
class_name |
policy | StudentTeacher 或 StudentTeacherRecurrent |
student_hidden_dims |
policy | 学生 MLP 隐藏层维度 |
teacher_hidden_dims |
policy | 教师 MLP 隐藏层维度 |
student_obs_normalization |
policy | 是否对学生观测做在线归一化 |
teacher_obs_normalization |
policy | 是否对教师观测做在线归一化 |
loss_type |
algorithm | 行为克隆损失:"mse" 或 "huber" |
gradient_length |
algorithm | 梯度累积步数,等效于 batch accumulation |
max_grad_norm |
algorithm | 仅对学生网络参数做梯度裁剪 |
learning_rate |
algorithm | 优化器学习率(默认 Adam) |
Sources: example_config.yaml, distillation.py
推理部署与模型导出
蒸馏完成后,部署时仅使用 student 网络进行推理。在 play.py 中,runner.get_inference_policy() 返回的是 alg.policy.act_inference,其实际调用的是学生网络的确定性前向传播,教师网络完全不被加载到推理路径中。
模型导出阶段,play.py 会检测策略是否包含 student_obs_normalizer,若存在则将其作为归一化器传入 export_policy_as_jit 和 export_policy_as_onnx。导出的 ONNX 或 TorchScript 模型仅包含学生网络及其归一化参数,输入维度与 onboard 观测一致,天然适合嵌入式部署。
Sources: play.py
分布式多卡训练
蒸馏框架继承了 RSL-RL 的分布式能力。Distillation 类在初始化时解析 multi_gpu_cfg,并在 update() 的 backward 后调用 reduce_parameters(),通过 torch.distributed.all_reduce 对所有 GPU 上的 student 梯度进行均值同步。由于 teacher 参数不参与梯度计算,分布式通信量仅为学生网络参数量,相比完整 Actor-Critic 的训练开销更低。
Sources: distillation.py
进一步阅读
学生-教师蒸馏框架通常建立在高质量的教师策略之上。若尚未完成教师策略的训练,建议先阅读 PPO 核心算法与超参数 与 Actor-Critic 网络架构详解。对于循环网络变体,理解 On-Policy Runner 训练循环 中的 RNN hidden state 管理机制将大有裨益。完成蒸馏后,可参考 策略测试与 Sim2Sim 部署 与 MuJoCo Sim2Sim 部署与真机迁移 将学生策略导出并部署到真机。