Skip to content

ONNX Module Architecture

This document describes the GPU-only ONNX execution path and how it feeds Gaussian Splatting without CPU round trips. The architecture supports automatic precision detection, static and dynamic inference modes, and seamless integration with the rendering pipeline.

Layered Architecture

ONNXManager (management layer)
├─ Model lifecycle & cleanup
├─ Integration with ModelManager / DynamicPointCloud
├─ Static vs dynamic inference control
└─ Resource tracking and disposal

ONNXGenerator (generator layer)
├─ Facade over OnnxGpuIO
├─ Input/output buffer exposure
├─ Metadata detection (color mode, capacity, input names)
├─ Precision information access
└─ Simplified API for inference

OnnxGpuIO (I/O layer)
├─ WebGPU-only InferenceSession (preferredOutputLocation='gpu-buffer')
├─ GPU buffer preallocation (shared device)
├─ Precision detection and configuration
├─ Inference execution & binding
└─ Graph capture support (optional)

PrecisionDetector (detection layer)
├─ Automatic precision detection from metadata
├─ Output name-based precision inference
└─ Quantization parameter extraction

Core Components

  • ONNXManager (src/app/managers/onnx-manager.ts)
  • Loads ONNX models through ONNXGenerator and creates dedicated DynamicPointCloud instances
  • Detects color mode (RGB vs SH) and forwards metadata to the point cloud
  • Supports staticInference and dynamic per-frame inference via AnimationManager
  • Manages resource lifetime and exposes diagnostics
  • Supports precision configuration via PrecisionConfig
  • Tracks generators and point clouds for proper cleanup

  • ONNXGenerator (src/onnx/onnx_generator.ts)

  • Thin façade providing initialize(device?) and generate(inputs?)
  • Exposes GPU buffers (getGaussianBuffer(), getSHBuffer(), getCountBuffer())
  • Provides metadata access (capacity, color mode, color dimensions, input names)
  • Exposes precision information (getGaussianPrecision(), getColorPrecision())
  • Automatically detects model input/output naming conventions

  • OnnxGpuIO (src/onnx/onnx_gpu_io.ts)

  • Builds a WebGPU-only ORT InferenceSession using the app's GPUDevice (shared with the renderer)
  • Preallocates GPU buffers for all inputs and outputs with proper precision
  • Implements precision detection via PrecisionDetector
  • Implements runInference() which binds GPU buffers as feeds/fetches via Tensor.fromGpuBuffer(...) and executes the session without touching the CPU
  • Supports WebGPU graph capture (with fallback if unsupported)
  • Implements exclusive execution chain to prevent concurrent inference conflicts

  • PrecisionDetector (src/onnx/precision-detector.ts)

  • Automatically detects data types (float16, float32, int8, uint8) from model metadata
  • Falls back to output name-based detection (e.g., _f16, _f32, _i8 suffixes)
  • Extracts quantization parameters (scale, zeroPoint) for quantized models
  • Calculates buffer sizes with 16-byte alignment

Device sharing (critical)

initWebGPU_onnx() acquires ORT’s device—or forces ORT to materialize one via a dummy session—so the renderer and ORT use the same GPUDevice. OnnxGpuIO.init({ device }) then allocates every buffer on that device, allowing preprocess/sort/render stages to consume ONNX outputs directly.

GPU I/O Binding

Preallocated Outputs

Output buffers are allocated based on detected precision and dimensions:

  • gaussBuf: [maxPoints, 10] packed Gaussian data
  • Precision: Auto-detected (default float16, 2 bytes/element)
  • Usage: STORAGE | COPY_SRC | COPY_DST | VERTEX
  • Size: Calculated with 16-byte alignment

  • shBuf: [maxPoints, colorDim] color data (SH or RGB)

  • Precision: Auto-detected (default float16, 2 bytes/element)
  • colorDim: 48 for SH, 3 for RGB (auto-detected from metadata)
  • Usage: STORAGE | COPY_SRC | COPY_DST | VERTEX
  • Size: Calculated with 16-byte alignment

  • countBuf: Single i32 for dynamic point counts

  • Usage: STORAGE | COPY_SRC | COPY_DST
  • Size: 16 bytes (aligned)

Preallocated Inputs

Input buffers are allocated for graph capture compatibility:

  • cameraMatrixBuf: 4×4 float32 matrix (64 bytes, 16-byte aligned)
  • Usage: STORAGE | COPY_DST

  • projMatrixBuf: 4×4 float32 matrix (64 bytes, 16-byte aligned)

  • Usage: STORAGE | COPY_DST

  • timeBuf: Single float32 value (16 bytes, aligned)

  • Usage: STORAGE | COPY_DST

At runtime, runInference() binds these buffers in the feeds/fetches map; ORT writes output tensors straight into GPU memory with no CPU roundtrips.

Precision Detection

The system automatically detects precision from: 1. Model Metadata: ONNX Runtime session output metadata 2. Output Names: Suffix-based detection (_f16, _f32, _i8, _u8) 3. Manual Override: PrecisionConfig in ONNXLoadOptions

Supported precisions: - float16 (2 bytes/element) - Default, most common - float32 (4 bytes/element) - Higher precision - int8 (1 byte/element) - Quantized models - uint8 (1 byte/element) - Quantized models

Data flow

Camera / Time → updateInputBuffers() → session.run(feeds, fetches)
                          [gaussBuf, shBuf, countBuf]
DynamicPointCloud (GPU buffers) → Preprocess (reads countBuf) → Sort → Render
  • Preprocess honours countBuf (when present) to update ModelParams.num_points before sorting.
  • Sorting and rendering then issue a single indirect draw with instanceCount sourced from the sorter uniforms.

Static vs Dynamic Inference

Static Mode (staticInference: true)

  • Calls generate({}) once during initialization
  • The resulting DynamicPointCloud behaves like a static model
  • No per-frame updates
  • Useful for pre-computed or one-time generated models

Dynamic Mode (staticInference: false)

  • Initial generate({ cameraMatrix, projectionMatrix, time }) during loading
  • Per-frame generate() calls via AnimationManager.updateDynamicPointClouds()
  • Uses updated camera matrices and time to stream animation data
  • DynamicPointCloud is wired to ONNXGenerator for automatic updates
  • Enables real-time neural rendering and animation

Integration Flow: 1. ONNXManager.loadONNXModel() with staticInference: false 2. DynamicPointCloud.setOnnxGenerator(gen) wires the generator 3. AnimationManager.updateDynamicPointClouds() calls dpc.update() 4. DynamicPointCloud.update() calls gen.generate() with current camera/time 5. New data flows directly to GPU buffers for rendering

Diagnostics & Requirements

Session Creation

  • Uses WebGPU execution provider exclusively
  • Falls back from URL loading to ArrayBuffer fetch if URL fails
  • Supports WebGPU graph capture (with automatic fallback if unsupported)
  • Throws descriptive errors when initialization fails

Debugging Tools

  • ONNXTestUtils provides performance measurement and validation
  • ONNXModelTester for isolated model testing
  • Debug helpers can read back countBuf in development builds to verify model output counts
  • Precision information available via getGaussianPrecision() and getColorPrecision()

Requirements

  • ONNX Runtime WASM: Ensure ONNX Runtime wasm assets are served (default /src/ort/)
  • WebGPU Support: Requires WebGPU-capable browser and GPU
  • Shared Device: Uses app's WebGPU device for buffer compatibility
  • Model Format: ONNX models must output gaussian and color tensors with compatible shapes

Exclusive Execution

OnnxGpuIO implements a global execution chain (runExclusive) to prevent concurrent inference conflicts. All inference calls are serialized to avoid ORT WebGPU IOBinding session conflicts.

Precision System

The ONNX module includes a comprehensive precision detection and configuration system:

PrecisionMetadata

interface PrecisionMetadata {
  dataType: 'float32' | 'float16' | 'int8' | 'uint8';
  bytesPerElement: number;
  scale?: number;      // For quantized int8/uint8
  zeroPoint?: number;  // For quantized int8/uint8
}

PrecisionConfig

interface PrecisionConfig {
  gaussian?: Partial<PrecisionMetadata>;  // Override for gaussian output
  color?: Partial<PrecisionMetadata>;     // Override for color output
  autoDetect?: boolean;                    // Legacy flag (deprecated)
}

Detection Strategy

  1. Metadata First: Checks ONNX Runtime session output metadata for type information
  2. Name Fallback: Uses output name suffixes (_f16, _f32, _i8, _u8)
  3. Manual Override: Applies PrecisionConfig if provided
  4. Default: Falls back to float16 if detection fails

Buffer Size Calculation

All buffers are calculated with 16-byte alignment:

sizeInBytes = align16(elements * bytesPerElement)

This ensures optimal GPU memory access patterns.