在强化学习与策略蒸馏的训练循环中,RolloutStorage 扮演着环境交互经验的中枢仓库角色。它负责在收集阶段暂存每个时间步的观测、动作、奖励与终止信号,并在学习阶段将这些经验组织成算法所需的批次格式。与通用回放缓冲区不同,此处的存储是严格按 on-policy 范式设计的:数据在每一次策略更新前被顺序写入,更新后即被清空,不保留历史策略产生的旧经验。本章将深入解析 RolloutStorage 的存储结构、内部 Transition 对象的生命周期,以及面向不同网络架构的批次生成机制。
flowchart TD
subgraph 收集阶段
A[算法 act] -->|填充 Transition| B[环境 step]
B -->|返回 rewards/dones| C[process_env_step]
C -->|add_transition| D[RolloutStorage]
end
subgraph 后处理
D --> E[compute_returns<br/>GAE 计算]
end
subgraph 学习阶段
E --> F[mini_batch_generator<br/>或 recurrent_mini_batch_generator]
F --> G[PPO update]
end
RolloutStorage 核心架构
RolloutStorage 以 [num_transitions_per_env, num_envs, ...] 为主要张量排布维度,其中第一维对应每个环境在时间轴上推进的步数,第二维对应并行环境的批量编号。这种设计天然适配向量化环境的高并发特性:所有环境在同一时刻的交互数据在内存中连续对齐,便于后续做 flatten 与随机洗牌操作。初始化时,存储会根据传入的 training_type 动态决定需要预分配的字段——强化学习模式需要值函数、对数概率、回报与优势估计;蒸馏模式则需要记录教师模型的特权动作。
| 字段分组 | 张量名称 | 形状示意 | 适用训练类型 |
|---|---|---|---|
| 核心数据 | observations |
[T, N, ...] (TensorDict) |
全部 |
| 核心数据 | actions |
[T, N, action_dim] |
全部 |
| 核心数据 | rewards |
[T, N, 1] |
全部 |
| 核心数据 | dones |
[T, N, 1] (byte) |
全部 |
| RL 专属 | values |
[T, N, 1] |
rl |
| RL 专属 | actions_log_prob |
[T, N, 1] |
rl |
| RL 专属 | mu / sigma |
[T, N, action_dim] |
rl |
| RL 专属 | returns / advantages |
[T, N, 1] |
rl |
| 蒸馏专属 | privileged_actions |
[T, N, action_dim] |
distillation |
| RNN 专属 | saved_hidden_state_a / saved_hidden_state_c |
列表,每层 [T, N, hidden_dim] |
rl / distillation |
这里 T 表示 num_transitions_per_env,N 表示 num_envs。对于循环网络,隐藏状态以列表形式存储,列表长度等于网络层数,每个元素对应一层在全部时间步与环境上的状态快照。这种分层存储使得在后续按轨迹切分时,可以准确提取每个轨迹起始时刻对应的 RNN 初始化状态。
Sources: rollout_storage.py
Transition:单步数据容器
RolloutStorage.Transition 是一个轻量的临时数据包,用于桥接策略前向传播、环境步进与存储写入三者之间的异步数据流。它的生命周期严格遵循“填充 → 完成 → 写入 → 清空”的节拍,保证每一步的数据在送入 RolloutStorage 前不会相互覆盖。Transition 包含的字段与 RolloutStorage 的核心列一一对应,但额外携带了循环网络的隐藏状态,因为 RNN 的隐状态需要在 act() 执行时捕获,而不能在 process_env_step() 阶段重新计算。
在 PPO.act() 中,算法会调用策略网络完成一次前向传播,将结果赋值给 Transition:动作由 policy.act() 采样得到,值估计由 policy.evaluate() 给出,动作对数概率、分布均值与标准差也一并记录。若策略是循环网络,还会通过 policy.get_hidden_states() 保存 actor 与 critic 在当前步的隐状态。随后 act() 返回动作供环境执行。这一步的关键在于预记录观测:transition.observations 保存的是执行动作之前的状态,因为后续计算优势函数与回报时需要匹配 (s_t, a_t) 的时序关系。
Sources: rollout_storage.py, ppo.py
数据写入与后处理流程
当环境执行完 step() 返回 rewards 与 dones 后,process_env_step() 负责将 Transition 最终封包并推入 RolloutStorage。此阶段还涉及若干重要的后处理逻辑:首先,观测归一化器会根据最新观测更新统计量;其次,若启用了 RND 探索奖励,会将内禀奖励叠加到外禀奖励上;最后,对于因达到最大步长而触发的 time_outs,算法会采用 bootstrap 技巧——将当前状态的值函数估计按折扣率加回到奖励中,避免人为截断造成的值估计偏差。完成上述修正后,storage.add_transition() 将 Transition 按当前 step 索引写入对应槽位,并将 step 计数器递增。
当一轮 Rollout 的全部步数收集完毕后,compute_returns() 方法会从最后一个时间步反向遍历,利用 GAE(Generalized Advantage Estimation)计算每个时间步的优势估计与折扣回报。具体而言,它通过 delta = r_t + gamma * V(s_{t+1}) * (1 - done) - V(s_t) 计算 TD 残差,再以 advantage = delta + gamma * lambda * (1 - done) * advantage 递归累积优势。最终 returns = advantages + values。若未启用按 mini-batch 归一化,则全局优势会在存储层被标准化为零均值、单位方差,以稳定策略梯度更新。
Sources: ppo.py, rollout_storage.py, ppo.py
批次生成器:从存储到梯度
RolloutStorage 提供了三种生成器接口,分别服务于不同的网络架构与训练范式。生成器采用惰性求值(Generator)模式,在 PPO 的 update() 循环中逐批次产出数据,避免一次性加载全部数据造成显存峰值。
前馈网络生成器:mini_batch_generator
对于不含循环连接的策略(如标准 MLP 或 CNN),mini_batch_generator 将全部时间步与环境的数据 flatten 为 [T*N, ...] 的一维批次,然后随机打乱索引,等分为 num_mini_batches 个子批次。在 num_epochs 轮迭代中,每一轮都会重新打乱索引,确保同一条经验在不同 epoch 中可能落入不同的 mini-batch。返回的每个批次包含观测、动作、目标值、优势、回报、旧对数概率、旧分布参数均值与标准差,以及预留的隐藏状态和掩码位(前馈网络下为 None)。
Sources: rollout_storage.py
循环网络生成器:recurrent_mini_batch_generator
循环网络无法像前馈网络那样随意打乱时间顺序,因为隐藏状态具有时序依赖性。recurrent_mini_batch_generator 首先调用 split_and_pad_trajectories,依据 dones 将 [T, N, ...] 的数据切割为若干条独立轨迹,并用零填充至最长轨迹长度,输出形状变为 [max_traj_len, num_trajectories, ...]。随后,生成器按轨迹维度将环境分组(每组大小为 num_envs // num_mini_batches),并在组内提取对应的轨迹块与掩码。隐藏状态的提取尤为复杂:原始存储形态为 [T, num_layers, N, hidden_dim],需要先 permute 为 [N, T, num_layers, hidden_dim],再根据 dones 的边界提取每条轨迹起始时刻的隐状态,最终调整为 [num_layers, batch, hidden_dim] 以匹配 RNN 的输入约定。同时,代码兼容 GRU(单层隐藏状态)与 LSTM(元组形式)两种格式。
Sources: rollout_storage.py, utils.py
蒸馏生成器:generator
在策略蒸馏场景下,数据不需要计算回报或优势,因此生成器逻辑最为简洁。它按顺序遍历每个时间步,产出 (observations, actions, privileged_actions, dones) 四元组,供学生网络进行行为克隆(Behavior Cloning)。这里的 privileged_actions 是教师网络在完整特权观测下产生的动作,学生网络仅访问非特权观测,目标是最小化两者动作差异。
Sources: rollout_storage.py, distillation.py
| 生成器 | 适用网络 | 数据组织方式 | 打乱策略 | 返回隐藏状态 |
|---|---|---|---|---|
mini_batch_generator |
前馈 (MLP/CNN/Attention) | Flatten 为 [T*N, ...] |
全局随机打乱 | None |
recurrent_mini_batch_generator |
循环 (GRU/LSTM) | 按 dones 切分轨迹并填充 |
按环境组切分轨迹 | 提取轨迹起始隐状态 |
generator |
蒸馏 | 按时间步顺序 [T, N, ...] |
无 | 不返回 |
CircularBuffer:环形历史缓冲区
除了 RolloutStorage 之外,storage 模块还包含一个独立的 CircularBuffer 组件,用于在多环境设置下维护一条固定长度的历史数据队列。与 RolloutStorage 的单次使用-清空语义不同,CircularBuffer 支持持续追加(append)与 LIFO 式检索(__getitem__),并允许对特定环境批次执行重置。其内部以环形数组实现,当数据写满后会覆盖最旧条目。在策略网络架构中,该组件可用于为注意力编码器或时序卷积提供历史观测帧。CircularBuffer 同样提供了 mini_batch_generator,支持从历史数据中按随机索引抽取固定长度的序列片段,适用于需要历史上下文采样的辅助训练任务。
Sources: circular_buffer.py
与训练运行器的交互
OnPolicyRunner 在 _construct_algorithm() 阶段实例化 RolloutStorage,并将其注入 PPO 算法。在 learn() 的主循环中,Runner 不负责直接操作存储,而是通过调用 alg.act()、alg.process_env_step()、alg.compute_returns() 与 alg.update() 四个标准接口,由算法层间接驱动存储的写入与读取。这种分层设计解耦了环境调度与数据存储细节,使得同样的 RolloutStorage 可以在 PPO、AMP、Distillation 等不同算法中被复用,只需在初始化时指定对应的 training_type 即可。
Sources: on_policy_runner.py, distillation_runner.py
延伸阅读
理解 Rollout 数据的存储结构后,下一步可以深入探索这些批次数据如何被算法层消费以计算策略梯度与值函数损失,详见 PPO 算法实现与训练流程。若你正在使用循环或注意力策略,建议结合 循环与注意力策略变体 来理解 recurrent_mini_batch_generator 中隐藏状态与掩码的传递语义。对于经验采样中涉及的 GAE 与优势归一化细节,也可以参考 经验采样与小批量生成。