custom_modify_selection/
custom_modify_selection.rs

1//! This example selects parts of the model, then applies a custom modifier
2//! using [`SelectionModifier`](wgpu_3dgs_editor::SelectionModifier) and a user-defined
3//! compute pipeline built with [`ComputeBundleBuilder`](wgpu_3dgs_editor::core::ComputeBundleBuilder).
4//!
5//! The custom selection is a cylinder aligned with the z-axis, centered at `pos` with a given `radius`.
6//!
7//! The custom modifier shifts the hue of the selected Gaussians according to their x and y coordinates
8//! about the axis of the cylinder selection.
9//!
10//! For example, to use the defaults (pos = (0, 0, 0), radius = 3.0) and apply the custom modifier:
11//!
12//! ```sh
13//! cargo run --example custom-modify-selection -- -m "path/to/model.ply"
14//! ```
15
16use clap::{Parser, ValueEnum};
17use glam::*;
18
19use wgpu::util::DeviceExt;
20use wgpu_3dgs_editor::{
21    self as gs,
22    core::{BufferWrapper, GaussianPod as _},
23};
24
25/// The command line arguments.
26#[derive(Parser, Debug)]
27#[command(
28    version,
29    about,
30    long_about = "\
31    A 3D Gaussian splatting editor to apply custom modifier to Gaussians in a model selected by a cylinder along the z-axis.
32    "
33)]
34struct Args {
35    /// Path to the .ply file.
36    #[arg(short, long, default_value = "examples/model.ply")]
37    model: String,
38
39    /// The output path for the modified .ply file.
40    #[arg(short, long, default_value = "target/output.ply")]
41    output: String,
42
43    /// The position of the selection cylinder.
44    #[arg(
45        short,
46        long,
47        allow_hyphen_values = true,
48        num_args = 3,
49        value_delimiter = ',',
50        default_value = "0.0,0.0,0.0"
51    )]
52    pos: Vec<f32>,
53
54    /// The radius of the selection cylinder.
55    #[arg(short, long, default_value_t = 3.0)]
56    radius: f32,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
60enum Factory {
61    Struct,
62    Closure,
63}
64
65type GaussianPod = gs::core::GaussianPodWithShSingleCov3dRotScaleConfigs;
66
67#[pollster::main]
68async fn main() {
69    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
70
71    let args = Args::parse();
72    let model_path = &args.model;
73    let pos = Vec3::from_slice(&args.pos);
74    let radius = args.radius;
75
76    log::debug!("Creating wgpu instance");
77    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
78
79    log::debug!("Requesting adapter");
80    let adapter = instance
81        .request_adapter(&wgpu::RequestAdapterOptions::default())
82        .await
83        .expect("adapter");
84
85    log::debug!("Requesting device");
86    let (device, queue) = adapter
87        .request_device(&wgpu::DeviceDescriptor {
88            label: Some("Device"),
89            required_limits: adapter.limits(),
90            ..Default::default()
91        })
92        .await
93        .expect("device");
94
95    log::debug!("Creating gaussians");
96    let gaussians = [
97        gs::core::GaussiansSource::Ply,
98        gs::core::GaussiansSource::Spz,
99    ]
100    .into_iter()
101    .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
102    .expect("gaussians");
103
104    log::debug!("Creating editor");
105    let editor = gs::Editor::<GaussianPod>::new(&device, &gaussians);
106
107    log::debug!("Creating buffers");
108    let pos_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
109        label: Some("Position Buffer"),
110        contents: bytemuck::bytes_of(&pos),
111        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
112    });
113
114    let radius_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
115        label: Some("Radius Buffer"),
116        contents: bytemuck::bytes_of(&radius),
117        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
118    });
119
120    const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor =
121        wgpu::BindGroupLayoutDescriptor {
122            label: Some("Bind Group Layout"),
123            entries: &[
124                // Position uniform buffer
125                wgpu::BindGroupLayoutEntry {
126                    binding: 0,
127                    visibility: wgpu::ShaderStages::COMPUTE,
128                    ty: wgpu::BindingType::Buffer {
129                        ty: wgpu::BufferBindingType::Uniform,
130                        has_dynamic_offset: false,
131                        min_binding_size: None,
132                    },
133                    count: None,
134                },
135                // Radius uniform buffer
136                wgpu::BindGroupLayoutEntry {
137                    binding: 1,
138                    visibility: wgpu::ShaderStages::COMPUTE,
139                    ty: wgpu::BindingType::Buffer {
140                        ty: wgpu::BufferBindingType::Uniform,
141                        has_dynamic_offset: false,
142                        min_binding_size: None,
143                    },
144                    count: None,
145                },
146                // Selection buffer (only in modifier pipeline)
147                wgpu::BindGroupLayoutEntry {
148                    binding: 2,
149                    visibility: wgpu::ShaderStages::COMPUTE,
150                    ty: wgpu::BindingType::Buffer {
151                        ty: wgpu::BufferBindingType::Storage { read_only: true },
152                        has_dynamic_offset: false,
153                        min_binding_size: None,
154                    },
155                    count: None,
156                },
157            ],
158        };
159
160    log::debug!("Creating cylinder selection compute bundle");
161    let cylinder_selection_bundle = gs::core::ComputeBundleBuilder::new()
162        .label("Selection")
163        .bind_group_layouts([
164            &gs::SelectionBundle::<GaussianPod>::GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
165            &wgpu::BindGroupLayoutDescriptor {
166                entries: &BIND_GROUP_LAYOUT_DESCRIPTOR.entries[..2],
167                ..BIND_GROUP_LAYOUT_DESCRIPTOR
168            },
169        ])
170        .main_shader("package::selection".parse().unwrap())
171        .entry_point("main")
172        .wesl_compile_options(wesl::CompileOptions {
173            features: GaussianPod::wesl_features(),
174            ..Default::default()
175        })
176        .resolver({
177            let mut resolver =
178                wesl::StandardResolver::new("examples/shader/custom_modify_selection");
179            resolver.add_package(&gs::core::shader::PACKAGE);
180            resolver
181        })
182        .build_without_bind_groups(&device)
183        .map_err(|e| log::error!("{e}"))
184        .expect("selection bundle");
185
186    log::debug!("Creating custom modifier");
187    #[allow(dead_code)]
188    let modifier_factory = |selection_buffer: &gs::SelectionBuffer| /* -> impl gs::Modifier<GaussianPod> */ {
189            log::debug!("Creating custom modifier compute bundle");
190            let modifier_bundle = gs::core::ComputeBundleBuilder::new()
191                .label("Modifier")
192                .bind_group_layouts([
193                    &gs::MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
194                    &BIND_GROUP_LAYOUT_DESCRIPTOR,
195                ])
196                .resolver({
197                    let mut resolver =
198                        wesl::StandardResolver::new("examples/shader/custom_modify_selection");
199                    resolver.add_package(&gs::core::shader::PACKAGE);
200                    resolver.add_package(&gs::shader::PACKAGE);
201                    resolver
202                })
203                .main_shader("package::modifier".parse().unwrap())
204                .entry_point("main")
205                .wesl_compile_options(wesl::CompileOptions {
206                    features: GaussianPod::wesl_features(),
207                    ..Default::default()
208                })
209                .build(
210                    &device,
211                    [
212                        vec![
213                            editor.gaussians_buffer.buffer().as_entire_binding(),
214                            editor.model_transform_buffer.buffer().as_entire_binding(),
215                            editor
216                                .gaussian_transform_buffer
217                                .buffer()
218                                .as_entire_binding(),
219                        ],
220                        vec![
221                            pos_buffer.as_entire_binding(),
222                            radius_buffer.as_entire_binding(),
223                            selection_buffer.buffer().as_entire_binding(),
224                        ],
225                    ],
226                )
227                .map_err(|e| log::error!("{e}"))
228                .expect("modifier bundle");
229
230            // This is a modifier closure because this function signature has blanket impl of the modifier trait
231            move |_device: &wgpu::Device,
232                  encoder: &mut wgpu::CommandEncoder,
233                  gaussians: &gs::core::GaussiansBuffer<GaussianPod>,
234                  _model_transform: &gs::core::ModelTransformBuffer,
235                  _gaussian_transform: &gs::core::GaussianTransformBuffer| {
236                  modifier_bundle.dispatch(encoder, gaussians.len() as u32);
237            }
238        };
239
240    #[allow(dead_code)]
241    struct Modifier<G: gs::core::GaussianPod>(gs::core::ComputeBundle, std::marker::PhantomData<G>);
242
243    impl<G: gs::core::GaussianPod> Modifier<G> {
244        #[allow(dead_code)]
245        fn new(
246            device: &wgpu::Device,
247            editor: &gs::Editor<G>,
248            pos_buffer: &wgpu::Buffer,
249            radius_buffer: &wgpu::Buffer,
250            selection_buffer: &gs::SelectionBuffer,
251        ) -> Self {
252            log::debug!("Creating custom modifier compute bundle");
253            let modifier_bundle = gs::core::ComputeBundleBuilder::new()
254                .label("Modifier")
255                .bind_group_layouts([
256                    &gs::MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
257                    &BIND_GROUP_LAYOUT_DESCRIPTOR,
258                ])
259                .resolver({
260                    let mut resolver =
261                        wesl::StandardResolver::new("examples/shader/custom_modify_selection");
262                    resolver.add_package(&gs::core::shader::PACKAGE);
263                    resolver.add_package(&gs::shader::PACKAGE);
264                    resolver
265                })
266                .main_shader("package::modifier".parse().unwrap())
267                .entry_point("main")
268                .wesl_compile_options(wesl::CompileOptions {
269                    features: GaussianPod::wesl_features(),
270                    ..Default::default()
271                })
272                .build(
273                    device,
274                    [
275                        vec![
276                            editor.gaussians_buffer.buffer().as_entire_binding(),
277                            editor.model_transform_buffer.buffer().as_entire_binding(),
278                            editor
279                                .gaussian_transform_buffer
280                                .buffer()
281                                .as_entire_binding(),
282                        ],
283                        vec![
284                            pos_buffer.as_entire_binding(),
285                            radius_buffer.as_entire_binding(),
286                            selection_buffer.buffer().as_entire_binding(),
287                        ],
288                    ],
289                )
290                .map_err(|e| log::error!("{e}"))
291                .expect("modifier bundle");
292
293            Self(modifier_bundle, std::marker::PhantomData)
294        }
295    }
296
297    impl<G: gs::core::GaussianPod> gs::Modifier<G> for Modifier<G> {
298        fn apply(
299            &self,
300            _device: &wgpu::Device,
301            encoder: &mut wgpu::CommandEncoder,
302            gaussians: &gs::core::GaussiansBuffer<G>,
303            _model_transform: &gs::core::ModelTransformBuffer,
304            _gaussian_transform: &gs::core::GaussianTransformBuffer,
305        ) {
306            self.0.dispatch(encoder, gaussians.len() as u32);
307        }
308    }
309
310    log::debug!("Creating selection modifier");
311    let mut selection_modifier = gs::SelectionModifier::<GaussianPod, _>::new(
312        &device,
313        &editor.gaussians_buffer,
314        vec![cylinder_selection_bundle],
315        modifier_factory,
316        // Uncomment the following line to use modifier struct instead of closure
317        // |selection_buffer| {
318        //     Modifier::new(
319        //         &device,
320        //         &editor,
321        //         &pos_buffer,
322        //         &radius_buffer,
323        //         selection_buffer,
324        //     )
325        // },
326    );
327
328    log::debug!("Creating selection expression");
329    selection_modifier.selection_expr = gs::SelectionExpr::selection(
330        0,
331        vec![
332            selection_modifier.selection.bundles[0]
333                .create_bind_group(
334                    &device,
335                    1, // index 0 is the Gaussians buffer, so we use 1,
336                    [
337                        pos_buffer.as_entire_binding(),
338                        radius_buffer.as_entire_binding(),
339                    ],
340                )
341                .expect("selection expr bind group"),
342        ],
343    );
344
345    log::info!("Starting editing process");
346    let time = std::time::Instant::now();
347
348    log::debug!("Editing Gaussians");
349    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
350        label: Some("Edit Encoder"),
351    });
352
353    editor.apply(
354        &device,
355        &mut encoder,
356        [&selection_modifier as &dyn gs::Modifier<GaussianPod>],
357    );
358
359    queue.submit(Some(encoder.finish()));
360
361    device
362        .poll(wgpu::PollType::wait_indefinitely())
363        .expect("poll");
364
365    log::info!("Editing process completed in {:?}", time.elapsed());
366
367    log::debug!("Downloading Gaussians");
368    let modified_gaussians = editor
369        .gaussians_buffer
370        .download_gaussians(&device, &queue)
371        .await
372        .map(|gs| {
373            match &args.output[args.output.len().saturating_sub(4)..] {
374                ".ply" => {
375                    gs::core::Gaussians::Ply(gs::core::PlyGaussians::from_iter(gs.into_iter()))
376                }
377                ".spz" => {
378                    gs::core::Gaussians::Spz(
379                        gs::core::SpzGaussians::from_gaussians_with_options(
380                            gs,
381                            &gs::core::SpzGaussiansFromGaussianSliceOptions {
382                                version: 2, // Version 2 is more widely supported as of now
383                                ..Default::default()
384                            },
385                        )
386                        .expect("SpzGaussians from gaussians"),
387                    )
388                }
389                _ => panic!("Unsupported output file extension, expected .ply or .spz"),
390            }
391        })
392        .expect("gaussians download");
393
394    log::debug!("Writing modified Gaussians to output file");
395    modified_gaussians
396        .write_to_file(&args.output)
397        .expect("write modified Gaussians to output file");
398
399    log::info!("Modified Gaussians written to {}", args.output);
400}