Skip to main content

wgpu_3dgs_viewer/
multi_model.rs

1use std::{collections::HashMap, hash::Hash};
2
3use crate::*;
4
5/// The buffers for [`Viewer`] related to the world.
6#[derive(Debug)]
7pub struct MultiModelViewerWorldBuffers {
8    pub camera_buffer: CameraBuffer,
9    pub gaussian_transform_buffer: GaussianTransformBuffer,
10}
11
12impl MultiModelViewerWorldBuffers {
13    /// Create a new viewer world buffers.
14    pub fn new(device: &wgpu::Device) -> Self {
15        log::debug!("Creating camera buffer");
16        let camera_buffer = CameraBuffer::new(device);
17
18        log::debug!("Creating gaussian transform buffer");
19        let gaussian_transform_buffer = GaussianTransformBuffer::new(device);
20
21        Self {
22            camera_buffer,
23            gaussian_transform_buffer,
24        }
25    }
26
27    /// Update the camera.
28    pub fn update_camera(
29        &mut self,
30        queue: &wgpu::Queue,
31        camera: &impl CameraTrait,
32        texture_size: UVec2,
33    ) {
34        self.camera_buffer.update(queue, camera, texture_size);
35    }
36
37    /// Update the camera with [`CameraPod`].
38    pub fn update_camera_with_pod(&mut self, queue: &wgpu::Queue, pod: &CameraPod) {
39        self.camera_buffer.update_with_pod(queue, pod);
40    }
41
42    /// Update the Gaussian transform.
43    pub fn update_gaussian_transform(
44        &mut self,
45        queue: &wgpu::Queue,
46        size: f32,
47        display_mode: GaussianDisplayMode,
48        sh_deg: GaussianShDegree,
49        no_sh0: bool,
50        max_std_dev: GaussianMaxStdDev,
51    ) {
52        self.gaussian_transform_buffer.update(
53            queue,
54            size,
55            display_mode,
56            sh_deg,
57            no_sh0,
58            max_std_dev,
59        );
60    }
61
62    /// Update the Gaussian transform with [`GaussianTransformPod`].
63    pub fn update_gaussian_transform_with_pod(
64        &mut self,
65        queue: &wgpu::Queue,
66        pod: &GaussianTransformPod,
67    ) {
68        self.gaussian_transform_buffer.update_with_pod(queue, pod);
69    }
70}
71
72/// The buffers for [`Viewer`] related to the Guassian model.
73#[derive(Debug)]
74pub struct MultiModelViewerGaussianBuffers<G: GaussianPod = DefaultGaussianPod> {
75    pub model_transform_buffer: ModelTransformBuffer,
76    pub gaussians_buffer: GaussiansBuffer<G>,
77    pub indirect_args_buffer: IndirectArgsBuffer,
78    pub radix_sort_indirect_args_buffer: RadixSortIndirectArgsBuffer,
79    pub indirect_indices_buffer: IndirectIndicesBuffer,
80    pub gaussians_depth_buffer: GaussiansDepthBuffer,
81    #[cfg(feature = "viewer-selection")]
82    pub selection_buffer: SelectionBuffer,
83    #[cfg(feature = "viewer-selection")]
84    pub invert_selection_buffer: selection::PreprocessorInvertSelectionBuffer,
85}
86
87impl<G: GaussianPod> MultiModelViewerGaussianBuffers<G> {
88    /// Create a new viewer Gaussian buffers.
89    pub fn new(device: &wgpu::Device, gaussians: &impl IterGaussian) -> Self {
90        Self::new_with(device, GaussiansBuffer::<G>::DEFAULT_USAGES, gaussians)
91    }
92
93    /// Create a new viewer Gaussian buffers with custom gaussians buffer usage.
94    pub fn new_with(
95        device: &wgpu::Device,
96        gaussians_buffer_usage: wgpu::BufferUsages,
97        gaussians: &impl IterGaussian,
98    ) -> Self {
99        log::debug!("Creating model transform buffer");
100        let model_transform_buffer = ModelTransformBuffer::new(device);
101
102        log::debug!("Creating gaussians buffer");
103        let gaussians_buffer =
104            GaussiansBuffer::new_with_usage(device, gaussians, gaussians_buffer_usage);
105
106        log::debug!("Creating indirect args buffer");
107        let indirect_args_buffer = IndirectArgsBuffer::new(device);
108
109        log::debug!("Creating radix sort indirect args buffer");
110        let radix_sort_indirect_args_buffer = RadixSortIndirectArgsBuffer::new(device);
111
112        // Assume it is cheap to call `iter_gaussian`.
113        let len = gaussians.iter_gaussian().len() as u32;
114
115        log::debug!("Creating indirect indices buffer");
116        let indirect_indices_buffer = IndirectIndicesBuffer::new(device, len);
117
118        log::debug!("Creating gaussians depth buffer");
119        let gaussians_depth_buffer = GaussiansDepthBuffer::new(device, len);
120
121        #[cfg(feature = "viewer-selection")]
122        let selection_buffer = {
123            log::debug!("Creating selection buffer");
124            SelectionBuffer::new(device, len)
125        };
126
127        #[cfg(feature = "viewer-selection")]
128        let invert_selection_buffer = {
129            log::debug!("Creating invert selection buffer");
130            selection::PreprocessorInvertSelectionBuffer::new(device)
131        };
132
133        Self {
134            model_transform_buffer,
135            gaussians_buffer,
136            indirect_args_buffer,
137            radix_sort_indirect_args_buffer,
138            indirect_indices_buffer,
139            gaussians_depth_buffer,
140            #[cfg(feature = "viewer-selection")]
141            selection_buffer,
142            #[cfg(feature = "viewer-selection")]
143            invert_selection_buffer,
144        }
145    }
146
147    /// Create a new viewer Gaussian buffers with only the count.
148    pub fn new_empty(device: &wgpu::Device, count: usize) -> Self {
149        Self::new_empty_with(device, count, GaussiansBuffer::<G>::DEFAULT_USAGES)
150    }
151
152    /// Create a new viewer Gaussian buffers with only the count and custom gaussians buffer usage.
153    pub fn new_empty_with(
154        device: &wgpu::Device,
155        count: usize,
156        gaussians_buffer_usage: wgpu::BufferUsages,
157    ) -> Self {
158        log::debug!("Creating model transform buffer");
159        let model_transform_buffer = ModelTransformBuffer::new(device);
160
161        log::debug!("Creating gaussians buffer");
162        let gaussians_buffer =
163            GaussiansBuffer::new_empty_with_usage(device, count, gaussians_buffer_usage);
164
165        log::debug!("Creating indirect args buffer");
166        let indirect_args_buffer = IndirectArgsBuffer::new(device);
167
168        log::debug!("Creating radix sort indirect args buffer");
169        let radix_sort_indirect_args_buffer = RadixSortIndirectArgsBuffer::new(device);
170
171        log::debug!("Creating indirect indices buffer");
172        let indirect_indices_buffer = IndirectIndicesBuffer::new(device, count as u32);
173
174        log::debug!("Creating gaussians depth buffer");
175        let gaussians_depth_buffer = GaussiansDepthBuffer::new(device, count as u32);
176
177        #[cfg(feature = "viewer-selection")]
178        let selection_buffer = {
179            log::debug!("Creating selection buffer");
180            SelectionBuffer::new(device, count as u32)
181        };
182
183        #[cfg(feature = "viewer-selection")]
184        let invert_selection_buffer = {
185            log::debug!("Creating invert selection buffer");
186            selection::PreprocessorInvertSelectionBuffer::new(device)
187        };
188
189        Self {
190            model_transform_buffer,
191            gaussians_buffer,
192            indirect_args_buffer,
193            radix_sort_indirect_args_buffer,
194            indirect_indices_buffer,
195            gaussians_depth_buffer,
196            #[cfg(feature = "viewer-selection")]
197            selection_buffer,
198            #[cfg(feature = "viewer-selection")]
199            invert_selection_buffer,
200        }
201    }
202
203    /// Update the model transform.
204    pub fn update_model_transform(
205        &mut self,
206        queue: &wgpu::Queue,
207        pos: Vec3,
208        rot: Quat,
209        scale: Vec3,
210    ) {
211        self.model_transform_buffer.update(queue, pos, rot, scale);
212    }
213
214    /// Update the model transform with [`ModelTransformPod`].
215    pub fn update_model_transform_with_pod(
216        &mut self,
217        queue: &wgpu::Queue,
218        pod: &ModelTransformPod,
219    ) {
220        self.model_transform_buffer.update_with_pod(queue, pod);
221    }
222}
223
224/// The bind groups for [`MultiModelViewer`].
225#[derive(Debug)]
226pub struct MultiModelViewerBindGroups {
227    pub preprocessor: wgpu::BindGroup,
228    pub radix_sorter: RadixSorterBindGroups,
229    pub renderer: wgpu::BindGroup,
230}
231
232impl MultiModelViewerBindGroups {
233    /// Create a new viewer bind groups.
234    pub fn new<G: GaussianPod>(
235        device: &wgpu::Device,
236        preprocessor: &Preprocessor<G, ()>,
237        radix_sorter: &RadixSorter<()>,
238        renderer: &Renderer<G, ()>,
239        gaussian_buffers: &MultiModelViewerGaussianBuffers<G>,
240        world_buffers: &MultiModelViewerWorldBuffers,
241    ) -> Self {
242        let preprocessor = preprocessor.create_bind_group(
243            device,
244            &world_buffers.camera_buffer,
245            &gaussian_buffers.model_transform_buffer,
246            &world_buffers.gaussian_transform_buffer,
247            &gaussian_buffers.gaussians_buffer,
248            &gaussian_buffers.indirect_args_buffer,
249            &gaussian_buffers.radix_sort_indirect_args_buffer,
250            &gaussian_buffers.indirect_indices_buffer,
251            &gaussian_buffers.gaussians_depth_buffer,
252            #[cfg(feature = "viewer-selection")]
253            &gaussian_buffers.selection_buffer,
254            #[cfg(feature = "viewer-selection")]
255            &gaussian_buffers.invert_selection_buffer,
256        );
257        let radix_sorter = radix_sorter.create_bind_groups(
258            device,
259            &gaussian_buffers.gaussians_depth_buffer,
260            &gaussian_buffers.indirect_indices_buffer,
261        );
262        let renderer = renderer.create_bind_group(
263            device,
264            &world_buffers.camera_buffer,
265            &gaussian_buffers.model_transform_buffer,
266            &world_buffers.gaussian_transform_buffer,
267            &gaussian_buffers.gaussians_buffer,
268            &gaussian_buffers.indirect_indices_buffer,
269        );
270
271        Self {
272            preprocessor,
273            radix_sorter,
274            renderer,
275        }
276    }
277}
278
279/// The model of the [`MultiModelViewer`].
280#[derive(Debug)]
281pub struct MultiModelViewerModel<G: GaussianPod = DefaultGaussianPod> {
282    /// Buffers for the model.
283    pub gaussian_buffers: MultiModelViewerGaussianBuffers<G>,
284
285    /// Bind groups for the model.
286    pub bind_groups: MultiModelViewerBindGroups,
287}
288
289/// The 3D Gaussian splatting viewer for multiple models.
290#[derive(Debug)]
291pub struct MultiModelViewer<G: GaussianPod = DefaultGaussianPod, K: Hash + std::cmp::Eq = String> {
292    pub models: HashMap<K, MultiModelViewerModel<G>>,
293    pub world_buffers: MultiModelViewerWorldBuffers,
294    pub preprocessor: Preprocessor<G, ()>,
295    pub radix_sorter: RadixSorter<()>,
296    pub renderer: Renderer<G, ()>,
297
298    /// The usage for the gaussians buffer when [`MultiModelViewer::insert_model`] is called.
299    ///
300    /// Can be overridden when inserting model using [`MultiModelViewer::insert_model_with`].
301    // If there are more than one of these default, maybe create something like InsertModelOptions
302    pub gaussians_buffer_usage: wgpu::BufferUsages,
303}
304
305impl<G: GaussianPod, K: Hash + std::cmp::Eq> MultiModelViewer<G, K> {
306    /// Create a new viewer.
307    pub fn new(
308        device: &wgpu::Device,
309        texture_format: wgpu::TextureFormat,
310    ) -> Result<Self, ViewerCreateError> {
311        Self::new_with_options(device, texture_format, ViewerCreateOptions::default())
312    }
313
314    /// Create a new viewer with extra [`ViewerCreateOptions`].
315    ///
316    /// Note that only [`ViewerCreateOptions::gaussians_buffer_usage`] is used when inserting models
317    /// with [`MultiModelViewer::insert_model`]. You can also override the usage using
318    /// [`MultiModelViewer::insert_model_with`].
319    pub fn new_with_options(
320        device: &wgpu::Device,
321        texture_format: wgpu::TextureFormat,
322        options: ViewerCreateOptions,
323    ) -> Result<Self, ViewerCreateError> {
324        let models = HashMap::new();
325
326        log::debug!("Creating world buffers");
327        let world_buffers = MultiModelViewerWorldBuffers::new(device);
328
329        log::debug!("Creating preprocessor");
330        let preprocessor = Preprocessor::new_without_bind_group(device)?;
331
332        log::debug!("Creating radix sorter");
333        let radix_sorter = RadixSorter::new_without_bind_groups(device);
334
335        log::debug!("Creating renderer");
336        let renderer =
337            Renderer::new_without_bind_group(device, texture_format, options.depth_stencil)?;
338
339        log::info!("Viewer created");
340
341        Ok(Self {
342            models,
343            world_buffers,
344            preprocessor,
345            radix_sorter,
346            renderer,
347
348            gaussians_buffer_usage: options.gaussians_buffer_usage,
349        })
350    }
351
352    /// Insert a new model to the viewer.
353    pub fn insert_model(
354        &mut self,
355        device: &wgpu::Device,
356        key: K,
357        gaussians: &impl IterGaussian,
358    ) -> Option<MultiModelViewerModel<G>> {
359        self.insert_model_with(device, key, self.gaussians_buffer_usage, gaussians)
360    }
361
362    /// Insert a new model to the viewer with custom gaussians buffer usage.
363    ///
364    /// This ignores [`MultiModelViewer::gaussians_buffer_usage`], and instead uses the provided
365    /// usage in the argument.
366    pub fn insert_model_with(
367        &mut self,
368        device: &wgpu::Device,
369        key: K,
370        gaussians_buffer_usage: wgpu::BufferUsages,
371        gaussians: &impl IterGaussian,
372    ) -> Option<MultiModelViewerModel<G>> {
373        let gaussian_buffers =
374            MultiModelViewerGaussianBuffers::new_with(device, gaussians_buffer_usage, gaussians);
375        let bind_groups = MultiModelViewerBindGroups::new(
376            device,
377            &self.preprocessor,
378            &self.radix_sorter,
379            &self.renderer,
380            &gaussian_buffers,
381            &self.world_buffers,
382        );
383        self.models.insert(
384            key,
385            MultiModelViewerModel {
386                gaussian_buffers,
387                bind_groups,
388            },
389        )
390    }
391
392    /// Remove a model from the viewer.
393    pub fn remove_model(&mut self, key: &K) -> Option<MultiModelViewerModel<G>> {
394        self.models.remove(key)
395    }
396
397    /// Update the camera.
398    pub fn update_camera(
399        &mut self,
400        queue: &wgpu::Queue,
401        camera: &impl CameraTrait,
402        texture_size: UVec2,
403    ) {
404        self.world_buffers
405            .update_camera(queue, camera, texture_size);
406    }
407
408    /// Update the camera with [`CameraPod`].
409    pub fn update_camera_with_pod(&mut self, queue: &wgpu::Queue, pod: &CameraPod) {
410        self.world_buffers.update_camera_with_pod(queue, pod);
411    }
412
413    /// Update the model transform.
414    pub fn update_model_transform(
415        &mut self,
416        queue: &wgpu::Queue,
417        key: &K,
418        pos: Vec3,
419        rot: Quat,
420        scale: Vec3,
421    ) -> Result<(), MultiModelViewerAccessError> {
422        self.models
423            .get_mut(key)
424            .ok_or(MultiModelViewerAccessError::ModelNotFound)?
425            .gaussian_buffers
426            .update_model_transform(queue, pos, rot, scale);
427        Ok(())
428    }
429
430    /// Update the model transform with [`ModelTransformPod`].
431    pub fn update_model_transform_with_pod(
432        &mut self,
433        queue: &wgpu::Queue,
434        key: &K,
435        pod: &ModelTransformPod,
436    ) -> Result<(), MultiModelViewerAccessError> {
437        self.models
438            .get_mut(key)
439            .ok_or(MultiModelViewerAccessError::ModelNotFound)?
440            .gaussian_buffers
441            .update_model_transform_with_pod(queue, pod);
442        Ok(())
443    }
444
445    /// Update the Gaussian transform.
446    pub fn update_gaussian_transform(
447        &mut self,
448        queue: &wgpu::Queue,
449        size: f32,
450        display_mode: GaussianDisplayMode,
451        sh_deg: GaussianShDegree,
452        no_sh0: bool,
453        max_std_dev: GaussianMaxStdDev,
454    ) {
455        self.world_buffers.update_gaussian_transform(
456            queue,
457            size,
458            display_mode,
459            sh_deg,
460            no_sh0,
461            max_std_dev,
462        );
463    }
464
465    /// Update the Gaussian transform with [`GaussianTransformPod`].
466    pub fn update_gaussian_transform_with_pod(
467        &mut self,
468        queue: &wgpu::Queue,
469        pod: &GaussianTransformPod,
470    ) {
471        self.world_buffers
472            .update_gaussian_transform_with_pod(queue, pod);
473    }
474
475    /// Render the viewer.
476    pub fn render(
477        &self,
478        encoder: &mut wgpu::CommandEncoder,
479        texture_view: &wgpu::TextureView,
480        keys: &[&K],
481    ) -> Result<(), MultiModelViewerAccessError> {
482        let models = keys
483            .iter()
484            .map(|key| {
485                self.models
486                    .get(key)
487                    .ok_or(MultiModelViewerAccessError::ModelNotFound)
488            })
489            .collect::<Result<Vec<_>, _>>()?;
490
491        for model in models.iter() {
492            self.preprocessor.preprocess(
493                encoder,
494                &model.bind_groups.preprocessor,
495                model.gaussian_buffers.gaussians_buffer.len() as u32,
496            );
497
498            self.radix_sorter.sort(
499                encoder,
500                &model.bind_groups.radix_sorter,
501                &model.gaussian_buffers.radix_sort_indirect_args_buffer,
502            );
503        }
504
505        {
506            let mut render_pass = encoder.begin_render_pass(&wgpu::RenderPassDescriptor {
507                label: Some("Multi Model Viewer Render Pass"),
508                color_attachments: &[Some(wgpu::RenderPassColorAttachment {
509                    view: texture_view,
510                    resolve_target: None,
511                    ops: wgpu::Operations {
512                        load: wgpu::LoadOp::Clear(wgpu::Color::BLACK),
513                        store: wgpu::StoreOp::Store,
514                    },
515                    depth_slice: None,
516                })],
517                ..Default::default()
518            });
519
520            for model in models.iter() {
521                self.renderer.render_with_pass(
522                    &mut render_pass,
523                    &model.bind_groups.renderer,
524                    &model.gaussian_buffers.indirect_args_buffer,
525                );
526            }
527        }
528
529        Ok(())
530    }
531}