1use clap::{Parser, ValueEnum};
14use glam::*;
15
16use wgpu_3dgs_editor::{self as gs, core::BufferWrapper};
17
18#[derive(Parser, Debug)]
20#[command(
21 version,
22 about,
23 long_about = "\
24 A 3D Gaussian splatting editor to apply basic modifier to selected Gaussians in a model.
25 "
26)]
27struct Args {
28 #[arg(short, long, default_value = "examples/model.ply")]
30 model: String,
31
32 #[arg(short, long, default_value = "target/output.ply")]
34 output: String,
35
36 #[arg(
38 short,
39 long,
40 allow_hyphen_values = true,
41 num_args = 3,
42 value_delimiter = ',',
43 default_value = "0.0,0.0,0.0"
44 )]
45 pos: Vec<f32>,
46
47 #[arg(
49 short,
50 long,
51 allow_hyphen_values = true,
52 num_args = 4,
53 value_delimiter = ',',
54 default_value = "0.0,0.0,0.0,1.0"
55 )]
56 rot: Vec<f32>,
57
58 #[arg(
60 short,
61 long,
62 allow_hyphen_values = true,
63 num_args = 3,
64 value_delimiter = ',',
65 default_value = "0.5,1.0,2.0"
66 )]
67 scale: Vec<f32>,
68
69 #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
71 shape: Shape,
72
73 #[arg(long, default_value = "1")]
75 repeat: u32,
76
77 #[arg(
79 long,
80 allow_hyphen_values = true,
81 num_args = 3,
82 value_delimiter = ',',
83 default_value = "2.0,0.0,0.0"
84 )]
85 offset: Vec<f32>,
86
87 #[arg(long)]
89 override_rgb: bool,
90
91 #[arg(
97 long,
98 allow_hyphen_values = true,
99 num_args = 3,
100 value_delimiter = ',',
101 default_value = "0.0,1.0,1.0"
102 )]
103 rgb_or_hsv: Vec<f32>,
104
105 #[arg(long, allow_hyphen_values = true, default_value = "1.0")]
107 alpha: f32,
108
109 #[arg(long, allow_hyphen_values = true, default_value = "0.0")]
113 contrast: f32,
114
115 #[arg(long, allow_hyphen_values = true, default_value = "0.0")]
119 exposure: f32,
120
121 #[arg(long, allow_hyphen_values = true, default_value = "1.0")]
125 gamma: f32,
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
129enum Shape {
130 Sphere,
131 Box,
132}
133
134type GaussianPod = gs::core::GaussianPodWithShSingleCov3dRotScaleConfigs;
135
136#[pollster::main]
137async fn main() {
138 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
139
140 let args = Args::parse();
141 let model_path = &args.model;
142 let pos = Vec3::from_slice(&args.pos);
143 let rot = Quat::from_slice(&args.rot);
144 let scale = Vec3::from_slice(&args.scale);
145 let shape = match args.shape {
146 Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
147 Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
148 };
149 let repeat = args.repeat;
150 let offset = Vec3::from_slice(&args.offset);
151
152 log::debug!("Creating wgpu instance");
153 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
154
155 log::debug!("Requesting adapter");
156 let adapter = instance
157 .request_adapter(&wgpu::RequestAdapterOptions::default())
158 .await
159 .expect("adapter");
160
161 log::debug!("Requesting device");
162 let (device, queue) = adapter
163 .request_device(&wgpu::DeviceDescriptor {
164 label: Some("Device"),
165 required_limits: adapter.limits(),
166 ..Default::default()
167 })
168 .await
169 .expect("device");
170
171 log::debug!("Creating gaussians");
172 let gaussians = [
173 gs::core::GaussiansSource::Ply,
174 gs::core::GaussiansSource::Spz,
175 ]
176 .into_iter()
177 .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
178 .expect("gaussians");
179
180 log::debug!("Creating editor");
181 let editor = gs::Editor::<GaussianPod>::new(&device, &gaussians);
182
183 log::debug!("Creating shape selection compute bundle");
184 let shape_selection = shape(&device);
185
186 log::debug!("Creating basic selection modifier");
187 let mut basic_selection_modifier = gs::SelectionModifier::new_with_basic_modifier(
188 &device,
189 &editor.gaussians_buffer,
190 &editor.model_transform_buffer,
191 &editor.gaussian_transform_buffer,
192 vec![shape_selection],
193 );
194
195 log::debug!("Configuring modifiers");
196 basic_selection_modifier
197 .modifier
198 .basic_color_modifiers_buffer
199 .update(
200 &queue,
201 match args.override_rgb {
202 true => gs::BasicColorRgbOverrideOrHsvModifiersPod::new_rgb_override,
203 false => gs::BasicColorRgbOverrideOrHsvModifiersPod::new_hsv_modifiers,
204 }(Vec3::from_slice(&args.rgb_or_hsv)),
205 args.alpha,
206 args.contrast,
207 args.exposure,
208 args.gamma,
209 );
210
211 log::debug!("Creating shape selection buffers");
212 let shape_selection_buffers = (0..repeat)
213 .map(|i| {
214 let offset_pos = pos + offset * i as f32;
215 let buffer = gs::InvTransformBuffer::new(&device);
216 buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
217 buffer
218 })
219 .collect::<Vec<_>>();
220
221 log::debug!("Creating shape selection bind groups");
222 let shape_selection_bind_groups = shape_selection_buffers
223 .iter()
224 .map(|buffer| {
225 basic_selection_modifier.selection.bundles[0]
226 .create_bind_group(
227 &device,
228 1,
231 [buffer.buffer().as_entire_binding()],
232 )
233 .expect("bind group")
234 })
235 .collect::<Vec<_>>();
236
237 log::debug!("Creating selection expression");
238 basic_selection_modifier.selection_expr = shape_selection_bind_groups.into_iter().fold(
239 gs::SelectionExpr::Identity,
240 |acc, bind_group| {
241 acc.union(gs::SelectionExpr::selection(
242 0, vec![bind_group],
244 ))
245 },
246 );
247
248 log::info!("Starting editing process");
249 let time = std::time::Instant::now();
250
251 log::debug!("Editing Gaussians");
252 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
253 label: Some("Edit Encoder"),
254 });
255
256 editor.apply(
257 &device,
258 &mut encoder,
259 [&basic_selection_modifier as &dyn gs::Modifier<GaussianPod>],
260 );
261
262 queue.submit(Some(encoder.finish()));
263
264 device
265 .poll(wgpu::PollType::wait_indefinitely())
266 .expect("poll");
267
268 log::info!("Editing process completed in {:?}", time.elapsed());
269
270 log::debug!("Downloading Gaussians");
271 let modified_gaussians = editor
272 .gaussians_buffer
273 .download_gaussians(&device, &queue)
274 .await
275 .map(|gs| {
276 match &args.output[args.output.len().saturating_sub(4)..] {
277 ".ply" => {
278 gs::core::Gaussians::Ply(gs::core::PlyGaussians::from_iter(gs.into_iter()))
279 }
280 ".spz" => {
281 gs::core::Gaussians::Spz(
282 gs::core::SpzGaussians::from_gaussians_with_options(
283 gs,
284 &gs::core::SpzGaussiansFromGaussianSliceOptions {
285 version: 2, ..Default::default()
287 },
288 )
289 .expect("SpzGaussians from gaussians"),
290 )
291 }
292 _ => panic!("Unsupported output file extension, expected .ply or .spz"),
293 }
294 })
295 .expect("gaussians download");
296
297 log::debug!("Writing modified Gaussians to output file");
298 modified_gaussians
299 .write_to_file(&args.output)
300 .expect("write modified Gaussians to output file");
301
302 log::info!("Modified Gaussians written to {}", args.output);
303}