本页深入解析推理节点(InferenceNode)中 ONNX 模型的加载机制、运行时环境配置以及实时推理循环的实现细节。作为连接强化学习策略与硬件执行的核心枢纽,推理节点采用 ONNX Runtime C++ API 实现跨平台(x64/aarch64)部署,并通过多线程实时调度确保控制频率的确定性。理解本节内容将为后续阅读 观测堆叠与多策略切换 和 动作序列加载与运动策略 奠定必要的代码层面基础。
推理节点整体架构
InferenceNode 是 ROS2 节点,内部封装了 ONNX Runtime 环境、多策略运行时以及硬件接口。每个策略(PolicyRuntime)拥有独立的 ModelContext,包含一个 ONNX Session、输入输出 Tensor 缓冲区以及观测堆叠状态。节点启动时会根据 YAML 配置为所有策略并行完成模型加载与验证,随后进入分离的推理线程与控制线程。
classDiagram
class InferenceNode {
+vector~PolicyRuntime~ policies_
+unique_ptr~Ort::Env~ env_
+thread inference_thread_
+thread control_thread_
+setup_model(ctx, path, input_size)
+inference()
+control()
+update_stacked_obs(...)
}
class ModelContext {
+unique_ptr~Ort::Session~ session
+unique_ptr~Ort::Value~ input_tensor
+unique_ptr~Ort::Value~ output_tensor
+vector~float~ input_buffer
+vector~float~ output_buffer
+vector~int64_t~ input_shape
+vector~int64_t~ output_shape
}
class PolicyRuntime {
+string model_path
+vector~ObsSourceSpec~ obs_layout
+int frame_stack
+ObsStackOrder stack_order
+unique_ptr~ModelContext~ ctx
}
InferenceNode --> PolicyRuntime : 管理 N 个
PolicyRuntime --> ModelContext : 包含 1 个
推理节点在构造阶段完成所有模型的加载,避免运行时的动态文件 I/O 干扰实时性。构造函数中首先初始化 Ort::Env,随后遍历 policies_ 数组调用 setup_model,最后启动两个实时线程。Sources: inference_node.hpp
ONNX Runtime 环境与会话配置
环境级线程控制
ONNX Runtime 的全局线程行为通过 Ort::ThreadingOptions 配置。代码中暴露 intra_threads 参数(默认值为 1),用于控制算子内的线程并行度。在边缘设备(aarch64)上,通常将其设为 1 以避免推理内部多线程与外部控制线程发生 CPU 抢占,从而保障 SCHED_FIFO 调度的确定性。
Ort::ThreadingOptions thread_opts;
if (intra_threads_ > 0) {
thread_opts.SetGlobalIntraOpNumThreads(intra_threads_);
}
env_ = std::make_unique<Ort::Env>(thread_opts, ORT_LOGGING_LEVEL_WARNING, "ONNXRuntimeInference");
Sources: inference_node.hpp
会话优化选项
每个 Ort::Session 的创建均伴随严格的性能优化配置:
| 配置项 | 设置值 | 作用说明 |
|---|---|---|
DisablePerSessionThreads |
— | 禁用会话私有线程池,复用全局线程选项中的线程数 |
EnableCpuMemArena |
— | 启用 CPU 内存池,减少推理过程中的系统分配开销 |
EnableMemPattern |
— | 允许内存复用模式,降低缓存未命中 |
SetGraphOptimizationLevel |
ORT_ENABLE_ALL |
启用全部图优化(常量折叠、算子融合等) |
这些选项的组合确保了单次 Session::Run 的延迟在亚毫秒级,满足高频控制需求。Sources: inference_node.cpp
模型加载与严格校验
setup_model 是模型加载的核心函数,它不仅创建会话,还对模型结构进行多项运行时断言,防止配置与模型实际输入输出不匹配导致未定义行为。
单输入约束与尺寸验证
系统当前仅支持单输入 ONNX 模型。加载时会读取输入张量的形状,将动态 batch 维度 -1 修正为 1,随后计算模型期望的总输入元素个数,并与配置给出的 input_size 进行严格比对。若不一致则立即抛出异常,避免在运行时才发现维度错位。
ctx->num_inputs = ctx->session->GetInputCount();
if (ctx->num_inputs != 1) {
throw std::runtime_error("Only single-input ONNX models are supported");
}
// ...
size_t model_input_size = 1;
for (size_t i = 0; i < ctx->input_shape.size(); i++) {
model_input_size *= static_cast<size_t>(ctx->input_shape[i]);
}
if (model_input_size != static_cast<size_t>(input_size)) {
throw std::runtime_error("ONNX input size mismatch ...");
}
Sources: inference_node.cpp
零拷贝 Tensor 绑定
为了消除推理前后的数据拷贝开销,代码采用 ONNX Runtime 的零拷贝机制:预先分配 input_buffer 和 output_buffer,并通过 Ort::Value::CreateTensor 将其直接包装为 Ort::Value。在推理循环中,Session::Run 直接读写这些缓冲区,无需额外的 memcpy。
ctx->memory_info = std::make_unique<Ort::MemoryInfo>(
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU));
ctx->input_tensor = std::make_unique<Ort::Value>(Ort::Value::CreateTensor<float>(
*ctx->memory_info, ctx->input_buffer.data(), ctx->input_buffer.size(),
ctx->input_shape.data(), ctx->input_shape.size()));
Sources: inference_node.cpp
实时推理循环与调度策略
推理节点包含两个独立的高优先级线程:inference 线程负责策略前向传播,control 线程负责将动作平滑后下发至电机。两者均通过 pthread_setschedparam 设置为 SCHED_FIFO 实时调度,优先级均为 70,高于主线程(优先级 50)。
flowchart TD
A[inference_thread<br/>周期: dt * decimation] --> B[读取传感器 / 更新观测]
B --> C[观测堆叠与 clamp]
C --> D[Ort::Session::Run]
D --> E[动作 clamp / scale / 坐标映射]
E --> F[写入 act_ 缓冲区]
G[control_thread<br/>周期: dt] --> H[读取 act_ 缓冲区]
H --> I[EMA 平滑: act_alpha_]
I --> J[robot_->apply_action]
推理线程周期
推理线程的运行周期由 dt 与 decimation 共同决定。以默认配置为例,dt = 0.004 s,decimation = 5,则推理周期为 20 ms(50 Hz)。线程内部通过 std::this_thread::sleep_for 进行周期对齐,并在超时时打印警告。
auto period = std::chrono::microseconds(
static_cast<long long>(dt_ * 1000 * 1000 * decimation_));
// ...
auto elapsed_time = std::chrono::duration_cast<std::chrono::microseconds>(
loop_end - loop_start);
auto sleep_time = period - elapsed_time;
Sources: inference_node.cpp
动作平滑与输出管线
推理线程输出原始动作后,并不直接下发电机。control 线程以更高频率(如 250 Hz)执行指数移动平均(EMA)平滑:
last_act_[i] = act_alpha_ * act_[i] + (1 - act_alpha_) * last_act_[i];
robot_->apply_action(last_act_);
act_alpha_ 默认为 1.0,即不启用平滑;降低该值可获得更柔顺的动作过渡。这种双线程解耦设计保证了即使策略推理耗时抖动,电机控制频率仍保持稳定。Sources: inference_node.cpp
观测向量构建与源注册
ONNX 模型的输入观测并非硬编码,而是通过声明式配置动态组装。系统维护一个观测源注册表 obs_source_definitions,将字符串名称映射到成员函数指针,实现配置驱动的观测采集。
观测源注册表
| 观测源名称 | 数据维度 | 采集来源 | 说明 |
|---|---|---|---|
ang_vel |
3 | IMU | 机体角速度,带 obs_scales_ang_vel_ 缩放 |
gravity_b |
3 | IMU 四元数 | 世界重力向量在机体坐标系下的投影 |
cmd_vel |
3 | /joy 或 /cmd_vel | 线速度 x/y 与角速度 yaw |
dof_pos |
23 | 电机反馈 | 关节位置经 usd2urdf_ 映射并减去默认角度 |
dof_vel |
23 | 电机反馈 | 关节速度经 usd2urdf_ 映射 |
last_action |
23 | 上帧输出 | 策略上帧输出的原始动作值 |
interrupt |
1 | 中断标志 | 是否处于中断模式的布尔值 |
perception |
N | /elevation_data | 外部感知模块输入(如地形高程) |
motion_pos |
23 | NPZ 动作文件 | 当前帧的参考关节位置 |
motion_vel |
23 | NPZ 动作文件 | 当前帧的参考关节速度 |
Sources: obs_manager.cpp
观测段组装流程
配置中的 obs_layouts 字符串(如 "ang_vel:3, gravity_b:3, cmd_vel:3, dof_pos:23, dof_vel:23, last_action:23")在初始化时被解析为 vector<ObsSourceSpec>。运行时,update_obs_segments 通过成员函数指针依次调用对应的 getter,将数据填充到 obs_segments 中;随后 flatten_obs_segments 将其拼接为一维向量。
void InferenceNode::update_obs_segments(
std::vector<std::vector<float>>& segments,
const std::vector<ObsSourceSpec>& layout) {
for (size_t i = 0; i < layout.size(); i++) {
(this->*(layout[i].source->get))(segments[i]);
}
}
Sources: obs_manager.cpp
观测堆叠策略
时序策略通常需要过去多帧的观测作为输入。系统支持两种堆叠顺序,通过 obs_stack_orders 配置指定。
FrameMajor(帧优先)
整帧观测按时间顺序连续排列。对于 obs_num = 78、frame_stack = 10 的情况,输入张量布局为 [t-9_frame, t-8_frame, ..., t_frame]。首帧时所有历史帧初始化为当前帧;后续帧通过 std::move 实现滑动窗口。
ObsMajor(观测域优先)
每个观测域独立堆叠。例如对于域 [ang_vel(3), gravity_b(3), ...],布局为 [ang_vel_t-9, ..., ang_vel_t, gravity_b_t-9, ..., gravity_b_t, ...]。该模式适用于部分域不需要堆叠或需要不同堆叠深度的场景。
if (is_first_frame) {
for (int frame = 0; frame < frame_stack; frame++) {
std::copy(obs.begin(), obs.end(),
input_buffer.begin() + frame * obs_num);
}
} else {
std::move(input_buffer.begin() + obs_num,
input_buffer.begin() + frame_stack * obs_num,
input_buffer.begin());
std::copy(obs.begin(), obs.end(),
input_buffer.begin() + (frame_stack - 1) * obs_num);
}
Sources: inference_node.cpp
| 堆叠模式 | 适用场景 | 内存移动特征 |
|---|---|---|
| FrameMajor | 通用时序策略,所有观测域等长堆叠 | 单次大块 std::move |
| ObsMajor | 混合策略,部分域不堆叠或需差异化处理 | 按域多次细粒度 std::move |
多策略运行时与模型上下文
推理节点原生支持多策略部署,通过 model_names 数组可同时加载多个 ONNX 模型。每个策略拥有独立的 PolicyRuntime 实例,内部包含完整的 ModelContext、观测布局、堆叠配置以及可选的 MotionLoader。
flowchart LR
subgraph 策略运行时 0
A0[ModelContext<br/>policy.onnx]
B0[obs_layout 78]
C0[frame_stack 10]
end
subgraph 策略运行时 1
A1[ModelContext<br/>policy_wave.onnx]
B1[obs_layout 141]
C1[frame_stack 1]
end
D[active_policy_idx_] -->|切换| A0
D -->|切换| A1
以 inference_beyondmimic.yaml 为例,配置中声明了 4 个模型:一个基础策略和三个参考运动策略(wave、dance、punch)。各策略的观测维度、堆叠深度可以完全不同。运行时通过 active_policy_idx_ 选择当前激活的策略,仅对该策略执行 update_obs_segments 与 Session::Run。Sources: inference_node.hpp, ros_interface.cpp
推理执行与后处理
单次推理循环的核心调用极为精简,仅为一行 Session::Run:
policy.ctx->session->Run(Ort::RunOptions{nullptr},
policy.ctx->input_names_raw.data(), policy.ctx->input_tensor.get(), policy.ctx->num_inputs,
policy.ctx->output_names_raw.data(), policy.ctx->output_tensor.get(), policy.ctx->num_outputs);
输出之后依次执行:动作值 clamp(限制在 ±clip_actions_)、USD 到 URDF 的关节坐标映射(usd2urdf_)、动作缩放(action_scale_)与默认角度偏置叠加。若策略支持中断模式(interrupt),还会将中断动作覆盖到输出向量的尾部关节。
for (int i = 0; i < policy.ctx->output_buffer.size(); i++) {
policy.ctx->output_buffer[i] = std::clamp(policy.ctx->output_buffer[i], -clip_actions_, clip_actions_);
act_[usd2urdf_[i]] = policy.ctx->output_buffer[i];
act_[usd2urdf_[i]] = act_[usd2urdf_[i]] * action_scale_ + joint_default_angle_[usd2urdf_[i]];
}
Sources: inference_node.cpp
构建系统与跨平台部署
ONNX Runtime 以预编译库形式嵌入 thirdparty/ 目录,CMake 根据 CMAKE_SYSTEM_PROCESSOR 自动选择 x64 或 aarch64 版本:
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
set(ONNXRUNTIME_ROOT_DIR ${CMAKE_SOURCE_DIR}/thirdparty/onnxruntime-linux-x64-1.21.0)
elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
set(ONNXRUNTIME_ROOT_DIR ${CMAKE_SOURCE_DIR}/thirdparty/onnxruntime-linux-aarch64-1.21.0)
endif()
推理可执行文件通过 -O3 -march=native 编译,并与 libonnxruntime.so 动态链接。安装阶段会将共享库一同部署到 lib/ 目录,确保目标设备无需额外安装 ONNX Runtime。Sources: CMakeLists.txt
下一步阅读建议
掌握 ONNX 模型加载与实时推理机制后,建议继续阅读 观测堆叠与多策略切换 以深入理解多策略切换的时序与状态管理细节,或阅读 动作序列加载与运动策略 了解 MotionLoader 如何与 ONNX 策略协同实现参考运动跟踪。若需从训练端导出兼容的 ONNX 模型,亦可参考 接入自定义强化学习策略。