🤖 roboto_origin_03 Wiki
首页 / 部署 / ONNX 模型加载与实时推理

本页深入解析推理节点(InferenceNode)中 ONNX 模型的加载机制、运行时环境配置以及实时推理循环的实现细节。作为连接强化学习策略与硬件执行的核心枢纽,推理节点采用 ONNX Runtime C++ API 实现跨平台(x64/aarch64)部署,并通过多线程实时调度确保控制频率的确定性。理解本节内容将为后续阅读 观测堆叠与多策略切换动作序列加载与运动策略 奠定必要的代码层面基础。

推理节点整体架构

InferenceNode 是 ROS2 节点,内部封装了 ONNX Runtime 环境、多策略运行时以及硬件接口。每个策略(PolicyRuntime)拥有独立的 ModelContext,包含一个 ONNX Session、输入输出 Tensor 缓冲区以及观测堆叠状态。节点启动时会根据 YAML 配置为所有策略并行完成模型加载与验证,随后进入分离的推理线程与控制线程。

classDiagram
    class InferenceNode {
        +vector~PolicyRuntime~ policies_
        +unique_ptr~Ort::Env~ env_
        +thread inference_thread_
        +thread control_thread_
        +setup_model(ctx, path, input_size)
        +inference()
        +control()
        +update_stacked_obs(...)
    }
    class ModelContext {
        +unique_ptr~Ort::Session~ session
        +unique_ptr~Ort::Value~ input_tensor
        +unique_ptr~Ort::Value~ output_tensor
        +vector~float~ input_buffer
        +vector~float~ output_buffer
        +vector~int64_t~ input_shape
        +vector~int64_t~ output_shape
    }
    class PolicyRuntime {
        +string model_path
        +vector~ObsSourceSpec~ obs_layout
        +int frame_stack
        +ObsStackOrder stack_order
        +unique_ptr~ModelContext~ ctx
    }
    InferenceNode --> PolicyRuntime : 管理 N 个
    PolicyRuntime --> ModelContext : 包含 1 个

推理节点在构造阶段完成所有模型的加载,避免运行时的动态文件 I/O 干扰实时性。构造函数中首先初始化 Ort::Env,随后遍历 policies_ 数组调用 setup_model,最后启动两个实时线程。Sources: inference_node.hpp

ONNX Runtime 环境与会话配置

环境级线程控制

ONNX Runtime 的全局线程行为通过 Ort::ThreadingOptions 配置。代码中暴露 intra_threads 参数(默认值为 1),用于控制算子内的线程并行度。在边缘设备(aarch64)上,通常将其设为 1 以避免推理内部多线程与外部控制线程发生 CPU 抢占,从而保障 SCHED_FIFO 调度的确定性。

Ort::ThreadingOptions thread_opts;
if (intra_threads_ > 0) {
    thread_opts.SetGlobalIntraOpNumThreads(intra_threads_);
}
env_ = std::make_unique<Ort::Env>(thread_opts, ORT_LOGGING_LEVEL_WARNING, "ONNXRuntimeInference");

Sources: inference_node.hpp

会话优化选项

每个 Ort::Session 的创建均伴随严格的性能优化配置:

配置项 设置值 作用说明
DisablePerSessionThreads 禁用会话私有线程池,复用全局线程选项中的线程数
EnableCpuMemArena 启用 CPU 内存池,减少推理过程中的系统分配开销
EnableMemPattern 允许内存复用模式,降低缓存未命中
SetGraphOptimizationLevel ORT_ENABLE_ALL 启用全部图优化(常量折叠、算子融合等)

这些选项的组合确保了单次 Session::Run 的延迟在亚毫秒级,满足高频控制需求。Sources: inference_node.cpp

模型加载与严格校验

setup_model 是模型加载的核心函数,它不仅创建会话,还对模型结构进行多项运行时断言,防止配置与模型实际输入输出不匹配导致未定义行为。

单输入约束与尺寸验证

系统当前仅支持单输入 ONNX 模型。加载时会读取输入张量的形状,将动态 batch 维度 -1 修正为 1,随后计算模型期望的总输入元素个数,并与配置给出的 input_size 进行严格比对。若不一致则立即抛出异常,避免在运行时才发现维度错位。

ctx->num_inputs = ctx->session->GetInputCount();
if (ctx->num_inputs != 1) {
    throw std::runtime_error("Only single-input ONNX models are supported");
}
// ...
size_t model_input_size = 1;
for (size_t i = 0; i < ctx->input_shape.size(); i++) {
    model_input_size *= static_cast<size_t>(ctx->input_shape[i]);
}
if (model_input_size != static_cast<size_t>(input_size)) {
    throw std::runtime_error("ONNX input size mismatch ...");
}

Sources: inference_node.cpp

零拷贝 Tensor 绑定

为了消除推理前后的数据拷贝开销,代码采用 ONNX Runtime 的零拷贝机制:预先分配 input_bufferoutput_buffer,并通过 Ort::Value::CreateTensor 将其直接包装为 Ort::Value。在推理循环中,Session::Run 直接读写这些缓冲区,无需额外的 memcpy

ctx->memory_info = std::make_unique<Ort::MemoryInfo>(
    Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU));

ctx->input_tensor = std::make_unique<Ort::Value>(Ort::Value::CreateTensor<float>(
    *ctx->memory_info, ctx->input_buffer.data(), ctx->input_buffer.size(),
    ctx->input_shape.data(), ctx->input_shape.size()));

Sources: inference_node.cpp

实时推理循环与调度策略

推理节点包含两个独立的高优先级线程:inference 线程负责策略前向传播,control 线程负责将动作平滑后下发至电机。两者均通过 pthread_setschedparam 设置为 SCHED_FIFO 实时调度,优先级均为 70,高于主线程(优先级 50)。

flowchart TD
    A[inference_thread<br/>周期: dt * decimation] --> B[读取传感器 / 更新观测]
    B --> C[观测堆叠与 clamp]
    C --> D[Ort::Session::Run]
    D --> E[动作 clamp / scale / 坐标映射]
    E --> F[写入 act_ 缓冲区]
    G[control_thread<br/>周期: dt] --> H[读取 act_ 缓冲区]
    H --> I[EMA 平滑: act_alpha_]
    I --> J[robot_->apply_action]

推理线程周期

推理线程的运行周期由 dtdecimation 共同决定。以默认配置为例,dt = 0.004 sdecimation = 5,则推理周期为 20 ms(50 Hz)。线程内部通过 std::this_thread::sleep_for 进行周期对齐,并在超时时打印警告。

auto period = std::chrono::microseconds(
    static_cast<long long>(dt_ * 1000 * 1000 * decimation_));
// ...
auto elapsed_time = std::chrono::duration_cast<std::chrono::microseconds>(
    loop_end - loop_start);
auto sleep_time = period - elapsed_time;

Sources: inference_node.cpp

动作平滑与输出管线

推理线程输出原始动作后,并不直接下发电机。control 线程以更高频率(如 250 Hz)执行指数移动平均(EMA)平滑:

last_act_[i] = act_alpha_ * act_[i] + (1 - act_alpha_) * last_act_[i];
robot_->apply_action(last_act_);

act_alpha_ 默认为 1.0,即不启用平滑;降低该值可获得更柔顺的动作过渡。这种双线程解耦设计保证了即使策略推理耗时抖动,电机控制频率仍保持稳定。Sources: inference_node.cpp

观测向量构建与源注册

ONNX 模型的输入观测并非硬编码,而是通过声明式配置动态组装。系统维护一个观测源注册表 obs_source_definitions,将字符串名称映射到成员函数指针,实现配置驱动的观测采集。

观测源注册表

观测源名称 数据维度 采集来源 说明
ang_vel 3 IMU 机体角速度,带 obs_scales_ang_vel_ 缩放
gravity_b 3 IMU 四元数 世界重力向量在机体坐标系下的投影
cmd_vel 3 /joy 或 /cmd_vel 线速度 x/y 与角速度 yaw
dof_pos 23 电机反馈 关节位置经 usd2urdf_ 映射并减去默认角度
dof_vel 23 电机反馈 关节速度经 usd2urdf_ 映射
last_action 23 上帧输出 策略上帧输出的原始动作值
interrupt 1 中断标志 是否处于中断模式的布尔值
perception N /elevation_data 外部感知模块输入(如地形高程)
motion_pos 23 NPZ 动作文件 当前帧的参考关节位置
motion_vel 23 NPZ 动作文件 当前帧的参考关节速度

Sources: obs_manager.cpp

观测段组装流程

配置中的 obs_layouts 字符串(如 "ang_vel:3, gravity_b:3, cmd_vel:3, dof_pos:23, dof_vel:23, last_action:23")在初始化时被解析为 vector<ObsSourceSpec>。运行时,update_obs_segments 通过成员函数指针依次调用对应的 getter,将数据填充到 obs_segments 中;随后 flatten_obs_segments 将其拼接为一维向量。

void InferenceNode::update_obs_segments(
    std::vector<std::vector<float>>& segments,
    const std::vector<ObsSourceSpec>& layout) {
    for (size_t i = 0; i < layout.size(); i++) {
        (this->*(layout[i].source->get))(segments[i]);
    }
}

Sources: obs_manager.cpp

观测堆叠策略

时序策略通常需要过去多帧的观测作为输入。系统支持两种堆叠顺序,通过 obs_stack_orders 配置指定。

FrameMajor(帧优先)

整帧观测按时间顺序连续排列。对于 obs_num = 78frame_stack = 10 的情况,输入张量布局为 [t-9_frame, t-8_frame, ..., t_frame]。首帧时所有历史帧初始化为当前帧;后续帧通过 std::move 实现滑动窗口。

ObsMajor(观测域优先)

每个观测域独立堆叠。例如对于域 [ang_vel(3), gravity_b(3), ...],布局为 [ang_vel_t-9, ..., ang_vel_t, gravity_b_t-9, ..., gravity_b_t, ...]。该模式适用于部分域不需要堆叠或需要不同堆叠深度的场景。

if (is_first_frame) {
    for (int frame = 0; frame < frame_stack; frame++) {
        std::copy(obs.begin(), obs.end(),
                  input_buffer.begin() + frame * obs_num);
    }
} else {
    std::move(input_buffer.begin() + obs_num,
              input_buffer.begin() + frame_stack * obs_num,
              input_buffer.begin());
    std::copy(obs.begin(), obs.end(),
              input_buffer.begin() + (frame_stack - 1) * obs_num);
}

Sources: inference_node.cpp

堆叠模式 适用场景 内存移动特征
FrameMajor 通用时序策略,所有观测域等长堆叠 单次大块 std::move
ObsMajor 混合策略,部分域不堆叠或需差异化处理 按域多次细粒度 std::move

多策略运行时与模型上下文

推理节点原生支持多策略部署,通过 model_names 数组可同时加载多个 ONNX 模型。每个策略拥有独立的 PolicyRuntime 实例,内部包含完整的 ModelContext、观测布局、堆叠配置以及可选的 MotionLoader

flowchart LR
    subgraph 策略运行时 0
        A0[ModelContext<br/>policy.onnx]
        B0[obs_layout 78]
        C0[frame_stack 10]
    end
    subgraph 策略运行时 1
        A1[ModelContext<br/>policy_wave.onnx]
        B1[obs_layout 141]
        C1[frame_stack 1]
    end
    D[active_policy_idx_] -->|切换| A0
    D -->|切换| A1

inference_beyondmimic.yaml 为例,配置中声明了 4 个模型:一个基础策略和三个参考运动策略(wave、dance、punch)。各策略的观测维度、堆叠深度可以完全不同。运行时通过 active_policy_idx_ 选择当前激活的策略,仅对该策略执行 update_obs_segmentsSession::Run。Sources: inference_node.hpp, ros_interface.cpp

推理执行与后处理

单次推理循环的核心调用极为精简,仅为一行 Session::Run

policy.ctx->session->Run(Ort::RunOptions{nullptr},
    policy.ctx->input_names_raw.data(), policy.ctx->input_tensor.get(), policy.ctx->num_inputs,
    policy.ctx->output_names_raw.data(), policy.ctx->output_tensor.get(), policy.ctx->num_outputs);

输出之后依次执行:动作值 clamp(限制在 ±clip_actions_)、USD 到 URDF 的关节坐标映射(usd2urdf_)、动作缩放(action_scale_)与默认角度偏置叠加。若策略支持中断模式(interrupt),还会将中断动作覆盖到输出向量的尾部关节。

for (int i = 0; i < policy.ctx->output_buffer.size(); i++) {
    policy.ctx->output_buffer[i] = std::clamp(policy.ctx->output_buffer[i], -clip_actions_, clip_actions_);
    act_[usd2urdf_[i]] = policy.ctx->output_buffer[i];
    act_[usd2urdf_[i]] = act_[usd2urdf_[i]] * action_scale_ + joint_default_angle_[usd2urdf_[i]];
}

Sources: inference_node.cpp

构建系统与跨平台部署

ONNX Runtime 以预编译库形式嵌入 thirdparty/ 目录,CMake 根据 CMAKE_SYSTEM_PROCESSOR 自动选择 x64 或 aarch64 版本:

if(CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
    set(ONNXRUNTIME_ROOT_DIR ${CMAKE_SOURCE_DIR}/thirdparty/onnxruntime-linux-x64-1.21.0)
elseif(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
    set(ONNXRUNTIME_ROOT_DIR ${CMAKE_SOURCE_DIR}/thirdparty/onnxruntime-linux-aarch64-1.21.0)
endif()

推理可执行文件通过 -O3 -march=native 编译,并与 libonnxruntime.so 动态链接。安装阶段会将共享库一同部署到 lib/ 目录,确保目标设备无需额外安装 ONNX Runtime。Sources: CMakeLists.txt

下一步阅读建议

掌握 ONNX 模型加载与实时推理机制后,建议继续阅读 观测堆叠与多策略切换 以深入理解多策略切换的时序与状态管理细节,或阅读 动作序列加载与运动策略 了解 MotionLoader 如何与 ONNX 策略协同实现参考运动跟踪。若需从训练端导出兼容的 ONNX 模型,亦可参考 接入自定义强化学习策略