1use glam::*;
2
3use wgpu::util::DeviceExt;
4
5use crate::{
6 BufferWrapper, DownloadBufferError, Gaussian, GaussianCov3dConfig, GaussianCov3dHalfConfig,
7 GaussianCov3dRotScaleConfig, GaussianCov3dSingleConfig, GaussianShConfig, GaussianShHalfConfig,
8 GaussianShNoneConfig, GaussianShNorm8Config, GaussianShSingleConfig,
9 GaussiansBufferTryFromBufferError, GaussiansBufferUpdateError, GaussiansBufferUpdateRangeError,
10 IterGaussian,
11};
12
13#[derive(Debug, Clone)]
17pub struct GaussiansBuffer<G: GaussianPod>(wgpu::Buffer, std::marker::PhantomData<G>);
18
19impl<G: GaussianPod> GaussiansBuffer<G> {
20 pub fn new(device: &wgpu::Device, gaussians: &impl IterGaussian) -> Self {
22 Self::new_with_pods(
23 device,
24 gaussians
25 .iter_gaussian()
26 .map(|g| G::from_gaussian(&g))
27 .collect::<Vec<_>>()
28 .as_slice(),
29 )
30 }
31
32 pub fn new_with_usage(
34 device: &wgpu::Device,
35 gaussians: &impl IterGaussian,
36 usage: wgpu::BufferUsages,
37 ) -> Self {
38 Self::new_with_pods_and_usage(
39 device,
40 gaussians
41 .iter_gaussian()
42 .map(|g| G::from_gaussian(&g))
43 .collect::<Vec<_>>()
44 .as_slice(),
45 usage,
46 )
47 }
48
49 pub fn new_with_pods(device: &wgpu::Device, gaussians: &[G]) -> Self {
51 Self::new_with_pods_and_usage(device, gaussians, Self::DEFAULT_USAGES)
52 }
53
54 pub fn new_with_pods_and_usage(
57 device: &wgpu::Device,
58 gaussians: &[G],
59 usage: wgpu::BufferUsages,
60 ) -> Self {
61 let buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
62 label: Some("Gaussians Buffer"),
63 contents: bytemuck::cast_slice(gaussians),
64 usage,
65 });
66
67 Self(buffer, std::marker::PhantomData)
68 }
69
70 pub fn new_empty(device: &wgpu::Device, len: usize) -> Self {
72 Self::new_empty_with_usage(device, len, Self::DEFAULT_USAGES)
73 }
74
75 pub fn new_empty_with_usage(
77 device: &wgpu::Device,
78 len: usize,
79 usage: wgpu::BufferUsages,
80 ) -> Self {
81 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
82 label: Some("Gaussians Buffer"),
83 size: (len * std::mem::size_of::<G>()) as wgpu::BufferAddress,
84 usage,
85 mapped_at_creation: false,
86 });
87
88 Self(buffer, std::marker::PhantomData)
89 }
90
91 pub fn len(&self) -> usize {
93 self.0.size() as usize / std::mem::size_of::<G>()
94 }
95
96 pub fn is_empty(&self) -> bool {
98 self.len() == 0
99 }
100
101 pub fn update(
105 &self,
106 queue: &wgpu::Queue,
107 gaussians: &impl IterGaussian,
108 ) -> Result<(), GaussiansBufferUpdateError> {
109 self.update_with_pod(
110 queue,
111 gaussians
112 .iter_gaussian()
113 .map(|g| G::from_gaussian(&g))
114 .collect::<Vec<_>>()
115 .as_slice(),
116 )
117 }
118
119 pub fn update_with_pod(
123 &self,
124 queue: &wgpu::Queue,
125 pods: &[G],
126 ) -> Result<(), GaussiansBufferUpdateError> {
127 if pods.len() != self.len() {
128 return Err(GaussiansBufferUpdateError::CountMismatch {
129 count: pods.len(),
130 expected_count: self.len(),
131 });
132 }
133
134 queue.write_buffer(&self.0, 0, bytemuck::cast_slice(pods));
135
136 Ok(())
137 }
138
139 pub fn update_range(
143 &self,
144 queue: &wgpu::Queue,
145 start: usize,
146 gaussians: &[Gaussian],
147 ) -> Result<(), GaussiansBufferUpdateRangeError> {
148 self.update_range_with_pod(
149 queue,
150 start,
151 gaussians
152 .iter()
153 .map(G::from_gaussian)
154 .collect::<Vec<_>>()
155 .as_slice(),
156 )
157 }
158
159 pub fn update_range_with_pod(
163 &self,
164 queue: &wgpu::Queue,
165 start: usize,
166 pods: &[G],
167 ) -> Result<(), GaussiansBufferUpdateRangeError> {
168 if start + pods.len() > self.len() {
169 return Err(GaussiansBufferUpdateRangeError::CountMismatch {
170 count: pods.len(),
171 start,
172 expected_count: self.len(),
173 });
174 }
175
176 queue.write_buffer(
177 &self.0,
178 (start * std::mem::size_of::<G>()) as wgpu::BufferAddress,
179 bytemuck::cast_slice(pods),
180 );
181
182 Ok(())
183 }
184
185 pub async fn download_gaussians(
187 &self,
188 device: &wgpu::Device,
189 queue: &wgpu::Queue,
190 ) -> Result<Vec<Gaussian>, DownloadBufferError> {
191 self.download::<G>(device, queue)
192 .await
193 .map(|pods| pods.into_iter().map(Into::into).collect::<Vec<_>>())
194 }
195}
196
197impl<G: GaussianPod> BufferWrapper for GaussiansBuffer<G> {
198 const DEFAULT_USAGES: wgpu::BufferUsages = wgpu::BufferUsages::from_bits_retain(
199 wgpu::BufferUsages::STORAGE.bits() | wgpu::BufferUsages::COPY_DST.bits(),
200 );
201
202 fn buffer(&self) -> &wgpu::Buffer {
203 &self.0
204 }
205}
206
207impl<G: GaussianPod> From<GaussiansBuffer<G>> for wgpu::Buffer {
208 fn from(wrapper: GaussiansBuffer<G>) -> Self {
209 wrapper.0
210 }
211}
212
213impl<G: GaussianPod> TryFrom<wgpu::Buffer> for GaussiansBuffer<G> {
214 type Error = GaussiansBufferTryFromBufferError;
215
216 fn try_from(buffer: wgpu::Buffer) -> Result<Self, Self::Error> {
217 if !buffer
218 .size()
219 .is_multiple_of(std::mem::size_of::<G>() as wgpu::BufferAddress)
220 {
221 return Err(GaussiansBufferTryFromBufferError::BufferSizeNotMultiple {
222 buffer_size: buffer.size(),
223 expected_multiple_size: std::mem::size_of::<G>() as wgpu::BufferAddress,
224 });
225 }
226
227 Ok(Self(buffer, std::marker::PhantomData))
228 }
229}
230
231pub trait GaussianPod:
240 for<'a> From<&'a Gaussian>
241 + Into<Gaussian>
242 + Send
243 + Sync
244 + std::fmt::Debug
245 + Clone
246 + Copy
247 + PartialEq
248 + bytemuck::NoUninit
249 + bytemuck::AnyBitPattern
250{
251 type ShConfig: GaussianShConfig;
253
254 type Cov3dConfig: GaussianCov3dConfig;
256
257 fn into_gaussian(self) -> Gaussian {
259 self.into()
260 }
261
262 fn from_gaussian(gaussian: &Gaussian) -> Self {
264 Self::from(gaussian)
265 }
266
267 fn features() -> [(&'static str, bool); 7] {
271 [
272 GaussianShSingleConfig::FEATURE,
273 GaussianShHalfConfig::FEATURE,
274 GaussianShNorm8Config::FEATURE,
275 GaussianShNoneConfig::FEATURE,
276 GaussianCov3dRotScaleConfig::FEATURE,
277 GaussianCov3dSingleConfig::FEATURE,
278 GaussianCov3dHalfConfig::FEATURE,
279 ]
280 .map(|name| {
281 (
282 name,
283 name == Self::ShConfig::FEATURE || name == Self::Cov3dConfig::FEATURE,
284 )
285 })
286 }
287
288 fn wesl_features() -> wesl::Features {
290 wesl::Features {
291 flags: Self::features()
292 .iter()
293 .map(|(name, enabled)| (name.to_string(), (*enabled).into()))
294 .collect(),
295 ..Default::default()
296 }
297 }
298}
299
300macro_rules! gaussian_pod {
302 (sh = $sh:ident, cov3d = $cov3d:ident, padding_size = $padding:expr) => {
303 paste::paste! {
304 #[repr(C)]
305 #[derive(Debug, Clone, Copy, PartialEq, bytemuck::Pod, bytemuck::Zeroable)]
306 pub struct [< GaussianPodWith Sh $sh Cov3d $cov3d Configs >] {
307 pub pos: Vec3,
308 pub color: U8Vec4,
309 pub sh: <[< GaussianSh $sh Config >] as GaussianShConfig>::Field,
310 pub cov3d: <[< GaussianCov3d $cov3d Config >] as GaussianCov3dConfig>::Field,
311 pub padding: [f32; $padding],
312 }
313
314 impl From<&Gaussian> for [< GaussianPodWith Sh $sh Cov3d $cov3d Configs >] {
315 fn from(gaussian: &Gaussian) -> Self {
316 let cov3d = <[< GaussianCov3d $cov3d Config >]>::from_rot_scale(
318 gaussian.rot,
319 gaussian.scale,
320 );
321
322 let color = gaussian.color;
324
325 let sh = [< GaussianSh $sh Config >]::from_sh(&gaussian.sh);
327
328 let pos = gaussian.pos;
330
331 Self {
332 pos,
333 color,
334 sh,
335 cov3d,
336 padding: [0.0; $padding],
337 }
338 }
339 }
340
341 impl From<[< GaussianPodWith Sh $sh Cov3d $cov3d Configs >]> for Gaussian {
342 fn from(pod: [< GaussianPodWith Sh $sh Cov3d $cov3d Configs >]) -> Self {
343 let pos = pod.pos;
345
346 let sh = [< GaussianSh $sh Config >]::to_sh(&pod.sh);
348
349 let color = pod.color;
351
352 let (rot, scale) = <[< GaussianCov3d $cov3d Config >]>::to_rot_scale(&pod.cov3d);
354
355 Self {
356 rot,
357 pos,
358 color,
359 sh,
360 scale,
361 }
362 }
363 }
364
365 impl GaussianPod for [< GaussianPodWith Sh $sh Cov3d $cov3d Configs >] {
366 type ShConfig = [< GaussianSh $sh Config >];
367 type Cov3dConfig = [< GaussianCov3d $cov3d Config >];
368 }
369 }
370 };
371}
372
373gaussian_pod!(sh = Single, cov3d = RotScale, padding_size = 0);
374gaussian_pod!(sh = Single, cov3d = Single, padding_size = 1);
375gaussian_pod!(sh = Single, cov3d = Half, padding_size = 0);
376gaussian_pod!(sh = Half, cov3d = RotScale, padding_size = 2);
377gaussian_pod!(sh = Half, cov3d = Single, padding_size = 3);
378gaussian_pod!(sh = Half, cov3d = Half, padding_size = 2);
379gaussian_pod!(sh = Norm8, cov3d = RotScale, padding_size = 1);
380gaussian_pod!(sh = Norm8, cov3d = Single, padding_size = 2);
381gaussian_pod!(sh = Norm8, cov3d = Half, padding_size = 1);
382gaussian_pod!(sh = None, cov3d = RotScale, padding_size = 1);
383gaussian_pod!(sh = None, cov3d = Single, padding_size = 2);
384gaussian_pod!(sh = None, cov3d = Half, padding_size = 1);
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 macro_rules! test_pod_from_gaussian {
391 ($name:ident, $pod_type:ty, true) => {
392 paste::paste! {
393 #[test]
394 #[should_panic]
395 fn [<test_ $name _into_gaussian_should_panic>]() {
396 let pod = $pod_type::from_gaussian(&Gaussian {
397 rot: Quat::from_xyzw(0.0, 0.0, 0.0, 1.0),
398 pos: Vec3::new(1.0, 2.0, 3.0),
399 color: U8Vec4::new(255, 128, 64, 32),
400 sh: [Vec3::new(0.1, 0.2, 0.3); 15],
401 scale: Vec3::new(1.0, 2.0, 3.0),
402 });
403
404 pod.into_gaussian();
405 }
406 }
407 };
408 ($name:ident, $pod_type:ty, false) => {
409 paste::paste! {
410 #[test]
411 fn [<test_ $name _into_gaussian_should_equal_original_pod>]() {
412 let pod = $pod_type::from_gaussian(&Gaussian {
413 rot: Quat::from_xyzw(0.0, 0.0, 0.0, 1.0),
414 pos: Vec3::new(1.0, 2.0, 3.0),
415 color: U8Vec4::new(255, 128, 64, 32),
416 sh: [Vec3::new(0.1, 0.2, 0.3); 15],
417 scale: Vec3::new(1.0, 2.0, 3.0),
418 });
419
420 let gaussian = pod.into_gaussian();
421
422 assert_eq!(pod.pos, gaussian.pos);
423 assert_eq!(pod.color, gaussian.color);
424 assert_eq!(
425 pod.sh,
426 <$pod_type as GaussianPod>::ShConfig::from_sh(&gaussian.sh),
427 );
428 assert_eq!(
429 pod.cov3d,
430 <$pod_type as GaussianPod>::Cov3dConfig::from_rot_scale(
431 gaussian.rot,
432 gaussian.scale
433 ),
434 );
435 }
436 }
437 };
438 }
439
440 macro_rules! test_pod {
441 ($name:ident, $pod_type:ty, $when_into_gaussian_should_panic:expr) => {
442 paste::paste! {
443 #[test]
444 fn [<test_ $name _from_gaussian_should_equal_original_gaussian>]() {
445 let gaussian = Gaussian {
446 rot: Quat::from_xyzw(0.0, 0.0, 0.0, 1.0),
447 pos: Vec3::new(1.0, 2.0, 3.0),
448 color: U8Vec4::new(255, 128, 64, 32),
449 sh: [Vec3::new(0.1, 0.2, 0.3); 15],
450 scale: Vec3::new(1.0, 2.0, 3.0),
451 };
452
453 let pod = $pod_type::from_gaussian(&gaussian);
454
455 assert_eq!(gaussian.pos, pod.pos);
456 assert_eq!(gaussian.color, pod.color);
457 assert_eq!(
458 <$pod_type as GaussianPod>::ShConfig::from_sh(&gaussian.sh),
459 pod.sh,
460 );
461 assert_eq!(
462 <$pod_type as GaussianPod>::Cov3dConfig::from_rot_scale(
463 gaussian.rot,
464 gaussian.scale
465 ),
466 pod.cov3d,
467 );
468 }
469
470 test_pod_from_gaussian!($name, $pod_type, $when_into_gaussian_should_panic);
471
472 #[test]
473 fn [<test_ $name _features_should_be_correct>]() {
474 let features = <$pod_type as GaussianPod>::features();
475
476 for (name, enabled) in features {
477 if name == <$pod_type as GaussianPod>::ShConfig::FEATURE
478 || name == <$pod_type as GaussianPod>::Cov3dConfig::FEATURE
479 {
480 assert!(enabled, "Feature {name} should be enabled");
481 } else {
482 assert!(!enabled, "Feature {name} should be disabled");
483 }
484 }
485 }
486
487 #[test]
488 fn [<test_ $name _wesl_features_should_be_correct>]() {
489 let wesl_features = <$pod_type as GaussianPod>::wesl_features();
490 let features = <$pod_type as GaussianPod>::features();
491
492 for (name, enabled) in features {
493 let wesl_enabled = wesl_features
494 .flags
495 .get(name)
496 .map(|v| *v == wesl::Feature::Enable)
497 .unwrap_or(false);
498
499 assert_eq!(
500 enabled, wesl_enabled,
501 "Feature {name} should be {}",
502 if enabled { "enabled" } else { "disabled" }
503 );
504 }
505 }
506 }
507 };
508 }
509
510 #[rustfmt::skip]
511 mod pod {
512 use super::*;
513
514 test_pod!(single_rotscale, GaussianPodWithShSingleCov3dRotScaleConfigs, false);
515 test_pod!(single_single, GaussianPodWithShSingleCov3dSingleConfigs, true);
516 test_pod!(single_half, GaussianPodWithShSingleCov3dHalfConfigs, true);
517 test_pod!(half_rotscale, GaussianPodWithShHalfCov3dRotScaleConfigs, false);
518 test_pod!(test_half_single, GaussianPodWithShHalfCov3dSingleConfigs, true);
519 test_pod!(test_half_half, GaussianPodWithShHalfCov3dHalfConfigs, true);
520 test_pod!(norm8_rotscale, GaussianPodWithShNorm8Cov3dRotScaleConfigs, false);
521 test_pod!(norm8_single, GaussianPodWithShNorm8Cov3dSingleConfigs, true);
522 test_pod!(norm8_half, GaussianPodWithShNorm8Cov3dHalfConfigs, true);
523 test_pod!(none_rotscale, GaussianPodWithShNoneCov3dRotScaleConfigs, true);
524 test_pod!(none_single, GaussianPodWithShNoneCov3dSingleConfigs, true);
525 test_pod!(none_half, GaussianPodWithShNoneCov3dHalfConfigs, true);
526 }
527}