🤖 roboto_origin_03 Wiki
首页 / RSL-RL / 策略蒸馏与师生框架

rsl_rl 的策略蒸馏模块实现了一套经典的教师-学生(Teacher-Student)框架,用于将拥有**特权观测(privileged observations)的教师策略知识迁移到只能访问部分观测(partial observations)**的学生策略中。该框架由 DistillationRunnerDistillation 算法以及 StudentTeacher 网络家族组成,与 PPO 等强化学习算法共享底层基础设施,但将学习目标从奖励最大化转变为行为克隆(Behavior Cloning)。在机器人学习场景中,这一设计尤为关键:训练阶段可以利用全状态信息(如接触力、地形高度)训练高性能教师,而部署阶段的学生仅需依赖机载传感器即可复现教师行为。

Sources: distillation_runner.py distillation.py student_teacher.py

核心设计动机

策略蒸馏解决的是观测不对称带来的部署难题。在强化学习训练环境中,系统往往能够提供完整的物理状态作为观测输入,但真实机器人搭载的传感器存在天然局限。如果直接在有限观测上训练策略,性能通常显著低于全状态策略。师生框架的核心思想是:先用标准 RL 算法(如 PPO)训练一个可以访问特权观测的教师策略,再通过监督学习让只能访问部分观测的学生策略模仿教师的输出动作。这样一来,学生策略在部署时无需特权信息,却能继承教师策略的决策质量。

Sources: student_teacher.py

架构全景

整个蒸馏框架由运行器、算法、网络与存储四层协同构成。DistillationRunner 继承自 OnPolicyRunner,复用了向量化环境的 Rollout 循环与日志系统;Distillation 算法替代了 PPO,负责行为克隆损失的计算与优化;StudentTeacher 模块内部并联了两个独立网络,分别对应学生与教师;RolloutStorage 则为蒸馏场景扩展了存储教师标签的能力。

classDiagram
    class OnPolicyRunner {
        +learn()
        +save()
        +load()
        +_construct_algorithm()
    }
    class DistillationRunner {
        +learn()
        +_construct_algorithm()
        +_get_default_obs_sets()
    }
    class Distillation {
        +policy
        +optimizer
        +act()
        +process_env_step()
        +update()
        +compute_returns()
        +broadcast_parameters()
        +reduce_parameters()
    }
    class StudentTeacher {
        +student: MLP
        +teacher: MLP
        +student_obs_normalizer
        +teacher_obs_normalizer
        +act()
        +evaluate()
        +act_inference()
        +load_state_dict()
    }
    class StudentTeacherRecurrent {
        +memory_s: Memory
        +memory_t: Memory
        +teacher_recurrent: bool
    }
    class RolloutStorage {
        +privileged_actions
        +generator()
    }
    
    OnPolicyRunner <|-- DistillationRunner
    DistillationRunner --> Distillation
    Distillation --> StudentTeacher
    StudentTeacher <|-- StudentTeacherRecurrent
    Distillation --> RolloutStorage

Sources: distillation_runner.py distillation.py student_teacher.py rollout_storage.py

师生网络架构

观测分组与网络并联

StudentTeacher 并非单一神经网络,而是在同一个模块内并联了两个独立的前馈网络studentteacher。两者均由 MLP 构成,但输入维度取决于 obs_groups 配置。obs_groups 是一个字典,必须包含 "policy""teacher" 两个键,分别列出学生与教师各自使用的观测组名称。模块通过 get_student_obs()get_teacher_obs() 从输入的 TensorDict 中提取并拼接对应观测,使得教师可以访问比学生更丰富的环境信息。例如,教师可能包含地形高度图和接触力,而学生仅包含 IMU 读数和关节位置。

Sources: student_teacher.py student_teacher.py

教师冻结与训练模式

教师网络在初始化后始终处于评估模式。train() 方法被显式重写:在调用父类将模块设为训练模式后,会立即将 teacherteacher_obs_normalizer 切换为 eval(),确保教师参数不会在蒸馏过程中更新。学生网络及其归一化器则正常参与训练。这种设计保证了蒸馏过程是纯粹的单向知识迁移,而非师生联合优化。

Sources: student_teacher.py

循环变体与记忆管理

StudentTeacherRecurrent 为学生网络引入了 Memory 循环记忆模块(支持 LSTM 与 GRU)。学生的观测首先经过循环网络编码,再将隐状态馈入后续的 student MLP。此外,该变体通过 teacher_recurrent 参数可选地为教师配置循环记忆,使其能够加载来自 ActorCriticRecurrent 训练得到的教师权重。循环版本的 hidden state 管理遵循与 PPO 循环策略相同的约定:在环境终止(dones)时重置对应环境的隐状态,在反向传播前执行 detach_hidden_states() 截断梯度流,并在跨 epoch 时通过 last_hidden_states 保存与恢复。

Sources: student_teacher_recurrent.py student_teacher_recurrent.py student_teacher_recurrent.py

蒸馏算法与训练循环

算法核心差异

Distillation 类替代了 PPO 在训练循环中的位置,但其内部逻辑显著简化。与 PPO 的核心差异体现在三个方面:第一,compute_returns() 为空操作,因为蒸馏属于监督学习,不需要计算折扣回报或 GAE 优势;第二,损失函数仅包含 behavior_loss,通过 MSE 或 Huber 损失衡量学生输出与教师标签的差异;第三,动作标签来源于教师在 act() 阶段通过 evaluate() 生成的确定性输出,并以 privileged_actions 存入 Transition

Sources: distillation.py distillation.py

Rollout 阶段的数据采集

DistillationRunnerlearn() 循环中,每一步的环境交互仍沿用 OnPolicyRunner 的标准流程:调用 alg.act(obs) 获取动作并执行环境步进。不同之处在于,act() 内部除了让学生网络采样动作外,还会同步调用 policy.evaluate(obs) 生成教师标签,两者一并存入 Transition。随后 process_env_step() 将完整 transition 推入 RolloutStorage,并更新学生观测归一化器。

Sources: distillation.py on_policy_runner.py

更新阶段的行为克隆

update() 方法按照 num_learning_epochs 遍历 storage.generator()。对于每一个时间步,学生通过 act_inference() 进行保留梯度的前向计算,教师标签则从存储中直接读取。损失按 gradient_length 进行累积,达到阈值后执行一次反向传播与优化器步进。这种梯度累积机制允许在显存受限的情况下模拟更大的 batch size。梯度裁剪仅作用于 self.policy.student.parameters(),避免误操作已冻结的教师网络。更新结束后,存储被清空,last_hidden_states 被保存以供下一轮迭代使用。

Sources: distillation.py

训练数据流序列

sequenceDiagram
    participant Env as 向量化环境
    participant Runner as DistillationRunner
    participant Alg as Distillation
    participant Policy as StudentTeacher
    participant Storage as RolloutStorage
    
    loop Rollout 阶段 (num_steps_per_env)
        Runner->>Alg: act(obs)
        Alg->>Policy: act(student_obs) → 学生采样动作
        Alg->>Policy: evaluate(teacher_obs) → 教师生成标签
        Alg->>Storage: add_transition(actions, privileged_actions)
        Runner->>Env: step(actions)
        Env-->>Runner: obs, rewards, dones
        Runner->>Alg: process_env_step(obs, rewards, dones, extras)
    end
    
    Runner->>Alg: update()
    loop Epochs × Transitions
        Alg->>Storage: generator()
        Storage-->>Alg: obs, actions, privileged_actions, dones
        Alg->>Policy: act_inference(obs) → 学生预测
        Alg->>Alg: loss_fn(学生预测, privileged_actions)
        note over Alg: 累积 gradient_length 步后<br/>反向传播 + 优化器 step
    end

Sources: distillation.py rollout_storage.py

模型加载与状态迁移

load_state_dict() 是师生框架与 PPO 训练管线衔接的关键桥梁。它智能识别两种来源的检查点,实现无缝的状态迁移:

DistillationRunner.learn() 在启动训练前会强制断言 loaded_teacher 标志,若教师未加载则立即抛出异常,防止在无监督信号的情况下进行无效训练。

Sources: student_teacher.py student_teacher_recurrent.py distillation_runner.py

存储层适配

RolloutStorage 为蒸馏类型进行了专门扩展。当 training_type == "distillation" 时,缓冲区会额外分配 privileged_actions 张量,并在 add_transition() 中持久化教师标签。generator() 方法按时间步顺序依次产出 (observations, actions, privileged_actions, dones),供蒸馏更新使用。这与 RL 类型的 mini_batch_generator(随机打乱索引)形成鲜明对比,因为循环网络必须保持时间顺序以正确维护隐状态传播。

Sources: rollout_storage.py rollout_storage.py

多 GPU 分布式支持

Distillation 完整继承了 OnPolicyRunner 的分布式训练基础设施。broadcast_parameters() 在训练开始前将模型参数从主 GPU 广播到所有进程,确保各 rank 从一致的初始状态出发;reduce_parameters() 在反向传播后通过 torch.distributed.all_reduce 对各 GPU 上的梯度求平均。由于教师网络处于冻结状态,其参数不会生成梯度,因此梯度同步与裁剪实际上仅影响学生网络,既保证了分布式训练的一致性,也避免了不必要的通信开销。

Sources: distillation.py distillation.py

配置参数速查

参数 所属配置 类型 默认值 说明
num_learning_epochs algorithm int 1 每个学习迭代中对整个 buffer 遍历的轮数
gradient_length algorithm int 15 梯度累积步数,每 N 步执行一次优化器 step
learning_rate algorithm float 1e-3 Adam 优化器学习率
max_grad_norm algorithm float / None None 仅对学生网络参数进行梯度裁剪的阈值
loss_type algorithm str "mse" 行为克隆损失类型,可选 "mse""huber"
optimizer algorithm str "adam" 优化器名称
class_name policy str 策略类名,如 "StudentTeacher""StudentTeacherRecurrent"
obs_groups 全局 dict 必须包含 "policy""teacher" 键的观测分组
student_obs_normalization policy bool False 是否对学生观测启用经验归一化
teacher_obs_normalization policy bool False 是否对教师观测启用经验归一化
student_hidden_dims policy list[int] [256,256,256] 学生 MLP 隐藏层维度
teacher_hidden_dims policy list[int] [256,256,256] 教师 MLP 隐藏层维度
rnn_type policy str "lstm" 循环变体的 RNN 类型("lstm" / "gru"
teacher_recurrent policy bool False 教师是否为循环网络(用于加载循环教师权重)

Sources: distillation.py student_teacher.py student_teacher_recurrent.py

与 PPO 的关键差异

维度 PPO Distillation
学习目标 奖励最大化(策略梯度) 行为克隆(监督学习)
网络结构 Actor + Critic Student + Teacher(教师冻结)
回报计算 GAE / 折扣回报 不需要(compute_returns 为空)
损失组成 策略损失 + 值损失 + 熵损失 behavior_loss(MSE / Huber)
动作标签来源 自身采样动作 教师网络 evaluate() 输出
观测分组 actor_obs / critic_obs policy_obs / teacher_obs
存储生成器 mini_batch_generator(随机打乱) generator()(按时间顺序)

Sources: ppo.py distillation.py rollout_storage.py

延伸阅读与下一步