modify_selection/
modify_selection.rs

1use clap::{Parser, ValueEnum};
2use glam::*;
3
4use wgpu_3dgs_editor::{self as gs, core::BufferWrapper};
5
6/// The command line arguments.
7#[derive(Parser, Debug)]
8#[command(
9    version,
10    about,
11    long_about = "\
12    A 3D Gaussian splatting editor to apply basic modifier to selected Gaussians in a model.
13    "
14)]
15struct Args {
16    /// Path to the .ply file.
17    #[arg(short, long)]
18    model: String,
19
20    /// The output path for the modified .ply file.
21    #[arg(short, long, default_value = "target/output.ply")]
22    output: String,
23
24    /// The position of the selection shape.
25    #[arg(
26        short,
27        long,
28        allow_hyphen_values = true,
29        num_args = 3,
30        value_delimiter = ',',
31        default_value = "0.0,0.0,0.0"
32    )]
33    pos: Vec<f32>,
34
35    /// The rotation of the selection shape.
36    #[arg(
37        short,
38        long,
39        allow_hyphen_values = true,
40        num_args = 4,
41        value_delimiter = ',',
42        default_value = "0.0,0.0,0.0,1.0"
43    )]
44    rot: Vec<f32>,
45
46    /// The scale of the selection shape.
47    #[arg(
48        short,
49        long,
50        allow_hyphen_values = true,
51        num_args = 3,
52        value_delimiter = ',',
53        default_value = "0.5,1.0,2.0"
54    )]
55    scale: Vec<f32>,
56
57    /// The shape of the selection.
58    #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
59    shape: Shape,
60
61    /// The number of times to run the selection.
62    #[arg(long, default_value = "1")]
63    repeat: u32,
64
65    /// The offset of each selection.
66    #[arg(
67        long,
68        allow_hyphen_values = true,
69        num_args = 3,
70        value_delimiter = ',',
71        default_value = "2.0,0.0,0.0"
72    )]
73    offset: Vec<f32>,
74
75    /// Whether to override the RGB color of the selected Gaussians.
76    #[arg(long)]
77    override_rgb: bool,
78
79    /// If [`Args::override_rgb`], then it is used to override the RGB color,
80    /// otherwise it is used to apply HSV modifications.
81    ///
82    /// Normally hue (H) is in [0, 1], saturation (S) and value (V) are in [0, 2].
83    /// This function adds the hue and multiplies saturation and value.
84    #[arg(
85        long,
86        allow_hyphen_values = true,
87        num_args = 3,
88        value_delimiter = ',',
89        default_value = "0.0,1.0,1.0"
90    )]
91    rgb_or_hsv: Vec<f32>,
92
93    /// Alpha is multiplied with the original alpha.
94    #[arg(long, allow_hyphen_values = true, default_value = "1.0")]
95    alpha: f32,
96
97    /// Contrast is applied to the RGB color.
98    ///
99    /// Normally the range is [-1, 1].
100    #[arg(long, allow_hyphen_values = true, default_value = "0.0")]
101    contrast: f32,
102
103    /// Exposure is applied to the RGB color.
104    ///
105    /// Normally the range is [-5, 5].
106    #[arg(long, allow_hyphen_values = true, default_value = "0.0")]
107    exposure: f32,
108
109    /// Gamma is applied to the RGB color.
110    ///
111    /// Normally the range is [0, 5].
112    #[arg(long, allow_hyphen_values = true, default_value = "1.0")]
113    gamma: f32,
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
117enum Shape {
118    Sphere,
119    Box,
120}
121
122type GaussianPod = gs::core::GaussianPodWithShSingleCov3dRotScaleConfigs;
123
124#[tokio::main]
125async fn main() {
126    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
127
128    let args = Args::parse();
129    let model_path = &args.model;
130    let pos = Vec3::from_slice(&args.pos);
131    let rot = Quat::from_slice(&args.rot);
132    let scale = Vec3::from_slice(&args.scale);
133    let shape = match args.shape {
134        Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
135        Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
136    };
137    let repeat = args.repeat;
138    let offset = Vec3::from_slice(&args.offset);
139
140    log::debug!("Creating wgpu instance");
141    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
142
143    log::debug!("Requesting adapter");
144    let adapter = instance
145        .request_adapter(&wgpu::RequestAdapterOptions::default())
146        .await
147        .expect("adapter");
148
149    log::debug!("Requesting device");
150    let (device, queue) = adapter
151        .request_device(&wgpu::DeviceDescriptor {
152            label: Some("Device"),
153            required_features: wgpu::Features::empty(),
154            required_limits: adapter.limits(),
155            memory_hints: wgpu::MemoryHints::default(),
156            trace: wgpu::Trace::Off,
157        })
158        .await
159        .expect("device");
160
161    log::debug!("Creating gaussians");
162    let f = std::fs::File::open(model_path).expect("ply file");
163    let mut reader = std::io::BufReader::new(f);
164    let gaussians = gs::core::Gaussians::read_ply(&mut reader).expect("gaussians");
165
166    log::debug!("Creating editor");
167    let editor = gs::Editor::<GaussianPod>::new(&device, &gaussians);
168
169    log::debug!("Creating shape selection compute bundle");
170    let shape_selection = shape(&device);
171
172    log::debug!("Creating basic selection modifier");
173    let mut basic_selection_modifier = gs::SelectionModifier::new_with_basic_modifier(
174        &device,
175        &editor.gaussians_buffer,
176        &editor.model_transform_buffer,
177        &editor.gaussian_transform_buffer,
178        vec![shape_selection],
179    );
180
181    log::debug!("Configuring modifiers");
182    match args.override_rgb {
183        true => basic_selection_modifier
184            .modifier
185            .basic_color_modifiers_buffer
186            .update_with_override_rgb(
187                &queue,
188                Vec3::from_slice(&args.rgb_or_hsv),
189                args.alpha,
190                args.contrast,
191                args.exposure,
192                args.gamma,
193            ),
194        false => basic_selection_modifier
195            .modifier
196            .basic_color_modifiers_buffer
197            .update_with_hsv_modifiers(
198                &queue,
199                Vec3::from_slice(&args.rgb_or_hsv),
200                args.alpha,
201                args.contrast,
202                args.exposure,
203                args.gamma,
204            ),
205    }
206
207    log::debug!("Creating shape selection buffers");
208    let shape_selection_buffers = (0..repeat)
209        .map(|i| {
210            let offset_pos = pos + offset * i as f32;
211            let buffer = gs::InvTransformBuffer::new(&device);
212            buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
213            buffer
214        })
215        .collect::<Vec<_>>();
216
217    log::debug!("Creating shape selection bind groups");
218    let shape_selection_bind_groups = shape_selection_buffers
219        .iter()
220        .map(|buffer| {
221            basic_selection_modifier.selection.bundles[0]
222                .create_bind_group(
223                    &device,
224                    // index 0 is the Gaussians buffer, so we use 1,
225                    // see docs of create_sphere_bundle or create_box_bundle
226                    1,
227                    [buffer.buffer().as_entire_binding()],
228                )
229                .expect("bind group")
230        })
231        .collect::<Vec<_>>();
232
233    log::debug!("Creating selection expression");
234    basic_selection_modifier.selection_expr = shape_selection_bind_groups.into_iter().fold(
235        gs::SelectionExpr::Identity,
236        |acc, bind_group| {
237            acc.union(gs::SelectionExpr::selection(
238                0, // the 0 here is the bundle index in the selection bundle
239                vec![bind_group],
240            ))
241        },
242    );
243
244    log::info!("Starting editing process");
245    let time = std::time::Instant::now();
246
247    log::debug!("Editing Gaussians");
248    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
249        label: Some("Edit Encoder"),
250    });
251
252    editor.apply(
253        &device,
254        &mut encoder,
255        [&basic_selection_modifier as &dyn gs::Modifier<GaussianPod>],
256    );
257
258    queue.submit(Some(encoder.finish()));
259
260    #[allow(unused_must_use)]
261    device.poll(wgpu::PollType::Wait);
262
263    log::info!("Editing process completed in {:?}", time.elapsed());
264
265    log::debug!("Downloading Gaussians");
266    let modified_gaussians = gs::core::Gaussians {
267        gaussians: editor
268            .gaussians_buffer
269            .download_gaussians(&device, &queue)
270            .await
271            .expect("gaussians download"),
272    };
273
274    log::debug!("Writing modified Gaussians to output file");
275    let output_file = std::fs::File::create(&args.output).expect("output file");
276    let mut writer = std::io::BufWriter::new(output_file);
277    modified_gaussians
278        .write_ply(&mut writer)
279        .expect("write modified Gaussians to output file");
280
281    log::info!("Modified Gaussians written to {}", args.output);
282}