wgpu_3dgs_editor/
modifier.rs

1use crate::{
2    BasicColorModifiersBuffer, RotScaleBuffer, SelectionBuffer, TransformFlagsBuffer,
3    core::{
4        self, BufferWrapper, ComputeBundle, ComputeBundleBuilder, GaussianPod,
5        GaussianTransformBuffer, GaussiansBuffer, ModelTransformBuffer,
6    },
7    shader,
8};
9
10/// A trait to apply modifier to Gaussians.
11///
12/// ## Overview
13///
14/// This trait simply defines a method to apply modifications to a set of Gaussians stored in a
15/// [`GaussiansBuffer`]. It makes it convenient for users to apply a sequence of modifications.
16///
17/// The trait is also blanket implemented for closures with the same signature, allowing users to
18/// easily create modifier closures instead of having to define a modifier struct.
19///
20/// [`Editor`](crate::Editor) also provides an `apply` method which takes a slice of
21/// [`Modifier`]s to apply them in sequence to the stored Gaussians.
22///
23/// ## Usage
24///
25/// There are many ways to use this but the recommended way is to implement this trait for a closure
26/// which dispatch a [`ComputeBundle`].
27///
28/// ```rust
29/// // Create the modifier compute bundle
30/// let my_modifier_bundle = ComputeBundleBuilder::new()
31///     .label("My Modifier")
32///     .bind_group_layouts([
33///         &MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR, // For accessing Gaussians and transforms
34///         &MY_CUSTOM_BIND_GROUP_LAYOUT_DESCRIPTOR, // Put your custom bind group layout here.
35///     ])
36///     .resolver({
37///         let mut resolver = wesl::StandardResolver::new("path/to/my/folder/containing/wesl");
38///         resolver.add_package(&core::shader::PACKAGE); // Required for using core buffer structs.
39///         resolver.add_package(&shader::PACKAGE); // Optionally add this for some utility functions.
40///         resolver
41///     })
42///     .main_shader("package::my_wesl_filename".parse().unwrap())
43///     .entry_point("main")
44///     .wesl_compile_options(wesl::CompileOptions {
45///         features: G::wesl_features(), // Required for enabling the correct features for core struct.
46///         ..Default::default()
47///     })
48///     .build(
49///         &device,
50///         [
51///             vec![
52///                 gaussians_buffer.buffer().as_entire_binding(),
53///                 model_transform_buffer.buffer().as_entire_binding(),
54///                 gaussian_transform_buffer.buffer().as_entire_binding(),
55///             ],
56///             vec![ /* Your custom bind group resources */ ],
57///         ],
58///     )
59///     .map_err(|e| log::error!("{e}"))
60///     .expect("my modifier bundle");
61///
62/// // Create the modifier closure
63/// let my_modifier = |device: &wgpu::Device,
64///                    encoder: &mut wgpu::CommandEncoder,
65///                    gaussians: &GaussiansBuffer<G>,
66///                    model_transform: &ModelTransformBuffer,
67///                    gaussian_transform: &GaussianTransformBuffer| {
68///     my_modifier_bundle.dispatch(encoder, gaussians.len() as u32);
69/// };
70///
71/// // Apply the modifier using an editor as an example
72/// let editor = Editor::new(&device, &gaussians);
73/// editor.apply(&device, &mut encoder, [&my_modifier as &dyn gs::Modifier<GaussianPod>]);
74/// ```
75///
76/// ## Shader Format
77///
78/// You may copy and paste the following shader bindings for
79/// [`MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR`] into your custom selection operation
80/// shader to ensure that the bindings are correct, then add your own bindings after that.
81///
82/// ```wgsl
83/// import wgpu_3dgs_core::{
84///     gaussian::Gaussian,
85///     gaussian_transform::GaussianTransform,
86///     model_transform::ModelTransform,
87/// };
88///
89/// @group(0) @binding(0)
90/// var<storage, read_write> gaussians: array<Gaussian>;
91///
92/// @group(0) @binding(1)
93/// var<uniform> model_transform: ModelTransform;
94///
95/// @group(0) @binding(2)
96/// var<uniform> gaussian_transform: GaussianTransform;
97///
98/// // Your custom bindings here...
99/// //
100/// // You may also apply modifier to selected gaussians only by adding:
101/// // @group(1) @binding(N)
102/// // var<storage, read> selection: array<u32>;
103///
104/// override workgroup_size: u32;
105///
106/// @compute @workgroup_size(workgroup_size)
107/// fn main(@builtin(global_invocation_id) id: vec3<u32>) {
108///     let index = id.x;
109///
110///     if index >= arrayLength(&gaussians) {
111///         return;
112///     }
113///     
114///     @if(/* using selection buffer */) {
115///         let word_index = index / 32u;
116///         let bit_index = index % 32u;
117///         let bit_mask = 1u << bit_index;
118///         if (selection[word_index] & bit_mask) == 0 {
119///             return;
120///         }
121///     }
122///
123///     var gaussian = gaussians[index];
124///
125///     // Your custom modifier operation code here...
126///
127///     gaussians[index] = gaussian;
128/// }
129pub trait Modifier<G: GaussianPod> {
130    /// Apply the modifier to the Gaussians.
131    fn apply(
132        &self,
133        device: &wgpu::Device,
134        encoder: &mut wgpu::CommandEncoder,
135        gaussians: &GaussiansBuffer<G>,
136        model_transform: &ModelTransformBuffer,
137        gaussian_transform: &GaussianTransformBuffer,
138    );
139}
140
141impl<
142    G: GaussianPod,
143    F: Fn(
144        &wgpu::Device,
145        &mut wgpu::CommandEncoder,
146        &GaussiansBuffer<G>,
147        &ModelTransformBuffer,
148        &GaussianTransformBuffer,
149    ),
150> Modifier<G> for F
151{
152    fn apply(
153        &self,
154        device: &wgpu::Device,
155        encoder: &mut wgpu::CommandEncoder,
156        gaussians: &GaussiansBuffer<G>,
157        model_transform: &ModelTransformBuffer,
158        gaussian_transform: &GaussianTransformBuffer,
159    ) {
160        self(
161            device,
162            encoder,
163            gaussians,
164            model_transform,
165            gaussian_transform,
166        );
167    }
168}
169
170/// The bind group layout descriptor for the Gaussians buffer, with the
171/// model transform and Gaussian transform.
172///
173/// This bind group layout takes the following buffers:
174/// - [`GaussiansBuffer`]
175/// - [`ModelTransformBuffer`]
176/// - [`GaussianTransformBuffer`]
177///
178/// This bind group is usually at group 0 for [`Modifier`]s.
179pub const MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor =
180    wgpu::BindGroupLayoutDescriptor {
181        label: Some("Modifier Gaussians Bind Group Layout"),
182        entries: &[
183            // Gaussians storage buffer
184            wgpu::BindGroupLayoutEntry {
185                binding: 0,
186                visibility: wgpu::ShaderStages::COMPUTE,
187                ty: wgpu::BindingType::Buffer {
188                    ty: wgpu::BufferBindingType::Storage { read_only: false },
189                    has_dynamic_offset: false,
190                    min_binding_size: None,
191                },
192                count: None,
193            },
194            // Model transform uniform buffer
195            wgpu::BindGroupLayoutEntry {
196                binding: 1,
197                visibility: wgpu::ShaderStages::COMPUTE,
198                ty: wgpu::BindingType::Buffer {
199                    ty: wgpu::BufferBindingType::Uniform,
200                    has_dynamic_offset: false,
201                    min_binding_size: None,
202                },
203                count: None,
204            },
205            // Gaussian transform uniform buffer
206            wgpu::BindGroupLayoutEntry {
207                binding: 2,
208                visibility: wgpu::ShaderStages::COMPUTE,
209                ty: wgpu::BindingType::Buffer {
210                    ty: wgpu::BufferBindingType::Uniform,
211                    has_dynamic_offset: false,
212                    min_binding_size: None,
213                },
214                count: None,
215            },
216        ],
217    };
218
219/// A marker struct to indicate that a modifier takes a selection buffer.
220#[derive(Debug)]
221pub struct WithSelection;
222
223/// A marker struct to indicate that a modifier does not take a selection buffer.
224#[derive(Debug)]
225pub struct NoSelection;
226
227/// A specialized [`ComputeBundle`] for some built-in basic modifier.
228///
229/// This bundle includes the modifiers for [`BasicColorModifiersBuffer`],
230/// [`RotScaleBuffer`], and [`TransformFlagsBuffer`] (which provides flags for applying
231/// [`core::ModelTransformBuffer`] and [`core::GaussianTransformBuffer`]).
232#[derive(Debug)]
233pub struct BasicModifierBundle<G: GaussianPod, S = NoSelection, B = wgpu::BindGroup> {
234    bundle: ComputeBundle<B>,
235    gaussian_pod_marker: std::marker::PhantomData<G>,
236    selection_marker: std::marker::PhantomData<S>,
237}
238
239impl<G: GaussianPod, S, B> BasicModifierBundle<G, S, B> {
240    /// Gets the inner [`ComputeBundle`].
241    pub fn bundle(&self) -> &ComputeBundle<B> {
242        &self.bundle
243    }
244}
245
246impl<G: GaussianPod> BasicModifierBundle<G> {
247    /// The bind group layout descriptor for the [`BasicModifierBundle`].
248    ///
249    /// This bind group layout takes the following buffers:
250    /// - [`TransformFlagsBuffer`]
251    /// - [`BasicColorModifiersBuffer`]
252    /// - [`RotScaleBuffer`]
253    ///
254    /// This is at group 1, because group 0 is the [`MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR`].
255    pub const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
256        wgpu::BindGroupLayoutDescriptor {
257            label: Some("Basic Modifier Bind Group Layout"),
258            entries: &BasicModifierBundle::<G, WithSelection>::BIND_GROUP_LAYOUT_DESCRIPTOR
259                .entries
260                .split_at(3)
261                .0,
262        };
263
264    /// Creates a new [`BasicModifierBundle`] bundle.
265    pub fn new(
266        device: &wgpu::Device,
267        gaussians_buffer: &GaussiansBuffer<G>,
268        model_transform_buffer: &ModelTransformBuffer,
269        gaussian_transform_buffer: &GaussianTransformBuffer,
270        transform_flags_buffer: &TransformFlagsBuffer,
271        basic_color_modifiers_buffer: &BasicColorModifiersBuffer,
272        rot_scale_buffer: &RotScaleBuffer,
273    ) -> Self {
274        Self::create_bundle_builder(false)
275            .build(
276                &device,
277                [
278                    [
279                        gaussians_buffer.buffer().as_entire_binding(),
280                        model_transform_buffer.buffer().as_entire_binding(),
281                        gaussian_transform_buffer.buffer().as_entire_binding(),
282                    ],
283                    [
284                        transform_flags_buffer.buffer().as_entire_binding(),
285                        basic_color_modifiers_buffer.buffer().as_entire_binding(),
286                        rot_scale_buffer.buffer().as_entire_binding(),
287                    ],
288                ],
289            )
290            .map(|bundle| Self {
291                bundle,
292                gaussian_pod_marker: std::marker::PhantomData,
293                selection_marker: std::marker::PhantomData,
294            })
295            .map_err(|e| log::error!("{e}"))
296            .expect("basic modifier bundle")
297    }
298
299    /// Creates a new [`ComputeBundleBuilder`] for the basic modifier.
300    fn create_bundle_builder<'a>(
301        has_selection: bool,
302    ) -> ComputeBundleBuilder<'a, wesl::PkgResolver> {
303        ComputeBundleBuilder::new()
304            .label("Basic Modifier")
305            .bind_group_layouts([
306                &MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
307                match has_selection {
308                    true => &BasicModifierBundle::<G, WithSelection>::BIND_GROUP_LAYOUT_DESCRIPTOR,
309                    false => &BasicModifierBundle::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR,
310                },
311            ])
312            .resolver({
313                let mut resolver = wesl::PkgResolver::new();
314                resolver.add_package(&core::shader::PACKAGE);
315                resolver.add_package(&shader::PACKAGE);
316                resolver
317            })
318            .main_shader(
319                "wgpu_3dgs_editor::modifier::basic"
320                    .parse()
321                    .expect("modifier::basic module path"),
322            )
323            .entry_point("main")
324            .wesl_compile_options(wesl::CompileOptions {
325                features: wesl::Features {
326                    flags: G::features()
327                        .into_iter()
328                        .chain(std::iter::once(("selection_buffer", has_selection)))
329                        .map(|(k, v)| (k.to_string(), v.into()))
330                        .collect(),
331                    ..Default::default()
332                },
333                ..Default::default()
334            })
335    }
336}
337
338impl<G: GaussianPod> BasicModifierBundle<G, WithSelection> {
339    /// The bind group layout descriptor for the [`BasicModifierBundle`] with a [`SelectionBuffer`].
340    ///
341    /// Thie bind group layout takes the following buffers:
342    /// - [`TransformFlagsBuffer`]
343    /// - [`BasicColorModifiersBuffer`]
344    /// - [`RotScaleBuffer`]
345    /// - [`SelectionBuffer`]
346    ///
347    /// This is at group 1, because group 0 is the [`MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR`].
348    pub const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
349        wgpu::BindGroupLayoutDescriptor {
350            label: Some("Basic Modifier Bind Group Layout"),
351            entries: &[
352                // Transform flags uniform buffer
353                wgpu::BindGroupLayoutEntry {
354                    binding: 0,
355                    visibility: wgpu::ShaderStages::COMPUTE,
356                    ty: wgpu::BindingType::Buffer {
357                        ty: wgpu::BufferBindingType::Uniform,
358                        has_dynamic_offset: false,
359                        min_binding_size: None,
360                    },
361                    count: None,
362                },
363                // Basic color modifiers uniform buffer
364                wgpu::BindGroupLayoutEntry {
365                    binding: 1,
366                    visibility: wgpu::ShaderStages::COMPUTE,
367                    ty: wgpu::BindingType::Buffer {
368                        ty: wgpu::BufferBindingType::Uniform,
369                        has_dynamic_offset: false,
370                        min_binding_size: None,
371                    },
372                    count: None,
373                },
374                // Scale rotation uniform buffer
375                wgpu::BindGroupLayoutEntry {
376                    binding: 2,
377                    visibility: wgpu::ShaderStages::COMPUTE,
378                    ty: wgpu::BindingType::Buffer {
379                        ty: wgpu::BufferBindingType::Uniform,
380                        has_dynamic_offset: false,
381                        min_binding_size: None,
382                    },
383                    count: None,
384                },
385                // Selection buffer
386                wgpu::BindGroupLayoutEntry {
387                    binding: 3,
388                    visibility: wgpu::ShaderStages::COMPUTE,
389                    ty: wgpu::BindingType::Buffer {
390                        ty: wgpu::BufferBindingType::Storage { read_only: true },
391                        has_dynamic_offset: false,
392                        min_binding_size: None,
393                    },
394                    count: None,
395                },
396            ],
397        };
398
399    /// Creates a new [`BasicModifierBundle`] bundle with [`SelectionBuffer`].
400    pub fn new_with_selection(
401        device: &wgpu::Device,
402        gaussians_buffer: &GaussiansBuffer<G>,
403        model_transform_buffer: &ModelTransformBuffer,
404        gaussian_transform_buffer: &GaussianTransformBuffer,
405        transform_flags_buffer: &TransformFlagsBuffer,
406        basic_color_modifiers_buffer: &BasicColorModifiersBuffer,
407        rot_scale_buffer: &RotScaleBuffer,
408        selection_buffer: &SelectionBuffer,
409    ) -> Self {
410        BasicModifierBundle::<G>::create_bundle_builder(true)
411            .build(
412                &device,
413                [
414                    vec![
415                        gaussians_buffer.buffer().as_entire_binding(),
416                        model_transform_buffer.buffer().as_entire_binding(),
417                        gaussian_transform_buffer.buffer().as_entire_binding(),
418                    ],
419                    vec![
420                        transform_flags_buffer.buffer().as_entire_binding(),
421                        basic_color_modifiers_buffer.buffer().as_entire_binding(),
422                        rot_scale_buffer.buffer().as_entire_binding(),
423                        selection_buffer.buffer().as_entire_binding(),
424                    ],
425                ],
426            )
427            .map(|bundle| Self {
428                bundle,
429                gaussian_pod_marker: std::marker::PhantomData,
430                selection_marker: std::marker::PhantomData,
431            })
432            .map_err(|e| log::error!("{e}"))
433            .expect("basic modifier bundle")
434    }
435}
436
437impl<G: GaussianPod, S> BasicModifierBundle<G, S> {
438    /// Apply the basic modifier to the Gaussians.
439    pub fn apply_with_count(&self, encoder: &mut wgpu::CommandEncoder, gaussian_count: u32) {
440        self.bundle().dispatch(encoder, gaussian_count);
441    }
442}
443
444impl<G: GaussianPod, S> Modifier<G> for BasicModifierBundle<G, S> {
445    fn apply(
446        &self,
447        _device: &wgpu::Device,
448        encoder: &mut wgpu::CommandEncoder,
449        gaussians: &GaussiansBuffer<G>,
450        _model_transform: &ModelTransformBuffer,
451        _gaussian_transform: &GaussianTransformBuffer,
452    ) {
453        self.apply_with_count(encoder, gaussians.len() as u32);
454    }
455}
456
457impl<G: GaussianPod> BasicModifierBundle<G, NoSelection, ()> {
458    /// Creates a new [`BasicModifierBundle`] bundle without a bind group.
459    pub fn new_without_bind_group(device: &wgpu::Device) -> Self {
460        BasicModifierBundle::<G>::create_bundle_builder(false)
461            .build_without_bind_groups(&device)
462            .map(|bundle| Self {
463                bundle,
464                gaussian_pod_marker: std::marker::PhantomData,
465                selection_marker: std::marker::PhantomData,
466            })
467            .expect("basic modifier bundle")
468    }
469}
470
471impl<G: GaussianPod> BasicModifierBundle<G, WithSelection, ()> {
472    /// Creates a new [`BasicModifierBundle`] bundle without a bind group with selection buffer.
473    pub fn new_without_bind_group_with_selection(device: &wgpu::Device) -> Self {
474        BasicModifierBundle::<G>::create_bundle_builder(true)
475            .build_without_bind_groups(&device)
476            .map(|bundle| Self {
477                bundle,
478                gaussian_pod_marker: std::marker::PhantomData,
479                selection_marker: std::marker::PhantomData,
480            })
481            .expect("basic modifier bundle")
482    }
483}
484
485impl<G: GaussianPod, S> BasicModifierBundle<G, S, ()> {
486    /// Apply the basic modifier to the Gaussians.
487    ///
488    /// - `gaussians_bind_group` is the bind group created from [`MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR`].
489    /// - `bind_group` is the bind group created from [`BasicModifierBundle::BIND_GROUP_LAYOUT_DESCRIPTOR`].
490    pub fn apply_with_count<'a>(
491        &self,
492        encoder: &mut wgpu::CommandEncoder,
493        gaussians_bind_group: &wgpu::BindGroup,
494        bind_group: &wgpu::BindGroup,
495        gaussian_count: u32,
496    ) {
497        self.bundle()
498            .dispatch(encoder, gaussian_count, [gaussians_bind_group, bind_group]);
499    }
500}
501
502/// A struct to handle basic modifier.
503///
504/// This modifier holds a [`BasicModifierBundle`] along with necessary buffers, and applies the
505/// basic modifier.
506#[derive(Debug)]
507pub struct BasicModifier<G: GaussianPod, S = NoSelection> {
508    pub transform_flags_buffer: TransformFlagsBuffer,
509    pub basic_color_modifiers_buffer: BasicColorModifiersBuffer,
510    pub rot_scale_buffer: RotScaleBuffer,
511    pub modifier: BasicModifierBundle<G, S>,
512}
513
514impl<G: GaussianPod> BasicModifier<G> {
515    /// Create a new basic modifier.
516    pub fn new(
517        device: &wgpu::Device,
518        gaussians_buffer: &GaussiansBuffer<G>,
519        model_transform_buffer: &ModelTransformBuffer,
520        gaussian_transform_buffer: &GaussianTransformBuffer,
521    ) -> Self {
522        log::debug!("Creating transform flags buffer");
523        let transform_flags_buffer = TransformFlagsBuffer::new(device);
524
525        log::debug!("Creating basic color modifiers buffer");
526        let basic_color_modifiers_buffer = BasicColorModifiersBuffer::new(device);
527
528        log::debug!("Creating rotation scale buffer");
529        let rot_scale_buffer = RotScaleBuffer::new(device);
530
531        log::debug!("Creating basic modifier bundle");
532        let modifier = BasicModifierBundle::new(
533            device,
534            gaussians_buffer,
535            model_transform_buffer,
536            gaussian_transform_buffer,
537            &transform_flags_buffer,
538            &basic_color_modifiers_buffer,
539            &rot_scale_buffer,
540        );
541
542        log::debug!("Basic modifier created");
543
544        Self {
545            transform_flags_buffer,
546            basic_color_modifiers_buffer,
547            rot_scale_buffer,
548
549            modifier,
550        }
551    }
552}
553
554impl<G: GaussianPod> BasicModifier<G, WithSelection> {
555    /// Create a new basic modifier with selection.
556    pub fn new_with_selection(
557        device: &wgpu::Device,
558        gaussians_buffer: &GaussiansBuffer<G>,
559        model_transform_buffer: &ModelTransformBuffer,
560        gaussian_transform_buffer: &GaussianTransformBuffer,
561        selection_buffer: &SelectionBuffer,
562    ) -> Self {
563        log::debug!("Creating transform flags buffer");
564        let transform_flags_buffer = TransformFlagsBuffer::new(device);
565
566        log::debug!("Creating basic color modifiers buffer");
567        let basic_color_modifiers_buffer = BasicColorModifiersBuffer::new(device);
568
569        log::debug!("Creating rotation scale buffer");
570        let rot_scale_buffer = RotScaleBuffer::new(device);
571
572        log::debug!("Creating basic modifier bundle");
573        let modifier = BasicModifierBundle::new_with_selection(
574            device,
575            gaussians_buffer,
576            model_transform_buffer,
577            gaussian_transform_buffer,
578            &transform_flags_buffer,
579            &basic_color_modifiers_buffer,
580            &rot_scale_buffer,
581            selection_buffer,
582        );
583
584        log::debug!("Basic modifier created");
585
586        Self {
587            transform_flags_buffer,
588            basic_color_modifiers_buffer,
589            rot_scale_buffer,
590
591            modifier,
592        }
593    }
594}
595
596impl<G: GaussianPod, S> Modifier<G> for BasicModifier<G, S> {
597    fn apply(
598        &self,
599        _device: &wgpu::Device,
600        encoder: &mut wgpu::CommandEncoder,
601        gaussians: &GaussiansBuffer<G>,
602        _model_transform: &ModelTransformBuffer,
603        _gaussian_transform: &GaussianTransformBuffer,
604    ) {
605        self.modifier
606            .apply_with_count(encoder, gaussians.len() as u32);
607    }
608}