Skip to main content

modify_selection/
modify_selection.rs

1//! This example selects parts of the model then apply a modifier using
2//! [`SelectionModifier::new_with_basic_modifier`](wgpu_3dgs_editor::SelectionModifier::new_with_basic_modifier).
3//!
4//! For example, to select a sphere at (0.5, 1.0, 0.5) with scale (1, 1, 1),
5//! repeat the selection 2 times with an offset (2, 0, 0), and decrease the contrast of the selected Gaussians:
6//!
7//! ```sh
8//! cargo run --example modify-selection -- \
9//!     -m "path/to/model.ply" \
10//!     -p 0.5 1.0 0.5 -s 1 1 1 --repeat 2 --offset 2.0 0.0 0.0 --contrast "-1.0"
11//! ```
12
13use clap::{Parser, ValueEnum};
14use glam::*;
15
16use wgpu_3dgs_editor::{self as gs, core::BufferWrapper};
17
18/// The command line arguments.
19#[derive(Parser, Debug)]
20#[command(
21    version,
22    about,
23    long_about = "\
24    A 3D Gaussian splatting editor to apply basic modifier to selected Gaussians in a model.
25    "
26)]
27struct Args {
28    /// Path to the .ply file.
29    #[arg(short, long, default_value = "examples/model.ply")]
30    model: String,
31
32    /// The output path for the modified .ply file.
33    #[arg(short, long, default_value = "target/output.ply")]
34    output: String,
35
36    /// The position of the selection shape.
37    #[arg(
38        short,
39        long,
40        allow_hyphen_values = true,
41        num_args = 3,
42        value_delimiter = ',',
43        default_value = "0.0,0.0,0.0"
44    )]
45    pos: Vec<f32>,
46
47    /// The rotation of the selection shape.
48    #[arg(
49        short,
50        long,
51        allow_hyphen_values = true,
52        num_args = 4,
53        value_delimiter = ',',
54        default_value = "0.0,0.0,0.0,1.0"
55    )]
56    rot: Vec<f32>,
57
58    /// The scale of the selection shape.
59    #[arg(
60        short,
61        long,
62        allow_hyphen_values = true,
63        num_args = 3,
64        value_delimiter = ',',
65        default_value = "0.5,1.0,2.0"
66    )]
67    scale: Vec<f32>,
68
69    /// The shape of the selection.
70    #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
71    shape: Shape,
72
73    /// The number of times to run the selection.
74    #[arg(long, default_value = "1")]
75    repeat: u32,
76
77    /// The offset of each selection.
78    #[arg(
79        long,
80        allow_hyphen_values = true,
81        num_args = 3,
82        value_delimiter = ',',
83        default_value = "2.0,0.0,0.0"
84    )]
85    offset: Vec<f32>,
86
87    /// Whether to override the RGB color of the selected Gaussians.
88    #[arg(long)]
89    override_rgb: bool,
90
91    /// If [`Args::override_rgb`], then it is used to override the RGB color,
92    /// otherwise it is used to apply HSV modifications.
93    ///
94    /// Normally hue (H) is in [0, 1], saturation (S) and value (V) are in [0, 2].
95    /// This function adds the hue and multiplies saturation and value.
96    #[arg(
97        long,
98        allow_hyphen_values = true,
99        num_args = 3,
100        value_delimiter = ',',
101        default_value = "0.0,1.0,1.0"
102    )]
103    rgb_or_hsv: Vec<f32>,
104
105    /// Alpha is multiplied with the original alpha.
106    #[arg(long, allow_hyphen_values = true, default_value = "1.0")]
107    alpha: f32,
108
109    /// Contrast is applied to the RGB color.
110    ///
111    /// Normally the range is [-1, 1].
112    #[arg(long, allow_hyphen_values = true, default_value = "0.0")]
113    contrast: f32,
114
115    /// Exposure is applied to the RGB color.
116    ///
117    /// Normally the range is [-5, 5].
118    #[arg(long, allow_hyphen_values = true, default_value = "0.0")]
119    exposure: f32,
120
121    /// Gamma is applied to the RGB color.
122    ///
123    /// Normally the range is [0, 5].
124    #[arg(long, allow_hyphen_values = true, default_value = "1.0")]
125    gamma: f32,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
129enum Shape {
130    Sphere,
131    Box,
132}
133
134type GaussianPod = gs::core::GaussianPodWithShSingleCov3dRotScaleConfigs;
135
136#[pollster::main]
137async fn main() {
138    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
139
140    let args = Args::parse();
141    let model_path = &args.model;
142    let pos = Vec3::from_slice(&args.pos);
143    let rot = Quat::from_slice(&args.rot);
144    let scale = Vec3::from_slice(&args.scale);
145    let shape = match args.shape {
146        Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
147        Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
148    };
149    let repeat = args.repeat;
150    let offset = Vec3::from_slice(&args.offset);
151
152    log::debug!("Creating wgpu instance");
153    let instance =
154        wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle_from_env());
155
156    log::debug!("Requesting adapter");
157    let adapter = instance
158        .request_adapter(&wgpu::RequestAdapterOptions::default())
159        .await
160        .expect("adapter");
161
162    log::debug!("Requesting device");
163    let (device, queue) = adapter
164        .request_device(&wgpu::DeviceDescriptor {
165            label: Some("Device"),
166            required_limits: adapter.limits(),
167            ..Default::default()
168        })
169        .await
170        .expect("device");
171
172    log::debug!("Creating gaussians");
173    let gaussians = [
174        gs::core::GaussiansSource::Ply,
175        gs::core::GaussiansSource::Spz,
176    ]
177    .into_iter()
178    .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
179    .expect("gaussians");
180
181    log::debug!("Creating editor");
182    let editor = gs::Editor::<GaussianPod>::new(&device, &gaussians);
183
184    log::debug!("Creating shape selection compute bundle");
185    let shape_selection = shape(&device);
186
187    log::debug!("Creating basic selection modifier");
188    let mut basic_selection_modifier = gs::SelectionModifier::new_with_basic_modifier(
189        &device,
190        &editor.gaussians_buffer,
191        &editor.model_transform_buffer,
192        &editor.gaussian_transform_buffer,
193        vec![shape_selection],
194    );
195
196    log::debug!("Configuring modifiers");
197    basic_selection_modifier
198        .modifier
199        .basic_color_modifiers_buffer
200        .update(
201            &queue,
202            match args.override_rgb {
203                true => gs::BasicColorRgbOverrideOrHsvModifiersPod::new_rgb_override,
204                false => gs::BasicColorRgbOverrideOrHsvModifiersPod::new_hsv_modifiers,
205            }(Vec3::from_slice(&args.rgb_or_hsv)),
206            args.alpha,
207            args.contrast,
208            args.exposure,
209            args.gamma,
210        );
211
212    log::debug!("Creating shape selection buffers");
213    let shape_selection_buffers = (0..repeat)
214        .map(|i| {
215            let offset_pos = pos + offset * i as f32;
216            let buffer = gs::InvTransformBuffer::new(&device);
217            buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
218            buffer
219        })
220        .collect::<Vec<_>>();
221
222    log::debug!("Creating shape selection bind groups");
223    let shape_selection_bind_groups = shape_selection_buffers
224        .iter()
225        .map(|buffer| {
226            basic_selection_modifier.selection.bundles[0]
227                .create_bind_group(
228                    &device,
229                    // index 0 is the Gaussians buffer, so we use 1,
230                    // see docs of create_sphere_bundle or create_box_bundle
231                    1,
232                    [buffer.buffer().as_entire_binding()],
233                )
234                .expect("bind group")
235        })
236        .collect::<Vec<_>>();
237
238    log::debug!("Creating selection expression");
239    basic_selection_modifier.selection_expr = shape_selection_bind_groups.into_iter().fold(
240        gs::SelectionExpr::Identity,
241        |acc, bind_group| {
242            acc.union(gs::SelectionExpr::selection(
243                0, // the 0 here is the bundle index in the selection bundle
244                vec![bind_group],
245            ))
246        },
247    );
248
249    log::info!("Starting editing process");
250    let time = std::time::Instant::now();
251
252    log::debug!("Editing Gaussians");
253    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
254        label: Some("Edit Encoder"),
255    });
256
257    editor.apply(
258        &device,
259        &mut encoder,
260        [&basic_selection_modifier as &dyn gs::Modifier<GaussianPod>],
261    );
262
263    queue.submit(Some(encoder.finish()));
264
265    device
266        .poll(wgpu::PollType::wait_indefinitely())
267        .expect("poll");
268
269    log::info!("Editing process completed in {:?}", time.elapsed());
270
271    log::debug!("Downloading Gaussians");
272    let modified_gaussians = editor
273        .gaussians_buffer
274        .download_gaussians(&device, &queue)
275        .await
276        .map(|gs| {
277            match &args.output[args.output.len().saturating_sub(4)..] {
278                ".ply" => {
279                    gs::core::Gaussians::Ply(gs::core::PlyGaussians::from_iter(gs.into_iter()))
280                }
281                ".spz" => {
282                    gs::core::Gaussians::Spz(
283                        gs::core::SpzGaussians::from_gaussians_with_options(
284                            gs,
285                            &gs::core::SpzGaussiansFromGaussianSliceOptions {
286                                version: 2, // Version 2 is more widely supported as of now
287                                ..Default::default()
288                            },
289                        )
290                        .expect("SpzGaussians from gaussians"),
291                    )
292                }
293                _ => panic!("Unsupported output file extension, expected .ply or .spz"),
294            }
295        })
296        .expect("gaussians download");
297
298    log::debug!("Writing modified Gaussians to output file");
299    modified_gaussians
300        .write_to_file(&args.output)
301        .expect("write modified Gaussians to output file");
302
303    log::info!("Modified Gaussians written to {}", args.output);
304}