filter_selection/
filter_selection.rs

1use clap::{Parser, ValueEnum};
2use glam::*;
3
4use wgpu_3dgs_editor::{
5    self as gs,
6    core::{BufferWrapper, DownloadableBufferWrapper},
7};
8
9/// The command line arguments.
10#[derive(Parser, Debug)]
11#[command(
12    version,
13    about,
14    long_about = "\
15    A 3D Gaussian splatting editor to filter selected Gaussians in a model.
16    "
17)]
18struct Args {
19    /// Path to the .ply file.
20    #[arg(short, long)]
21    model: String,
22
23    /// The output path for the modified .ply file.
24    #[arg(short, long, default_value = "target/output.ply")]
25    output: String,
26
27    /// The position of the selection shape.
28    #[arg(
29        short,
30        long,
31        allow_hyphen_values = true,
32        num_args = 3,
33        value_delimiter = ',',
34        default_value = "0.0,0.0,0.0"
35    )]
36    pos: Vec<f32>,
37
38    /// The rotation of the selection shape.
39    #[arg(
40        short,
41        long,
42        allow_hyphen_values = true,
43        num_args = 4,
44        value_delimiter = ',',
45        default_value = "0.0,0.0,0.0,1.0"
46    )]
47    rot: Vec<f32>,
48
49    /// The scale of the selection shape.
50    #[arg(
51        short,
52        long,
53        allow_hyphen_values = true,
54        num_args = 3,
55        value_delimiter = ',',
56        default_value = "0.5,1.0,2.0"
57    )]
58    scale: Vec<f32>,
59
60    /// The shape of the selection.
61    #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
62    shape: Shape,
63
64    /// The number of times to run the selection.
65    #[arg(long, default_value = "1")]
66    repeat: u32,
67
68    /// The offset of each selection.
69    #[arg(
70        long,
71        allow_hyphen_values = true,
72        num_args = 3,
73        value_delimiter = ',',
74        default_value = "2.0,0.0,0.0"
75    )]
76    offset: Vec<f32>,
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
80enum Shape {
81    Sphere,
82    Box,
83}
84
85type GaussianPod = gs::core::GaussianPodWithShSingleCov3dSingleConfigs;
86
87#[tokio::main]
88async fn main() {
89    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
90
91    let args = Args::parse();
92    let model_path = &args.model;
93    let pos = Vec3::from_slice(&args.pos);
94    let rot = Quat::from_slice(&args.rot);
95    let scale = Vec3::from_slice(&args.scale);
96    let shape = match args.shape {
97        Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
98        Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
99    };
100    let repeat = args.repeat;
101    let offset = Vec3::from_slice(&args.offset);
102
103    log::debug!("Creating wgpu instance");
104    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
105
106    log::debug!("Requesting adapter");
107    let adapter = instance
108        .request_adapter(&wgpu::RequestAdapterOptions::default())
109        .await
110        .expect("adapter");
111
112    log::debug!("Requesting device");
113    let (device, queue) = adapter
114        .request_device(&wgpu::DeviceDescriptor {
115            label: Some("Device"),
116            required_features: wgpu::Features::empty(),
117            required_limits: adapter.limits(),
118            memory_hints: wgpu::MemoryHints::default(),
119            trace: wgpu::Trace::Off,
120        })
121        .await
122        .expect("device");
123
124    log::debug!("Creating gaussians");
125    let f = std::fs::File::open(model_path).expect("ply file");
126    let mut reader = std::io::BufReader::new(f);
127    let gaussians = gs::core::Gaussians::read_ply(&mut reader).expect("gaussians");
128
129    log::debug!("Creating gaussians buffer");
130    let gaussians_buffer =
131        gs::core::GaussiansBuffer::<GaussianPod>::new(&device, &gaussians.gaussians);
132
133    log::debug!("Creating model transform buffer");
134    let model_transform = gs::core::ModelTransformBuffer::new(&device);
135
136    log::debug!("Creating Gaussian transform buffer");
137    let gaussian_transform = gs::core::GaussianTransformBuffer::new(&device);
138
139    log::debug!("Creating shape selection compute bundle");
140    let shape_selection = shape(&device);
141
142    log::debug!("Creating selection bundle");
143    let selection_bundle = gs::SelectionBundle::<GaussianPod>::new(&device, vec![shape_selection]);
144
145    log::debug!("Creating shape selection buffers");
146    let shape_selection_buffers = (0..repeat)
147        .map(|i| {
148            let offset_pos = pos + offset * i as f32;
149            let buffer = gs::InvTransformBuffer::new(&device);
150            buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
151            buffer
152        })
153        .collect::<Vec<_>>();
154
155    log::debug!("Creating shape selection bind groups");
156    let shape_selection_bind_groups = shape_selection_buffers
157        .iter()
158        .map(|buffer| {
159            selection_bundle.bundles[0]
160                .create_bind_group(
161                    &device,
162                    // index 0 is the Gaussians buffer, so we use 1,
163                    // see docs of create_sphere_bundle or create_box_bundle
164                    1,
165                    [buffer.buffer().as_entire_binding()],
166                )
167                .expect("bind group")
168        })
169        .collect::<Vec<_>>();
170
171    log::debug!("Creating selection expression");
172    let selection_expr = shape_selection_bind_groups.into_iter().fold(
173        gs::SelectionExpr::Identity,
174        |acc, bind_group| {
175            acc.union(gs::SelectionExpr::selection(
176                0, // the 0 here is the bundle index in the selection bundle
177                vec![bind_group],
178            ))
179        },
180    );
181
182    log::debug!("Creating destination buffer");
183    let dest = gs::SelectionBuffer::new(&device, gaussians_buffer.len() as u32);
184
185    log::info!("Starting selection process");
186    let time = std::time::Instant::now();
187
188    log::debug!("Selecting Gaussians");
189    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
190        label: Some("Selection Encoder"),
191    });
192
193    selection_bundle.evaluate(
194        &device,
195        &mut encoder,
196        &selection_expr,
197        &dest,
198        &model_transform,
199        &gaussian_transform,
200        &gaussians_buffer,
201    );
202
203    queue.submit(Some(encoder.finish()));
204
205    #[allow(unused_must_use)]
206    device.poll(wgpu::PollType::Wait);
207
208    log::info!("Editing process completed in {:?}", time.elapsed());
209
210    log::debug!("Filtering Gaussians");
211    let selected_gaussians = gs::core::Gaussians {
212        gaussians: dest
213            .download::<u32>(&device, &queue)
214            .await
215            .expect("selected download")
216            .iter()
217            .flat_map(|group| {
218                std::iter::repeat_n(group, 32)
219                    .enumerate()
220                    .map(|(i, g)| g & (1 << i) != 0)
221            })
222            .zip(gaussians.gaussians.iter())
223            .filter(|(selected, _)| *selected)
224            .map(|(_, g)| g.clone())
225            .collect::<Vec<_>>(),
226    };
227
228    log::debug!("Writing modified Gaussians to output file");
229    let output_file = std::fs::File::create(&args.output).expect("output file");
230    let mut writer = std::io::BufWriter::new(output_file);
231    selected_gaussians
232        .write_ply(&mut writer)
233        .expect("write modified Gaussians to output file");
234
235    log::info!("Filtered Gaussians written to {}", args.output);
236}