🤖 roboto_origin_03 Wiki
首页 / RSL-RL / Actor-Critic 基础架构设计

本文档聚焦 rsl_rl 框架中最核心的策略网络基础设施:ActorCritic 基类及其配套组件。作为连接环境观测与 PPO 算法的桥梁,该架构不仅实现了经典的策略-值函数分离设计,还通过观测分组(obs_groups)机制允许 Actor 与 Critic 接收差异化的输入,同时以模块化的网络组件支撑从纯 MLP 到 CNN、RNN、注意力编码器等多元变体的平滑扩展。理解这一层的设计决策,是掌握后续训练流程与高级变体的先决条件。

Sources: actor_critic.py

整体架构概览

rsl_rl 的 Actor-Critic 架构遵循策略网络(Actor)与值函数网络(Critic)分离的经典范式,但在此基础上引入了两项关键设计:一是通过 TensorDictobs_groups 实现观测的语义化解耦,使策略与值函数可以消费不同的观测子集;二是将网络构建、分布生成、归一化更新封装在同一模块内,形成对外统一、对内高内聚的接口。下图展示了核心类之间的静态结构关系:

classDiagram
    class ActorCritic {
        +bool is_recurrent = False
        +MLP actor
        +MLP critic
        +nn.Module actor_obs_normalizer
        +nn.Module critic_obs_normalizer
        +Normal distribution
        +Tensor~Dict~ obs_groups
        +act(obs) Tensor
        +act_inference(obs) Tensor
        +evaluate(obs) Tensor
        +update_normalization(obs) void
        -_update_distribution(obs) void
    }

    class MLP {
        +__init__(input_dim, output_dim, hidden_dims, activation)
        +init_weights(scales) void
    }

    class EmpiricalNormalization {
        +Tensor _mean
        +Tensor _var
        +Tensor _std
        +update(x) void
        +forward(x) Tensor
    }

    class TensorDict {
        +观测字典
    }

    ActorCritic --> MLP : actor / critic
    ActorCritic --> EmpiricalNormalization : 可选归一化
    ActorCritic --> TensorDict : 输入观测
    MLP --|> nn.Sequential

ActorCritic 直接继承自 torch.nn.Module,内部聚合两个 MLP 实例分别承担策略输出与值估计职责,并可选地嵌入 EmpiricalNormalization 模块进行输入归一化。所有变体(循环、CNN、注意力)均复用或重写上述核心交互模式,保持接口一致性。

Sources: actor_critic.py, mlp.py, normalization.py

观测分组机制:解耦 Actor 与 Critic 的输入空间

传统实现中 Actor 与 Critic 往往共享同一观测向量,这在很多任务中是次优的:值函数可能需要全局状态信息(如其他智能体的位置),而策略出于泛化性考虑仅需局部本体感知。rsl_rl 通过 obs_groups 机制显式解决了这一问题。obs_groups 是一个字典,键为观测集合名(如 "policy""critic"),值为环境观测组名的列表。运行时,框架调用 resolve_obs_groups() 校验并补全配置:若 critic 集合缺失,默认会回退到 policy 使用的观测组,或直接使用环境中同名的观测组。在 ActorCritic 初始化阶段,代码遍历 obs_groups["policy"]obs_groups["critic"] 中指定的组名,从输入 TensorDict 中读取对应张量并在最后一维拼接,从而分别得到 num_actor_obsnum_critic_obs。这种设计使得同一环境可以无缝输出结构化观测,而策略与值函数按需取用,无需修改环境代码。

Sources: actor_critic.py, utils.py, vec_env.py

ActorCritic 核心类解析

网络构建与参数体系

ActorCritic.__init__ 的参数体系围绕网络结构归一化策略探索噪声三个维度展开。网络结构侧,actor_hidden_dimscritic_hidden_dims 允许独立配置隐藏层尺寸,默认均为 [256, 256, 256]activation 则通过 resolve_nn_activation() 解析为 PyTorch 激活模块,默认采用 elu。归一化侧,actor_obs_normalizationcritic_obs_normalization 两个布尔开关分别控制是否在 Actor 与 Critic 输入前插入 EmpiricalNormalization。探索噪声侧,框架支持两种标准差建模方式:全局共享的标量/对数参数(state_dependent_std=False),或网络输出头中的状态依赖标准差(state_dependent_std=True)。当启用状态依赖时,actor 的输出维度会被扩展为 [2, num_actions],其中第一维对应动作均值,第二维对应标准差,初始化时对标准差分支的权重置零、偏置设为初始噪声值,确保训练初期分布接近预设值。

Sources: actor_critic.py, utils.py

动作分布与接口方法

ActorCritic 不实现通用的 forward() 方法,而是暴露三个语义明确的接口。act(obs) 用于训练时的随机采样:先提取并归一化 Actor 观测,调用 _update_distribution() 构建 Normal 分布,再执行 sample()act_inference(obs) 用于确定性推理:同样经过前向计算,但直接返回分布均值(若状态依赖标准差则取输出张量的第一个切片)。evaluate(obs) 则服务于 Critic:提取 Critic 观测、归一化后输入 Critic MLP,输出状态值估计。三者职责分离,避免了训练和推理阶段因调用路径不一致导致的潜在错误。此外,get_actions_log_prob()action_meanaction_stdentropy 等属性为 PPO 算法计算损失提供了便捷的分布访问入口。

Sources: actor_critic.py

状态依赖标准差与噪声类型

动作分布的标准差决定了策略的探索强度。框架在 noise_std_type 上提供 "scalar""log" 两种参数化形式:前者将标准差直接作为可学习参数,后者参数化对数标准差以强制正值并改善数值稳定性。当 state_dependent_std=True 时,标准差不再是全局标量,而是随输入状态变化,适用于需要异方差噪声的复杂控制任务。_update_distribution() 方法内部统一处理这四种组合(两种参数类型 × 两种依赖模式),将最终均值与标准差传入 torch.distributions.Normal,并通过 Normal.set_default_validate_args(False) 关闭分布参数校验以换取前向速度。

Sources: actor_critic.py, actor_critic.py

基础网络组件

MLP 多层感知机

MLP 类继承自 nn.Sequential,是 Actor 与 Critic 的骨干网络。其构造函数接受 input_dimoutput_dimhidden_dimsactivation。设计上的两个便利特性值得注意:一是隐藏层维度支持 -1 占位符,框架会自动将其替换为 input_dim,便于构建瓶颈或对称结构;二是 output_dim 支持整数或元组/列表,当传入元组时,最后一层线性输出后接 nn.Unflatten,自动将扁平张量恢复为目标形状,这一特性在状态依赖标准差(输出形状 [2, num_actions])等场景中被直接利用。init_weights() 方法提供正交初始化支持,允许按层传入不同的 gain 缩放因子。

Sources: mlp.py

观测归一化:EmpiricalNormalization

EmpiricalNormalization 基于运行样本统计量对输入进行在线均值-方差标准化。它注册 _mean_var_std 三个 buffer,在 forward() 中执行 (x - _mean) / (_std + eps) 的归一化。update(x) 方法采用增量式 Welford 风格更新:按当前 batch 占总样本的比例逐步修正均值与方差,支持 until 参数设定停止更新的样本阈值,常用于前若干步收集数据时学习分布、之后冻结参数的训练策略。由于 update 被标记为 @torch.jit.unused,该模块可安全地用于 JIT 编译的推理路径而不触发跟踪错误。ActorCriticupdate_normalization() 方法封装了对 Actor 与 Critic 两侧归一化器的批量更新。

Sources: normalization.py, actor_critic.py

可扩展变体与演进路径

ActorCritic 作为最简基类,其设计模式被后续变体严格遵循:is_recurrent 类属性用于运行期快速判断策略类型;act / act_inference / evaluate / reset / update_normalization 的方法签名保持一致;观测分组提取逻辑复用相同模式。例如,ActorCriticRecurrent 在 Actor 与 Critic 的 MLP 前各插入一个 Memory(GRU/LSTM)模块,通过 reset(dones) 在 episode 终止时清零对应环境的隐状态。ActorCriticCNN 则保留相同的 MLP 与噪声结构,但在观测提取阶段区分 1D 与 2D 观测组,将 2D 输入先经 CNN 编码后再与 1D 观测拼接送入 MLP。这种“先编码后决策”的分层范式,使得新变体只需重写观测预处理和分布更新逻辑,无需改动 PPO 训练代码。

Sources: actor_critic_recurrent.py, actor_critic_cnn.py

与训练运行器的协作

OnPolicyRunner 在初始化阶段通过 resolve_callable() 从配置中动态解析策略类名(如 "ActorCritic""ActorCriticRecurrent"),随后将环境观测 TensorDict、解析后的 obs_groups 以及 env.num_actions 传入构造函数。完成策略实例化后,该策略对象被直接注入 PPO 算法。在训练循环中,运行器以 self.alg.act(obs) 采样动作,环境步进后通过 self.alg.process_env_step() 存储 transition,最终调用 self.alg.update() 进行策略更新。这一过程中,ActorCritic 对运行器与算法完全透明,仅需保证接口契约即可。

Sources: on_policy_runner.py, ppo.py


延伸阅读建议:若需了解 Actor-Critic 在 PPO 中的具体损失计算与梯度更新流程,请阅读 PPO 算法实现与训练流程。若对循环记忆网络、CNN 图像编码器或注意力编码器的策略变体感兴趣,可依次参阅 循环与注意力策略变体CNN 观测编码与特征提取。观测归一化的数学细节与工具函数体系则可参考 观测归一化与网络工具函数