1use clap::{Parser, ValueEnum};
17use glam::*;
18
19use wgpu::util::DeviceExt;
20use wgpu_3dgs_editor::{
21 self as gs,
22 core::{BufferWrapper, GaussianPod as _},
23};
24
25#[derive(Parser, Debug)]
27#[command(
28 version,
29 about,
30 long_about = "\
31 A 3D Gaussian splatting editor to apply custom modifier to Gaussians in a model selected by a cylinder along the z-axis.
32 "
33)]
34struct Args {
35 #[arg(short, long, default_value = "examples/model.ply")]
37 model: String,
38
39 #[arg(short, long, default_value = "target/output.ply")]
41 output: String,
42
43 #[arg(
45 short,
46 long,
47 allow_hyphen_values = true,
48 num_args = 3,
49 value_delimiter = ',',
50 default_value = "0.0,0.0,0.0"
51 )]
52 pos: Vec<f32>,
53
54 #[arg(short, long, default_value_t = 3.0)]
56 radius: f32,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
60enum Factory {
61 Struct,
62 Closure,
63}
64
65type GaussianPod = gs::core::GaussianPodWithShSingleCov3dRotScaleConfigs;
66
67#[pollster::main]
68async fn main() {
69 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
70
71 let args = Args::parse();
72 let model_path = &args.model;
73 let pos = Vec3::from_slice(&args.pos);
74 let radius = args.radius;
75
76 log::debug!("Creating wgpu instance");
77 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
78
79 log::debug!("Requesting adapter");
80 let adapter = instance
81 .request_adapter(&wgpu::RequestAdapterOptions::default())
82 .await
83 .expect("adapter");
84
85 log::debug!("Requesting device");
86 let (device, queue) = adapter
87 .request_device(&wgpu::DeviceDescriptor {
88 label: Some("Device"),
89 required_limits: adapter.limits(),
90 ..Default::default()
91 })
92 .await
93 .expect("device");
94
95 log::debug!("Creating gaussians");
96 let gaussians = [
97 gs::core::GaussiansSource::Ply,
98 gs::core::GaussiansSource::Spz,
99 ]
100 .into_iter()
101 .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
102 .expect("gaussians");
103
104 log::debug!("Creating editor");
105 let editor = gs::Editor::<GaussianPod>::new(&device, &gaussians);
106
107 log::debug!("Creating buffers");
108 let pos_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
109 label: Some("Position Buffer"),
110 contents: bytemuck::bytes_of(&pos),
111 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
112 });
113
114 let radius_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
115 label: Some("Radius Buffer"),
116 contents: bytemuck::bytes_of(&radius),
117 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
118 });
119
120 const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor =
121 wgpu::BindGroupLayoutDescriptor {
122 label: Some("Bind Group Layout"),
123 entries: &[
124 wgpu::BindGroupLayoutEntry {
126 binding: 0,
127 visibility: wgpu::ShaderStages::COMPUTE,
128 ty: wgpu::BindingType::Buffer {
129 ty: wgpu::BufferBindingType::Uniform,
130 has_dynamic_offset: false,
131 min_binding_size: None,
132 },
133 count: None,
134 },
135 wgpu::BindGroupLayoutEntry {
137 binding: 1,
138 visibility: wgpu::ShaderStages::COMPUTE,
139 ty: wgpu::BindingType::Buffer {
140 ty: wgpu::BufferBindingType::Uniform,
141 has_dynamic_offset: false,
142 min_binding_size: None,
143 },
144 count: None,
145 },
146 wgpu::BindGroupLayoutEntry {
148 binding: 2,
149 visibility: wgpu::ShaderStages::COMPUTE,
150 ty: wgpu::BindingType::Buffer {
151 ty: wgpu::BufferBindingType::Storage { read_only: true },
152 has_dynamic_offset: false,
153 min_binding_size: None,
154 },
155 count: None,
156 },
157 ],
158 };
159
160 log::debug!("Creating cylinder selection compute bundle");
161 let cylinder_selection_bundle = gs::core::ComputeBundleBuilder::new()
162 .label("Selection")
163 .bind_group_layouts([
164 &gs::SelectionBundle::<GaussianPod>::GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
165 &wgpu::BindGroupLayoutDescriptor {
166 entries: &BIND_GROUP_LAYOUT_DESCRIPTOR.entries[..2],
167 ..BIND_GROUP_LAYOUT_DESCRIPTOR
168 },
169 ])
170 .main_shader("package::selection".parse().unwrap())
171 .entry_point("main")
172 .wesl_compile_options(wesl::CompileOptions {
173 features: GaussianPod::wesl_features(),
174 ..Default::default()
175 })
176 .resolver({
177 let mut resolver =
178 wesl::StandardResolver::new("examples/shader/custom_modify_selection");
179 resolver.add_package(&gs::core::shader::PACKAGE);
180 resolver
181 })
182 .build_without_bind_groups(&device)
183 .map_err(|e| log::error!("{e}"))
184 .expect("selection bundle");
185
186 log::debug!("Creating custom modifier");
187 #[allow(dead_code)]
188 let modifier_factory = |selection_buffer: &gs::SelectionBuffer| {
189 log::debug!("Creating custom modifier compute bundle");
190 let modifier_bundle = gs::core::ComputeBundleBuilder::new()
191 .label("Modifier")
192 .bind_group_layouts([
193 &gs::MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
194 &BIND_GROUP_LAYOUT_DESCRIPTOR,
195 ])
196 .resolver({
197 let mut resolver =
198 wesl::StandardResolver::new("examples/shader/custom_modify_selection");
199 resolver.add_package(&gs::core::shader::PACKAGE);
200 resolver.add_package(&gs::shader::PACKAGE);
201 resolver
202 })
203 .main_shader("package::modifier".parse().unwrap())
204 .entry_point("main")
205 .wesl_compile_options(wesl::CompileOptions {
206 features: GaussianPod::wesl_features(),
207 ..Default::default()
208 })
209 .build(
210 &device,
211 [
212 vec![
213 editor.gaussians_buffer.buffer().as_entire_binding(),
214 editor.model_transform_buffer.buffer().as_entire_binding(),
215 editor
216 .gaussian_transform_buffer
217 .buffer()
218 .as_entire_binding(),
219 ],
220 vec![
221 pos_buffer.as_entire_binding(),
222 radius_buffer.as_entire_binding(),
223 selection_buffer.buffer().as_entire_binding(),
224 ],
225 ],
226 )
227 .map_err(|e| log::error!("{e}"))
228 .expect("modifier bundle");
229
230 move |_device: &wgpu::Device,
232 encoder: &mut wgpu::CommandEncoder,
233 gaussians: &gs::core::GaussiansBuffer<GaussianPod>,
234 _model_transform: &gs::core::ModelTransformBuffer,
235 _gaussian_transform: &gs::core::GaussianTransformBuffer| {
236 modifier_bundle.dispatch(encoder, gaussians.len() as u32);
237 }
238 };
239
240 #[allow(dead_code)]
241 struct Modifier<G: gs::core::GaussianPod>(gs::core::ComputeBundle, std::marker::PhantomData<G>);
242
243 impl<G: gs::core::GaussianPod> Modifier<G> {
244 #[allow(dead_code)]
245 fn new(
246 device: &wgpu::Device,
247 editor: &gs::Editor<G>,
248 pos_buffer: &wgpu::Buffer,
249 radius_buffer: &wgpu::Buffer,
250 selection_buffer: &gs::SelectionBuffer,
251 ) -> Self {
252 log::debug!("Creating custom modifier compute bundle");
253 let modifier_bundle = gs::core::ComputeBundleBuilder::new()
254 .label("Modifier")
255 .bind_group_layouts([
256 &gs::MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
257 &BIND_GROUP_LAYOUT_DESCRIPTOR,
258 ])
259 .resolver({
260 let mut resolver =
261 wesl::StandardResolver::new("examples/shader/custom_modify_selection");
262 resolver.add_package(&gs::core::shader::PACKAGE);
263 resolver.add_package(&gs::shader::PACKAGE);
264 resolver
265 })
266 .main_shader("package::modifier".parse().unwrap())
267 .entry_point("main")
268 .wesl_compile_options(wesl::CompileOptions {
269 features: GaussianPod::wesl_features(),
270 ..Default::default()
271 })
272 .build(
273 device,
274 [
275 vec![
276 editor.gaussians_buffer.buffer().as_entire_binding(),
277 editor.model_transform_buffer.buffer().as_entire_binding(),
278 editor
279 .gaussian_transform_buffer
280 .buffer()
281 .as_entire_binding(),
282 ],
283 vec![
284 pos_buffer.as_entire_binding(),
285 radius_buffer.as_entire_binding(),
286 selection_buffer.buffer().as_entire_binding(),
287 ],
288 ],
289 )
290 .map_err(|e| log::error!("{e}"))
291 .expect("modifier bundle");
292
293 Self(modifier_bundle, std::marker::PhantomData)
294 }
295 }
296
297 impl<G: gs::core::GaussianPod> gs::Modifier<G> for Modifier<G> {
298 fn apply(
299 &self,
300 _device: &wgpu::Device,
301 encoder: &mut wgpu::CommandEncoder,
302 gaussians: &gs::core::GaussiansBuffer<G>,
303 _model_transform: &gs::core::ModelTransformBuffer,
304 _gaussian_transform: &gs::core::GaussianTransformBuffer,
305 ) {
306 self.0.dispatch(encoder, gaussians.len() as u32);
307 }
308 }
309
310 log::debug!("Creating selection modifier");
311 let mut selection_modifier = gs::SelectionModifier::<GaussianPod, _>::new(
312 &device,
313 &editor.gaussians_buffer,
314 vec![cylinder_selection_bundle],
315 modifier_factory,
316 );
327
328 log::debug!("Creating selection expression");
329 selection_modifier.selection_expr = gs::SelectionExpr::selection(
330 0,
331 vec![
332 selection_modifier.selection.bundles[0]
333 .create_bind_group(
334 &device,
335 1, [
337 pos_buffer.as_entire_binding(),
338 radius_buffer.as_entire_binding(),
339 ],
340 )
341 .expect("selection expr bind group"),
342 ],
343 );
344
345 log::info!("Starting editing process");
346 let time = std::time::Instant::now();
347
348 log::debug!("Editing Gaussians");
349 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
350 label: Some("Edit Encoder"),
351 });
352
353 editor.apply(
354 &device,
355 &mut encoder,
356 [&selection_modifier as &dyn gs::Modifier<GaussianPod>],
357 );
358
359 queue.submit(Some(encoder.finish()));
360
361 device
362 .poll(wgpu::PollType::wait_indefinitely())
363 .expect("poll");
364
365 log::info!("Editing process completed in {:?}", time.elapsed());
366
367 log::debug!("Downloading Gaussians");
368 let modified_gaussians = editor
369 .gaussians_buffer
370 .download_gaussians(&device, &queue)
371 .await
372 .map(|gs| {
373 match &args.output[args.output.len().saturating_sub(4)..] {
374 ".ply" => {
375 gs::core::Gaussians::Ply(gs::core::PlyGaussians::from_iter(gs.into_iter()))
376 }
377 ".spz" => {
378 gs::core::Gaussians::Spz(
379 gs::core::SpzGaussians::from_gaussians_with_options(
380 gs,
381 &gs::core::SpzGaussiansFromGaussianSliceOptions {
382 version: 2, ..Default::default()
384 },
385 )
386 .expect("SpzGaussians from gaussians"),
387 )
388 }
389 _ => panic!("Unsupported output file extension, expected .ply or .spz"),
390 }
391 })
392 .expect("gaussians download");
393
394 log::debug!("Writing modified Gaussians to output file");
395 modified_gaussians
396 .write_to_file(&args.output)
397 .expect("write modified Gaussians to output file");
398
399 log::info!("Modified Gaussians written to {}", args.output);
400}