Skip to main content

wgpu_3dgs_viewer/
preprocessor.rs

1use crate::{
2    CameraBuffer, GaussiansDepthBuffer, IndirectArgsBuffer, IndirectIndicesBuffer,
3    PreprocessorCreateError, RadixSortIndirectArgsBuffer,
4    core::{
5        BufferWrapper, ComputeBundle, ComputeBundleBuilder, GaussianPod, GaussianTransformBuffer,
6        GaussiansBuffer, ModelTransformBuffer,
7    },
8    wesl_utils,
9};
10
11#[cfg(feature = "viewer-selection")]
12use crate::{editor::SelectionBuffer, selection};
13
14/// Preprocessor to preprocess the Gaussians.
15///
16/// It computes the depth for [`RadixSorter`](crate::RadixSorter) and do frustum culling.
17#[derive(Debug)]
18pub struct Preprocessor<G: GaussianPod, B = wgpu::BindGroup> {
19    /// The bind group layout.
20    #[allow(dead_code)]
21    bind_group_layout: wgpu::BindGroupLayout,
22    /// The bind group.
23    bind_group: B,
24    /// The pre preprocess bundle.
25    pre_bundle: ComputeBundle<()>,
26    /// The preprocess bundle.
27    bundle: ComputeBundle<()>,
28    /// The post preprocess bundle.
29    post_bundle: ComputeBundle<()>,
30    /// The marker for the Gaussian POD type.
31    gaussian_pod_marker: std::marker::PhantomData<G>,
32}
33
34impl<G: GaussianPod, B> Preprocessor<G, B> {
35    /// Create the bind group.
36    #[allow(clippy::too_many_arguments)]
37    pub fn create_bind_group(
38        &self,
39        device: &wgpu::Device,
40        camera: &CameraBuffer,
41        model_transform: &ModelTransformBuffer,
42        gaussian_transform: &GaussianTransformBuffer,
43        gaussians: &GaussiansBuffer<G>,
44        indirect_args: &IndirectArgsBuffer,
45        radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
46        indirect_indices: &IndirectIndicesBuffer,
47        gaussians_depth: &GaussiansDepthBuffer,
48        #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
49        #[cfg(feature = "viewer-selection")]
50        invert_selection: &selection::PreprocessorInvertSelectionBuffer,
51    ) -> wgpu::BindGroup {
52        Preprocessor::create_bind_group_static(
53            device,
54            &self.bind_group_layout,
55            camera,
56            model_transform,
57            gaussian_transform,
58            gaussians,
59            indirect_args,
60            radix_sort_indirect_args,
61            indirect_indices,
62            gaussians_depth,
63            #[cfg(feature = "viewer-selection")]
64            selection,
65            #[cfg(feature = "viewer-selection")]
66            invert_selection,
67        )
68    }
69
70    /// Get the number of invocations in one workgroup.
71    pub fn workgroup_size(&self) -> u32 {
72        self.bundle.workgroup_size()
73    }
74
75    /// Get the bind group layouts.
76    pub fn bind_group_layout(&self) -> &wgpu::BindGroupLayout {
77        &self.bind_group_layout
78    }
79}
80
81impl<G: GaussianPod> Preprocessor<G> {
82    /// The label.
83    const LABEL: &str = "Preprocessor";
84
85    /// The main shader module path.
86    const MAIN_SHADER: &str = "wgpu_3dgs_viewer::preprocess";
87
88    /// The bind group layout descriptor.
89    pub const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
90        wgpu::BindGroupLayoutDescriptor {
91            label: Some("Preprocessor Bind Group Layout"),
92            entries: &[
93                // Camera uniform buffer
94                wgpu::BindGroupLayoutEntry {
95                    binding: 0,
96                    visibility: wgpu::ShaderStages::COMPUTE,
97                    ty: wgpu::BindingType::Buffer {
98                        ty: wgpu::BufferBindingType::Uniform,
99                        has_dynamic_offset: false,
100                        min_binding_size: None,
101                    },
102                    count: None,
103                },
104                // Model transform uniform buffer
105                wgpu::BindGroupLayoutEntry {
106                    binding: 1,
107                    visibility: wgpu::ShaderStages::COMPUTE,
108                    ty: wgpu::BindingType::Buffer {
109                        ty: wgpu::BufferBindingType::Uniform,
110                        has_dynamic_offset: false,
111                        min_binding_size: None,
112                    },
113                    count: None,
114                },
115                // Gaussian transform uniform buffer
116                wgpu::BindGroupLayoutEntry {
117                    binding: 2,
118                    visibility: wgpu::ShaderStages::COMPUTE,
119                    ty: wgpu::BindingType::Buffer {
120                        ty: wgpu::BufferBindingType::Uniform,
121                        has_dynamic_offset: false,
122                        min_binding_size: None,
123                    },
124                    count: None,
125                },
126                // Gaussian storage buffer
127                wgpu::BindGroupLayoutEntry {
128                    binding: 3,
129                    visibility: wgpu::ShaderStages::COMPUTE,
130                    ty: wgpu::BindingType::Buffer {
131                        ty: wgpu::BufferBindingType::Storage { read_only: true },
132                        has_dynamic_offset: false,
133                        min_binding_size: None,
134                    },
135                    count: None,
136                },
137                // Indirect args storage buffer
138                wgpu::BindGroupLayoutEntry {
139                    binding: 4,
140                    visibility: wgpu::ShaderStages::COMPUTE,
141                    ty: wgpu::BindingType::Buffer {
142                        ty: wgpu::BufferBindingType::Storage { read_only: false },
143                        has_dynamic_offset: false,
144                        min_binding_size: None,
145                    },
146                    count: None,
147                },
148                // Radix sort indirect args storage buffer
149                wgpu::BindGroupLayoutEntry {
150                    binding: 5,
151                    visibility: wgpu::ShaderStages::COMPUTE,
152                    ty: wgpu::BindingType::Buffer {
153                        ty: wgpu::BufferBindingType::Storage { read_only: false },
154                        has_dynamic_offset: false,
155                        min_binding_size: None,
156                    },
157                    count: None,
158                },
159                // Indirect indices storage buffer
160                wgpu::BindGroupLayoutEntry {
161                    binding: 6,
162                    visibility: wgpu::ShaderStages::COMPUTE,
163                    ty: wgpu::BindingType::Buffer {
164                        ty: wgpu::BufferBindingType::Storage { read_only: false },
165                        has_dynamic_offset: false,
166                        min_binding_size: None,
167                    },
168                    count: None,
169                },
170                // Gaussians depth storage buffer
171                wgpu::BindGroupLayoutEntry {
172                    binding: 7,
173                    visibility: wgpu::ShaderStages::COMPUTE,
174                    ty: wgpu::BindingType::Buffer {
175                        ty: wgpu::BufferBindingType::Storage { read_only: false },
176                        has_dynamic_offset: false,
177                        min_binding_size: None,
178                    },
179                    count: None,
180                },
181                // Selection buffer
182                #[cfg(feature = "viewer-selection")]
183                wgpu::BindGroupLayoutEntry {
184                    binding: 8,
185                    visibility: wgpu::ShaderStages::COMPUTE,
186                    ty: wgpu::BindingType::Buffer {
187                        ty: wgpu::BufferBindingType::Storage { read_only: true },
188                        has_dynamic_offset: false,
189                        min_binding_size: None,
190                    },
191                    count: None,
192                },
193                // Invert selection buffer
194                #[cfg(feature = "viewer-selection")]
195                wgpu::BindGroupLayoutEntry {
196                    binding: 9,
197                    visibility: wgpu::ShaderStages::COMPUTE,
198                    ty: wgpu::BindingType::Buffer {
199                        ty: wgpu::BufferBindingType::Uniform,
200                        has_dynamic_offset: false,
201                        min_binding_size: None,
202                    },
203                    count: None,
204                },
205            ],
206        };
207
208    /// Create a new preprocessor.
209    #[allow(clippy::too_many_arguments)]
210    pub fn new(
211        device: &wgpu::Device,
212        camera: &CameraBuffer,
213        model_transform: &ModelTransformBuffer,
214        gaussian_transform: &GaussianTransformBuffer,
215        gaussians: &GaussiansBuffer<G>,
216        indirect_args: &IndirectArgsBuffer,
217        radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
218        indirect_indices: &IndirectIndicesBuffer,
219        gaussians_depth: &GaussiansDepthBuffer,
220        #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
221        #[cfg(feature = "viewer-selection")]
222        invert_selection: &selection::PreprocessorInvertSelectionBuffer,
223    ) -> Result<Self, PreprocessorCreateError> {
224        if (device.limits().max_storage_buffer_binding_size as wgpu::BufferAddress)
225            < gaussians.buffer().size()
226        {
227            return Err(PreprocessorCreateError::ModelSizeExceedsDeviceLimit {
228                model_size: gaussians.buffer().size(),
229                device_limit: device.limits().max_storage_buffer_binding_size,
230            });
231        }
232
233        let this = Preprocessor::new_without_bind_group(device)?;
234
235        log::debug!("Creating preprocessor bind group");
236        let bind_group = this.create_bind_group(
237            device,
238            camera,
239            model_transform,
240            gaussian_transform,
241            gaussians,
242            indirect_args,
243            radix_sort_indirect_args,
244            indirect_indices,
245            gaussians_depth,
246            #[cfg(feature = "viewer-selection")]
247            selection,
248            #[cfg(feature = "viewer-selection")]
249            invert_selection,
250        );
251
252        Ok(Self {
253            bind_group_layout: this.bind_group_layout,
254            bind_group,
255            pre_bundle: this.pre_bundle,
256            bundle: this.bundle,
257            post_bundle: this.post_bundle,
258            gaussian_pod_marker: std::marker::PhantomData,
259        })
260    }
261
262    /// Preprocess the Gaussians.
263    pub fn preprocess(&self, encoder: &mut wgpu::CommandEncoder, gaussian_count: u32) {
264        self.pre_bundle.dispatch(encoder, 1, [&self.bind_group]);
265
266        self.bundle
267            .dispatch(encoder, gaussian_count, [&self.bind_group]);
268
269        self.post_bundle.dispatch(encoder, 1, [&self.bind_group]);
270    }
271
272    /// Create the bind group statically.
273    #[allow(clippy::too_many_arguments)]
274    fn create_bind_group_static(
275        device: &wgpu::Device,
276        bind_group_layout: &wgpu::BindGroupLayout,
277        camera: &CameraBuffer,
278        model_transform: &ModelTransformBuffer,
279        gaussian_transform: &GaussianTransformBuffer,
280        gaussians: &GaussiansBuffer<G>,
281        indirect_args: &IndirectArgsBuffer,
282        radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
283        indirect_indices: &IndirectIndicesBuffer,
284        gaussians_depth: &GaussiansDepthBuffer,
285        #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
286        #[cfg(feature = "viewer-selection")]
287        invert_selection: &selection::PreprocessorInvertSelectionBuffer,
288    ) -> wgpu::BindGroup {
289        device.create_bind_group(&wgpu::BindGroupDescriptor {
290            label: Some("Preprocessor Bind Group"),
291            layout: bind_group_layout,
292            entries: &[
293                // Camera uniform buffer
294                wgpu::BindGroupEntry {
295                    binding: 0,
296                    resource: camera.buffer().as_entire_binding(),
297                },
298                // Model transform uniform buffer
299                wgpu::BindGroupEntry {
300                    binding: 1,
301                    resource: model_transform.buffer().as_entire_binding(),
302                },
303                // Gaussian transform uniform buffer
304                wgpu::BindGroupEntry {
305                    binding: 2,
306                    resource: gaussian_transform.buffer().as_entire_binding(),
307                },
308                // Gaussian storage buffer
309                wgpu::BindGroupEntry {
310                    binding: 3,
311                    resource: gaussians.buffer().as_entire_binding(),
312                },
313                // Indirect args storage buffer
314                wgpu::BindGroupEntry {
315                    binding: 4,
316                    resource: indirect_args.buffer().as_entire_binding(),
317                },
318                // Radix sort indirect args storage buffer
319                wgpu::BindGroupEntry {
320                    binding: 5,
321                    resource: radix_sort_indirect_args.buffer().as_entire_binding(),
322                },
323                // Indirect indices storage buffer
324                wgpu::BindGroupEntry {
325                    binding: 6,
326                    resource: indirect_indices.buffer().as_entire_binding(),
327                },
328                // Gaussians depth storage buffer
329                wgpu::BindGroupEntry {
330                    binding: 7,
331                    resource: gaussians_depth.buffer().as_entire_binding(),
332                },
333                // Selection buffer
334                #[cfg(feature = "viewer-selection")]
335                wgpu::BindGroupEntry {
336                    binding: 8,
337                    resource: selection.buffer().as_entire_binding(),
338                },
339                // Invert selection buffer
340                #[cfg(feature = "viewer-selection")]
341                wgpu::BindGroupEntry {
342                    binding: 9,
343                    resource: invert_selection.buffer().as_entire_binding(),
344                },
345            ],
346        })
347    }
348}
349
350impl<G: GaussianPod> Preprocessor<G, ()> {
351    /// Create a new preprocessor without interally managed bind group.
352    ///
353    /// To create a bind group with layout matched to this preprocessor, use the
354    /// [`Preprocessor::create_bind_group`] method.
355    pub fn new_without_bind_group(device: &wgpu::Device) -> Result<Self, PreprocessorCreateError> {
356        let main_shader: wesl::ModulePath = Preprocessor::<G>::MAIN_SHADER
357            .parse()
358            .expect("preprocess module path");
359
360        let wesl_compile_options = wesl::CompileOptions {
361            features: wesl::Features {
362                flags: G::features()
363                    .into_iter()
364                    .chain(std::iter::once((
365                        "selection_buffer",
366                        cfg!(feature = "viewer-selection"),
367                    )))
368                    .map(|(k, v)| (k.to_string(), v.into()))
369                    .collect(),
370                ..Default::default()
371            },
372            ..Default::default()
373        };
374
375        let bind_group_layout =
376            device.create_bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR);
377
378        let pre_bundle = ComputeBundleBuilder::new()
379            .label(format!("Pre {}", Preprocessor::<G>::LABEL).as_str())
380            .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
381            .entry_point("pre")
382            .main_shader(main_shader.clone())
383            .wesl_compile_options(wesl_compile_options.clone())
384            .resolver(wesl_utils::resolver())
385            .build_without_bind_groups(device)?;
386
387        let bundle = ComputeBundleBuilder::new()
388            .label(Preprocessor::<G>::LABEL)
389            .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
390            .entry_point("main")
391            .main_shader(main_shader.clone())
392            .wesl_compile_options(wesl_compile_options.clone())
393            .resolver(wesl_utils::resolver())
394            .build_without_bind_groups(device)?;
395
396        let post_bundle = ComputeBundleBuilder::new()
397            .label(format!("Post {}", Preprocessor::<G>::LABEL).as_str())
398            .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
399            .entry_point("post")
400            .main_shader(main_shader)
401            .wesl_compile_options(wesl_compile_options)
402            .resolver(wesl_utils::resolver())
403            .build_without_bind_groups(device)?;
404
405        log::info!("Preprocessor created");
406
407        Ok(Self {
408            bind_group_layout,
409            bind_group: (),
410            pre_bundle,
411            bundle,
412            post_bundle,
413            gaussian_pod_marker: std::marker::PhantomData,
414        })
415    }
416
417    /// Preprocess the Gaussians.
418    pub fn preprocess(
419        &self,
420        encoder: &mut wgpu::CommandEncoder,
421        bind_group: &wgpu::BindGroup,
422        gaussian_count: u32,
423    ) {
424        self.pre_bundle.dispatch(encoder, 1, [bind_group]);
425
426        self.bundle.dispatch(encoder, gaussian_count, [bind_group]);
427
428        self.post_bundle.dispatch(encoder, 1, [bind_group]);
429    }
430}