🤖 roboto_origin_03 Wiki
首页 / 训练 / Actor-Critic 网络架构详解

Actor-Critic 架构是 RSL-RL 训练管线中所有策略网络的统一抽象与唯一入口。无论是处理纯向量观测的基础策略,还是需要时序记忆、图像编码或注意力感知的复杂策略,它们都共享同一套接口契约:向 PPO 算法提供动作采样、对数概率、熵以及状态价值估计。深入理解这一架构的内部组成、噪声建模方式以及各变体的扩展模式,是定制足式机器人策略网络、调试训练不稳定现象的前提。本文将系统拆解基类 ActorCritic 的网络组成、观测预处理链路、动作分布构造逻辑,并梳理循环、卷积与注意力三类变体的设计差异,最终落脚到它们与 PPO 训练循环的交互方式。

Sources: actor_critic.py

整体架构概览

在 RSL-RL 中,策略模块并非孤立存在,而是作为 PPO 类的 policy 属性被持有。PPO 通过调用策略的 actevaluate 收集交互数据,再通过 update 阶段重新计算分布统计量以构造替代损失。下图展示了基类与主要变体在类级别的关系,以及它们与算法、存储组件的协作边界。

classDiagram
    class ActorCritic {
        +actor: MLP
        +critic: MLP
        +actor_obs_normalizer: EmpiricalNormalization|Identity
        +critic_obs_normalizer: EmpiricalNormalization|Identity
        +distribution: Normal
        +act(obs) Tensor
        +act_inference(obs) Tensor
        +evaluate(obs) Tensor
        +update_normalization(obs)
    }
    class ActorCriticCNN {
        +actor_cnns: ModuleDict~CNN~
        +critic_cnns: ModuleDict~CNN~
    }
    class ActorCriticRecurrent {
        +memory_a: Memory
        +memory_c: Memory
        +is_recurrent = True
    }
    class ActorCriticAttnEnc {
        +encoder: AttentionEncoder
        +actor_obs_encoder: MLP
        +estimator: MLP
    }
    class PPO {
        +policy: ActorCritic|...
        +storage: RolloutStorage
        +act(obs)
        +update()
    }
    ActorCritic <|-- ActorCriticCNN
    ActorCriticRecurrent --|> ActorCritic : 接口兼容
    ActorCriticAttnEnc --|> ActorCritic : 接口兼容
    PPO --> ActorCritic : 持有并调用

从图中可以看出,ActorCriticCNN 直接继承基类并扩展了 CNN 编码链路;而 ActorCriticRecurrentActorCriticAttnEnc 虽然未从基类继承,但实现了完全一致的公共接口(actevaluateentropyget_actions_log_prob 等),从而保证 PPO 算法可以无差别地替换策略实例。

Sources: ppo.py

基类 ActorCritic 的组成与数据流

基类 ActorCritic 是最简洁的实现,面向一维向量观测场景。它的核心设计哲学是显式分离 Actor 与 Critic 的观测预处理链路,同时保持噪声模型与分布构造的高度可配置性。以下将沿数据流向拆解其五个关键环节。

观测分组与动态拼接

策略不直接感知环境输出的原始 TensorDict,而是通过 obs_groups 字典决定哪些观测字段进入 Actor、哪些进入 Critic。obs_groups 包含 "policy""critic" 两个键,每个键映射到一个字段名列表;初始化时,模块会遍历对应列表,将各字段在最后一维拼接,从而动态推导出 num_actor_obsnum_critic_obs。这种设计允许 Actor 与 Critic 在不修改网络代码的情况下消费不同的观测子集。

Sources: actor_critic.py

Actor 与 Critic 的 MLP 主干

Actor 与 Critic 均使用 rsl_rl.networks.MLP 构建。默认隐藏层维度为 [256, 256, 256],默认激活函数为 elu。Actor 的输出维度通常为 num_actions;若启用状态相关标准差(state_dependent_std=True),则输出维度变为 [2, num_actions],其中第二路用于建模标准差。Critic 始终输出单一标量值。MLP 内部通过 nn.Sequential 组织,支持在最后一层后追加 nn.Unflatten 以处理元组形式的输出维度,同时提供 init_weights 方法供外部进行正交初始化。

Sources: actor_critic.py, mlp.py

观测归一化层

为了缓解早期训练中的数值不稳定,Actor 与 Critic 各自可选择性地接入 EmpiricalNormalization。该模块基于整个批次维护运行均值与方差(而非逐环境独立统计),并在 forward 中执行 (x - mean) / (std + eps)。若配置中关闭归一化,则替换为 torch.nn.Identity(),保证计算图零开销。训练过程中,update_normalization 方法会根据新采集的观测更新运行统计量,通常在 PPO.process_env_step 中被调用。

Sources: actor_critic.py, normalization.py

动作噪声模型与分布构造

动作分布采用对角高斯 torch.distributions.Normal,其标准差支持两种参数化模式:独立可学习参数(state_dependent_std=False)或网络输出(state_dependent_std=True)。对于独立参数模式,noise_std_type 可选择 scalar(直接存储 std)或 log(存储 log_std 并通过指数映射得到正数标准差)。对于状态相关模式,Actor 最后一层同时输出均值与标准差,初始化时对标准差对应分支的权重置零、偏置置为初始噪声,以保证训练初始阶段策略接近确定性。_update_distribution 方法负责将均值与标准差组装为 Normal 实例;为了加速采样,代码显式关闭了分布的参数校验。

Sources: actor_critic.py

公共接口:采样、推理与评估

act 方法完成 Actor 全链路:拼接观测 → 归一化 → 更新分布 → 采样动作。act_inference 则跳过采样,直接返回分布的均值,用于部署或确定性评估。evaluate 方法走 Critic 链路,返回状态价值。此外,get_actions_log_prob 在 PPO 的损失计算阶段被调用,用于求取动作在当前分布下的对数概率;entropy 属性则返回分布熵,作为探索奖励项。所有这些接口均假设观测输入为 TensorDict,保持了与环境封装的一致性。

Sources: actor_critic.py

核心网络组件:MLP 与经验归一化

基类的能力完全建立在两个可复用组件之上:MLPEmpiricalNormalizationMLP 继承自 nn.Sequential,除了常规的线性层与激活函数堆叠外,还支持用 -1 作为隐藏层维度占位符以自动继承输入维度,以及通过 reduce 计算元组输出维度的总大小后接 nn.UnflattenEmpiricalNormalization 则采用增量式 Welford 风格更新:维护 count_mean_var,在 update 中根据新批次大小计算加权率并修正方差,无需保存全部历史数据。这两个组件的独立性意味着它们不仅服务于 Actor-Critic,也可用于 RND、学生-教师蒸馏等其他模块。

Sources: mlp.py, normalization.py

变体架构:循环、卷积与注意力

面对不同观测模态与任务需求,RSL-RL 在基类之上提供了三种官方变体。它们在保持接口一致的前提下,替换了观测编码器或增加了时序记忆。

ActorCriticRecurrent:时序记忆增强

ActorCriticRecurrent 在归一化层与 MLP 之间插入了 Memory 模块。该模块内部封装 nn.GRUnn.LSTM,为 Actor 与 Critic 各自维护独立的隐状态。推理时,隐状态随时间步自回归更新;训练时,隐状态由 RolloutStorage 保存并按轨迹分块后通过 split_and_pad_trajectoriesunpad_trajectories 处理,确保跨 episode 的梯度不泄漏。reset(dones) 会在环境终止时将对应环境的隐状态清零,从而避免价值估计跨 episode 污染。is_recurrent 标志置为 True,提示 PPO 使用 recurrent_mini_batch_generator

Sources: actor_critic_recurrent.py, memory.py

ActorCriticCNN:图像与混合观测编码

ActorCriticCNN 直接继承自 ActorCritic,扩展了对二维图像观测(形状为 B, C, H, W)的支持。初始化阶段,模块通过 obs_groups 中的字段形状自动区分 1D 与 2D 观测,并为每个 2D 字段创建独立的 CNN 编码器(存储于 nn.ModuleDict)。CNN 的输出经展平后与 1D 观测拼接,再送入父类已定义的 MLP。Actor 与 Critic 各自拥有独立的 CNN 字典,因此策略网络与价值网络可以学习不同的视觉表征。_update_distributionevaluate 被重写以在 MLP 前注入 CNN 特征。

Sources: actor_critic_cnn.py, cnn.py

ActorCriticAttnEnc:注意力感知编码

ActorCriticAttnEnc 面向地图类或结构化感知观测,引入了 AttentionEncoder 将空间信息压缩为嵌入向量。除了注意力主干外,该变体还预留了两个可选子模块:一是 obs_encoder,用于将原始观测压缩为低维潜变量;二是 estimator,用于辅助预测 Critic 关注的特定状态量。actevaluate 的拼接逻辑因此更复杂:需要将历史观测切片、潜变量与注意力嵌入按维度拼接后送入 MLP。由于注意力编码器的设计与足式机器人感知密切相关,其详细机制将在后续专题中展开。

Sources: actor_critic_attn_enc.py

变体能力对比

变体 核心扩展 观测类型 是否继承基类 典型场景
ActorCritic 1D 向量 本体感知为主的足式控制
ActorCriticRecurrent GRU/LSTM 记忆 1D 向量 否(接口兼容) 需要历史依赖的导航任务
ActorCriticCNN CNN 编码器 1D + 2D 混合 视觉输入、高度图、摄像头
ActorCriticAttnEnc AttentionEncoder 1D + 地图/感知 否(接口兼容) 结构化空间感知、代价地图

Sources: modules/init.py

与 PPO 算法的交互方式

策略网络的生命周期由 PPO 类驱动。在数据收集阶段,PPO.act 调用 policy.act(obs) 得到采样动作,同时调用 policy.evaluate(obs) 得到价值估计,并将动作对数概率、分布均值与标准差存入 RolloutStorage.Transition。在策略更新阶段,PPO.update 通过 mini-batch 生成器遍历历史数据,对每个 batch 重新执行 policy.act(obs_batch) 以构建当前参数下的分布,进而通过 get_actions_log_prob(actions_batch)entropy 计算替代损失和熵正则项;价值损失则由 evaluate 的输出与回报之差构成。若策略为循环网络,生成器还会提供 maskshidden_states_batch,供 Memory 模块在 batch 模式下正确解包轨迹。这种“重计算”设计避免了在存储中保存中间激活值,显著降低了显存占用。

Sources: ppo.py, ppo.py

配置参数与扩展建议

下表汇总了基类 ActorCritic 中直接影响网络容量与探索行为的关键配置项。变体类在此基础上追加各自的专属参数(如 rnn_hidden_dimactor_cnn_cfgembedding_dim 等)。

参数 默认值 作用说明
actor_hidden_dims [256, 256, 256] Actor MLP 各隐藏层维度
critic_hidden_dims [256, 256, 256] Critic MLP 各隐藏层维度
activation "elu" 隐藏层激活函数(支持 elureluseluswish 等)
actor_obs_normalization False 是否为 Actor 观测启用经验归一化
critic_obs_normalization False 是否为 Critic 观测启用经验归一化
state_dependent_std False 标准差是否由网络输出
noise_std_type "scalar" 标准差参数化方式:scalarlog
init_noise_std 1.0 初始探索噪声水平

若需自定义新的策略架构,推荐遵循以下扩展契约:实现 actact_inferenceevaluateget_actions_log_probentropyupdate_normalization 方法;若涉及时序记忆,额外实现 is_recurrent = Truereset(dones)get_hidden_states()。通过保持接口一致,新策略可直接接入现有的 PPOOnPolicyRunner 与存储模块,无需修改算法代码。

Sources: actor_critic.py, actor_critic_recurrent.py

下一步阅读

掌握 Actor-Critic 的网络结构后,建议继续深入以下主题以打通训练闭环: