跳转至

ONNX Integration 架构

本文档描述了仅限 GPU 的 ONNX 执行路径,以及它如何在没有 CPU 往返的情况下为高斯 Splatting 提供数据。该架构支持自动精度检测、静态和动态推理模式,以及与渲染管道的无缝集成。

分层架构

ONNXManager (管理层)
├─ 模型生命周期与清理
├─ 与 ModelManager / DynamicPointCloud 的集成
├─ 静态与动态推理控制
└─ 资源跟踪与销毁

ONNXGenerator (生成器层)
├─ OnnxGpuIO 的外观模式封装 (Facade)
├─ 输入/输出缓冲区暴露
├─ 元数据检测 (颜色模式, 容量, 输入名称)
├─ 精度信息访问
└─ 简化的推理 API

OnnxGpuIO (I/O 层)
├─ 仅限 WebGPU 的 InferenceSession (preferredOutputLocation='gpu-buffer')
├─ GPU 缓冲区预分配 (共享设备)
├─ 精度检测与配置
├─ 推理执行与绑定
└─ 图捕获支持 (可选)

PrecisionDetector (检测层)
├─ 从元数据自动检测精度
├─ 基于输出名称的精度推断
└─ 量化参数提取

核心组件

  • ONNXManager (src/app/managers/onnx-manager.ts)
  • 通过 ONNXGenerator 加载 ONNX 模型并创建专用的 DynamicPointCloud 实例
  • 检测颜色模式(RGB vs SH)并将元数据转发给点云
  • 支持 staticInference 和通过 AnimationManager 进行的动态每帧推理
  • 管理资源生命周期并暴露诊断信息
  • 支持通过 PrecisionConfig 进行精度配置
  • 跟踪生成器和点云以便正确清理

  • ONNXGenerator (src/onnx/onnx_generator.ts)

  • 提供 initialize(device?)generate(inputs?) 的薄外观(facade)
  • 暴露 GPU 缓冲区(getGaussianBuffer(), getSHBuffer(), getCountBuffer()
  • 提供元数据访问(容量,颜色模式,颜色维度,输入名称)
  • 暴露精度信息(getGaussianPrecision(), getColorPrecision()
  • 自动检测模型输入/输出命名约定

  • OnnxGpuIO (src/onnx/onnx_gpu_io.ts)

  • 使用应用的 GPUDevice(与渲染器共享)构建仅限 WebGPU 的 ORT InferenceSession
  • 为所有输入和输出预分配具有适当精度的 GPU 缓冲区
  • 通过 PrecisionDetector 实现精度检测
  • 实现 runInference(),通过 Tensor.fromGpuBuffer(...) 将 GPU 缓冲区绑定为 feeds/fetches,并在不接触 CPU 的情况下执行会话
  • 支持 WebGPU 图捕获(如果不支持则回退)
  • 实现独占执行链以防止并发推理冲突

  • PrecisionDetector (src/onnx/precision-detector.ts)

  • 从模型元数据自动检测数据类型(float16, float32, int8, uint8)
  • 回退到基于输出名称的检测(例如,_f16, _f32, _i8 后缀)
  • 提取量化模型的量化参数(scale, zeroPoint)
  • 计算 16 字节对齐的缓冲区大小

设备共享(关键)

initWebGPU_onnx() 获取 ORT 的设备——或者强制 ORT 通过虚拟会话实例化一个——以便渲染器和 ORT 使用相同的 GPUDeviceOnnxGpuIO.init({ device }) 随后在该设备上分配每个缓冲区,允许预处理/排序/渲染阶段直接使用 ONNX 输出。

GPU I/O 绑定

预分配输出

根据检测到的精度和维度分配输出缓冲区:

  • gaussBuf: [maxPoints, 10] 打包的高斯数据
  • 精度: 自动检测(默认 float16,2 字节/元素)
  • 用法: STORAGE | COPY_SRC | COPY_DST | VERTEX
  • 大小: 使用 16 字节对齐计算

  • shBuf: [maxPoints, colorDim] 颜色数据(SH 或 RGB)

  • 精度: 自动检测(默认 float16,2 字节/元素)
  • colorDim: SH 为 48,RGB 为 3(从元数据自动检测)
  • 用法: STORAGE | COPY_SRC | COPY_DST | VERTEX
  • 大小: 使用 16 字节对齐计算

  • countBuf: 用于动态点计数的单个 i32

  • 用法: STORAGE | COPY_SRC | COPY_DST
  • 大小: 16 字节(对齐)

预分配输入

为图捕获兼容性分配输入缓冲区:

  • cameraMatrixBuf: 4×4 float32 矩阵(64 字节,16 字节对齐)
  • 用法: STORAGE | COPY_DST

  • projMatrixBuf: 4×4 float32 矩阵(64 字节,16 字节对齐)

  • 用法: STORAGE | COPY_DST

  • timeBuf: 单个 float32 值(16 字节,对齐)

  • 用法: STORAGE | COPY_DST

在运行时,runInference() 将这些缓冲区绑定在 feeds/fetches 映射中;ORT 直接将输出张量写入 GPU 内存,无需 CPU 往返。

精度检测

系统自动从以下来源检测精度: 1. 模型元数据: ONNX Runtime 会话输出元数据 2. 输出名称: 基于后缀的检测(_f16, _f32, _i8, _u8) 3. 手动覆盖: ONNXLoadOptions 中的 PrecisionConfig

支持的精度: - float16 (2 字节/元素) - 默认,最常见 - float32 (4 字节/元素) - 更高精度 - int8 (1 字节/元素) - 量化模型 - uint8 (1 字节/元素) - 量化模型

数据流

相机 / 时间 → updateInputBuffers() → session.run(feeds, fetches)
                          [gaussBuf, shBuf, countBuf]
DynamicPointCloud (GPU buffers) → 预处理 (读取 countBuf) → 排序 → 渲染
  • 预处理尊重 countBuf(当存在时)以在排序前更新 ModelParams.num_points
  • 排序和渲染随后发出单个间接绘制(indirect draw),实例计数(instanceCount)来源于排序器 uniforms。

静态 vs 动态推理

静态模式 (staticInference: true)

  • 在初始化期间调用一次 generate({})
  • 生成的 DynamicPointCloud 表现得像静态模型
  • 无每帧更新
  • 适用于预计算或一次性生成的模型

动态模式 (staticInference: false)

  • 加载期间进行初始 generate({ cameraMatrix, projectionMatrix, time })
  • 通过 AnimationManager.updateDynamicPointClouds() 进行每帧 generate() 调用
  • 使用更新的相机矩阵和时间来流式传输动画数据
  • DynamicPointCloudONNXGenerator 连接以实现自动更新
  • 启用实时神经渲染和动画

集成流程:

  1. ONNXManager.loadONNXModel() 设置 staticInference: false
  2. DynamicPointCloud.setOnnxGenerator(gen) 连接生成器
  3. AnimationManager.updateDynamicPointClouds() 调用 dpc.update()
  4. DynamicPointCloud.update() 调用带有当前相机/时间的 gen.generate()
  5. 新数据直接流向 GPU 缓冲区以进行渲染

诊断与要求

会话创建

  • 独占使用 WebGPU 执行提供程序
  • 如果 URL 加载失败,从 URL 加载回退到 ArrayBuffer 获取
  • 支持 WebGPU 图捕获(如果不支持则自动回退)
  • 初始化失败时抛出描述性错误

调试工具

  • ONNXTestUtils 提供性能测量和验证
  • ONNXModelTester 用于隔离模型测试
  • 调试助手可以在开发构建中读回 countBuf 以验证模型输出计数
  • 通过 getGaussianPrecision()getColorPrecision() 获取精度信息

要求

  • ONNX Runtime WASM: 确保提供 ONNX Runtime wasm 资产(默认 /src/ort/
  • WebGPU 支持: 需要支持 WebGPU 的浏览器和 GPU
  • 共享设备: 使用应用的 WebGPU 设备以实现缓冲区兼容性
  • 模型格式: ONNX 模型必须输出具有兼容形状的高斯和颜色张量

独占执行

OnnxGpuIO 实现了全局执行链(runExclusive)以防止并发推理冲突。所有推理调用都被序列化,以避免 ORT WebGPU IOBinding 会话冲突。

精度系统

ONNX 模块包含一个全面的精度检测和配置系统:

PrecisionMetadata

interface PrecisionMetadata {
  dataType: 'float32' | 'float16' | 'int8' | 'uint8';
  bytesPerElement: number;
  scale?: number;      // 用于量化的 int8/uint8
  zeroPoint?: number;  // 用于量化的 int8/uint8
}

PrecisionConfig

interface PrecisionConfig {
  gaussian?: Partial<PrecisionMetadata>;  // 高斯输出的覆盖配置
  color?: Partial<PrecisionMetadata>;     // 颜色输出的覆盖配置
  autoDetect?: boolean;                    // 遗留标志(已弃用)
}

检测策略

  1. 元数据优先: 检查 ONNX Runtime 会话输出元数据中的类型信息
  2. 名称回退: 使用输出名称后缀(_f16, _f32, _i8, _u8
  3. 手动覆盖: 如果提供,则应用 PrecisionConfig
  4. 默认: 如果检测失败,回退到 float16

缓冲区大小计算

所有缓冲区均按 16 字节对齐计算:

sizeInBytes = align16(elements * bytesPerElement)

这确保了最佳的 GPU 内存访问模式。