跳转至

ONNX 模块

本模块记录了 ONNX/WebGPU 集成,用于直接在 GPU 上生成高斯 Splatting 数据,无需 CPU 读回。它为动态点云生成提供仅限 GPU 的推理,支持静态和每帧动态推理模式。

概览

  • 执行提供程序: onnxruntime-web (WebGPU)
  • I/O 绑定: 所有输入和输出均为 GPU 缓冲区(无 CPU 往返)
  • 设备共享: 重用应用的 WebGPU 设备以实现缓冲区兼容性
  • 精度支持: 自动检测和配置 float16, float32, int8, uint8
  • 输出:
  • gaussian: (N, 10) 打包的高斯参数(精度自动检测)
  • shrgb: (N, colorDim) 颜色数据(SH 或 RGB,精度自动检测)
  • num_points: int32[1] 每次推理的实际计数
  • 消费者: DynamicPointCloud + GaussianPreprocessor (num_points GPU 覆盖)

主要特性

  • 纯 GPU 管道: 无 CPU 读回,所有数据停留在 GPU 上
  • 自动精度检测: 从模型元数据或输出名称检测数据类型
  • 静态与动态模式: 支持一次性推理或每帧更新
  • 颜色模式检测: 自动检测 SH vs RGB 输出
  • 容量检测: 从模型元数据自动检测最大点数
  • 图捕获支持: 可选的 WebGPU 图捕获以提高性能
  • 资源管理: 适当清理和销毁 GPU 资源

关注的文件

  • src/onnx/onnx_generator.ts - 高级生成器外观 (Facade)
  • src/onnx/onnx_gpu_io.ts - 底层 GPU I/O 绑定
  • src/onnx/precision-detector.ts - 自动精度检测
  • src/onnx/precision-types.ts - 精度类型定义
  • src/app/managers/onnx-manager.ts - ONNX 模型生命周期管理

快速开始

基本用法(通过 ONNXManager)

import { ONNXManager } from './app/managers/onnx-manager';

const onnxManager = new ONNXManager(modelManager);

// 加载 ONNX 模型(静态推理)
const entry = await onnxManager.loadONNXModel(
  device,
  '/models/gaussians3d.onnx',
  cameraMatrix,
  projectionMatrix,
  'static-model',
  { 
    staticInference: true,
    maxPoints: 2_000_000,
    debugLogging: true
  }
);

// 加载 ONNX 模型(动态每帧推理)
const dynamicEntry = await onnxManager.loadONNXModel(
  device,
  '/models/gaussians3d_dynamic.onnx',
  cameraMatrix,
  projectionMatrix,
  'dynamic-model',
  { 
    staticInference: false,  // 启用每帧更新
    maxPoints: 2_000_000,
    debugLogging: true
  }
);

直接 Generator 用法

import { ONNXGenerator } from './onnx/onnx_generator';

const generator = new ONNXGenerator({
  modelUrl: '/models/gaussians3d.onnx',
  maxPoints: 2_000_000,
  debugLogging: true,
  device: gpuDevice
});

await generator.initialize();
await generator.generate({
  cameraMatrix: viewMatrix,
  projectionMatrix: projMatrix,
  time: performance.now() / 1000
});

// 访问 GPU 缓冲区
const gaussianBuffer = generator.getGaussianBuffer();
const shBuffer = generator.getSHBuffer();
const countBuffer = generator.getCountBuffer();

精度配置

// 手动精度覆盖
const entry = await onnxManager.loadONNXModel(
  device,
  '/models/model.onnx',
  cam, proj, 'model',
  {
    precisionConfig: {
      gaussian: { dataType: 'float32', bytesPerElement: 4 },
      color: { dataType: 'float16', bytesPerElement: 2 }
    }
  }
);

数据流

  1. 加载: ONNXManager.loadONNXModel() 创建 ONNXGeneratorDynamicPointCloud
  2. 初始化: 生成器初始化 ONNX Runtime 会话并预分配 GPU 缓冲区
  3. 推理: ONNXGenerator.generate() 运行推理并写入 GPU 缓冲区
  4. 动态更新: AnimationManager 每帧调用 DynamicPointCloud.update()ONNXGenerator.generate()
  5. 渲染: 预处理读取 countBuffer 以更新实例计数,然后排序并渲染

详情请参阅: ArchitectureAPI Reference

相关文档

  • Architecture – 会话生命周期、缓冲区所有权和精度检测流程。
  • API Reference – 生成器、管理器和精度配置 API 及使用说明。
  • Point Cloud 模块 – 展示 ONNX 输出如何馈送给 DynamicPointCloud
  • Preprocess 模块 – 详述 ONNX 计数和精度标志如何驱动投影。
  • Timeline 模块 – 涵盖动态生成器的每帧动画钩子。
  • Config 模块 – 记录在推理之前如何配置 ORT WASM 路径。