🤖 roboto_origin_03 Wiki
首页 / 部署 / 接入自定义强化学习策略

本文面向具备强化学习训练经验的高级开发者,系统阐述如何将自定义 ONNX 策略模型接入 Atom01 机器人的推理节点。内容涵盖模型接口契约、观测布局配置、动作后处理链路、多策略运行时切换,以及基于 Python SDK 的独立部署方案。掌握这些机制后,你可以在不修改推理节点源码的前提下,仅通过配置调整与模型替换完成策略迭代。

Sources: inference_node.hpp

推理节点架构与策略加载机制

推理节点(InferenceNode)内部采用**策略运行时(PolicyRuntime)**抽象来管理每个 ONNX 模型。每个 PolicyRuntime 实例封装了模型会话、观测内存、帧堆叠状态以及可选的运动序列加载器。节点在构造函数中按配置顺序逐一初始化这些运行时,随后启动 inferencecontrol 两个实时线程,分别负责策略前向推理与电机指令下发。

推理线程以 dt * decimation 为周期执行一次完整的前向传播:先根据当前激活策略的观测布局(obs_layout)从传感器和内部状态中提取各观测段,再按帧堆叠规则拼合成 ONNX 输入张量,调用 Ort::Session::Run 得到原始动作输出,最后经过裁剪、坐标重映射、缩放与平滑后写入电机目标队列。控制线程则以固定频率 dt 从队列中读取并执行 apply_action,实现推理与控制的时序解耦。

graph TD
    A[配置文件 YAML] -->|load_config| B[PolicyRuntime 数组]
    B -->|setup_model| C[ONNX Session 与 MemoryInfo]
    D[传感器数据] -->|get_xxx_obs| E[Obs Segments]
    E -->|flatten + stack| F[ONNX Input Tensor]
    F -->|Session::Run| G[ONNX Output Tensor]
    G -->|clip + usd2urdf + scale + default_angle| H[动作缓冲区 act_]
    H -->|act_alpha 平滑| I[电机接口 RobotInterface]
    I --> J[CAN/CANFD 总线]

Sources: inference_node.hpp, inference_node.cpp

ONNX 模型接口契约

推理节点对 ONNX 模型有严格的单输入、单输出要求。setup_model 在加载时会验证输入张量的总元素个数是否等于配置计算出的尺寸,不匹配将直接抛出运行时异常。

输入规范:模型必须恰好有 1 个输入节点,数据类型为 float32。输入尺寸由以下公式决定:

input_size = obs_num × frame_stack + extra_obs_num

其中 obs_numobs_layout 中各观测源尺寸之和,extra_obs_numextra_obs_layout 中各源尺寸之和。若 frame_stack = 1 且不存在 extra_obs_layout,则输入尺寸即等于单次观测维度。

输出规范:模型输出节点的数据类型同样为 float32,输出尺寸(元素总数)必须等于 joint_num(默认 23),因为推理节点会将每个输出元素直接映射为一个关节目标值。

模型文件需放置在 src/inference/models/ 目录下,并在配置文件的 model_names 数组中以文件名形式引用。CMake 通过 -DROOT_DIR 宏将该目录作为运行时根路径,因此模型路径在部署后会自动解析为安装目录下的对应位置。

Sources: inference_node.cpp

观测系统与配置语法

观测系统是接入自定义策略时最需要精确对齐的环节。推理节点通过声明式字符串描述观测布局,将传感器原始数据转换为模型输入。

观测源注册表

系统内置了 10 种观测源,每种通过 ObsSourceDefinition 静态注册,并在 parse_obs_layout 时按名称匹配到对应的成员函数:

观测源名称 尺寸含义 数据来源 缩放系数
ang_vel 3 IMU 角速度(经外参旋转到机体坐标系) obs_scales_ang_vel
gravity_b 3 世界重力向量经机体姿态旋转后的投影 obs_scales_gravity_b
cmd_vel 3 手柄或 /cmd_vel 指令(x, y, yaw) obs_scales_lin_vel / obs_scales_ang_vel
dof_pos joint_num 关节位置相对默认角度的偏移 obs_scales_dof_pos
dof_vel joint_num 关节速度 obs_scales_dof_vel
last_action joint_num 上一帧策略原始输出(clip 后、scale 前)
interrupt 1 是否处于中断模式(0 或 1)
perception 自定义 外部感知节点发布的 Float32MultiArray
motion_pos joint_num 当前帧运动序列的关节位置
motion_vel joint_num 当前帧运动序列的关节速度

Sources: obs_manager.cpp, obs_manager.cpp

布局配置格式

obs_layoutsextra_obs_layouts 均采用逗号分隔的 name:size 对格式,例如:

obs_layouts:
  - "ang_vel:3, gravity_b:3, cmd_vel:3, dof_pos:23, dof_vel:23, last_action:23"
extra_obs_layouts:
  - "perception:187"

两者关键区别在于:obs_layouts 中的字段会参与帧堆叠(frame stacking),而 extra_obs_layouts 中的字段不会堆叠,直接拼贴在堆叠张量之后。这一设计专为时序无关的高维感知特征(如地形高程图、视觉编码)而设,使其不受历史帧重复填充的维度膨胀影响。

Sources: ros_interface.cpp, inference_node.cpp

帧堆叠与排列顺序

frame_stack 定义历史帧缓存长度,obs_stack_order 决定堆叠后的内存布局:

首次运行时,推理节点会将当前观测复制填满整个堆叠缓冲区,避免冷启动时的零值污染。

Sources: inference_node.cpp

动作后处理与坐标映射

策略模型的原始输出不会直接发送到电机,而是经过一条固定的后处理链。理解这条链路对训练时的动作空间定义至关重要。

处理顺序如下:

  1. 裁剪(Clip)output = clamp(raw_output, -clip_actions, clip_actions)
  2. USD 到 URDF 映射(usd2urdf)act[usd2urdf[i]] = output[i]。训练环境中的关节顺序与真实机器人固件中的电机顺序可能不同,该映射完成索引重排。
  3. 缩放与偏置act[j] = act[j] * action_scale + joint_default_angle[j]。模型输出的是围绕默认姿态的偏移量,而非绝对关节角。
  4. 动作平滑last_act = act_alpha * act + (1 - act_alpha) * last_actact_alpha 控制策略输出突变时的惯性阻尼,1.0 表示无平滑。

训练侧建议:在仿真环境中训练时,确保动作输出已经过与上述链路可逆的归一化。通常推荐让网络输出单位尺度动作,在推理侧通过 action_scalejoint_default_angle 恢复物理量,这样同一模型可以仅通过修改配置适配不同幅度的动作空间。

Sources: inference_node.cpp, inference_node.cpp

接入自定义策略的完整步骤

以下流程展示从训练完成到机器人运行的最小闭环。

flowchart TD
    A[训练得到 PyTorch/JAX 策略] --> B[导出为 ONNX<br/>单输入 float32<br/>输出维度 = joint_num]
    C[定义观测布局字符串<br/>与训练时 obs 一致] --> D[编写 YAML 配置文件]
    B --> E[复制 .onnx 到 models/]
    D --> F[复制 .yaml 到 config/]
    E --> G[修改 launch 文件<br/>指向新配置]
    F --> G
    G --> H[编译并部署]
    H --> I[启动推理节点]
    I --> J[按手柄 B 键开始推理]

步骤 1:导出 ONNX 模型

确保导出的模型满足以下约束:

步骤 2:编写配置

src/inference/config/ 下新建 YAML,至少包含以下字段:

inference_node:
    ros__parameters:
        model_names: ["my_policy.onnx"]
        obs_layouts:
          - "ang_vel:3, gravity_b:3, cmd_vel:3, dof_pos:23, dof_vel:23, last_action:23"
        frame_stacks: [10]
        obs_stack_orders: ["frame_major"]
        joint_num: 23
        decimation: 5
        dt: 0.004
        action_scale: 0.25
        clip_actions: 100.0
        usd2urdf: [0, 6, 12, 1, 7, 13, 18, 2, 8, 14, 19, 3, 9, 15, 20, 4, 10, 16, 21, 5, 11, 17, 22]
        joint_default_angle: [...]

若策略需要感知特征作为额外输入,添加 extra_obs_layouts 并确保有外部节点向 perception_obs_topic 发布 Float32MultiArray

步骤 3:更新启动文件

修改 launch/inference.launch.py,将 configs 列表指向你的新 YAML 文件:

configs = [
    os.path.join(
        get_package_share_directory("inference"),
        "config",
        "my_policy_config.yaml",
    ),
]

步骤 4:编译与运行

执行标准 colcon 构建后,通过 launch 文件启动。节点初始化时会打印各策略的模型路径、观测布局、是否支持中断/运动策略等诊断信息,可用于快速核对配置。

Sources: inference.launch.py, ros_interface.cpp

高级模式:多策略与行为切换

推理节点原生支持在单个节点内加载多个策略,并通过手柄按键切换。此能力适用于需要基础运动策略与特殊行为策略(如舞蹈、击打、起身)共存的应用场景。

Motion Policy 模式

当某个策略配置了非空的 motion_names 条目时,该策略被归类为运动策略。motion_loader 会加载对应的 .npz 文件,提供 motion_posmotion_vel 观测源。观测布局中必须包含这两个源,以便模型根据当前运动帧的参考轨迹生成残差动作。

运行时按手柄 LB 键可在基础策略与运动策略之间切换,按 RB 键可在多个运动策略之间轮询选择。切换时推理线程会先暂停,重置目标策略的运行时状态(清空观测缓冲与堆叠历史),再恢复运行,避免状态污染。

Sources: ros_interface.cpp, inference_node.hpp

Interrupt 模式

若观测布局中包含 interrupt:1,则系统启用中断能力。通过手柄 LB 键可以切换 is_interrupt_ 标志。当标志为真时,推理线程在生成最终 act_ 数组时,会将后 10 个关节(由 interrupt_action_ 数组决定)的值覆盖为外部来源(如另一上位机通过 /joint_ref_states 发布的目标位置),其余关节仍由策略输出控制。这允许在保持主体平衡的同时,由外部模块接管末端关节。

Sources: inference_node.cpp, ros_interface.cpp

Perception 观测

对于需要地形感知或视觉输入的策略,将感知特征放在 extra_obs_layouts 中而非 obs_layouts 中,可免除对感知数据进行历史帧堆叠的开销。推理节点通过 ROS2 Topic 订阅 std_msgs::Float32MultiArray,在每次推理循环前将最新数据拷贝到 perception_obs_buffer_。Topic 名称由 perception_obs_topic 参数指定。

Sources: ros_interface.cpp, inference_node.cpp

通过 Python SDK 独立运行策略

如果你希望完全绕过 C++ 推理节点,在 Python 中自行管理策略前向与观测构造,可以使用 robot_py 模块。该模块通过 pybind11 暴露了 RobotInterface,支持直接读取关节状态、IMU 姿态并下发动作。

import robot_py
import numpy as np
import onnxruntime as ort

robot = robot_py.RobotInterface("config/robot.yaml")
robot.init_motors()

session = ort.InferenceSession("models/my_policy.onnx")
input_name = session.get_inputs()[0].name

while True:
    q = robot.get_joint_q()
    vel = robot.get_joint_vel()
    quat = robot.get_quat()
    ang_vel = robot.get_ang_vel()
    
    obs = construct_obs(q, vel, quat, ang_vel, ...)  # 自行构造观测
    action = session.run(None, {input_name: obs.astype(np.float32)})[0]
    robot.apply_action(action.tolist())

这种方式的灵活性最高,允许你在 Python 侧实现任意复杂的观测堆叠、策略切换或训练时未包含的在线适配逻辑;但相应的,实时性完全依赖 Python GIL 与调度,难以达到 C++ 推理节点的硬实时保障。建议仅用于算法验证或低频交互场景。

Sources: pybind_module.cpp

参数速查表

参数名 类型 说明
model_names string[] ONNX 模型文件名列表,按索引与后续数组对齐
motion_names string[] 运动序列文件名列表,空字符串表示无运动
obs_layouts string[] 主观测布局,格式 name:size,name:size
extra_obs_layouts string[] 额外观测布局,不参与帧堆叠
frame_stacks int[] 每策略的历史帧缓存长度
obs_stack_orders string[] frame_majorobs_major
act_alpha float 动作平滑系数,1.0 为无平滑
action_scale float 模型输出到关节偏移的缩放系数
clip_actions float 模型输出裁剪边界
clip_observations float 观测输入裁剪边界
usd2urdf int[] 训练关节顺序到真实电机顺序的映射
joint_default_angle double[] 各关节默认角度(rad)
dt float 控制周期(秒)
decimation int 推理周期 = dt × decimation
intra_threads int ONNX Runtime 内部线程数,-1 为默认

Sources: ros_interface.cpp

参考与下一步

完成自定义策略接入后,你可能需要进一步调试动作映射或优化推理延迟。相关主题可参阅: