ONNX Integration 架构
本文档描述了仅限 GPU 的 ONNX 执行路径,以及它如何在没有 CPU 往返的情况下为高斯 Splatting 提供数据。该架构支持自动精度检测、静态和动态推理模式,以及与渲染管道的无缝集成。
分层架构
ONNXManager (管理层)
├─ 模型生命周期与清理
├─ 与 ModelManager / DynamicPointCloud 的集成
├─ 静态与动态推理控制
└─ 资源跟踪与销毁
ONNXGenerator (生成器层)
├─ OnnxGpuIO 的外观模式封装 (Facade)
├─ 输入/输出缓冲区暴露
├─ 元数据检测 (颜色模式, 容量, 输入名称)
├─ 精度信息访问
└─ 简化的推理 API
OnnxGpuIO (I/O 层)
├─ 仅限 WebGPU 的 InferenceSession (preferredOutputLocation='gpu-buffer')
├─ GPU 缓冲区预分配 (共享设备)
├─ 精度检测与配置
├─ 推理执行与绑定
└─ 图捕获支持 (可选)
PrecisionDetector (检测层)
├─ 从元数据自动检测精度
├─ 基于输出名称的精度推断
└─ 量化参数提取
核心组件
- ONNXManager (
src/app/managers/onnx-manager.ts) - 通过
ONNXGenerator加载 ONNX 模型并创建专用的DynamicPointCloud实例 - 检测颜色模式(RGB vs SH)并将元数据转发给点云
- 支持
staticInference和通过AnimationManager进行的动态每帧推理 - 管理资源生命周期并暴露诊断信息
- 支持通过
PrecisionConfig进行精度配置 -
跟踪生成器和点云以便正确清理
-
ONNXGenerator (
src/onnx/onnx_generator.ts) - 提供
initialize(device?)和generate(inputs?)的薄外观(facade) - 暴露 GPU 缓冲区(
getGaussianBuffer(),getSHBuffer(),getCountBuffer()) - 提供元数据访问(容量,颜色模式,颜色维度,输入名称)
- 暴露精度信息(
getGaussianPrecision(),getColorPrecision()) -
自动检测模型输入/输出命名约定
-
OnnxGpuIO (
src/onnx/onnx_gpu_io.ts) - 使用应用的
GPUDevice(与渲染器共享)构建仅限 WebGPU 的 ORTInferenceSession - 为所有输入和输出预分配具有适当精度的 GPU 缓冲区
- 通过
PrecisionDetector实现精度检测 - 实现
runInference(),通过Tensor.fromGpuBuffer(...)将 GPU 缓冲区绑定为 feeds/fetches,并在不接触 CPU 的情况下执行会话 - 支持 WebGPU 图捕获(如果不支持则回退)
-
实现独占执行链以防止并发推理冲突
-
PrecisionDetector (
src/onnx/precision-detector.ts) - 从模型元数据自动检测数据类型(float16, float32, int8, uint8)
- 回退到基于输出名称的检测(例如,
_f16,_f32,_i8后缀) - 提取量化模型的量化参数(scale, zeroPoint)
- 计算 16 字节对齐的缓冲区大小
设备共享(关键)
initWebGPU_onnx() 获取 ORT 的设备——或者强制 ORT 通过虚拟会话实例化一个——以便渲染器和 ORT 使用相同的 GPUDevice。OnnxGpuIO.init({ device }) 随后在该设备上分配每个缓冲区,允许预处理/排序/渲染阶段直接使用 ONNX 输出。
GPU I/O 绑定
预分配输出
根据检测到的精度和维度分配输出缓冲区:
gaussBuf:[maxPoints, 10]打包的高斯数据- 精度: 自动检测(默认 float16,2 字节/元素)
- 用法:
STORAGE | COPY_SRC | COPY_DST | VERTEX -
大小: 使用 16 字节对齐计算
-
shBuf:[maxPoints, colorDim]颜色数据(SH 或 RGB) - 精度: 自动检测(默认 float16,2 字节/元素)
colorDim: SH 为 48,RGB 为 3(从元数据自动检测)- 用法:
STORAGE | COPY_SRC | COPY_DST | VERTEX -
大小: 使用 16 字节对齐计算
-
countBuf: 用于动态点计数的单个i32 - 用法:
STORAGE | COPY_SRC | COPY_DST - 大小: 16 字节(对齐)
预分配输入
为图捕获兼容性分配输入缓冲区:
cameraMatrixBuf: 4×4 float32 矩阵(64 字节,16 字节对齐)-
用法:
STORAGE | COPY_DST -
projMatrixBuf: 4×4 float32 矩阵(64 字节,16 字节对齐) -
用法:
STORAGE | COPY_DST -
timeBuf: 单个 float32 值(16 字节,对齐) - 用法:
STORAGE | COPY_DST
在运行时,runInference() 将这些缓冲区绑定在 feeds/fetches 映射中;ORT 直接将输出张量写入 GPU 内存,无需 CPU 往返。
精度检测
系统自动从以下来源检测精度:
1. 模型元数据: ONNX Runtime 会话输出元数据
2. 输出名称: 基于后缀的检测(_f16, _f32, _i8, _u8)
3. 手动覆盖: ONNXLoadOptions 中的 PrecisionConfig
支持的精度:
- float16 (2 字节/元素) - 默认,最常见
- float32 (4 字节/元素) - 更高精度
- int8 (1 字节/元素) - 量化模型
- uint8 (1 字节/元素) - 量化模型
数据流
相机 / 时间 → updateInputBuffers() → session.run(feeds, fetches)
↓
[gaussBuf, shBuf, countBuf]
↓
DynamicPointCloud (GPU buffers) → 预处理 (读取 countBuf) → 排序 → 渲染
- 预处理尊重
countBuf(当存在时)以在排序前更新ModelParams.num_points。 - 排序和渲染随后发出单个间接绘制(indirect draw),实例计数(instanceCount)来源于排序器 uniforms。
静态 vs 动态推理
静态模式 (staticInference: true)
- 在初始化期间调用一次
generate({}) - 生成的
DynamicPointCloud表现得像静态模型 - 无每帧更新
- 适用于预计算或一次性生成的模型
动态模式 (staticInference: false)
- 加载期间进行初始
generate({ cameraMatrix, projectionMatrix, time }) - 通过
AnimationManager.updateDynamicPointClouds()进行每帧generate()调用 - 使用更新的相机矩阵和时间来流式传输动画数据
DynamicPointCloud与ONNXGenerator连接以实现自动更新- 启用实时神经渲染和动画
集成流程:
ONNXManager.loadONNXModel()设置staticInference: falseDynamicPointCloud.setOnnxGenerator(gen)连接生成器AnimationManager.updateDynamicPointClouds()调用dpc.update()DynamicPointCloud.update()调用带有当前相机/时间的gen.generate()- 新数据直接流向 GPU 缓冲区以进行渲染
诊断与要求
会话创建
- 独占使用 WebGPU 执行提供程序
- 如果 URL 加载失败,从 URL 加载回退到 ArrayBuffer 获取
- 支持 WebGPU 图捕获(如果不支持则自动回退)
- 初始化失败时抛出描述性错误
调试工具
ONNXTestUtils提供性能测量和验证ONNXModelTester用于隔离模型测试- 调试助手可以在开发构建中读回
countBuf以验证模型输出计数 - 通过
getGaussianPrecision()和getColorPrecision()获取精度信息
要求
- ONNX Runtime WASM: 确保提供 ONNX Runtime wasm 资产(默认
/src/ort/) - WebGPU 支持: 需要支持 WebGPU 的浏览器和 GPU
- 共享设备: 使用应用的 WebGPU 设备以实现缓冲区兼容性
- 模型格式: ONNX 模型必须输出具有兼容形状的高斯和颜色张量
独占执行
OnnxGpuIO 实现了全局执行链(runExclusive)以防止并发推理冲突。所有推理调用都被序列化,以避免 ORT WebGPU IOBinding 会话冲突。
精度系统
ONNX 模块包含一个全面的精度检测和配置系统:
PrecisionMetadata
interface PrecisionMetadata {
dataType: 'float32' | 'float16' | 'int8' | 'uint8';
bytesPerElement: number;
scale?: number; // 用于量化的 int8/uint8
zeroPoint?: number; // 用于量化的 int8/uint8
}
PrecisionConfig
interface PrecisionConfig {
gaussian?: Partial<PrecisionMetadata>; // 高斯输出的覆盖配置
color?: Partial<PrecisionMetadata>; // 颜色输出的覆盖配置
autoDetect?: boolean; // 遗留标志(已弃用)
}
检测策略
- 元数据优先: 检查 ONNX Runtime 会话输出元数据中的类型信息
- 名称回退: 使用输出名称后缀(
_f16,_f32,_i8,_u8) - 手动覆盖: 如果提供,则应用
PrecisionConfig - 默认: 如果检测失败,回退到 float16
缓冲区大小计算
所有缓冲区均按 16 字节对齐计算:
这确保了最佳的 GPU 内存访问模式。