wgpu_3dgs_core/buffer/
gaussian.rs

1use glam::*;
2
3use wgpu::util::DeviceExt;
4
5use crate::{
6    BufferWrapper, DownloadBufferError, Gaussian, GaussianCov3dConfig, GaussianCov3dHalfConfig,
7    GaussianCov3dRotScaleConfig, GaussianCov3dSingleConfig, GaussianShConfig, GaussianShHalfConfig,
8    GaussianShNoneConfig, GaussianShNorm8Config, GaussianShSingleConfig,
9    GaussiansBufferTryFromBufferError, GaussiansBufferUpdateError, GaussiansBufferUpdateRangeError,
10    IterGaussian,
11};
12
13/// The Gaussians storage buffer.
14///
15/// This buffer holds an array of Gaussians represented by the specified [`GaussianPod`].
16#[derive(Debug, Clone)]
17pub struct GaussiansBuffer<G: GaussianPod>(wgpu::Buffer, std::marker::PhantomData<G>);
18
19impl<G: GaussianPod> GaussiansBuffer<G> {
20    /// Create a new Gaussians buffer.
21    pub fn new(device: &wgpu::Device, gaussians: &impl IterGaussian) -> Self {
22        Self::new_with_pods(
23            device,
24            gaussians
25                .iter_gaussian()
26                .map(|g| G::from_gaussian(&g))
27                .collect::<Vec<_>>()
28                .as_slice(),
29        )
30    }
31
32    /// Create a new Gaussians buffer with the specified size with [`wgpu::BufferUsages`].
33    pub fn new_with_usage(
34        device: &wgpu::Device,
35        gaussians: &impl IterGaussian,
36        usage: wgpu::BufferUsages,
37    ) -> Self {
38        Self::new_with_pods_and_usage(
39            device,
40            gaussians
41                .iter_gaussian()
42                .map(|g| G::from_gaussian(&g))
43                .collect::<Vec<_>>()
44                .as_slice(),
45            usage,
46        )
47    }
48
49    /// Create a new Gaussians buffer with [`GaussianPod`].
50    pub fn new_with_pods(device: &wgpu::Device, gaussians: &[G]) -> Self {
51        Self::new_with_pods_and_usage(device, gaussians, Self::DEFAULT_USAGES)
52    }
53
54    /// Create a new Gaussians buffer with [`GaussianPod`] and the specified size and
55    /// [`wgpu::BufferUsages`].
56    pub fn new_with_pods_and_usage(
57        device: &wgpu::Device,
58        gaussians: &[G],
59        usage: wgpu::BufferUsages,
60    ) -> Self {
61        let buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
62            label: Some("Gaussians Buffer"),
63            contents: bytemuck::cast_slice(gaussians),
64            usage,
65        });
66
67        Self(buffer, std::marker::PhantomData)
68    }
69
70    /// Create a new Gaussians buffer with the specified size.
71    pub fn new_empty(device: &wgpu::Device, len: usize) -> Self {
72        Self::new_empty_with_usage(device, len, Self::DEFAULT_USAGES)
73    }
74
75    /// Create a new Gaussians buffer with the specified size and [`wgpu::BufferUsages`].
76    pub fn new_empty_with_usage(
77        device: &wgpu::Device,
78        len: usize,
79        usage: wgpu::BufferUsages,
80    ) -> Self {
81        let buffer = device.create_buffer(&wgpu::BufferDescriptor {
82            label: Some("Gaussians Buffer"),
83            size: (len * std::mem::size_of::<G>()) as wgpu::BufferAddress,
84            usage,
85            mapped_at_creation: false,
86        });
87
88        Self(buffer, std::marker::PhantomData)
89    }
90
91    /// Get the number of Gaussians.
92    pub fn len(&self) -> usize {
93        self.0.size() as usize / std::mem::size_of::<G>()
94    }
95
96    /// Check if the buffer is empty.
97    pub fn is_empty(&self) -> bool {
98        self.len() == 0
99    }
100
101    /// Update the buffer.
102    ///
103    /// `gaussians` should have the same number of Gaussians as the buffer.
104    pub fn update(
105        &self,
106        queue: &wgpu::Queue,
107        gaussians: &impl IterGaussian,
108    ) -> Result<(), GaussiansBufferUpdateError> {
109        self.update_with_pod(
110            queue,
111            gaussians
112                .iter_gaussian()
113                .map(|g| G::from_gaussian(&g))
114                .collect::<Vec<_>>()
115                .as_slice(),
116        )
117    }
118
119    /// Update the buffer with [`GaussianPod`].
120    ///
121    /// `pods` should have the same number of Gaussians as the buffer.
122    pub fn update_with_pod(
123        &self,
124        queue: &wgpu::Queue,
125        pods: &[G],
126    ) -> Result<(), GaussiansBufferUpdateError> {
127        if pods.len() != self.len() {
128            return Err(GaussiansBufferUpdateError::CountMismatch {
129                count: pods.len(),
130                expected_count: self.len(),
131            });
132        }
133
134        queue.write_buffer(&self.0, 0, bytemuck::cast_slice(pods));
135
136        Ok(())
137    }
138
139    /// Update a range of the buffer.
140    ///
141    /// `gaussians` should fit in the buffer starting from `start`.
142    pub fn update_range(
143        &self,
144        queue: &wgpu::Queue,
145        start: usize,
146        gaussians: &[Gaussian],
147    ) -> Result<(), GaussiansBufferUpdateRangeError> {
148        self.update_range_with_pod(
149            queue,
150            start,
151            gaussians
152                .iter()
153                .map(G::from_gaussian)
154                .collect::<Vec<_>>()
155                .as_slice(),
156        )
157    }
158
159    /// Update a range of the buffer with [`GaussianPod`].
160    ///
161    /// `pods` should fit in the buffer starting from `start`.
162    pub fn update_range_with_pod(
163        &self,
164        queue: &wgpu::Queue,
165        start: usize,
166        pods: &[G],
167    ) -> Result<(), GaussiansBufferUpdateRangeError> {
168        if start + pods.len() > self.len() {
169            return Err(GaussiansBufferUpdateRangeError::CountMismatch {
170                count: pods.len(),
171                start,
172                expected_count: self.len(),
173            });
174        }
175
176        queue.write_buffer(
177            &self.0,
178            (start * std::mem::size_of::<G>()) as wgpu::BufferAddress,
179            bytemuck::cast_slice(pods),
180        );
181
182        Ok(())
183    }
184
185    /// Download the buffer data into a [`Vec`] of [`Gaussian`].
186    pub async fn download_gaussians(
187        &self,
188        device: &wgpu::Device,
189        queue: &wgpu::Queue,
190    ) -> Result<Vec<Gaussian>, DownloadBufferError> {
191        self.download::<G>(device, queue)
192            .await
193            .map(|pods| pods.into_iter().map(Into::into).collect::<Vec<_>>())
194    }
195}
196
197impl<G: GaussianPod> BufferWrapper for GaussiansBuffer<G> {
198    const DEFAULT_USAGES: wgpu::BufferUsages = wgpu::BufferUsages::from_bits_retain(
199        wgpu::BufferUsages::STORAGE.bits() | wgpu::BufferUsages::COPY_DST.bits(),
200    );
201
202    fn buffer(&self) -> &wgpu::Buffer {
203        &self.0
204    }
205}
206
207impl<G: GaussianPod> From<GaussiansBuffer<G>> for wgpu::Buffer {
208    fn from(wrapper: GaussiansBuffer<G>) -> Self {
209        wrapper.0
210    }
211}
212
213impl<G: GaussianPod> TryFrom<wgpu::Buffer> for GaussiansBuffer<G> {
214    type Error = GaussiansBufferTryFromBufferError;
215
216    fn try_from(buffer: wgpu::Buffer) -> Result<Self, Self::Error> {
217        if !buffer
218            .size()
219            .is_multiple_of(std::mem::size_of::<G>() as wgpu::BufferAddress)
220        {
221            return Err(GaussiansBufferTryFromBufferError::BufferSizeNotMultiple {
222                buffer_size: buffer.size(),
223                expected_multiple_size: std::mem::size_of::<G>() as wgpu::BufferAddress,
224            });
225        }
226
227        Ok(Self(buffer, std::marker::PhantomData))
228    }
229}
230
231/// The Gaussian POD trait.
232///
233/// The number of configurations for this is the combination of all the [`GaussianShConfig`]
234/// and [`GaussianCov3dConfig`].
235///
236/// You can use the corresponding config by using the name in the following format:
237/// `GaussianPodWithSh{ShConfig}Cov3d{Cov3dConfig}Configs`, e.g.
238/// [`GaussianPodWithShSingleCov3dRotScaleConfigs`].
239pub trait GaussianPod:
240    for<'a> From<&'a Gaussian>
241    + Into<Gaussian>
242    + Send
243    + Sync
244    + std::fmt::Debug
245    + Clone
246    + Copy
247    + PartialEq
248    + bytemuck::NoUninit
249    + bytemuck::AnyBitPattern
250{
251    /// The SH configuration.
252    type ShConfig: GaussianShConfig;
253
254    /// The covariance 3D configuration.
255    type Cov3dConfig: GaussianCov3dConfig;
256
257    /// Convert from POD to Gaussian.
258    fn into_gaussian(self) -> Gaussian {
259        self.into()
260    }
261
262    /// Create a new Gaussian POD from the Gaussian.
263    fn from_gaussian(gaussian: &Gaussian) -> Self {
264        Self::from(gaussian)
265    }
266
267    /// Create the features for [`Wesl`](wesl::Wesl) compilation.
268    ///
269    /// You may want to use [`GaussianPod::wesl_features`] most of the time instead.
270    fn features() -> [(&'static str, bool); 7] {
271        [
272            GaussianShSingleConfig::FEATURE,
273            GaussianShHalfConfig::FEATURE,
274            GaussianShNorm8Config::FEATURE,
275            GaussianShNoneConfig::FEATURE,
276            GaussianCov3dRotScaleConfig::FEATURE,
277            GaussianCov3dSingleConfig::FEATURE,
278            GaussianCov3dHalfConfig::FEATURE,
279        ]
280        .map(|name| {
281            (
282                name,
283                name == Self::ShConfig::FEATURE || name == Self::Cov3dConfig::FEATURE,
284            )
285        })
286    }
287
288    /// Create the features for [`Wesl`](wesl::Wesl) compilation as a [`wesl::Features`].
289    fn wesl_features() -> wesl::Features {
290        wesl::Features {
291            flags: Self::features()
292                .iter()
293                .map(|(name, enabled)| (name.to_string(), (*enabled).into()))
294                .collect(),
295            ..Default::default()
296        }
297    }
298}
299
300/// Macro to create the POD representation of Gaussian given the configurations.
301macro_rules! gaussian_pod {
302    (sh = $sh:ident, cov3d = $cov3d:ident, padding_size = $padding:expr) => {
303        paste::paste! {
304            #[repr(C)]
305            #[derive(Debug, Clone, Copy, PartialEq, bytemuck::Pod, bytemuck::Zeroable)]
306            pub struct [< GaussianPodWith Sh $sh Cov3d $cov3d Configs >] {
307                pub pos: Vec3,
308                pub color: U8Vec4,
309                pub sh: <[< GaussianSh $sh Config >] as GaussianShConfig>::Field,
310                pub cov3d: <[< GaussianCov3d $cov3d Config >] as GaussianCov3dConfig>::Field,
311                pub padding: [f32; $padding],
312            }
313
314            impl From<&Gaussian> for [< GaussianPodWith Sh $sh Cov3d $cov3d Configs >] {
315                fn from(gaussian: &Gaussian) -> Self {
316                    // Covariance
317                    let cov3d = <[< GaussianCov3d $cov3d Config >]>::from_rot_scale(
318                        gaussian.rot,
319                        gaussian.scale,
320                    );
321
322                    // Color
323                    let color = gaussian.color;
324
325                    // Spherical harmonics
326                    let sh = [< GaussianSh $sh Config >]::from_sh(&gaussian.sh);
327
328                    // Position
329                    let pos = gaussian.pos;
330
331                    Self {
332                        pos,
333                        color,
334                        sh,
335                        cov3d,
336                        padding: [0.0; $padding],
337                    }
338                }
339            }
340
341            impl From<[< GaussianPodWith Sh $sh Cov3d $cov3d Configs >]> for Gaussian {
342                fn from(pod: [< GaussianPodWith Sh $sh Cov3d $cov3d Configs >]) -> Self {
343                    // Position
344                    let pos = pod.pos;
345
346                    // Spherical harmonics
347                    let sh = [< GaussianSh $sh Config >]::to_sh(&pod.sh);
348
349                    // Color
350                    let color = pod.color;
351
352                    // Rotation
353                    let (rot, scale) = <[< GaussianCov3d $cov3d Config >]>::to_rot_scale(&pod.cov3d);
354
355                    Self {
356                        rot,
357                        pos,
358                        color,
359                        sh,
360                        scale,
361                    }
362                }
363            }
364
365            impl GaussianPod for [< GaussianPodWith Sh $sh Cov3d $cov3d Configs >] {
366                type ShConfig = [< GaussianSh $sh Config >];
367                type Cov3dConfig = [< GaussianCov3d $cov3d Config >];
368            }
369        }
370    };
371}
372
373gaussian_pod!(sh = Single, cov3d = RotScale, padding_size = 0);
374gaussian_pod!(sh = Single, cov3d = Single, padding_size = 1);
375gaussian_pod!(sh = Single, cov3d = Half, padding_size = 0);
376gaussian_pod!(sh = Half, cov3d = RotScale, padding_size = 2);
377gaussian_pod!(sh = Half, cov3d = Single, padding_size = 3);
378gaussian_pod!(sh = Half, cov3d = Half, padding_size = 2);
379gaussian_pod!(sh = Norm8, cov3d = RotScale, padding_size = 1);
380gaussian_pod!(sh = Norm8, cov3d = Single, padding_size = 2);
381gaussian_pod!(sh = Norm8, cov3d = Half, padding_size = 1);
382gaussian_pod!(sh = None, cov3d = RotScale, padding_size = 1);
383gaussian_pod!(sh = None, cov3d = Single, padding_size = 2);
384gaussian_pod!(sh = None, cov3d = Half, padding_size = 1);
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    macro_rules! test_pod_from_gaussian {
391        ($name:ident, $pod_type:ty, true) => {
392            paste::paste! {
393                #[test]
394                #[should_panic]
395                fn [<test_ $name _into_gaussian_should_panic>]() {
396                    let pod = $pod_type::from_gaussian(&Gaussian {
397                        rot: Quat::from_xyzw(0.0, 0.0, 0.0, 1.0),
398                        pos: Vec3::new(1.0, 2.0, 3.0),
399                        color: U8Vec4::new(255, 128, 64, 32),
400                        sh: [Vec3::new(0.1, 0.2, 0.3); 15],
401                        scale: Vec3::new(1.0, 2.0, 3.0),
402                    });
403
404                    pod.into_gaussian();
405                }
406            }
407        };
408        ($name:ident, $pod_type:ty, false) => {
409            paste::paste! {
410                #[test]
411                fn [<test_ $name _into_gaussian_should_equal_original_pod>]() {
412                    let pod = $pod_type::from_gaussian(&Gaussian {
413                        rot: Quat::from_xyzw(0.0, 0.0, 0.0, 1.0),
414                        pos: Vec3::new(1.0, 2.0, 3.0),
415                        color: U8Vec4::new(255, 128, 64, 32),
416                        sh: [Vec3::new(0.1, 0.2, 0.3); 15],
417                        scale: Vec3::new(1.0, 2.0, 3.0),
418                    });
419
420                    let gaussian = pod.into_gaussian();
421
422                    assert_eq!(pod.pos, gaussian.pos);
423                    assert_eq!(pod.color, gaussian.color);
424                    assert_eq!(
425                        pod.sh,
426                        <$pod_type as GaussianPod>::ShConfig::from_sh(&gaussian.sh),
427                    );
428                    assert_eq!(
429                        pod.cov3d,
430                        <$pod_type as GaussianPod>::Cov3dConfig::from_rot_scale(
431                            gaussian.rot,
432                            gaussian.scale
433                        ),
434                    );
435                }
436            }
437        };
438    }
439
440    macro_rules! test_pod {
441        ($name:ident, $pod_type:ty, $when_into_gaussian_should_panic:expr) => {
442            paste::paste! {
443                #[test]
444                fn [<test_ $name _from_gaussian_should_equal_original_gaussian>]() {
445                    let gaussian = Gaussian {
446                        rot: Quat::from_xyzw(0.0, 0.0, 0.0, 1.0),
447                        pos: Vec3::new(1.0, 2.0, 3.0),
448                        color: U8Vec4::new(255, 128, 64, 32),
449                        sh: [Vec3::new(0.1, 0.2, 0.3); 15],
450                        scale: Vec3::new(1.0, 2.0, 3.0),
451                    };
452
453                    let pod = $pod_type::from_gaussian(&gaussian);
454
455                    assert_eq!(gaussian.pos, pod.pos);
456                    assert_eq!(gaussian.color, pod.color);
457                    assert_eq!(
458                        <$pod_type as GaussianPod>::ShConfig::from_sh(&gaussian.sh),
459                        pod.sh,
460                    );
461                    assert_eq!(
462                        <$pod_type as GaussianPod>::Cov3dConfig::from_rot_scale(
463                            gaussian.rot,
464                            gaussian.scale
465                        ),
466                        pod.cov3d,
467                    );
468                }
469
470                test_pod_from_gaussian!($name, $pod_type, $when_into_gaussian_should_panic);
471
472                #[test]
473                fn [<test_ $name _features_should_be_correct>]() {
474                    let features = <$pod_type as GaussianPod>::features();
475
476                    for (name, enabled) in features {
477                        if name == <$pod_type as GaussianPod>::ShConfig::FEATURE
478                            || name == <$pod_type as GaussianPod>::Cov3dConfig::FEATURE
479                        {
480                            assert!(enabled, "Feature {name} should be enabled");
481                        } else {
482                            assert!(!enabled, "Feature {name} should be disabled");
483                        }
484                    }
485                }
486
487                #[test]
488                fn [<test_ $name _wesl_features_should_be_correct>]() {
489                    let wesl_features = <$pod_type as GaussianPod>::wesl_features();
490                    let features = <$pod_type as GaussianPod>::features();
491
492                    for (name, enabled) in features {
493                        let wesl_enabled = wesl_features
494                            .flags
495                            .get(name)
496                            .map(|v| *v == wesl::Feature::Enable)
497                            .unwrap_or(false);
498
499                        assert_eq!(
500                            enabled, wesl_enabled,
501                            "Feature {name} should be {}",
502                            if enabled { "enabled" } else { "disabled" }
503                        );
504                    }
505                }
506            }
507        };
508    }
509
510    #[rustfmt::skip]
511    mod pod {
512        use super::*;
513
514        test_pod!(single_rotscale, GaussianPodWithShSingleCov3dRotScaleConfigs, false);
515        test_pod!(single_single, GaussianPodWithShSingleCov3dSingleConfigs, true);
516        test_pod!(single_half, GaussianPodWithShSingleCov3dHalfConfigs, true);
517        test_pod!(half_rotscale, GaussianPodWithShHalfCov3dRotScaleConfigs, false);
518        test_pod!(test_half_single, GaussianPodWithShHalfCov3dSingleConfigs, true);
519        test_pod!(test_half_half, GaussianPodWithShHalfCov3dHalfConfigs, true);
520        test_pod!(norm8_rotscale, GaussianPodWithShNorm8Cov3dRotScaleConfigs, false);
521        test_pod!(norm8_single, GaussianPodWithShNorm8Cov3dSingleConfigs, true);
522        test_pod!(norm8_half, GaussianPodWithShNorm8Cov3dHalfConfigs, true);
523        test_pod!(none_rotscale, GaussianPodWithShNoneCov3dRotScaleConfigs, true);
524        test_pod!(none_single, GaussianPodWithShNoneCov3dSingleConfigs, true);
525        test_pod!(none_half, GaussianPodWithShNoneCov3dHalfConfigs, true);
526    }
527}