1use clap::{Parser, ValueEnum};
2use glam::*;
3
4use wgpu_3dgs_editor::{self as gs, core::BufferWrapper};
5
6#[derive(Parser, Debug)]
8#[command(
9 version,
10 about,
11 long_about = "\
12 A 3D Gaussian splatting editor to apply basic modifier to selected Gaussians in a model.
13 "
14)]
15struct Args {
16 #[arg(short, long)]
18 model: String,
19
20 #[arg(short, long, default_value = "target/output.ply")]
22 output: String,
23
24 #[arg(
26 short,
27 long,
28 allow_hyphen_values = true,
29 num_args = 3,
30 value_delimiter = ',',
31 default_value = "0.0,0.0,0.0"
32 )]
33 pos: Vec<f32>,
34
35 #[arg(
37 short,
38 long,
39 allow_hyphen_values = true,
40 num_args = 4,
41 value_delimiter = ',',
42 default_value = "0.0,0.0,0.0,1.0"
43 )]
44 rot: Vec<f32>,
45
46 #[arg(
48 short,
49 long,
50 allow_hyphen_values = true,
51 num_args = 3,
52 value_delimiter = ',',
53 default_value = "0.5,1.0,2.0"
54 )]
55 scale: Vec<f32>,
56
57 #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
59 shape: Shape,
60
61 #[arg(long, default_value = "1")]
63 repeat: u32,
64
65 #[arg(
67 long,
68 allow_hyphen_values = true,
69 num_args = 3,
70 value_delimiter = ',',
71 default_value = "2.0,0.0,0.0"
72 )]
73 offset: Vec<f32>,
74
75 #[arg(long)]
77 override_rgb: bool,
78
79 #[arg(
85 long,
86 allow_hyphen_values = true,
87 num_args = 3,
88 value_delimiter = ',',
89 default_value = "0.0,1.0,1.0"
90 )]
91 rgb_or_hsv: Vec<f32>,
92
93 #[arg(long, allow_hyphen_values = true, default_value = "1.0")]
95 alpha: f32,
96
97 #[arg(long, allow_hyphen_values = true, default_value = "0.0")]
101 contrast: f32,
102
103 #[arg(long, allow_hyphen_values = true, default_value = "0.0")]
107 exposure: f32,
108
109 #[arg(long, allow_hyphen_values = true, default_value = "1.0")]
113 gamma: f32,
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
117enum Shape {
118 Sphere,
119 Box,
120}
121
122type GaussianPod = gs::core::GaussianPodWithShSingleCov3dRotScaleConfigs;
123
124#[tokio::main]
125async fn main() {
126 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
127
128 let args = Args::parse();
129 let model_path = &args.model;
130 let pos = Vec3::from_slice(&args.pos);
131 let rot = Quat::from_slice(&args.rot);
132 let scale = Vec3::from_slice(&args.scale);
133 let shape = match args.shape {
134 Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
135 Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
136 };
137 let repeat = args.repeat;
138 let offset = Vec3::from_slice(&args.offset);
139
140 log::debug!("Creating wgpu instance");
141 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
142
143 log::debug!("Requesting adapter");
144 let adapter = instance
145 .request_adapter(&wgpu::RequestAdapterOptions::default())
146 .await
147 .expect("adapter");
148
149 log::debug!("Requesting device");
150 let (device, queue) = adapter
151 .request_device(&wgpu::DeviceDescriptor {
152 label: Some("Device"),
153 required_features: wgpu::Features::empty(),
154 required_limits: adapter.limits(),
155 memory_hints: wgpu::MemoryHints::default(),
156 trace: wgpu::Trace::Off,
157 })
158 .await
159 .expect("device");
160
161 log::debug!("Creating gaussians");
162 let f = std::fs::File::open(model_path).expect("ply file");
163 let mut reader = std::io::BufReader::new(f);
164 let gaussians = gs::core::Gaussians::read_ply(&mut reader).expect("gaussians");
165
166 log::debug!("Creating editor");
167 let editor = gs::Editor::<GaussianPod>::new(&device, &gaussians);
168
169 log::debug!("Creating shape selection compute bundle");
170 let shape_selection = shape(&device);
171
172 log::debug!("Creating basic selection modifier");
173 let mut basic_selection_modifier = gs::SelectionModifier::new_with_basic_modifier(
174 &device,
175 &editor.gaussians_buffer,
176 &editor.model_transform_buffer,
177 &editor.gaussian_transform_buffer,
178 vec![shape_selection],
179 );
180
181 log::debug!("Configuring modifiers");
182 match args.override_rgb {
183 true => basic_selection_modifier
184 .modifier
185 .basic_color_modifiers_buffer
186 .update_with_override_rgb(
187 &queue,
188 Vec3::from_slice(&args.rgb_or_hsv),
189 args.alpha,
190 args.contrast,
191 args.exposure,
192 args.gamma,
193 ),
194 false => basic_selection_modifier
195 .modifier
196 .basic_color_modifiers_buffer
197 .update_with_hsv_modifiers(
198 &queue,
199 Vec3::from_slice(&args.rgb_or_hsv),
200 args.alpha,
201 args.contrast,
202 args.exposure,
203 args.gamma,
204 ),
205 }
206
207 log::debug!("Creating shape selection buffers");
208 let shape_selection_buffers = (0..repeat)
209 .map(|i| {
210 let offset_pos = pos + offset * i as f32;
211 let buffer = gs::InvTransformBuffer::new(&device);
212 buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
213 buffer
214 })
215 .collect::<Vec<_>>();
216
217 log::debug!("Creating shape selection bind groups");
218 let shape_selection_bind_groups = shape_selection_buffers
219 .iter()
220 .map(|buffer| {
221 basic_selection_modifier.selection.bundles[0]
222 .create_bind_group(
223 &device,
224 1,
227 [buffer.buffer().as_entire_binding()],
228 )
229 .expect("bind group")
230 })
231 .collect::<Vec<_>>();
232
233 log::debug!("Creating selection expression");
234 basic_selection_modifier.selection_expr = shape_selection_bind_groups.into_iter().fold(
235 gs::SelectionExpr::Identity,
236 |acc, bind_group| {
237 acc.union(gs::SelectionExpr::selection(
238 0, vec![bind_group],
240 ))
241 },
242 );
243
244 log::info!("Starting editing process");
245 let time = std::time::Instant::now();
246
247 log::debug!("Editing Gaussians");
248 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
249 label: Some("Edit Encoder"),
250 });
251
252 editor.apply(
253 &device,
254 &mut encoder,
255 [&basic_selection_modifier as &dyn gs::Modifier<GaussianPod>],
256 );
257
258 queue.submit(Some(encoder.finish()));
259
260 #[allow(unused_must_use)]
261 device.poll(wgpu::PollType::Wait);
262
263 log::info!("Editing process completed in {:?}", time.elapsed());
264
265 log::debug!("Downloading Gaussians");
266 let modified_gaussians = gs::core::Gaussians {
267 gaussians: editor
268 .gaussians_buffer
269 .download_gaussians(&device, &queue)
270 .await
271 .expect("gaussians download"),
272 };
273
274 log::debug!("Writing modified Gaussians to output file");
275 let output_file = std::fs::File::create(&args.output).expect("output file");
276 let mut writer = std::io::BufWriter::new(output_file);
277 modified_gaussians
278 .write_ply(&mut writer)
279 .expect("write modified Gaussians to output file");
280
281 log::info!("Modified Gaussians written to {}", args.output);
282}