Shaders Module Architecture
This document drills into the WGSL programs that run on the GPU. It focuses on how each shader is structured, how data flows between stages, and which binding groups/resources they expect.
Stage diagram
Gaussians (GPU buffers)
�? ├─ preprocess.wgsl -> splat buffer + sort keys/payloads + indirect counters
�? ├─ radix_sort.wgsl -> sorted payloads (indices) + updated indirect dispatch
�? ├─ gaussian.wgsl (VS/FS) -> color buffer (premultiplied RGBA)
�? └─ compositor (display shader lives outside this module)
Utility compute kernels (compress_gaussians.wgsl, convert_precision.wgsl, debug-helpers.wgsl) work on the same buffer layouts but are invoked on-demand by tooling/diagnostics.
Binding layout summary
Group 0 : Camera uniforms (view/proj + viewport/focal)
Group 1 : Point cloud data (packed gaussians, SH/RAW color, splat buffer)
Group 2 : Sorting buffers (SortInfos, depth keys, payloads, dispatch)
Group 3 : Render/model settings (RenderSettings, ModelParams)
The renderer forwards ModelParams (transform, base offset, per-model scaling, precision metadata) in group 3 when calling preprocess. gaussian.wgsl only needs groups 0�? (camera + splats) and the sorted payload indices.
Preprocess (preprocess.wgsl)
Key responsibilities
- Read packed Gaussians; format depends on uModel.gaussDataType (0=f32,1=f16,2=int8,3=uint8).
- Apply camera matrices to generate clip-space positions.
- Transform the 3×3 covariance into a 2×2 screen-space ellipse (via Jacobian J and camera rotation W).
- Evaluate color either via spherical harmonics (evaluate_sh) or direct RGB fetch (USE_RAW_COLOR).
- Write a Splat struct to the shared buffer and depth keys/payload indices to the sorter buffers.
- Increment sort_infos.keys_size, drawIndirect.instance_count, and DispatchIndirect.dispatch_x atomically.
Important structs
wgsl
struct CameraUniforms { view, view_inv, proj, proj_inv : mat4x4<f32>; viewport : vec2<f32>; focal : vec2<f32>; }
struct ModelParams { model : mat4x4<f32>; baseOffset : u32; num_points : u32; gaussianScaling : f32; maxShDeg : u32; ... precision fields }
struct Splat { v_0:u32; v_1:u32; pos:u32; posz:f32; color_0:u32; color_1:u32; }
uModel.num_points is overwritten by ONNX count buffers (GPU copy) so preprocess can early-out when gid.x >= num_points.
Overrides
- USE_RAW_COLOR (default alse) �?skip SH evaluation, treat the color buffer as RGB channels.
- SH_LAYOUT_CHANNEL_MAJOR (default alse) �?select channel-major layout for SH coefficients.
- DISCARD_BY_WORLD_TRACE and MAX_WORLD_TRACE �?optional guard to drop splats whose world covariance trace is too large.
Precision helpers
ead_gaussian_pos_opacity, ead_gaussian_cov, and ead_color_channel branch on uModel.gaussDataType/colorDataType. INT8 paths use colorScale/colorZeroPoint to dequantize data provided by the ONNX precision converter.
Radix sort (
adix_sort.wgsl)
- Constants such as histogram_wg_size, histogram_sg_size, s_radix_log2 are injected from TypeScript so they match the GPU.
- Four phases per pass: zero_histograms, calculate_histogram, prefix_histogram, scatter_even/scatter_odd.
- Each even/odd scatter handles two passes (0,2) and (1,3) respectively, ping-ponging key/payload buffers.
- Uses a look-back mechanism to coordinate partitions across workgroups while minimising atomics.
- SortInfos and DispatchIndirect are shared with preprocess/renderer via bind group 2.
Gaussian rasterizer (gaussian.wgsl)
Vertex shader: - Reads Splat record via points_2d[indices[instance_id]] (indices buffer comes from sorter payloads). - Generates four vertices per instance by building a screen-aligned quad scaled by eigenvectors (mat2(v1, v2) × CUTOFF). - Emits NDC position (with clamped z) and passes screen-space coordinates + color to the fragment shader.
Fragment shader: - Computes ² = dot(screen_pos, screen_pos). - Discards fragments where ² > 2*CUTOFF (tight bounding circle) to save bandwidth. - Evaluates Gaussian falloff exp(-r²) and multiplies by the alpha channel (capped at 0.99) to avoid fully-opaque artifacts. - Returns premultiplied color ec4(rgb,1) * weight for correct blending.
Utility kernels
- compress_gaussians.wgsl �?reads FP32 Gaussians and writes packed FP16/INT8 versions. Used by import tools or offline pipelines.
- convert_precision.wgsl �?similar to the compressor but operates on existing GPU buffers so ONNX-driven conversion can run entirely on the GPU.
- debug-helpers.wgsl �?small compute kernels for copying counters/buffers into staging areas for inspection (used by developer debug menus).
Binding updates
All shaders rely on consistent bind-group order. The renderer/preprocessor ensure that: - Group 0: camera uniforms (shared across stages). - Group 1: point cloud buffers (gaussians, SH/RAW color, splats). - Group 2: sort data (infos, depth keys, payloads, dispatch indirect). - Group 3: render/model settings (RenderSettings + ModelParams). gaussian.wgsl only needs group 0/1; sort uses group 0 only.
Indirect dispatch & draw counts are produced entirely on the GPU: 1. Preprocess increments sort_infos.keys_size and sort_dispatch.dispatch_x per 256*15 splats. 2. Radix sort consumes sort_dispatch via dispatchWorkgroupsIndirect. 3. Renderer copies sort_infos.keys_size into the draw-indirect buffer before calling pass.drawIndirect.
This architecture keeps CPU-GPU sync minimal: once buffers are allocated, the GPU decides how many workgroups and instances to launch every frame.