wgpu_3dgs_core/buffer/
gaussian_transform.rs

1use glam::*;
2use wgpu::util::DeviceExt;
3
4use crate::{BufferWrapper, FixedSizeBufferWrapper, FixedSizeBufferWrapperError};
5
6/// The Gaussian display modes.
7#[repr(u8)]
8#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
9pub enum GaussianDisplayMode {
10    #[default]
11    Splat = 0,
12    Ellipse = 1,
13    Point = 2,
14}
15
16/// The Gaussian spherical harmonics degrees.
17#[repr(transparent)]
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct GaussianShDegree(u8);
20
21impl GaussianShDegree {
22    /// Create a new Gaussian SH degree.
23    ///
24    /// Returns [`None`] if the degree is not in the range of \[0, 3\].
25    pub const fn new(sh_deg: u8) -> Option<Self> {
26        match sh_deg {
27            0..=3 => Some(Self(sh_deg)),
28            _ => None,
29        }
30    }
31
32    /// Create a new Gaussian SH degree without checking.
33    ///
34    /// # Safety
35    ///
36    /// The degree must be in the range of \[0, 3\].
37    pub const unsafe fn new_unchecked(sh_deg: u8) -> Self {
38        Self(sh_deg)
39    }
40
41    /// Get the degree.
42    pub const fn get(&self) -> u8 {
43        self.0
44    }
45}
46
47impl Default for GaussianShDegree {
48    fn default() -> Self {
49        // SAFETY: 3 is in the range of [0, 3].
50        unsafe { Self::new_unchecked(3) }
51    }
52}
53
54/// The Gaussian's maximum standard deviation.
55#[repr(transparent)]
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub struct GaussianMaxStdDev(u8);
58
59impl GaussianMaxStdDev {
60    /// Create a new Gaussian maximum standard deviation.
61    ///
62    /// Returns [`None`] if the maximum standard deviation is not in the range of \[0.0, 3.0\].
63    pub const fn new(max_std_dev: f32) -> Option<Self> {
64        match max_std_dev {
65            0.0..=3.0 => Some(Self((max_std_dev / 3.0 * 255.0) as u8)),
66            _ => None,
67        }
68    }
69
70    /// Create a new Gaussian maximum standard deviation without checking.
71    ///
72    /// # Safety
73    ///
74    /// The maximum standard deviation must be in the range of \[0.0, 3.0\].
75    pub const unsafe fn new_unchecked(max_std_dev: f32) -> Self {
76        Self((max_std_dev / 3.0 * 255.0) as u8)
77    }
78
79    /// Get the maximum standard deviation.
80    ///
81    /// Note that the returned value may have a small precision loss due to the internal
82    /// representation of [`prim@u8`].
83    pub const fn get(&self) -> f32 {
84        (self.0 as f32) / 255.0 * 3.0
85    }
86
87    /// Get the maximum standard deviation as the internal representation of [`prim@u8`].
88    pub const fn as_u8(&self) -> u8 {
89        self.0
90    }
91}
92
93impl Default for GaussianMaxStdDev {
94    fn default() -> Self {
95        // SAFETY: 3.0 is in the range of [0.0, 3.0].
96        unsafe { Self::new_unchecked(3.0) }
97    }
98}
99
100/// The Gaussian transform buffer.
101///
102/// This buffer holds the Gaussian transformation data, including size, display mode, SH degree,
103/// and whether to show SH0.
104#[derive(Debug, Clone)]
105pub struct GaussianTransformBuffer(wgpu::Buffer);
106
107impl GaussianTransformBuffer {
108    /// Create a new Gaussian transform buffer.
109    pub fn new(device: &wgpu::Device) -> Self {
110        let buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
111            label: Some("Gaussian transform Buffer"),
112            contents: bytemuck::bytes_of(&GaussianTransformPod::default()),
113            usage: Self::DEFAULT_USAGES,
114        });
115
116        Self(buffer)
117    }
118
119    /// Update the Gaussian transformation buffer.
120    pub fn update(
121        &self,
122        queue: &wgpu::Queue,
123        size: f32,
124        display_mode: GaussianDisplayMode,
125        sh_deg: GaussianShDegree,
126        no_sh0: bool,
127        max_std_dev: GaussianMaxStdDev,
128    ) {
129        self.update_with_pod(
130            queue,
131            &GaussianTransformPod::new(size, display_mode, sh_deg, no_sh0, max_std_dev),
132        );
133    }
134
135    /// Update the Gaussian transformation buffer with [`GaussianTransformPod`].
136    pub fn update_with_pod(&self, queue: &wgpu::Queue, transform: &GaussianTransformPod) {
137        queue.write_buffer(&self.0, 0, bytemuck::bytes_of(transform));
138    }
139}
140
141impl BufferWrapper for GaussianTransformBuffer {
142    fn buffer(&self) -> &wgpu::Buffer {
143        &self.0
144    }
145}
146
147impl From<GaussianTransformBuffer> for wgpu::Buffer {
148    fn from(wrapper: GaussianTransformBuffer) -> Self {
149        wrapper.0
150    }
151}
152
153impl TryFrom<wgpu::Buffer> for GaussianTransformBuffer {
154    type Error = FixedSizeBufferWrapperError;
155
156    fn try_from(buffer: wgpu::Buffer) -> Result<Self, Self::Error> {
157        Self::verify_buffer_size(&buffer).map(|()| Self(buffer))
158    }
159}
160
161impl FixedSizeBufferWrapper for GaussianTransformBuffer {
162    type Pod = GaussianTransformPod;
163}
164
165/// The POD representation of a Gaussian transformation.
166#[repr(C)]
167#[derive(Debug, Clone, Copy, PartialEq, bytemuck::Pod, bytemuck::Zeroable)]
168pub struct GaussianTransformPod {
169    pub size: f32,
170
171    /// \[display_mode, sh_deg, no_sh0, std_dev\]
172    pub flags: U8Vec4,
173}
174
175impl GaussianTransformPod {
176    /// Create a new Gaussian transformation.
177    pub const fn new(
178        size: f32,
179        display_mode: GaussianDisplayMode,
180        sh_deg: GaussianShDegree,
181        no_sh0: bool,
182        max_std_dev: GaussianMaxStdDev,
183    ) -> Self {
184        let display_mode = display_mode as u8;
185        let sh_deg = sh_deg.0;
186        let no_sh0 = no_sh0 as u8;
187        let max_std_dev = max_std_dev.0;
188
189        Self {
190            size,
191            flags: u8vec4(display_mode, sh_deg, no_sh0, max_std_dev),
192        }
193    }
194}
195
196impl Default for GaussianTransformPod {
197    fn default() -> Self {
198        Self::new(
199            1.0,
200            GaussianDisplayMode::default(),
201            GaussianShDegree::default(),
202            false,
203            GaussianMaxStdDev::default(),
204        )
205    }
206}