custom_modify_selection/
custom_modify_selection.rs

1use clap::{Parser, ValueEnum};
2use glam::*;
3
4use wgpu::util::DeviceExt;
5use wgpu_3dgs_editor::{
6    self as gs,
7    core::{BufferWrapper, GaussianPod as _},
8};
9
10/// The command line arguments.
11#[derive(Parser, Debug)]
12#[command(
13    version,
14    about,
15    long_about = "\
16    A 3D Gaussian splatting editor to apply custom modifier to Gaussians in a model selected by a cylinder along the z-axis.
17    "
18)]
19struct Args {
20    /// Path to the .ply file.
21    #[arg(short, long)]
22    model: String,
23
24    /// The output path for the modified .ply file.
25    #[arg(short, long, default_value = "target/output.ply")]
26    output: String,
27
28    /// The position of the selection cylinder.
29    #[arg(
30        short,
31        long,
32        allow_hyphen_values = true,
33        num_args = 3,
34        value_delimiter = ',',
35        default_value = "0.0,0.0,0.0"
36    )]
37    pos: Vec<f32>,
38
39    /// The radius of the selection cylinder.
40    #[arg(short, long, default_value_t = 3.0)]
41    radius: f32,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
45enum Factory {
46    Struct,
47    Closure,
48}
49
50type GaussianPod = gs::core::GaussianPodWithShSingleCov3dRotScaleConfigs;
51
52#[tokio::main]
53async fn main() {
54    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
55
56    let args = Args::parse();
57    let model_path = &args.model;
58    let pos = Vec3::from_slice(&args.pos);
59    let radius = args.radius;
60
61    log::debug!("Creating wgpu instance");
62    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
63
64    log::debug!("Requesting adapter");
65    let adapter = instance
66        .request_adapter(&wgpu::RequestAdapterOptions::default())
67        .await
68        .expect("adapter");
69
70    log::debug!("Requesting device");
71    let (device, queue) = adapter
72        .request_device(&wgpu::DeviceDescriptor {
73            label: Some("Device"),
74            required_features: wgpu::Features::empty(),
75            required_limits: adapter.limits(),
76            memory_hints: wgpu::MemoryHints::default(),
77            trace: wgpu::Trace::Off,
78        })
79        .await
80        .expect("device");
81
82    log::debug!("Creating gaussians");
83    let f = std::fs::File::open(model_path).expect("ply file");
84    let mut reader = std::io::BufReader::new(f);
85    let gaussians = gs::core::Gaussians::read_ply(&mut reader).expect("gaussians");
86
87    log::debug!("Creating editor");
88    let editor = gs::Editor::<GaussianPod>::new(&device, &gaussians);
89
90    log::debug!("Creating buffers");
91    let pos_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
92        label: Some("Position Buffer"),
93        contents: bytemuck::bytes_of(&pos),
94        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
95    });
96
97    let radius_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
98        label: Some("Radius Buffer"),
99        contents: bytemuck::bytes_of(&radius),
100        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
101    });
102
103    const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor =
104        wgpu::BindGroupLayoutDescriptor {
105            label: Some("Bind Group Layout"),
106            entries: &[
107                // Position uniform buffer
108                wgpu::BindGroupLayoutEntry {
109                    binding: 0,
110                    visibility: wgpu::ShaderStages::COMPUTE,
111                    ty: wgpu::BindingType::Buffer {
112                        ty: wgpu::BufferBindingType::Uniform,
113                        has_dynamic_offset: false,
114                        min_binding_size: None,
115                    },
116                    count: None,
117                },
118                // Radius uniform buffer
119                wgpu::BindGroupLayoutEntry {
120                    binding: 1,
121                    visibility: wgpu::ShaderStages::COMPUTE,
122                    ty: wgpu::BindingType::Buffer {
123                        ty: wgpu::BufferBindingType::Uniform,
124                        has_dynamic_offset: false,
125                        min_binding_size: None,
126                    },
127                    count: None,
128                },
129                // Selection buffer (only in modifier pipeline)
130                wgpu::BindGroupLayoutEntry {
131                    binding: 2,
132                    visibility: wgpu::ShaderStages::COMPUTE,
133                    ty: wgpu::BindingType::Buffer {
134                        ty: wgpu::BufferBindingType::Storage { read_only: true },
135                        has_dynamic_offset: false,
136                        min_binding_size: None,
137                    },
138                    count: None,
139                },
140            ],
141        };
142
143    log::debug!("Creating cylinder selection compute bundle");
144    let cylinder_selection_bundle = gs::core::ComputeBundleBuilder::new()
145        .label("Selection")
146        .bind_group_layouts([
147            &gs::SelectionBundle::<GaussianPod>::GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
148            &wgpu::BindGroupLayoutDescriptor {
149                entries: &BIND_GROUP_LAYOUT_DESCRIPTOR.entries[..2],
150                ..BIND_GROUP_LAYOUT_DESCRIPTOR
151            },
152        ])
153        .main_shader("package::selection".parse().unwrap())
154        .entry_point("main")
155        .wesl_compile_options(wesl::CompileOptions {
156            features: GaussianPod::wesl_features(),
157            ..Default::default()
158        })
159        .resolver({
160            let mut resolver =
161                wesl::StandardResolver::new("examples/shader/custom_modify_selection");
162            resolver.add_package(&gs::core::shader::PACKAGE);
163            resolver
164        })
165        .build_without_bind_groups(&device)
166        .map_err(|e| log::error!("{e}"))
167        .expect("selection bundle");
168
169    log::debug!("Creating custom modifier");
170    #[allow(dead_code)]
171    let modifier_factory = |selection_buffer: &gs::SelectionBuffer| /* -> impl gs::Modifier<GaussianPod> */ {
172            log::debug!("Creating custom modifier compute bundle");
173            let modifier_bundle = gs::core::ComputeBundleBuilder::new()
174                .label("Modifier")
175                .bind_group_layouts([
176                    &gs::MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
177                    &BIND_GROUP_LAYOUT_DESCRIPTOR,
178                ])
179                .resolver({
180                    let mut resolver =
181                        wesl::StandardResolver::new("examples/shader/custom_modify_selection");
182                    resolver.add_package(&gs::core::shader::PACKAGE);
183                    resolver.add_package(&gs::shader::PACKAGE);
184                    resolver
185                })
186                .main_shader("package::modifier".parse().unwrap())
187                .entry_point("main")
188                .wesl_compile_options(wesl::CompileOptions {
189                    features: GaussianPod::wesl_features(),
190                    ..Default::default()
191                })
192                .build(
193                    &device,
194                    [
195                        vec![
196                            editor.gaussians_buffer.buffer().as_entire_binding(),
197                            editor.model_transform_buffer.buffer().as_entire_binding(),
198                            editor
199                                .gaussian_transform_buffer
200                                .buffer()
201                                .as_entire_binding(),
202                        ],
203                        vec![
204                            pos_buffer.as_entire_binding(),
205                            radius_buffer.as_entire_binding(),
206                            selection_buffer.buffer().as_entire_binding(),
207                        ],
208                    ],
209                )
210                .map_err(|e| log::error!("{e}"))
211                .expect("modifier bundle");
212
213            // This is a modifier closure because this function signature has blanket impl of the modifier trait
214            move |_device: &wgpu::Device,
215                  encoder: &mut wgpu::CommandEncoder,
216                  gaussians: &gs::core::GaussiansBuffer<GaussianPod>,
217                  _model_transform: &gs::core::ModelTransformBuffer,
218                  _gaussian_transform: &gs::core::GaussianTransformBuffer| {
219                  modifier_bundle.dispatch(encoder, gaussians.len() as u32);
220            }
221        };
222
223    #[allow(dead_code)]
224    struct Modifier<G: gs::core::GaussianPod>(gs::core::ComputeBundle, std::marker::PhantomData<G>);
225
226    impl<G: gs::core::GaussianPod> Modifier<G> {
227        #[allow(dead_code)]
228        fn new(
229            device: &wgpu::Device,
230            editor: &gs::Editor<G>,
231            pos_buffer: &wgpu::Buffer,
232            radius_buffer: &wgpu::Buffer,
233            selection_buffer: &gs::SelectionBuffer,
234        ) -> Self {
235            log::debug!("Creating custom modifier compute bundle");
236            let modifier_bundle = gs::core::ComputeBundleBuilder::new()
237                .label("Modifier")
238                .bind_group_layouts([
239                    &gs::MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
240                    &BIND_GROUP_LAYOUT_DESCRIPTOR,
241                ])
242                .resolver({
243                    let mut resolver =
244                        wesl::StandardResolver::new("examples/shader/custom_modify_selection");
245                    resolver.add_package(&gs::core::shader::PACKAGE);
246                    resolver.add_package(&gs::shader::PACKAGE);
247                    resolver
248                })
249                .main_shader("package::modifier".parse().unwrap())
250                .entry_point("main")
251                .wesl_compile_options(wesl::CompileOptions {
252                    features: GaussianPod::wesl_features(),
253                    ..Default::default()
254                })
255                .build(
256                    &device,
257                    [
258                        vec![
259                            editor.gaussians_buffer.buffer().as_entire_binding(),
260                            editor.model_transform_buffer.buffer().as_entire_binding(),
261                            editor
262                                .gaussian_transform_buffer
263                                .buffer()
264                                .as_entire_binding(),
265                        ],
266                        vec![
267                            pos_buffer.as_entire_binding(),
268                            radius_buffer.as_entire_binding(),
269                            selection_buffer.buffer().as_entire_binding(),
270                        ],
271                    ],
272                )
273                .map_err(|e| log::error!("{e}"))
274                .expect("modifier bundle");
275
276            Self(modifier_bundle, std::marker::PhantomData)
277        }
278    }
279
280    impl<G: gs::core::GaussianPod> gs::Modifier<G> for Modifier<G> {
281        fn apply(
282            &self,
283            _device: &wgpu::Device,
284            encoder: &mut wgpu::CommandEncoder,
285            gaussians: &gs::core::GaussiansBuffer<G>,
286            _model_transform: &gs::core::ModelTransformBuffer,
287            _gaussian_transform: &gs::core::GaussianTransformBuffer,
288        ) {
289            self.0.dispatch(encoder, gaussians.len() as u32);
290        }
291    }
292
293    log::debug!("Creating selection modifier");
294    let mut selection_modifier = gs::SelectionModifier::<GaussianPod, _>::new(
295        &device,
296        &editor.gaussians_buffer,
297        vec![cylinder_selection_bundle],
298        modifier_factory,
299        // Uncomment the following line to use modifier struct instead of closure
300        // |selection_buffer| {
301        //     Modifier::new(
302        //         &device,
303        //         &editor,
304        //         &pos_buffer,
305        //         &radius_buffer,
306        //         selection_buffer,
307        //     )
308        // },
309    );
310
311    log::debug!("Creating selection expression");
312    selection_modifier.selection_expr = gs::SelectionExpr::selection(
313        0,
314        vec![
315            selection_modifier.selection.bundles[0]
316                .create_bind_group(
317                    &device,
318                    1, // index 0 is the Gaussians buffer, so we use 1,
319                    [
320                        pos_buffer.as_entire_binding(),
321                        radius_buffer.as_entire_binding(),
322                    ],
323                )
324                .expect("selection expr bind group"),
325        ],
326    );
327
328    log::info!("Starting editing process");
329    let time = std::time::Instant::now();
330
331    log::debug!("Editing Gaussians");
332    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
333        label: Some("Edit Encoder"),
334    });
335
336    editor.apply(
337        &device,
338        &mut encoder,
339        [&selection_modifier as &dyn gs::Modifier<GaussianPod>],
340    );
341
342    queue.submit(Some(encoder.finish()));
343
344    #[allow(unused_must_use)]
345    device.poll(wgpu::PollType::Wait);
346
347    log::info!("Editing process completed in {:?}", time.elapsed());
348
349    log::debug!("Downloading Gaussians");
350    let modified_gaussians = gs::core::Gaussians {
351        gaussians: editor
352            .gaussians_buffer
353            .download_gaussians(&device, &queue)
354            .await
355            .expect("gaussians download"),
356    };
357
358    log::debug!("Writing modified Gaussians to output file");
359    let output_file = std::fs::File::create(&args.output).expect("output file");
360    let mut writer = std::io::BufWriter::new(output_file);
361    modified_gaussians
362        .write_ply(&mut writer)
363        .expect("write modified Gaussians to output file");
364
365    log::info!("Modified Gaussians written to {}", args.output);
366}