🤖 roboto_origin_03 Wiki
首页 / 项目根 / 推理节点与ONNX策略部署

推理节点(inference_node)是连接强化学习训练成果与真实硬件执行的实时决策中枢。它以固定频率读取IMU与电机反馈,将多源观测向量组装为ONNX模型的输入张量,执行神经网络前向推理后将生成的动作指令经平滑处理后同步下发至底层电机控制回路。本文档从架构设计、线程模型、模型加载、观测堆叠、多策略切换与配置部署六个维度,系统阐述推理节点的实现原理与使用方法。

Sources: inference_node.hpp, inference_node.cpp

推理节点在控制架构中的定位

在完整的ROS2控制系统中,推理节点处于策略层执行层的交界位置:向上订阅手柄或/cmd_vel的运动指令、感知模块的地形高程数据;向下通过RobotInterface聚合IMU驱动与多路CAN/CANFD电机总线。节点内部引入了两条独立的实时线程——inference_thread_负责周期性神经网络推理,control_thread_负责周期性动作下发——而ROS2的MultiThreadedExecutor仅用于处理外部话题订阅与服务回调,避免标准回调模型对实时闭环的干扰。

graph TB
    subgraph ROS2通信层["ROS2 通信层 (Executor 2 threads, SCHED_FIFO 50)"]
        A1[/joy 手柄输入/]
        A2[/cmd_vel 速度指令/]
        A3[/elevation_data 感知输入/]
        A4[服务接口]
        A5[/imu, joint_states, action 状态发布/]
    end

    subgraph 推理节点核心["InferenceNode 核心"]
        B1[inference_thread<br/>周期: dt×decimation<br/>SCHED_FIFO 70]
        B2[control_thread<br/>周期: dt<br/>SCHED_FIFO 70]
        B3[PolicyRuntime 0..N]
        B4[act_ / last_act_ 双缓冲]
    end

    subgraph 硬件抽象层["RobotInterface 硬件抽象层"]
        C1[ThreadPool<br/>多总线并行]
        C2[IMU驱动]
        C3[电机驱动总线0]
        C4[电机驱动总线N]
    end

    A1 -->|cmd_vel_| B1
    A3 -->|perception_obs_| B1
    B1 -->|act_| B4
    B4 -->|last_act_| B2
    B2 -->|apply_action| C1
    C1 --> C3
    C1 --> C4
    B1 --> A5

每个策略的运行时状态被封装为独立的PolicyRuntime实例,包含完整的观测布局、时序堆叠缓冲区、ONNX会话上下文以及可选的MotionLoader。这种设计使得多策略热切换无需重新初始化ONNX Session,切换延迟控制在毫秒级。

Sources: inference_node.hpp, robot_interface.hpp

双线程实时推理架构

推理节点采用三层线程架构:主线程(ROS2 Executor)、推理线程、控制线程。三者的职责、调度策略与周期严格分离,形成清晰的实时层级。

线程/线程池 职责 调度策略 典型周期 优先级
main + Executor 节点初始化、参数加载、ROS回调处理 SCHED_FIFO 事件驱动 50
inference 观测采集、堆叠、ONNX推理、动作生成 SCHED_FIFO dt × decimation 70
control 动作指数平滑、调用RobotInterface::apply_action SCHED_FIFO dt 70
ThreadPool 并行执行多总线电机MIT指令 SCHED_FIFO 随control触发 70

main函数在初始化时通过mlockall(MCL_CURRENT | MCL_FUTURE)将进程内存锁定,防止运行时发生页交换导致的非确定性延迟。若任何实时线程未能成功设置SCHED_FIFO,节点会立即调用rclcpp::shutdown()终止运行,避免在降级调度下产生不可控的电机指令。推理线程与控制线程的协作遵循生产者-消费者模式,通过act_mutex_保护的共享缓冲区act_进行握手。

sequenceDiagram
    participant I as inference_thread
    participant B as act_ / last_act_ 缓冲区
    participant C as control_thread
    participant R as RobotInterface

    loop 每 dt × decimation 周期
        I->>I: 采集观测 / 堆叠历史帧
        I->>I: ONNX Session::Run
        I->>B: 写入 act_ (act_mutex_)
    end

    loop 每 dt 周期
        C->>B: 读取 act_ (act_mutex_)
        C->>C: last_act_ = α·act_ + (1-α)·last_act_
        C->>R: apply_action(last_act_)
        R->>R: 多总线并行电机指令
    end

control线程以dt为周期高频运行(默认4 ms,250 Hz),而inference线程以dt × decimation为周期运行(默认20 ms,50 Hz)。apply_action()内部执行动作指数平滑last_act_[i] = act_alpha_ * act_[i] + (1 - act_alpha_) * last_act_[i]。当act_alpha_ = 1.0时输出直接跟踪策略输出;降低该值可获得更柔顺的动作过渡。这种双线程解耦设计保证了即使策略推理耗时抖动,电机控制频率仍保持稳定。

Sources: inference_node.cpp, inference_node.cpp

ONNX Runtime 环境与模型加载

推理节点采用ONNX Runtime C++ API实现跨平台(x86_64/aarch64)部署。节点构造阶段完成所有模型的加载,避免运行时的动态文件I/O干扰实时性。

环境级线程控制

ONNX Runtime的全局线程行为通过Ort::ThreadingOptions配置。代码暴露intra_threads参数(默认值为1),用于控制算子内的线程并行度。在边缘设备(aarch64)上,通常将其设为1以避免推理内部多线程与外部控制线程发生CPU抢占。

会话优化选项

每个Ort::Session的创建均伴随严格的性能优化配置:

配置项 设置值 作用说明
DisablePerSessionThreads 禁用会话私有线程池,复用全局线程选项中的线程数
EnableCpuMemArena 启用CPU内存池,减少推理过程中的系统分配开销
EnableMemPattern 允许内存复用模式,降低缓存未命中
SetGraphOptimizationLevel ORT_ENABLE_ALL 启用全部图优化(常量折叠、算子融合等)

这些选项的组合确保了单次Session::Run的延迟在亚毫秒级,满足高频控制需求。

模型加载与严格校验

setup_model()是模型加载的核心函数。系统当前仅支持单输入ONNX模型。加载时会读取输入张量的形状,将动态batch维度-1修正为1,随后计算模型期望的总输入元素个数,并与配置给出的input_size进行严格比对。若不一致则立即抛出异常,避免在运行时才发现维度不匹配。

零拷贝Tensor绑定

为了消除推理前后的数据拷贝开销,代码采用ONNX Runtime的零拷贝机制:预先分配input_bufferoutput_buffer,并通过Ort::Value::CreateTensor将其直接包装为Ort::Value。在推理循环中,Session::Run直接读写这些缓冲区,无需额外的memcpy

ctx->memory_info = std::make_unique<Ort::MemoryInfo>(
    Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU));

ctx->input_tensor = std::make_unique<Ort::Value>(Ort::Value::CreateTensor<float>(
    *ctx->memory_info, ctx->input_buffer.data(), ctx->input_buffer.size(),
    ctx->input_shape.data(), ctx->input_shape.size()));

Sources: inference_node.hpp, inference_node.cpp

观测向量构建与堆叠机制

ONNX模型的输入观测并非硬编码,而是通过声明式配置动态组装。系统维护一个编译期固定的观测源注册表,将字符串名称映射到成员函数指针,实现配置驱动的观测采集。

观测源注册表

观测源名称 数据维度 采集来源 典型用途
ang_vel 3 IMU角速度 本体姿态变化率
gravity_b 3 IMU四元数转机体重力向量 姿态估计与跌倒检测
cmd_vel 3 /joy/cmd_vel 用户运动指令
dof_pos 23 电机反馈经usd2urdf映射 关节位置(相对默认角度)
dof_vel 23 电机反馈经usd2urdf映射 关节速度
last_action 23 上一帧策略输出 动作平滑与历史依赖
motion_pos 23 .npz运动文件当前帧 模仿学习参考位姿
motion_vel 23 .npz运动文件当前帧 模仿学习参考速度
interrupt 1 布尔标志位 触发关节中断/接管
perception 可配 /elevation_data等外部话题 地形感知等高层特征

配置中的obs_layouts字符串(如"ang_vel:3, gravity_b:3, cmd_vel:3, dof_pos:23, dof_vel:23, last_action:23")在初始化时被解析为vector<ObsSourceSpec>。运行时,update_obs_segments通过成员函数指针依次调用对应的getter,将数据填充到obs_segments二维数组中,再经flatten_obs_segments拼接为当前帧的扁平化观测。

观测堆叠策略

时序策略通常需要过去多帧的观测作为输入。系统通过frame_stackstack_order两个参数实现可配置的观测滑动窗口,直接决定ONNX输入张量的内存布局:

双轨观测输入

除了参与时序堆叠的obs_layout外,系统还引入了extra_obs_layout机制。附加观测不参与滑窗,每帧直接拼接到堆叠缓冲区末尾。这种"双轨"设计让高频时序特征(如关节状态)与低频大维度特征(如187维地形感知)可以共存于同一输入张量。以inference_attn_enc.yaml为例:基础观测布局obs_layout包含角速度、重力向量、指令、关节状态等共78维,设置frame_stack: 5;而extra_obs_layouts中的perception:187作为附加观测,直接以当前帧的187维地形数据拼接到78×5的堆叠区之后,形成总输入维度577。

Sources: obs_manager.cpp, obs_manager.cpp, inference_node.cpp

多策略运行时与热切换

为支持行走、舞蹈、起身、模仿等多策略热切换,InferenceNode在初始化时为配置文件中声明的每个ONNX模型创建独立的PolicyRuntime实例。每个实例持有完整的推理上下文,包括ModelContext(ONNX Session、输入输出Tensor、内存信息)、观测分段、时序堆叠状态以及可选的MotionLoader

classDiagram
    class InferenceNode {
        +vector~PolicyRuntime~ policies_
        +atomic~bool~ is_running_
        +atomic~bool~ is_interrupt_
        +int active_policy_idx_
        +thread inference_thread_
        +thread control_thread_
        +inference()
        +control()
    }
    class PolicyRuntime {
        +string name
        +string model_path
        +vector~ObsSourceSpec~ obs_layout
        +vector~float~ obs
        +int frame_stack
        +ObsStackOrder stack_order
        +ModelContext ctx
        +shared_ptr~MotionLoader~ motion_loader
        +bool is_first_frame
    }
    class ModelContext {
        +unique_ptr~Ort::Session~ session
        +unique_ptr~Ort::Value~ input_tensor
        +unique_ptr~Ort::Value~ output_tensor
        +vector~float~ input_buffer
        +vector~float~ output_buffer
    }
    InferenceNode --> PolicyRuntime
    PolicyRuntime --> ModelContext

系统支持两种截然不同的多策略切换语义:

中断模式(Interrupt Mode):当任意策略的观测布局中包含interrupt:1时,节点启用中断支持。此时通常只配置单一策略(如policy_interrupt.onnx),interrupt观测源在get_interrupt_obs()中读取is_interrupt_原子布尔。手柄LB键按下时,通过switch_while_paused暂停推理、翻转中断标志、恢复推理。在中断激活状态下,策略的输出动作会被/joint_ref_states话题中的目标关节位置覆盖,实现外部对末端关节的直接接管。

运动策略模式(Motion Policy Mode):当配置中存在带motion_names的策略时,节点启用运动策略支持。典型配置如inference_beyondmimic.yaml,包含一个基策略policy.onnx(无运动文件)与三个运动策略policy_wave.onnxpolicy_dance.onnxpolicy_punch.onnx(各绑定.npz动作文件)。此模式下LB键在"基策略 ↔ 当前选中的运动策略"之间切换;RB键则在后台循环选择下一个运动策略(仅在处于基策略时允许切换,避免运动播放中途跳转)。切换同样通过switch_while_paused完成:暂停推理、切换索引、重置目标策略的运行时状态(观测缓冲区清零、is_first_frame置true)、恢复推理。

switch_while_paused是一个通用lambda包装器,执行任何模式变更前先将is_running_置为false,变更完成后再视情况恢复。这保证了策略切换期间不会有半完成的推理循环读写观测缓冲区。此外,mode_mutex_保护active_policy_idx_与中断/运动标志的并发访问。

Sources: inference_node.hpp, ros_interface.cpp

配置文件与启动方式

推理节点的全部行为由YAML参数文件驱动,通过load_config()在构造阶段一次性加载。仓库提供了五种典型配置,覆盖从标准行走到多技能模仿的不同部署场景:

配置文件 策略数 frame_stack 附加观测 切换模式 核心用途
inference.yaml 1 10 标准行走Locomotion
inference_amp.yaml 1 1 AMP风格单帧策略
inference_attn_enc.yaml 1 5 perception:187 注意力编码器 + 地形感知
inference_interrupt.yaml 1 10 中断模式 外部关节中断/接管
inference_getup.yaml 2 10 / 1 运动策略 行走 + 起身恢复
inference_beyondmimic.yaml 4 10 / 1×3 运动策略 行走 + 多动作模仿

以标准行走配置inference.yaml为例,其核心参数说明如下:

inference_node:
    ros__parameters:
        model_names: ["policy.onnx"]           # ONNX模型文件名列表
        obs_layouts:                           # 每个策略的观测布局字符串
          - "ang_vel:3, gravity_b:3, cmd_vel:3, dof_pos:23, dof_vel:23, last_action:23"
        frame_stacks: [10]                     # 时序堆叠深度
        obs_stack_orders: ["frame_major"]      # 堆叠顺序:frame_major 或 obs_major
        act_alpha: 1.0                         # 动作EMA平滑系数(1.0为不启用)
        joint_num: 23                          # 关节总数
        decimation: 5                          # 推理周期 = dt × decimation
        intra_threads: 1                       # ONNX Runtime算子内线程数
        dt: 0.004                              # 控制周期(秒)
        action_scale: 0.25                     # 策略输出缩放系数
        clip_actions: 100.0                    # 动作输出截断阈值
        usd2urdf: [...]                        # USD到URDF关节索引映射
        joint_default_angle: [...]             # 各关节默认角度(弧度)
        joint_limits: [...]                    # 各关节机械限位 [min, max] 对
        gravity_z_upper: -0.5                  # 跌倒检测阈值(机体系重力向量Z分量)

启动推理节点需先确保已构建工作空间,然后使用Launch文件:

# 使用默认配置启动
ros2 launch inference inference.launch.py

# 或指定其他配置文件(需修改launch文件中的config路径)
ros2 launch inference inference.launch.py

inference.launch.py默认加载config/inference.yaml,开发者可通过修改configs列表指向inference_beyondmimic.yaml等其他配置,实现不同策略组合的部署。

Sources: ros_interface.cpp, inference.launch.py, config/

安全监控与故障处理

推理节点内置了多层硬性安全检查,所有检查均运行在inference线程中,确保在策略输出到达电机之前拦截危险状态。

1. 跌倒检测与紧急停机

get_gravity_b_obs()通过IMU四元数计算世界重力向量在机体坐标系下的投影。当gravity_b.z()超过gravity_z_upper_阈值(默认-0.5,表示机器人严重倾斜或已摔倒)时,节点立即输出RCLCPP_FATAL并调用rclcpp::shutdown(),终止全部线程与电机通信。

2. 关节限位保护

get_dof_pos_obs()在将关节位置写入观测向量之前,逐一检查各关节是否超出joint_limits_定义的机械限位。一旦检测到超限,节点同样触发致命错误并立即停机,防止机械结构因过行程而损坏。

3. 推理超时告警

inference线程在每次循环结束时精确计算执行耗时。若单次循环耗时超过设定周期,节点输出Inference loop overran警告,提示开发者当前模型或观测预处理存在性能瓶颈,可能需要降低模型复杂度、减少堆叠深度或提升硬件算力。

4. 实时内存锁定

main()函数在节点启动前调用mlockall(MCL_CURRENT | MCL_FUTURE),将进程的全部虚拟地址空间锁定在物理内存中,防止运行时因内存页交换引入非确定性延迟。若锁定失败,节点仅输出警告但仍继续运行,因为部分系统环境可能不支持该调用。

安全机制 触发条件 响应行为 配置参数
跌倒检测 机体系重力Z分量 > 阈值 rclcpp::shutdown() gravity_z_upper
关节限位 关节位置超出机械限位 rclcpp::shutdown() joint_limits
推理超时 单周期执行耗时 > 设定周期 RCLCPP_WARN告警 dt, decimation
调度失败 实时线程无法设置SCHED_FIFO rclcpp::shutdown()

Sources: obs_manager.cpp, obs_manager.cpp, inference_node.cpp

与前后环节的衔接

推理节点并非孤立运行,其观测上游依赖IMU驱动与传感器集成提供的姿态数据,以及电机驱动与CAN总线通信提供的关节反馈;其动作下游通过RobotInterface将MIT控制指令分发至多路CAN总线。若需深入理解RobotInterface内部的电机总线并行通信、IMU坐标变换和闭链解耦计算,请参阅后续文档。

掌握推理节点的架构与配置后,下一步可以了解手柄控制与服务接口中如何通过手柄按键控制推理启停与策略切换,或参考Python SDK与二次开发在高层实现自定义行为编排。若需调整控制频率、平滑系数或配置新的感知观测话题,可直接修改YAML配置文件并重新编译部署。