Skip to main content

wgpu_3dgs_viewer/
preprocessor.rs

1use crate::{
2    CameraBuffer, GaussiansDepthBuffer, IndirectArgsBuffer, IndirectIndicesBuffer,
3    PreprocessorCreateError, RadixSortIndirectArgsBuffer,
4    core::{
5        BufferWrapper, ComputeBundle, ComputeBundleBuilder, GaussianPod, GaussianTransformBuffer,
6        GaussiansBuffer, ModelTransformBuffer,
7    },
8    wesl_utils,
9};
10
11#[cfg(feature = "viewer-selection")]
12use crate::{editor::SelectionBuffer, selection};
13
14/// Preprocessor to preprocess the Gaussians.
15///
16/// It computes the depth for [`RadixSorter`](crate::RadixSorter) and do frustum culling.
17#[derive(Debug)]
18pub struct Preprocessor<G: GaussianPod, B = wgpu::BindGroup> {
19    /// The bind group layout.
20    #[allow(dead_code)]
21    bind_group_layout: wgpu::BindGroupLayout,
22    /// The bind group.
23    bind_group: B,
24    /// The pre preprocess bundle.
25    pre_bundle: ComputeBundle<()>,
26    /// The preprocess bundle.
27    bundle: ComputeBundle<()>,
28    /// The post preprocess bundle.
29    post_bundle: ComputeBundle<()>,
30    /// The marker for the Gaussian POD type.
31    gaussian_pod_marker: std::marker::PhantomData<G>,
32}
33
34impl<G: GaussianPod, B> Preprocessor<G, B> {
35    /// Create the bind group.
36    #[allow(clippy::too_many_arguments)]
37    pub fn create_bind_group(
38        &self,
39        device: &wgpu::Device,
40        camera: &CameraBuffer,
41        model_transform: &ModelTransformBuffer,
42        gaussian_transform: &GaussianTransformBuffer,
43        gaussians: &GaussiansBuffer<G>,
44        indirect_args: &IndirectArgsBuffer,
45        radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
46        indirect_indices: &IndirectIndicesBuffer,
47        gaussians_depth: &GaussiansDepthBuffer,
48        #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
49        #[cfg(feature = "viewer-selection")]
50        invert_selection: &selection::PreprocessorInvertSelectionBuffer,
51    ) -> wgpu::BindGroup {
52        Preprocessor::create_bind_group_static(
53            device,
54            &self.bind_group_layout,
55            camera,
56            model_transform,
57            gaussian_transform,
58            gaussians,
59            indirect_args,
60            radix_sort_indirect_args,
61            indirect_indices,
62            gaussians_depth,
63            #[cfg(feature = "viewer-selection")]
64            selection,
65            #[cfg(feature = "viewer-selection")]
66            invert_selection,
67        )
68    }
69
70    /// Get the number of invocations in one workgroup.
71    pub fn workgroup_size(&self) -> u32 {
72        self.bundle.workgroup_size()
73    }
74
75    /// Get the bind group layout.
76    pub fn bind_group_layout(&self) -> &wgpu::BindGroupLayout {
77        &self.bind_group_layout
78    }
79
80    /// Get the pre preprocess bundle.
81    pub fn pre_bundle(&self) -> &ComputeBundle<()> {
82        &self.pre_bundle
83    }
84
85    /// Get the preprocess bundle.
86    pub fn bundle(&self) -> &ComputeBundle<()> {
87        &self.bundle
88    }
89
90    /// Get the post preprocess bundle.
91    pub fn post_bundle(&self) -> &ComputeBundle<()> {
92        &self.post_bundle
93    }
94}
95
96impl<G: GaussianPod> Preprocessor<G> {
97    /// The label.
98    const LABEL: &str = "Preprocessor";
99
100    /// The main shader module path.
101    const MAIN_SHADER: &str = "wgpu_3dgs_viewer::preprocess";
102
103    /// The bind group layout descriptor.
104    pub const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
105        wgpu::BindGroupLayoutDescriptor {
106            label: Some("Preprocessor Bind Group Layout"),
107            entries: &[
108                // Camera uniform buffer
109                wgpu::BindGroupLayoutEntry {
110                    binding: 0,
111                    visibility: wgpu::ShaderStages::COMPUTE,
112                    ty: wgpu::BindingType::Buffer {
113                        ty: wgpu::BufferBindingType::Uniform,
114                        has_dynamic_offset: false,
115                        min_binding_size: None,
116                    },
117                    count: None,
118                },
119                // Model transform uniform buffer
120                wgpu::BindGroupLayoutEntry {
121                    binding: 1,
122                    visibility: wgpu::ShaderStages::COMPUTE,
123                    ty: wgpu::BindingType::Buffer {
124                        ty: wgpu::BufferBindingType::Uniform,
125                        has_dynamic_offset: false,
126                        min_binding_size: None,
127                    },
128                    count: None,
129                },
130                // Gaussian transform uniform buffer
131                wgpu::BindGroupLayoutEntry {
132                    binding: 2,
133                    visibility: wgpu::ShaderStages::COMPUTE,
134                    ty: wgpu::BindingType::Buffer {
135                        ty: wgpu::BufferBindingType::Uniform,
136                        has_dynamic_offset: false,
137                        min_binding_size: None,
138                    },
139                    count: None,
140                },
141                // Gaussian storage buffer
142                wgpu::BindGroupLayoutEntry {
143                    binding: 3,
144                    visibility: wgpu::ShaderStages::COMPUTE,
145                    ty: wgpu::BindingType::Buffer {
146                        ty: wgpu::BufferBindingType::Storage { read_only: true },
147                        has_dynamic_offset: false,
148                        min_binding_size: None,
149                    },
150                    count: None,
151                },
152                // Indirect args storage buffer
153                wgpu::BindGroupLayoutEntry {
154                    binding: 4,
155                    visibility: wgpu::ShaderStages::COMPUTE,
156                    ty: wgpu::BindingType::Buffer {
157                        ty: wgpu::BufferBindingType::Storage { read_only: false },
158                        has_dynamic_offset: false,
159                        min_binding_size: None,
160                    },
161                    count: None,
162                },
163                // Radix sort indirect args storage buffer
164                wgpu::BindGroupLayoutEntry {
165                    binding: 5,
166                    visibility: wgpu::ShaderStages::COMPUTE,
167                    ty: wgpu::BindingType::Buffer {
168                        ty: wgpu::BufferBindingType::Storage { read_only: false },
169                        has_dynamic_offset: false,
170                        min_binding_size: None,
171                    },
172                    count: None,
173                },
174                // Indirect indices storage buffer
175                wgpu::BindGroupLayoutEntry {
176                    binding: 6,
177                    visibility: wgpu::ShaderStages::COMPUTE,
178                    ty: wgpu::BindingType::Buffer {
179                        ty: wgpu::BufferBindingType::Storage { read_only: false },
180                        has_dynamic_offset: false,
181                        min_binding_size: None,
182                    },
183                    count: None,
184                },
185                // Gaussians depth storage buffer
186                wgpu::BindGroupLayoutEntry {
187                    binding: 7,
188                    visibility: wgpu::ShaderStages::COMPUTE,
189                    ty: wgpu::BindingType::Buffer {
190                        ty: wgpu::BufferBindingType::Storage { read_only: false },
191                        has_dynamic_offset: false,
192                        min_binding_size: None,
193                    },
194                    count: None,
195                },
196                // Selection buffer
197                #[cfg(feature = "viewer-selection")]
198                wgpu::BindGroupLayoutEntry {
199                    binding: 8,
200                    visibility: wgpu::ShaderStages::COMPUTE,
201                    ty: wgpu::BindingType::Buffer {
202                        ty: wgpu::BufferBindingType::Storage { read_only: true },
203                        has_dynamic_offset: false,
204                        min_binding_size: None,
205                    },
206                    count: None,
207                },
208                // Invert selection buffer
209                #[cfg(feature = "viewer-selection")]
210                wgpu::BindGroupLayoutEntry {
211                    binding: 9,
212                    visibility: wgpu::ShaderStages::COMPUTE,
213                    ty: wgpu::BindingType::Buffer {
214                        ty: wgpu::BufferBindingType::Uniform,
215                        has_dynamic_offset: false,
216                        min_binding_size: None,
217                    },
218                    count: None,
219                },
220            ],
221        };
222
223    /// Create a new preprocessor.
224    #[allow(clippy::too_many_arguments)]
225    pub fn new(
226        device: &wgpu::Device,
227        camera: &CameraBuffer,
228        model_transform: &ModelTransformBuffer,
229        gaussian_transform: &GaussianTransformBuffer,
230        gaussians: &GaussiansBuffer<G>,
231        indirect_args: &IndirectArgsBuffer,
232        radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
233        indirect_indices: &IndirectIndicesBuffer,
234        gaussians_depth: &GaussiansDepthBuffer,
235        #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
236        #[cfg(feature = "viewer-selection")]
237        invert_selection: &selection::PreprocessorInvertSelectionBuffer,
238    ) -> Result<Self, PreprocessorCreateError> {
239        if (device.limits().max_storage_buffer_binding_size as wgpu::BufferAddress)
240            < gaussians.buffer().size()
241        {
242            return Err(PreprocessorCreateError::ModelSizeExceedsDeviceLimit {
243                model_size: gaussians.buffer().size(),
244                device_limit: device.limits().max_storage_buffer_binding_size,
245            });
246        }
247
248        let this = Preprocessor::new_without_bind_group(device)?;
249
250        log::debug!("Creating preprocessor bind group");
251        let bind_group = this.create_bind_group(
252            device,
253            camera,
254            model_transform,
255            gaussian_transform,
256            gaussians,
257            indirect_args,
258            radix_sort_indirect_args,
259            indirect_indices,
260            gaussians_depth,
261            #[cfg(feature = "viewer-selection")]
262            selection,
263            #[cfg(feature = "viewer-selection")]
264            invert_selection,
265        );
266
267        Ok(Self {
268            bind_group_layout: this.bind_group_layout,
269            bind_group,
270            pre_bundle: this.pre_bundle,
271            bundle: this.bundle,
272            post_bundle: this.post_bundle,
273            gaussian_pod_marker: std::marker::PhantomData,
274        })
275    }
276
277    /// Get the bind group.
278    pub fn bind_group(&self) -> &wgpu::BindGroup {
279        &self.bind_group
280    }
281
282    /// Preprocess the Gaussians.
283    pub fn preprocess(&self, encoder: &mut wgpu::CommandEncoder, gaussian_count: u32) {
284        self.pre_bundle.dispatch(encoder, 1, [&self.bind_group]);
285
286        self.bundle
287            .dispatch(encoder, gaussian_count, [&self.bind_group]);
288
289        self.post_bundle.dispatch(encoder, 1, [&self.bind_group]);
290    }
291
292    /// Create the bind group statically.
293    #[allow(clippy::too_many_arguments)]
294    fn create_bind_group_static(
295        device: &wgpu::Device,
296        bind_group_layout: &wgpu::BindGroupLayout,
297        camera: &CameraBuffer,
298        model_transform: &ModelTransformBuffer,
299        gaussian_transform: &GaussianTransformBuffer,
300        gaussians: &GaussiansBuffer<G>,
301        indirect_args: &IndirectArgsBuffer,
302        radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
303        indirect_indices: &IndirectIndicesBuffer,
304        gaussians_depth: &GaussiansDepthBuffer,
305        #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
306        #[cfg(feature = "viewer-selection")]
307        invert_selection: &selection::PreprocessorInvertSelectionBuffer,
308    ) -> wgpu::BindGroup {
309        device.create_bind_group(&wgpu::BindGroupDescriptor {
310            label: Some("Preprocessor Bind Group"),
311            layout: bind_group_layout,
312            entries: &[
313                // Camera uniform buffer
314                wgpu::BindGroupEntry {
315                    binding: 0,
316                    resource: camera.buffer().as_entire_binding(),
317                },
318                // Model transform uniform buffer
319                wgpu::BindGroupEntry {
320                    binding: 1,
321                    resource: model_transform.buffer().as_entire_binding(),
322                },
323                // Gaussian transform uniform buffer
324                wgpu::BindGroupEntry {
325                    binding: 2,
326                    resource: gaussian_transform.buffer().as_entire_binding(),
327                },
328                // Gaussian storage buffer
329                wgpu::BindGroupEntry {
330                    binding: 3,
331                    resource: gaussians.buffer().as_entire_binding(),
332                },
333                // Indirect args storage buffer
334                wgpu::BindGroupEntry {
335                    binding: 4,
336                    resource: indirect_args.buffer().as_entire_binding(),
337                },
338                // Radix sort indirect args storage buffer
339                wgpu::BindGroupEntry {
340                    binding: 5,
341                    resource: radix_sort_indirect_args.buffer().as_entire_binding(),
342                },
343                // Indirect indices storage buffer
344                wgpu::BindGroupEntry {
345                    binding: 6,
346                    resource: indirect_indices.buffer().as_entire_binding(),
347                },
348                // Gaussians depth storage buffer
349                wgpu::BindGroupEntry {
350                    binding: 7,
351                    resource: gaussians_depth.buffer().as_entire_binding(),
352                },
353                // Selection buffer
354                #[cfg(feature = "viewer-selection")]
355                wgpu::BindGroupEntry {
356                    binding: 8,
357                    resource: selection.buffer().as_entire_binding(),
358                },
359                // Invert selection buffer
360                #[cfg(feature = "viewer-selection")]
361                wgpu::BindGroupEntry {
362                    binding: 9,
363                    resource: invert_selection.buffer().as_entire_binding(),
364                },
365            ],
366        })
367    }
368}
369
370impl<G: GaussianPod> Preprocessor<G, ()> {
371    /// Create a new preprocessor without interally managed bind group.
372    ///
373    /// To create a bind group with layout matched to this preprocessor, use the
374    /// [`Preprocessor::create_bind_group`] method.
375    pub fn new_without_bind_group(device: &wgpu::Device) -> Result<Self, PreprocessorCreateError> {
376        let main_shader: wesl::ModulePath = Preprocessor::<G>::MAIN_SHADER
377            .parse()
378            .expect("preprocess module path");
379
380        let wesl_compile_options = wesl::CompileOptions {
381            features: wesl::Features {
382                flags: G::features()
383                    .into_iter()
384                    .chain(std::iter::once((
385                        "selection_buffer",
386                        cfg!(feature = "viewer-selection"),
387                    )))
388                    .map(|(k, v)| (k.to_string(), v.into()))
389                    .collect(),
390                ..Default::default()
391            },
392            ..Default::default()
393        };
394
395        let bind_group_layout =
396            device.create_bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR);
397
398        let pre_bundle = ComputeBundleBuilder::new()
399            .label(format!("Pre {}", Preprocessor::<G>::LABEL).as_str())
400            .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
401            .entry_point("pre")
402            .main_shader(main_shader.clone())
403            .wesl_compile_options(wesl_compile_options.clone())
404            .resolver(wesl_utils::resolver())
405            .build_without_bind_groups(device)?;
406
407        let bundle = ComputeBundleBuilder::new()
408            .label(Preprocessor::<G>::LABEL)
409            .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
410            .entry_point("main")
411            .main_shader(main_shader.clone())
412            .wesl_compile_options(wesl_compile_options.clone())
413            .resolver(wesl_utils::resolver())
414            .build_without_bind_groups(device)?;
415
416        let post_bundle = ComputeBundleBuilder::new()
417            .label(format!("Post {}", Preprocessor::<G>::LABEL).as_str())
418            .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
419            .entry_point("post")
420            .main_shader(main_shader)
421            .wesl_compile_options(wesl_compile_options)
422            .resolver(wesl_utils::resolver())
423            .build_without_bind_groups(device)?;
424
425        log::info!("Preprocessor created");
426
427        Ok(Self {
428            bind_group_layout,
429            bind_group: (),
430            pre_bundle,
431            bundle,
432            post_bundle,
433            gaussian_pod_marker: std::marker::PhantomData,
434        })
435    }
436
437    /// Preprocess the Gaussians.
438    pub fn preprocess(
439        &self,
440        encoder: &mut wgpu::CommandEncoder,
441        bind_group: &wgpu::BindGroup,
442        gaussian_count: u32,
443    ) {
444        self.pre_bundle.dispatch(encoder, 1, [bind_group]);
445
446        self.bundle.dispatch(encoder, gaussian_count, [bind_group]);
447
448        self.post_bundle.dispatch(encoder, 1, [bind_group]);
449    }
450}