1use clap::{Parser, ValueEnum};
17use glam::*;
18
19use wgpu::util::DeviceExt;
20use wgpu_3dgs_editor::{
21 self as gs,
22 core::{BufferWrapper, GaussianPod as _},
23};
24
25#[derive(Parser, Debug)]
27#[command(
28 version,
29 about,
30 long_about = "\
31 A 3D Gaussian splatting editor to apply custom modifier to Gaussians in a model selected by a cylinder along the z-axis.
32 "
33)]
34struct Args {
35 #[arg(short, long, default_value = "examples/model.ply")]
37 model: String,
38
39 #[arg(short, long, default_value = "target/output.ply")]
41 output: String,
42
43 #[arg(
45 short,
46 long,
47 allow_hyphen_values = true,
48 num_args = 3,
49 value_delimiter = ',',
50 default_value = "0.0,0.0,0.0"
51 )]
52 pos: Vec<f32>,
53
54 #[arg(short, long, default_value_t = 3.0)]
56 radius: f32,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
60enum Factory {
61 Struct,
62 Closure,
63}
64
65type GaussianPod = gs::core::GaussianPodWithShSingleCov3dRotScaleConfigs;
66
67#[pollster::main]
68async fn main() {
69 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
70
71 let args = Args::parse();
72 let model_path = &args.model;
73 let pos = Vec3::from_slice(&args.pos);
74 let radius = args.radius;
75
76 log::debug!("Creating wgpu instance");
77 let instance =
78 wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle_from_env());
79
80 log::debug!("Requesting adapter");
81 let adapter = instance
82 .request_adapter(&wgpu::RequestAdapterOptions::default())
83 .await
84 .expect("adapter");
85
86 log::debug!("Requesting device");
87 let (device, queue) = adapter
88 .request_device(&wgpu::DeviceDescriptor {
89 label: Some("Device"),
90 required_limits: adapter.limits(),
91 ..Default::default()
92 })
93 .await
94 .expect("device");
95
96 log::debug!("Creating gaussians");
97 let gaussians = [
98 gs::core::GaussiansSource::Ply,
99 gs::core::GaussiansSource::Spz,
100 ]
101 .into_iter()
102 .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
103 .expect("gaussians");
104
105 log::debug!("Creating editor");
106 let editor = gs::Editor::<GaussianPod>::new(&device, &gaussians);
107
108 log::debug!("Creating buffers");
109 let pos_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
110 label: Some("Position Buffer"),
111 contents: bytemuck::bytes_of(&pos),
112 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
113 });
114
115 let radius_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
116 label: Some("Radius Buffer"),
117 contents: bytemuck::bytes_of(&radius),
118 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
119 });
120
121 const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor =
122 wgpu::BindGroupLayoutDescriptor {
123 label: Some("Bind Group Layout"),
124 entries: &[
125 wgpu::BindGroupLayoutEntry {
127 binding: 0,
128 visibility: wgpu::ShaderStages::COMPUTE,
129 ty: wgpu::BindingType::Buffer {
130 ty: wgpu::BufferBindingType::Uniform,
131 has_dynamic_offset: false,
132 min_binding_size: None,
133 },
134 count: None,
135 },
136 wgpu::BindGroupLayoutEntry {
138 binding: 1,
139 visibility: wgpu::ShaderStages::COMPUTE,
140 ty: wgpu::BindingType::Buffer {
141 ty: wgpu::BufferBindingType::Uniform,
142 has_dynamic_offset: false,
143 min_binding_size: None,
144 },
145 count: None,
146 },
147 wgpu::BindGroupLayoutEntry {
149 binding: 2,
150 visibility: wgpu::ShaderStages::COMPUTE,
151 ty: wgpu::BindingType::Buffer {
152 ty: wgpu::BufferBindingType::Storage { read_only: true },
153 has_dynamic_offset: false,
154 min_binding_size: None,
155 },
156 count: None,
157 },
158 ],
159 };
160
161 log::debug!("Creating cylinder selection compute bundle");
162 let cylinder_selection_bundle = gs::core::ComputeBundleBuilder::new()
163 .label("Selection")
164 .bind_group_layouts([
165 &gs::SelectionBundle::<GaussianPod>::GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
166 &wgpu::BindGroupLayoutDescriptor {
167 entries: &BIND_GROUP_LAYOUT_DESCRIPTOR.entries[..2],
168 ..BIND_GROUP_LAYOUT_DESCRIPTOR
169 },
170 ])
171 .main_shader("package::selection".parse().unwrap())
172 .entry_point("main")
173 .wesl_compile_options(wesl::CompileOptions {
174 features: GaussianPod::wesl_features(),
175 ..Default::default()
176 })
177 .resolver({
178 let mut resolver =
179 wesl::StandardResolver::new("examples/shader/custom_modify_selection");
180 resolver.add_package(&gs::core::shader::PACKAGE);
181 resolver
182 })
183 .build_without_bind_groups(&device)
184 .map_err(|e| log::error!("{e}"))
185 .expect("selection bundle");
186
187 log::debug!("Creating custom modifier");
188 #[allow(dead_code)]
189 let modifier_factory = |selection_buffer: &gs::SelectionBuffer| {
190 log::debug!("Creating custom modifier compute bundle");
191 let modifier_bundle = gs::core::ComputeBundleBuilder::new()
192 .label("Modifier")
193 .bind_group_layouts([
194 &gs::MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
195 &BIND_GROUP_LAYOUT_DESCRIPTOR,
196 ])
197 .resolver({
198 let mut resolver =
199 wesl::StandardResolver::new("examples/shader/custom_modify_selection");
200 resolver.add_package(&gs::core::shader::PACKAGE);
201 resolver.add_package(&gs::shader::PACKAGE);
202 resolver
203 })
204 .main_shader("package::modifier".parse().unwrap())
205 .entry_point("main")
206 .wesl_compile_options(wesl::CompileOptions {
207 features: GaussianPod::wesl_features(),
208 ..Default::default()
209 })
210 .build(
211 &device,
212 [
213 vec![
214 editor.gaussians_buffer.buffer().as_entire_binding(),
215 editor.model_transform_buffer.buffer().as_entire_binding(),
216 editor
217 .gaussian_transform_buffer
218 .buffer()
219 .as_entire_binding(),
220 ],
221 vec![
222 pos_buffer.as_entire_binding(),
223 radius_buffer.as_entire_binding(),
224 selection_buffer.buffer().as_entire_binding(),
225 ],
226 ],
227 )
228 .map_err(|e| log::error!("{e}"))
229 .expect("modifier bundle");
230
231 move |_device: &wgpu::Device,
233 encoder: &mut wgpu::CommandEncoder,
234 gaussians: &gs::core::GaussiansBuffer<GaussianPod>,
235 _model_transform: &gs::core::ModelTransformBuffer,
236 _gaussian_transform: &gs::core::GaussianTransformBuffer| {
237 modifier_bundle.dispatch(encoder, gaussians.len() as u32);
238 }
239 };
240
241 #[allow(dead_code)]
242 struct Modifier<G: gs::core::GaussianPod>(gs::core::ComputeBundle, std::marker::PhantomData<G>);
243
244 impl<G: gs::core::GaussianPod> Modifier<G> {
245 #[allow(dead_code)]
246 fn new(
247 device: &wgpu::Device,
248 editor: &gs::Editor<G>,
249 pos_buffer: &wgpu::Buffer,
250 radius_buffer: &wgpu::Buffer,
251 selection_buffer: &gs::SelectionBuffer,
252 ) -> Self {
253 log::debug!("Creating custom modifier compute bundle");
254 let modifier_bundle = gs::core::ComputeBundleBuilder::new()
255 .label("Modifier")
256 .bind_group_layouts([
257 &gs::MODIFIER_GAUSSIANS_BIND_GROUP_LAYOUT_DESCRIPTOR,
258 &BIND_GROUP_LAYOUT_DESCRIPTOR,
259 ])
260 .resolver({
261 let mut resolver =
262 wesl::StandardResolver::new("examples/shader/custom_modify_selection");
263 resolver.add_package(&gs::core::shader::PACKAGE);
264 resolver.add_package(&gs::shader::PACKAGE);
265 resolver
266 })
267 .main_shader("package::modifier".parse().unwrap())
268 .entry_point("main")
269 .wesl_compile_options(wesl::CompileOptions {
270 features: GaussianPod::wesl_features(),
271 ..Default::default()
272 })
273 .build(
274 device,
275 [
276 vec![
277 editor.gaussians_buffer.buffer().as_entire_binding(),
278 editor.model_transform_buffer.buffer().as_entire_binding(),
279 editor
280 .gaussian_transform_buffer
281 .buffer()
282 .as_entire_binding(),
283 ],
284 vec![
285 pos_buffer.as_entire_binding(),
286 radius_buffer.as_entire_binding(),
287 selection_buffer.buffer().as_entire_binding(),
288 ],
289 ],
290 )
291 .map_err(|e| log::error!("{e}"))
292 .expect("modifier bundle");
293
294 Self(modifier_bundle, std::marker::PhantomData)
295 }
296 }
297
298 impl<G: gs::core::GaussianPod> gs::Modifier<G> for Modifier<G> {
299 fn apply(
300 &self,
301 _device: &wgpu::Device,
302 encoder: &mut wgpu::CommandEncoder,
303 gaussians: &gs::core::GaussiansBuffer<G>,
304 _model_transform: &gs::core::ModelTransformBuffer,
305 _gaussian_transform: &gs::core::GaussianTransformBuffer,
306 ) {
307 self.0.dispatch(encoder, gaussians.len() as u32);
308 }
309 }
310
311 log::debug!("Creating selection modifier");
312 let mut selection_modifier = gs::SelectionModifier::<GaussianPod, _>::new(
313 &device,
314 &editor.gaussians_buffer,
315 vec![cylinder_selection_bundle],
316 modifier_factory,
317 );
328
329 log::debug!("Creating selection expression");
330 selection_modifier.selection_expr = gs::SelectionExpr::selection(
331 0,
332 vec![
333 selection_modifier.selection.bundles[0]
334 .create_bind_group(
335 &device,
336 1, [
338 pos_buffer.as_entire_binding(),
339 radius_buffer.as_entire_binding(),
340 ],
341 )
342 .expect("selection expr bind group"),
343 ],
344 );
345
346 log::info!("Starting editing process");
347 let time = std::time::Instant::now();
348
349 log::debug!("Editing Gaussians");
350 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
351 label: Some("Edit Encoder"),
352 });
353
354 editor.apply(
355 &device,
356 &mut encoder,
357 [&selection_modifier as &dyn gs::Modifier<GaussianPod>],
358 );
359
360 queue.submit(Some(encoder.finish()));
361
362 device
363 .poll(wgpu::PollType::wait_indefinitely())
364 .expect("poll");
365
366 log::info!("Editing process completed in {:?}", time.elapsed());
367
368 log::debug!("Downloading Gaussians");
369 let modified_gaussians = editor
370 .gaussians_buffer
371 .download_gaussians(&device, &queue)
372 .await
373 .map(|gs| {
374 match &args.output[args.output.len().saturating_sub(4)..] {
375 ".ply" => {
376 gs::core::Gaussians::Ply(gs::core::PlyGaussians::from_iter(gs.into_iter()))
377 }
378 ".spz" => {
379 gs::core::Gaussians::Spz(
380 gs::core::SpzGaussians::from_gaussians_with_options(
381 gs,
382 &gs::core::SpzGaussiansFromGaussianSliceOptions {
383 version: 2, ..Default::default()
385 },
386 )
387 .expect("SpzGaussians from gaussians"),
388 )
389 }
390 _ => panic!("Unsupported output file extension, expected .ply or .spz"),
391 }
392 })
393 .expect("gaussians download");
394
395 log::debug!("Writing modified Gaussians to output file");
396 modified_gaussians
397 .write_to_file(&args.output)
398 .expect("write modified Gaussians to output file");
399
400 log::info!("Modified Gaussians written to {}", args.output);
401}