1use clap::{Parser, ValueEnum};
14use glam::*;
15
16use wgpu_3dgs_core::IterGaussian;
17use wgpu_3dgs_editor::{self as gs, core::BufferWrapper};
18
19#[derive(Parser, Debug)]
21#[command(
22 version,
23 about,
24 long_about = "\
25 A 3D Gaussian splatting editor to filter selected Gaussians in a model.
26 "
27)]
28struct Args {
29 #[arg(short, long, default_value = "examples/model.ply")]
31 model: String,
32
33 #[arg(short, long, default_value = "target/output.ply")]
35 output: String,
36
37 #[arg(
39 short,
40 long,
41 allow_hyphen_values = true,
42 num_args = 3,
43 value_delimiter = ',',
44 default_value = "0.0,0.0,0.0"
45 )]
46 pos: Vec<f32>,
47
48 #[arg(
50 short,
51 long,
52 allow_hyphen_values = true,
53 num_args = 4,
54 value_delimiter = ',',
55 default_value = "0.0,0.0,0.0,1.0"
56 )]
57 rot: Vec<f32>,
58
59 #[arg(
61 short,
62 long,
63 allow_hyphen_values = true,
64 num_args = 3,
65 value_delimiter = ',',
66 default_value = "0.5,1.0,2.0"
67 )]
68 scale: Vec<f32>,
69
70 #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
72 shape: Shape,
73
74 #[arg(long, default_value = "1")]
76 repeat: u32,
77
78 #[arg(
80 long,
81 allow_hyphen_values = true,
82 num_args = 3,
83 value_delimiter = ',',
84 default_value = "2.0,0.0,0.0"
85 )]
86 offset: Vec<f32>,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
90enum Shape {
91 Sphere,
92 Box,
93}
94
95type GaussianPod = gs::core::GaussianPodWithShSingleCov3dSingleConfigs;
96
97#[pollster::main]
98async fn main() {
99 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
100
101 let args = Args::parse();
102 let model_path = &args.model;
103 let pos = Vec3::from_slice(&args.pos);
104 let rot = Quat::from_slice(&args.rot);
105 let scale = Vec3::from_slice(&args.scale);
106 let shape = match args.shape {
107 Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
108 Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
109 };
110 let repeat = args.repeat;
111 let offset = Vec3::from_slice(&args.offset);
112
113 log::debug!("Creating wgpu instance");
114 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
115
116 log::debug!("Requesting adapter");
117 let adapter = instance
118 .request_adapter(&wgpu::RequestAdapterOptions::default())
119 .await
120 .expect("adapter");
121
122 log::debug!("Requesting device");
123 let (device, queue) = adapter
124 .request_device(&wgpu::DeviceDescriptor {
125 label: Some("Device"),
126 required_limits: adapter.limits(),
127 ..Default::default()
128 })
129 .await
130 .expect("device");
131
132 log::debug!("Creating gaussians");
133 let gaussians = [
134 gs::core::GaussiansSource::Ply,
135 gs::core::GaussiansSource::Spz,
136 ]
137 .into_iter()
138 .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
139 .expect("gaussians");
140
141 log::debug!("Creating gaussians buffer");
142 let gaussians_buffer = gs::core::GaussiansBuffer::<GaussianPod>::new(&device, &gaussians);
143
144 log::debug!("Creating model transform buffer");
145 let model_transform = gs::core::ModelTransformBuffer::new(&device);
146
147 log::debug!("Creating Gaussian transform buffer");
148 let gaussian_transform = gs::core::GaussianTransformBuffer::new(&device);
149
150 log::debug!("Creating shape selection compute bundle");
151 let shape_selection = shape(&device);
152
153 log::debug!("Creating selection bundle");
154 let selection_bundle = gs::SelectionBundle::<GaussianPod>::new(&device, vec![shape_selection]);
155
156 log::debug!("Creating shape selection buffers");
157 let shape_selection_buffers = (0..repeat)
158 .map(|i| {
159 let offset_pos = pos + offset * i as f32;
160 let buffer = gs::InvTransformBuffer::new(&device);
161 buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
162 buffer
163 })
164 .collect::<Vec<_>>();
165
166 log::debug!("Creating shape selection bind groups");
167 let shape_selection_bind_groups = shape_selection_buffers
168 .iter()
169 .map(|buffer| {
170 selection_bundle.bundles[0]
171 .create_bind_group(
172 &device,
173 1,
176 [buffer.buffer().as_entire_binding()],
177 )
178 .expect("bind group")
179 })
180 .collect::<Vec<_>>();
181
182 log::debug!("Creating selection expression");
183 let selection_expr = shape_selection_bind_groups.into_iter().fold(
184 gs::SelectionExpr::Identity,
185 |acc, bind_group| {
186 acc.union(gs::SelectionExpr::selection(
187 0, vec![bind_group],
189 ))
190 },
191 );
192
193 log::debug!("Creating destination buffer");
194 let dest = gs::SelectionBuffer::new(&device, gaussians_buffer.len() as u32);
195
196 log::info!("Starting selection process");
197 let time = std::time::Instant::now();
198
199 log::debug!("Selecting Gaussians");
200 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
201 label: Some("Selection Encoder"),
202 });
203
204 selection_bundle.evaluate(
205 &device,
206 &mut encoder,
207 &selection_expr,
208 &dest,
209 &model_transform,
210 &gaussian_transform,
211 &gaussians_buffer,
212 );
213
214 queue.submit(Some(encoder.finish()));
215
216 device
217 .poll(wgpu::PollType::wait_indefinitely())
218 .expect("poll");
219
220 log::info!("Editing process completed in {:?}", time.elapsed());
221
222 log::debug!("Filtering Gaussians");
223 let selected_gaussians = dest
224 .download::<u32>(&device, &queue)
225 .await
226 .expect("selected download")
227 .iter()
228 .flat_map(|group| {
229 std::iter::repeat_n(group, 32)
230 .enumerate()
231 .map(|(i, g)| g & (1 << i) != 0)
232 })
233 .zip(gaussians.iter_gaussian())
234 .filter(|(selected, _)| *selected)
235 .map(|(_, g)| g)
236 .collect::<Vec<_>>();
237
238 let selected_gaussians = match &args.output[args.output.len().saturating_sub(4)..] {
239 ".ply" => gs::core::Gaussians::Ply(gs::core::PlyGaussians::from_iter(
240 selected_gaussians.into_iter(),
241 )),
242 ".spz" => {
243 gs::core::Gaussians::Spz(
244 gs::core::SpzGaussians::from_gaussians_with_options(
245 selected_gaussians,
246 &gs::core::SpzGaussiansFromGaussianSliceOptions {
247 version: 2, ..Default::default()
249 },
250 )
251 .expect("SpzGaussians from gaussians"),
252 )
253 }
254 _ => panic!("Unsupported output file extension, expected .ply or .spz"),
255 };
256
257 log::debug!("Writing modified Gaussians to output file");
258 selected_gaussians
259 .write_to_file(&args.output)
260 .expect("write modified Gaussians to output file");
261
262 log::info!("Filtered Gaussians written to {}", args.output);
263}