1use clap::{Parser, ValueEnum};
14use glam::*;
15
16use wgpu_3dgs_core::IterGaussian;
17use wgpu_3dgs_editor::{self as gs, core::BufferWrapper};
18
19#[derive(Parser, Debug)]
21#[command(
22 version,
23 about,
24 long_about = "\
25 A 3D Gaussian splatting editor to filter selected Gaussians in a model.
26 "
27)]
28struct Args {
29 #[arg(short, long, default_value = "examples/model.ply")]
31 model: String,
32
33 #[arg(short, long, default_value = "target/output.ply")]
35 output: String,
36
37 #[arg(
39 short,
40 long,
41 allow_hyphen_values = true,
42 num_args = 3,
43 value_delimiter = ',',
44 default_value = "0.0,0.0,0.0"
45 )]
46 pos: Vec<f32>,
47
48 #[arg(
50 short,
51 long,
52 allow_hyphen_values = true,
53 num_args = 4,
54 value_delimiter = ',',
55 default_value = "0.0,0.0,0.0,1.0"
56 )]
57 rot: Vec<f32>,
58
59 #[arg(
61 short,
62 long,
63 allow_hyphen_values = true,
64 num_args = 3,
65 value_delimiter = ',',
66 default_value = "0.5,1.0,2.0"
67 )]
68 scale: Vec<f32>,
69
70 #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
72 shape: Shape,
73
74 #[arg(long, default_value = "1")]
76 repeat: u32,
77
78 #[arg(
80 long,
81 allow_hyphen_values = true,
82 num_args = 3,
83 value_delimiter = ',',
84 default_value = "2.0,0.0,0.0"
85 )]
86 offset: Vec<f32>,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
90enum Shape {
91 Sphere,
92 Box,
93}
94
95type GaussianPod = gs::core::GaussianPodWithShSingleCov3dSingleConfigs;
96
97#[pollster::main]
98async fn main() {
99 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
100
101 let args = Args::parse();
102 let model_path = &args.model;
103 let pos = Vec3::from_slice(&args.pos);
104 let rot = Quat::from_slice(&args.rot);
105 let scale = Vec3::from_slice(&args.scale);
106 let shape = match args.shape {
107 Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
108 Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
109 };
110 let repeat = args.repeat;
111 let offset = Vec3::from_slice(&args.offset);
112
113 log::debug!("Creating wgpu instance");
114 let instance =
115 wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle_from_env());
116
117 log::debug!("Requesting adapter");
118 let adapter = instance
119 .request_adapter(&wgpu::RequestAdapterOptions::default())
120 .await
121 .expect("adapter");
122
123 log::debug!("Requesting device");
124 let (device, queue) = adapter
125 .request_device(&wgpu::DeviceDescriptor {
126 label: Some("Device"),
127 required_limits: adapter.limits(),
128 ..Default::default()
129 })
130 .await
131 .expect("device");
132
133 log::debug!("Creating gaussians");
134 let gaussians = [
135 gs::core::GaussiansSource::Ply,
136 gs::core::GaussiansSource::Spz,
137 ]
138 .into_iter()
139 .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
140 .expect("gaussians");
141
142 log::debug!("Creating gaussians buffer");
143 let gaussians_buffer = gs::core::GaussiansBuffer::<GaussianPod>::new(&device, &gaussians);
144
145 log::debug!("Creating model transform buffer");
146 let model_transform = gs::core::ModelTransformBuffer::new(&device);
147
148 log::debug!("Creating Gaussian transform buffer");
149 let gaussian_transform = gs::core::GaussianTransformBuffer::new(&device);
150
151 log::debug!("Creating shape selection compute bundle");
152 let shape_selection = shape(&device);
153
154 log::debug!("Creating selection bundle");
155 let selection_bundle = gs::SelectionBundle::<GaussianPod>::new(&device, vec![shape_selection]);
156
157 log::debug!("Creating shape selection buffers");
158 let shape_selection_buffers = (0..repeat)
159 .map(|i| {
160 let offset_pos = pos + offset * i as f32;
161 let buffer = gs::InvTransformBuffer::new(&device);
162 buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
163 buffer
164 })
165 .collect::<Vec<_>>();
166
167 log::debug!("Creating shape selection bind groups");
168 let shape_selection_bind_groups = shape_selection_buffers
169 .iter()
170 .map(|buffer| {
171 selection_bundle.bundles[0]
172 .create_bind_group(
173 &device,
174 1,
177 [buffer.buffer().as_entire_binding()],
178 )
179 .expect("bind group")
180 })
181 .collect::<Vec<_>>();
182
183 log::debug!("Creating selection expression");
184 let selection_expr = shape_selection_bind_groups.into_iter().fold(
185 gs::SelectionExpr::Identity,
186 |acc, bind_group| {
187 acc.union(gs::SelectionExpr::selection(
188 0, vec![bind_group],
190 ))
191 },
192 );
193
194 log::debug!("Creating destination buffer");
195 let dest = gs::SelectionBuffer::new(&device, gaussians_buffer.len() as u32);
196
197 log::info!("Starting selection process");
198 let time = std::time::Instant::now();
199
200 log::debug!("Selecting Gaussians");
201 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
202 label: Some("Selection Encoder"),
203 });
204
205 selection_bundle.evaluate(
206 &device,
207 &mut encoder,
208 &selection_expr,
209 &dest,
210 &model_transform,
211 &gaussian_transform,
212 &gaussians_buffer,
213 );
214
215 queue.submit(Some(encoder.finish()));
216
217 device
218 .poll(wgpu::PollType::wait_indefinitely())
219 .expect("poll");
220
221 log::info!("Editing process completed in {:?}", time.elapsed());
222
223 log::debug!("Filtering Gaussians");
224 let selected_gaussians = dest
225 .download::<u32>(&device, &queue)
226 .await
227 .expect("selected download")
228 .iter()
229 .flat_map(|group| {
230 std::iter::repeat_n(group, 32)
231 .enumerate()
232 .map(|(i, g)| g & (1 << i) != 0)
233 })
234 .zip(gaussians.iter_gaussian())
235 .filter(|(selected, _)| *selected)
236 .map(|(_, g)| g)
237 .collect::<Vec<_>>();
238
239 let selected_gaussians = match &args.output[args.output.len().saturating_sub(4)..] {
240 ".ply" => gs::core::Gaussians::Ply(gs::core::PlyGaussians::from_iter(
241 selected_gaussians.into_iter(),
242 )),
243 ".spz" => {
244 gs::core::Gaussians::Spz(
245 gs::core::SpzGaussians::from_gaussians_with_options(
246 selected_gaussians,
247 &gs::core::SpzGaussiansFromGaussianSliceOptions {
248 version: 2, ..Default::default()
250 },
251 )
252 .expect("SpzGaussians from gaussians"),
253 )
254 }
255 _ => panic!("Unsupported output file extension, expected .ply or .spz"),
256 };
257
258 log::debug!("Writing modified Gaussians to output file");
259 selected_gaussians
260 .write_to_file(&args.output)
261 .expect("write modified Gaussians to output file");
262
263 log::info!("Filtered Gaussians written to {}", args.output);
264}