wgpu_3dgs_core/
gaussian_config.rs1use glam::*;
2use half::f16;
3
4pub trait GaussianShConfig {
16 const FEATURE: &'static str;
20
21 type Field: bytemuck::Pod + bytemuck::Zeroable;
23
24 fn from_sh(sh: &[Vec3; 15]) -> Self::Field;
26
27 fn to_sh(field: &Self::Field) -> [Vec3; 15];
29}
30
31pub struct GaussianShSingleConfig;
33
34impl GaussianShConfig for GaussianShSingleConfig {
35 const FEATURE: &'static str = "sh_single";
36
37 type Field = [Vec3; 15];
38
39 fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
40 *sh
41 }
42
43 fn to_sh(field: &Self::Field) -> [Vec3; 15] {
44 *field
45 }
46}
47
48pub struct GaussianShHalfConfig;
50
51impl GaussianShConfig for GaussianShHalfConfig {
52 const FEATURE: &'static str = "sh_half";
53
54 type Field = [f16; 3 * 15 + 1];
55
56 fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
57 sh.iter()
58 .flat_map(|sh| sh.to_array())
59 .map(f16::from_f32)
60 .chain(std::iter::once(f16::from_f32(0.0)))
61 .collect::<Vec<_>>()
62 .try_into()
63 .expect("SH half")
64 }
65
66 fn to_sh(field: &Self::Field) -> [Vec3; 15] {
67 field
68 .chunks_exact(3)
69 .map(|chunk| {
70 Vec3::new(
71 f16::to_f32(chunk[0]),
72 f16::to_f32(chunk[1]),
73 f16::to_f32(chunk[2]),
74 )
75 })
76 .collect::<Vec<_>>()
77 .try_into()
78 .expect("SH half")
79 }
80}
81
82pub struct GaussianShNorm8Config;
86
87impl GaussianShConfig for GaussianShNorm8Config {
88 const FEATURE: &'static str = "sh_norm8";
89
90 type Field = [i8; 3 * 15 + 3];
91
92 fn from_sh(sh: &[Vec3; 15]) -> Self::Field {
93 sh.iter()
94 .flat_map(|sh| sh.to_array())
95 .map(|v| (v * 127.0).clamp(-127.0, 127.0) as i8)
96 .chain(std::iter::repeat_n(0, 3))
97 .collect::<Vec<_>>()
98 .try_into()
99 .expect("SH norm8")
100 }
101
102 fn to_sh(field: &Self::Field) -> [Vec3; 15] {
103 field
104 .chunks_exact(3)
105 .take(15)
106 .map(|chunk| {
107 Vec3::new(
108 ((chunk[0] as f32) / 127.0).max(-1.0),
109 ((chunk[1] as f32) / 127.0).max(-1.0),
110 ((chunk[2] as f32) / 127.0).max(-1.0),
111 )
112 })
113 .collect::<Vec<_>>()
114 .try_into()
115 .expect("SH norm8")
116 }
117}
118
119pub struct GaussianShNoneConfig;
123
124impl GaussianShConfig for GaussianShNoneConfig {
125 const FEATURE: &'static str = "sh_none";
126
127 type Field = ();
128
129 fn from_sh(_sh: &[Vec3; 15]) -> Self::Field {}
130
131 fn to_sh(_field: &Self::Field) -> [Vec3; 15] {
132 panic!("Cannot convert from SH None configuration")
133 }
134}
135
136pub trait GaussianCov3dConfig {
148 const FEATURE: &'static str;
152
153 type Field: bytemuck::Pod + bytemuck::Zeroable;
155
156 fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field;
158
159 fn to_rot_scale(field: &Self::Field) -> (Quat, Vec3);
161}
162
163pub struct GaussianCov3dRotScaleConfig;
167
168impl GaussianCov3dConfig for GaussianCov3dRotScaleConfig {
169 const FEATURE: &'static str = "cov3d_rot_scale";
170
171 type Field = [f32; 7]; fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
174 [rot.x, rot.y, rot.z, rot.w, scale.x, scale.y, scale.z]
175 }
176
177 fn to_rot_scale(field: &Self::Field) -> (Quat, Vec3) {
178 (
179 Quat::from_xyzw(field[0], field[1], field[2], field[3]),
180 Vec3::new(field[4], field[5], field[6]),
181 )
182 }
183}
184
185pub struct GaussianCov3dSingleConfig;
189
190impl GaussianCov3dConfig for GaussianCov3dSingleConfig {
191 const FEATURE: &'static str = "cov3d_single";
192
193 type Field = [f32; 6];
194
195 fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
196 let r = Mat3::from_quat(rot);
197 let s = Mat3::from_diagonal(scale);
198 let m = r * s;
199 let sigma = m * m.transpose();
200
201 [
202 sigma.x_axis.x,
203 sigma.x_axis.y,
204 sigma.x_axis.z,
205 sigma.y_axis.y,
206 sigma.y_axis.z,
207 sigma.z_axis.z,
208 ]
209 }
210
211 fn to_rot_scale(_field: &Self::Field) -> (Quat, Vec3) {
212 panic!("Cannot convert from Cov3d Single configuration")
213 }
214}
215
216pub struct GaussianCov3dHalfConfig;
220
221impl GaussianCov3dConfig for GaussianCov3dHalfConfig {
222 const FEATURE: &'static str = "cov3d_half";
223
224 type Field = [f16; 6];
225
226 fn from_rot_scale(rot: Quat, scale: Vec3) -> Self::Field {
227 GaussianCov3dSingleConfig::from_rot_scale(rot, scale).map(f16::from_f32)
228 }
229
230 fn to_rot_scale(_field: &Self::Field) -> (Quat, Vec3) {
231 panic!("Cannot convert from Cov3d Half configuration")
232 }
233}