wgpu_3dgs_core/
gaussian_config.rs

1use glam::*;
2use half::f16;
3
4/// The spherical harmonics configuration of Gaussian.
5///
6/// Currently, there are four configurations:
7/// - Single precision [`GaussianShSingleConfig`](crate::GaussianShSingleConfig)
8///     - Format: 15 * [`Vec3`]
9/// - Half precision [`GaussianShHalfConfig`](crate::GaussianShHalfConfig)
10///     - Format: (15 * 3 + 1) * [`struct@f16`]
11/// - 8 bit normalized [`GaussianShNorm8Config`](crate::GaussianShNorm8Config)
12///     - Format: (15 * 3 + 3) * [`prim@i8`]
13/// - None [`GaussianShNoneConfig`](crate::GaussianShNoneConfig)
14///    - Cannot be converted back to SH
15pub trait GaussianShConfig {
16    /// The feature name of the configuration.
17    ///
18    /// Must match the [`wesl::Feature`] name in the shader.
19    const FEATURE: &'static str;
20
21    /// The [`GaussianPod`](crate::GaussianPod) field type.
22    type Field: bytemuck::Pod + bytemuck::Zeroable;
23
24    /// Create from [`Gaussian::sh`](crate::Gaussian::sh).
25    fn from_sh(sh: &[Vec3; 15]) -> Self::Field;
26
27    /// Convert the field to [`Gaussian::sh`](crate::Gaussian::sh).
28    fn to_sh(field: &Self::Field) -> [Vec3; 15];
29}
30
31/// The single precision SH configuration of Gaussian.
32pub struct GaussianShSingleConfig;
33
34impl GaussianShConfig for GaussianShSingleConfig {
35    const FEATURE: &'static str = "sh_single";
36
37    type Field = [Vec3; 15];
38
39    fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
40        *sh
41    }
42
43    fn to_sh(field: &Self::Field) -> [Vec3; 15] {
44        *field
45    }
46}
47
48/// The half precision SH configuration of Gaussian.
49pub struct GaussianShHalfConfig;
50
51impl GaussianShConfig for GaussianShHalfConfig {
52    const FEATURE: &'static str = "sh_half";
53
54    type Field = [f16; 3 * 15 + 1];
55
56    fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
57        sh.iter()
58            .flat_map(|sh| sh.to_array())
59            .map(f16::from_f32)
60            .chain(std::iter::once(f16::from_f32(0.0)))
61            .collect::<Vec<_>>()
62            .try_into()
63            .expect("SH half")
64    }
65
66    fn to_sh(field: &Self::Field) -> [Vec3; 15] {
67        field
68            .chunks_exact(3)
69            .map(|chunk| {
70                Vec3::new(
71                    f16::to_f32(chunk[0]),
72                    f16::to_f32(chunk[1]),
73                    f16::to_f32(chunk[2]),
74                )
75            })
76            .collect::<Vec<_>>()
77            .try_into()
78            .expect("SH half")
79    }
80}
81
82/// The 8 bit signed normalized SH configuration of Gaussian.
83///
84/// This is by the fact that SH coefficients are within \[-1, 1\].
85pub struct GaussianShNorm8Config;
86
87impl GaussianShConfig for GaussianShNorm8Config {
88    const FEATURE: &'static str = "sh_norm8";
89
90    type Field = [i8; 3 * 15 + 3];
91
92    fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
93        sh.iter()
94            .flat_map(|sh| sh.to_array())
95            .map(|v| (v * 127.0).clamp(-127.0, 127.0) as i8)
96            .chain(std::iter::repeat_n(0, 3))
97            .collect::<Vec<_>>()
98            .try_into()
99            .expect("SH norm8")
100    }
101
102    fn to_sh(field: &Self::Field) -> [Vec3; 15] {
103        field
104            .chunks_exact(3)
105            .take(15)
106            .map(|chunk| {
107                Vec3::new(
108                    ((chunk[0] as f32) / 127.0).max(-1.0),
109                    ((chunk[1] as f32) / 127.0).max(-1.0),
110                    ((chunk[2] as f32) / 127.0).max(-1.0),
111                )
112            })
113            .collect::<Vec<_>>()
114            .try_into()
115            .expect("SH norm8")
116    }
117}
118
119/// The none SH configuration of Gaussian.
120///
121/// Calling [`GaussianShConfig::to_sh`] will panic on this config.
122pub struct GaussianShNoneConfig;
123
124impl GaussianShConfig for GaussianShNoneConfig {
125    const FEATURE: &'static str = "sh_none";
126
127    type Field = ();
128
129    fn from_sh(_sh: &[Vec3; 15]) -> Self::Field {}
130
131    fn to_sh(_field: &Self::Field) -> [Vec3; 15] {
132        panic!("Cannot convert from SH None configuration")
133    }
134}
135
136/// The covariance 3D configuration of Gaussian.
137///
138/// Currently, there are three configurations:
139/// - Rotation and scale [`GaussianCov3dRotScaleConfig`](crate::GaussianCov3dRotScaleConfig)
140///     - Format: [`Quat`] + [`Vec3`]
141/// - Single precision [`GaussianCov3dSingleConfig`](crate::GaussianCov3dSingleConfig)
142///     - Format: 6 * [`prim@f32`]
143///     - Cannot be converted back to rotation and scale
144/// - Half precision [`GaussianCov3dHalfConfig`](crate::GaussianCov3dHalfConfig)
145///     - Format: 6 * [`struct@f16`]
146///     - Cannot be converted back to rotation and scale
147pub trait GaussianCov3dConfig {
148    /// The name of the configuration.
149    ///
150    /// Must match the [`wesl::Feature`] name in the shader.
151    const FEATURE: &'static str;
152
153    /// The [`GaussianPod`](crate::GaussianPod) field type.
154    type Field: bytemuck::Pod + bytemuck::Zeroable;
155
156    /// Create from [`Gaussian::rot`](crate::Gaussian::rot) and [`Gaussian::scale`](crate::Gaussian::scale).
157    fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field;
158
159    /// Convert the field to [`Gaussian::rot`](crate::Gaussian::rot) and [`Gaussian::scale`](crate::Gaussian::scale).
160    fn to_rot_scale(field: &Self::Field) -> (Quat, Vec3);
161}
162
163/// The unconverted rotation and scale covariance 3D configuration of Gaussian.
164///
165/// Instead of storing the covariance matrix, this config stores the rotation and scale directly.
166pub struct GaussianCov3dRotScaleConfig;
167
168impl GaussianCov3dConfig for GaussianCov3dRotScaleConfig {
169    const FEATURE: &'static str = "cov3d_rot_scale";
170
171    type Field = [f32; 7]; // (rot: [f32; 4], scale: [f32; 3])
172
173    fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
174        [rot.x, rot.y, rot.z, rot.w, scale.x, scale.y, scale.z]
175    }
176
177    fn to_rot_scale(field: &Self::Field) -> (Quat, Vec3) {
178        (
179            Quat::from_xyzw(field[0], field[1], field[2], field[3]),
180            Vec3::new(field[4], field[5], field[6]),
181        )
182    }
183}
184
185/// The single precision covariance 3D configuration of Gaussian.
186///
187/// Calling [`GaussianCov3dConfig::to_rot_scale`] will panic on this config.
188pub struct GaussianCov3dSingleConfig;
189
190impl GaussianCov3dConfig for GaussianCov3dSingleConfig {
191    const FEATURE: &'static str = "cov3d_single";
192
193    type Field = [f32; 6];
194
195    fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
196        let r = Mat3::from_quat(rot);
197        let s = Mat3::from_diagonal(scale);
198        let m = r * s;
199        let sigma = m * m.transpose();
200
201        [
202            sigma.x_axis.x,
203            sigma.x_axis.y,
204            sigma.x_axis.z,
205            sigma.y_axis.y,
206            sigma.y_axis.z,
207            sigma.z_axis.z,
208        ]
209    }
210
211    fn to_rot_scale(_field: &Self::Field) -> (Quat, Vec3) {
212        panic!("Cannot convert from Cov3d Single configuration")
213    }
214}
215
216/// The half precision covariance 3D configuration of Gaussian.
217///
218/// Calling [`GaussianCov3dConfig::to_rot_scale`] will panic on this config.
219pub struct GaussianCov3dHalfConfig;
220
221impl GaussianCov3dConfig for GaussianCov3dHalfConfig {
222    const FEATURE: &'static str = "cov3d_half";
223
224    type Field = [f16; 6];
225
226    fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
227        GaussianCov3dSingleConfig::from_rot_scale(rot, scale).map(f16::from_f32)
228    }
229
230    fn to_rot_scale(_field: &Self::Field) -> (Quat, Vec3) {
231        panic!("Cannot convert from Cov3d Half configuration")
232    }
233}