ONNX Module Architecture
This document describes the GPU-only ONNX execution path and how it feeds Gaussian Splatting without CPU round trips. The architecture supports automatic precision detection, static and dynamic inference modes, and seamless integration with the rendering pipeline.
Layered Architecture
ONNXManager (management layer)
├─ Model lifecycle & cleanup
├─ Integration with ModelManager / DynamicPointCloud
├─ Static vs dynamic inference control
└─ Resource tracking and disposal
ONNXGenerator (generator layer)
├─ Facade over OnnxGpuIO
├─ Input/output buffer exposure
├─ Metadata detection (color mode, capacity, input names)
├─ Precision information access
└─ Simplified API for inference
OnnxGpuIO (I/O layer)
├─ WebGPU-only InferenceSession (preferredOutputLocation='gpu-buffer')
├─ GPU buffer preallocation (shared device)
├─ Precision detection and configuration
├─ Inference execution & binding
└─ Graph capture support (optional)
PrecisionDetector (detection layer)
├─ Automatic precision detection from metadata
├─ Output name-based precision inference
└─ Quantization parameter extraction
Core Components
- ONNXManager (
src/app/managers/onnx-manager.ts) - Loads ONNX models through
ONNXGeneratorand creates dedicatedDynamicPointCloudinstances - Detects color mode (RGB vs SH) and forwards metadata to the point cloud
- Supports
staticInferenceand dynamic per-frame inference viaAnimationManager - Manages resource lifetime and exposes diagnostics
- Supports precision configuration via
PrecisionConfig -
Tracks generators and point clouds for proper cleanup
-
ONNXGenerator (
src/onnx/onnx_generator.ts) - Thin façade providing
initialize(device?)andgenerate(inputs?) - Exposes GPU buffers (
getGaussianBuffer(),getSHBuffer(),getCountBuffer()) - Provides metadata access (capacity, color mode, color dimensions, input names)
- Exposes precision information (
getGaussianPrecision(),getColorPrecision()) -
Automatically detects model input/output naming conventions
-
OnnxGpuIO (
src/onnx/onnx_gpu_io.ts) - Builds a WebGPU-only ORT
InferenceSessionusing the app'sGPUDevice(shared with the renderer) - Preallocates GPU buffers for all inputs and outputs with proper precision
- Implements precision detection via
PrecisionDetector - Implements
runInference()which binds GPU buffers as feeds/fetches viaTensor.fromGpuBuffer(...)and executes the session without touching the CPU - Supports WebGPU graph capture (with fallback if unsupported)
-
Implements exclusive execution chain to prevent concurrent inference conflicts
-
PrecisionDetector (
src/onnx/precision-detector.ts) - Automatically detects data types (float16, float32, int8, uint8) from model metadata
- Falls back to output name-based detection (e.g.,
_f16,_f32,_i8suffixes) - Extracts quantization parameters (scale, zeroPoint) for quantized models
- Calculates buffer sizes with 16-byte alignment
Device sharing (critical)
initWebGPU_onnx() acquires ORT’s device—or forces ORT to materialize one via a dummy session—so the renderer and ORT use the same GPUDevice. OnnxGpuIO.init({ device }) then allocates every buffer on that device, allowing preprocess/sort/render stages to consume ONNX outputs directly.
GPU I/O Binding
Preallocated Outputs
Output buffers are allocated based on detected precision and dimensions:
gaussBuf:[maxPoints, 10]packed Gaussian data- Precision: Auto-detected (default float16, 2 bytes/element)
- Usage:
STORAGE | COPY_SRC | COPY_DST | VERTEX -
Size: Calculated with 16-byte alignment
-
shBuf:[maxPoints, colorDim]color data (SH or RGB) - Precision: Auto-detected (default float16, 2 bytes/element)
colorDim: 48 for SH, 3 for RGB (auto-detected from metadata)- Usage:
STORAGE | COPY_SRC | COPY_DST | VERTEX -
Size: Calculated with 16-byte alignment
-
countBuf: Singlei32for dynamic point counts - Usage:
STORAGE | COPY_SRC | COPY_DST - Size: 16 bytes (aligned)
Preallocated Inputs
Input buffers are allocated for graph capture compatibility:
cameraMatrixBuf: 4×4 float32 matrix (64 bytes, 16-byte aligned)-
Usage:
STORAGE | COPY_DST -
projMatrixBuf: 4×4 float32 matrix (64 bytes, 16-byte aligned) -
Usage:
STORAGE | COPY_DST -
timeBuf: Single float32 value (16 bytes, aligned) - Usage:
STORAGE | COPY_DST
At runtime, runInference() binds these buffers in the feeds/fetches map; ORT writes output tensors straight into GPU memory with no CPU roundtrips.
Precision Detection
The system automatically detects precision from:
1. Model Metadata: ONNX Runtime session output metadata
2. Output Names: Suffix-based detection (_f16, _f32, _i8, _u8)
3. Manual Override: PrecisionConfig in ONNXLoadOptions
Supported precisions:
- float16 (2 bytes/element) - Default, most common
- float32 (4 bytes/element) - Higher precision
- int8 (1 byte/element) - Quantized models
- uint8 (1 byte/element) - Quantized models
Data flow
Camera / Time → updateInputBuffers() → session.run(feeds, fetches)
↓
[gaussBuf, shBuf, countBuf]
↓
DynamicPointCloud (GPU buffers) → Preprocess (reads countBuf) → Sort → Render
- Preprocess honours
countBuf(when present) to updateModelParams.num_pointsbefore sorting. - Sorting and rendering then issue a single indirect draw with instanceCount sourced from the sorter uniforms.
Static vs Dynamic Inference
Static Mode (staticInference: true)
- Calls
generate({})once during initialization - The resulting
DynamicPointCloudbehaves like a static model - No per-frame updates
- Useful for pre-computed or one-time generated models
Dynamic Mode (staticInference: false)
- Initial
generate({ cameraMatrix, projectionMatrix, time })during loading - Per-frame
generate()calls viaAnimationManager.updateDynamicPointClouds() - Uses updated camera matrices and time to stream animation data
DynamicPointCloudis wired toONNXGeneratorfor automatic updates- Enables real-time neural rendering and animation
Integration Flow:
1. ONNXManager.loadONNXModel() with staticInference: false
2. DynamicPointCloud.setOnnxGenerator(gen) wires the generator
3. AnimationManager.updateDynamicPointClouds() calls dpc.update()
4. DynamicPointCloud.update() calls gen.generate() with current camera/time
5. New data flows directly to GPU buffers for rendering
Diagnostics & Requirements
Session Creation
- Uses WebGPU execution provider exclusively
- Falls back from URL loading to ArrayBuffer fetch if URL fails
- Supports WebGPU graph capture (with automatic fallback if unsupported)
- Throws descriptive errors when initialization fails
Debugging Tools
ONNXTestUtilsprovides performance measurement and validationONNXModelTesterfor isolated model testing- Debug helpers can read back
countBufin development builds to verify model output counts - Precision information available via
getGaussianPrecision()andgetColorPrecision()
Requirements
- ONNX Runtime WASM: Ensure ONNX Runtime wasm assets are served (default
/src/ort/) - WebGPU Support: Requires WebGPU-capable browser and GPU
- Shared Device: Uses app's WebGPU device for buffer compatibility
- Model Format: ONNX models must output gaussian and color tensors with compatible shapes
Exclusive Execution
OnnxGpuIO implements a global execution chain (runExclusive) to prevent concurrent inference conflicts. All inference calls are serialized to avoid ORT WebGPU IOBinding session conflicts.
Precision System
The ONNX module includes a comprehensive precision detection and configuration system:
PrecisionMetadata
interface PrecisionMetadata {
dataType: 'float32' | 'float16' | 'int8' | 'uint8';
bytesPerElement: number;
scale?: number; // For quantized int8/uint8
zeroPoint?: number; // For quantized int8/uint8
}
PrecisionConfig
interface PrecisionConfig {
gaussian?: Partial<PrecisionMetadata>; // Override for gaussian output
color?: Partial<PrecisionMetadata>; // Override for color output
autoDetect?: boolean; // Legacy flag (deprecated)
}
Detection Strategy
- Metadata First: Checks ONNX Runtime session output metadata for type information
- Name Fallback: Uses output name suffixes (
_f16,_f32,_i8,_u8) - Manual Override: Applies
PrecisionConfigif provided - Default: Falls back to float16 if detection fails
Buffer Size Calculation
All buffers are calculated with 16-byte alignment:
This ensures optimal GPU memory access patterns.