1use clap::{Parser, ValueEnum};
2use glam::*;
3
4use wgpu_3dgs_editor::{
5 self as gs,
6 core::{BufferWrapper, DownloadableBufferWrapper},
7};
8
9#[derive(Parser, Debug)]
11#[command(
12 version,
13 about,
14 long_about = "\
15 A 3D Gaussian splatting editor to filter selected Gaussians in a model.
16 "
17)]
18struct Args {
19 #[arg(short, long)]
21 model: String,
22
23 #[arg(short, long, default_value = "target/output.ply")]
25 output: String,
26
27 #[arg(
29 short,
30 long,
31 allow_hyphen_values = true,
32 num_args = 3,
33 value_delimiter = ',',
34 default_value = "0.0,0.0,0.0"
35 )]
36 pos: Vec<f32>,
37
38 #[arg(
40 short,
41 long,
42 allow_hyphen_values = true,
43 num_args = 4,
44 value_delimiter = ',',
45 default_value = "0.0,0.0,0.0,1.0"
46 )]
47 rot: Vec<f32>,
48
49 #[arg(
51 short,
52 long,
53 allow_hyphen_values = true,
54 num_args = 3,
55 value_delimiter = ',',
56 default_value = "0.5,1.0,2.0"
57 )]
58 scale: Vec<f32>,
59
60 #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
62 shape: Shape,
63
64 #[arg(long, default_value = "1")]
66 repeat: u32,
67
68 #[arg(
70 long,
71 allow_hyphen_values = true,
72 num_args = 3,
73 value_delimiter = ',',
74 default_value = "2.0,0.0,0.0"
75 )]
76 offset: Vec<f32>,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
80enum Shape {
81 Sphere,
82 Box,
83}
84
85type GaussianPod = gs::core::GaussianPodWithShSingleCov3dSingleConfigs;
86
87#[tokio::main]
88async fn main() {
89 env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
90
91 let args = Args::parse();
92 let model_path = &args.model;
93 let pos = Vec3::from_slice(&args.pos);
94 let rot = Quat::from_slice(&args.rot);
95 let scale = Vec3::from_slice(&args.scale);
96 let shape = match args.shape {
97 Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
98 Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
99 };
100 let repeat = args.repeat;
101 let offset = Vec3::from_slice(&args.offset);
102
103 log::debug!("Creating wgpu instance");
104 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
105
106 log::debug!("Requesting adapter");
107 let adapter = instance
108 .request_adapter(&wgpu::RequestAdapterOptions::default())
109 .await
110 .expect("adapter");
111
112 log::debug!("Requesting device");
113 let (device, queue) = adapter
114 .request_device(&wgpu::DeviceDescriptor {
115 label: Some("Device"),
116 required_features: wgpu::Features::empty(),
117 required_limits: adapter.limits(),
118 memory_hints: wgpu::MemoryHints::default(),
119 trace: wgpu::Trace::Off,
120 })
121 .await
122 .expect("device");
123
124 log::debug!("Creating gaussians");
125 let f = std::fs::File::open(model_path).expect("ply file");
126 let mut reader = std::io::BufReader::new(f);
127 let gaussians = gs::core::Gaussians::read_ply(&mut reader).expect("gaussians");
128
129 log::debug!("Creating gaussians buffer");
130 let gaussians_buffer =
131 gs::core::GaussiansBuffer::<GaussianPod>::new(&device, &gaussians.gaussians);
132
133 log::debug!("Creating model transform buffer");
134 let model_transform = gs::core::ModelTransformBuffer::new(&device);
135
136 log::debug!("Creating Gaussian transform buffer");
137 let gaussian_transform = gs::core::GaussianTransformBuffer::new(&device);
138
139 log::debug!("Creating shape selection compute bundle");
140 let shape_selection = shape(&device);
141
142 log::debug!("Creating selection bundle");
143 let selection_bundle = gs::SelectionBundle::<GaussianPod>::new(&device, vec![shape_selection]);
144
145 log::debug!("Creating shape selection buffers");
146 let shape_selection_buffers = (0..repeat)
147 .map(|i| {
148 let offset_pos = pos + offset * i as f32;
149 let buffer = gs::InvTransformBuffer::new(&device);
150 buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
151 buffer
152 })
153 .collect::<Vec<_>>();
154
155 log::debug!("Creating shape selection bind groups");
156 let shape_selection_bind_groups = shape_selection_buffers
157 .iter()
158 .map(|buffer| {
159 selection_bundle.bundles[0]
160 .create_bind_group(
161 &device,
162 1,
165 [buffer.buffer().as_entire_binding()],
166 )
167 .expect("bind group")
168 })
169 .collect::<Vec<_>>();
170
171 log::debug!("Creating selection expression");
172 let selection_expr = shape_selection_bind_groups.into_iter().fold(
173 gs::SelectionExpr::Identity,
174 |acc, bind_group| {
175 acc.union(gs::SelectionExpr::selection(
176 0, vec![bind_group],
178 ))
179 },
180 );
181
182 log::debug!("Creating destination buffer");
183 let dest = gs::SelectionBuffer::new(&device, gaussians_buffer.len() as u32);
184
185 log::info!("Starting selection process");
186 let time = std::time::Instant::now();
187
188 log::debug!("Selecting Gaussians");
189 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
190 label: Some("Selection Encoder"),
191 });
192
193 selection_bundle.evaluate(
194 &device,
195 &mut encoder,
196 &selection_expr,
197 &dest,
198 &model_transform,
199 &gaussian_transform,
200 &gaussians_buffer,
201 );
202
203 queue.submit(Some(encoder.finish()));
204
205 #[allow(unused_must_use)]
206 device.poll(wgpu::PollType::Wait);
207
208 log::info!("Editing process completed in {:?}", time.elapsed());
209
210 log::debug!("Filtering Gaussians");
211 let selected_gaussians = gs::core::Gaussians {
212 gaussians: dest
213 .download::<u32>(&device, &queue)
214 .await
215 .expect("selected download")
216 .iter()
217 .flat_map(|group| {
218 std::iter::repeat_n(group, 32)
219 .enumerate()
220 .map(|(i, g)| g & (1 << i) != 0)
221 })
222 .zip(gaussians.gaussians.iter())
223 .filter(|(selected, _)| *selected)
224 .map(|(_, g)| g.clone())
225 .collect::<Vec<_>>(),
226 };
227
228 log::debug!("Writing modified Gaussians to output file");
229 let output_file = std::fs::File::create(&args.output).expect("output file");
230 let mut writer = std::io::BufWriter::new(output_file);
231 selected_gaussians
232 .write_ply(&mut writer)
233 .expect("write modified Gaussians to output file");
234
235 log::info!("Filtered Gaussians written to {}", args.output);
236}