filter_selection/
filter_selection.rs

1//! This example selects parts of the model, then filters the model to keep only the selected parts
2//! using [`SelectionBundle`](wgpu_3dgs_editor::SelectionBundle) and [`SelectionExpr`](wgpu_3dgs_editor::SelectionExpr).
3//!
4//! For example, to select a sphere at (0.5, 1.0, 0.5) with scale (1, 1, 1),
5//! repeat the selection 2 times with an offset (2, 0, 0), and keep only the selected Gaussians:
6//!
7//! ```sh
8//! cargo run --example filter-selection -- \
9//!     -m "path/to/model.ply" \
10//!     -p 0.5 1.0 0.5 -s 1 1 1 --repeat 2 --offset 2.0 0.0 0.0
11//! ```
12
13use clap::{Parser, ValueEnum};
14use glam::*;
15
16use wgpu_3dgs_core::IterGaussian;
17use wgpu_3dgs_editor::{self as gs, core::BufferWrapper};
18
19/// The command line arguments.
20#[derive(Parser, Debug)]
21#[command(
22    version,
23    about,
24    long_about = "\
25    A 3D Gaussian splatting editor to filter selected Gaussians in a model.
26    "
27)]
28struct Args {
29    /// Path to the .ply file.
30    #[arg(short, long, default_value = "examples/model.ply")]
31    model: String,
32
33    /// The output path for the modified .ply file.
34    #[arg(short, long, default_value = "target/output.ply")]
35    output: String,
36
37    /// The position of the selection shape.
38    #[arg(
39        short,
40        long,
41        allow_hyphen_values = true,
42        num_args = 3,
43        value_delimiter = ',',
44        default_value = "0.0,0.0,0.0"
45    )]
46    pos: Vec<f32>,
47
48    /// The rotation of the selection shape.
49    #[arg(
50        short,
51        long,
52        allow_hyphen_values = true,
53        num_args = 4,
54        value_delimiter = ',',
55        default_value = "0.0,0.0,0.0,1.0"
56    )]
57    rot: Vec<f32>,
58
59    /// The scale of the selection shape.
60    #[arg(
61        short,
62        long,
63        allow_hyphen_values = true,
64        num_args = 3,
65        value_delimiter = ',',
66        default_value = "0.5,1.0,2.0"
67    )]
68    scale: Vec<f32>,
69
70    /// The shape of the selection.
71    #[arg(long, value_enum, default_value_t = Shape::Sphere, ignore_case = true)]
72    shape: Shape,
73
74    /// The number of times to run the selection.
75    #[arg(long, default_value = "1")]
76    repeat: u32,
77
78    /// The offset of each selection.
79    #[arg(
80        long,
81        allow_hyphen_values = true,
82        num_args = 3,
83        value_delimiter = ',',
84        default_value = "2.0,0.0,0.0"
85    )]
86    offset: Vec<f32>,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
90enum Shape {
91    Sphere,
92    Box,
93}
94
95type GaussianPod = gs::core::GaussianPodWithShSingleCov3dSingleConfigs;
96
97#[pollster::main]
98async fn main() {
99    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
100
101    let args = Args::parse();
102    let model_path = &args.model;
103    let pos = Vec3::from_slice(&args.pos);
104    let rot = Quat::from_slice(&args.rot);
105    let scale = Vec3::from_slice(&args.scale);
106    let shape = match args.shape {
107        Shape::Sphere => gs::SelectionBundle::<GaussianPod>::create_sphere_bundle,
108        Shape::Box => gs::SelectionBundle::<GaussianPod>::create_box_bundle,
109    };
110    let repeat = args.repeat;
111    let offset = Vec3::from_slice(&args.offset);
112
113    log::debug!("Creating wgpu instance");
114    let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
115
116    log::debug!("Requesting adapter");
117    let adapter = instance
118        .request_adapter(&wgpu::RequestAdapterOptions::default())
119        .await
120        .expect("adapter");
121
122    log::debug!("Requesting device");
123    let (device, queue) = adapter
124        .request_device(&wgpu::DeviceDescriptor {
125            label: Some("Device"),
126            required_limits: adapter.limits(),
127            ..Default::default()
128        })
129        .await
130        .expect("device");
131
132    log::debug!("Creating gaussians");
133    let gaussians = [
134        gs::core::GaussiansSource::Ply,
135        gs::core::GaussiansSource::Spz,
136    ]
137    .into_iter()
138    .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
139    .expect("gaussians");
140
141    log::debug!("Creating gaussians buffer");
142    let gaussians_buffer = gs::core::GaussiansBuffer::<GaussianPod>::new(&device, &gaussians);
143
144    log::debug!("Creating model transform buffer");
145    let model_transform = gs::core::ModelTransformBuffer::new(&device);
146
147    log::debug!("Creating Gaussian transform buffer");
148    let gaussian_transform = gs::core::GaussianTransformBuffer::new(&device);
149
150    log::debug!("Creating shape selection compute bundle");
151    let shape_selection = shape(&device);
152
153    log::debug!("Creating selection bundle");
154    let selection_bundle = gs::SelectionBundle::<GaussianPod>::new(&device, vec![shape_selection]);
155
156    log::debug!("Creating shape selection buffers");
157    let shape_selection_buffers = (0..repeat)
158        .map(|i| {
159            let offset_pos = pos + offset * i as f32;
160            let buffer = gs::InvTransformBuffer::new(&device);
161            buffer.update_with_scale_rot_pos(&queue, scale, rot, offset_pos);
162            buffer
163        })
164        .collect::<Vec<_>>();
165
166    log::debug!("Creating shape selection bind groups");
167    let shape_selection_bind_groups = shape_selection_buffers
168        .iter()
169        .map(|buffer| {
170            selection_bundle.bundles[0]
171                .create_bind_group(
172                    &device,
173                    // index 0 is the Gaussians buffer, so we use 1,
174                    // see docs of create_sphere_bundle or create_box_bundle
175                    1,
176                    [buffer.buffer().as_entire_binding()],
177                )
178                .expect("bind group")
179        })
180        .collect::<Vec<_>>();
181
182    log::debug!("Creating selection expression");
183    let selection_expr = shape_selection_bind_groups.into_iter().fold(
184        gs::SelectionExpr::Identity,
185        |acc, bind_group| {
186            acc.union(gs::SelectionExpr::selection(
187                0, // the 0 here is the bundle index in the selection bundle
188                vec![bind_group],
189            ))
190        },
191    );
192
193    log::debug!("Creating destination buffer");
194    let dest = gs::SelectionBuffer::new(&device, gaussians_buffer.len() as u32);
195
196    log::info!("Starting selection process");
197    let time = std::time::Instant::now();
198
199    log::debug!("Selecting Gaussians");
200    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
201        label: Some("Selection Encoder"),
202    });
203
204    selection_bundle.evaluate(
205        &device,
206        &mut encoder,
207        &selection_expr,
208        &dest,
209        &model_transform,
210        &gaussian_transform,
211        &gaussians_buffer,
212    );
213
214    queue.submit(Some(encoder.finish()));
215
216    device
217        .poll(wgpu::PollType::wait_indefinitely())
218        .expect("poll");
219
220    log::info!("Editing process completed in {:?}", time.elapsed());
221
222    log::debug!("Filtering Gaussians");
223    let selected_gaussians = dest
224        .download::<u32>(&device, &queue)
225        .await
226        .expect("selected download")
227        .iter()
228        .flat_map(|group| {
229            std::iter::repeat_n(group, 32)
230                .enumerate()
231                .map(|(i, g)| g & (1 << i) != 0)
232        })
233        .zip(gaussians.iter_gaussian())
234        .filter(|(selected, _)| *selected)
235        .map(|(_, g)| g)
236        .collect::<Vec<_>>();
237
238    let selected_gaussians = match &args.output[args.output.len().saturating_sub(4)..] {
239        ".ply" => gs::core::Gaussians::Ply(gs::core::PlyGaussians::from_iter(
240            selected_gaussians.into_iter(),
241        )),
242        ".spz" => {
243            gs::core::Gaussians::Spz(
244                gs::core::SpzGaussians::from_gaussians_with_options(
245                    selected_gaussians,
246                    &gs::core::SpzGaussiansFromGaussianSliceOptions {
247                        version: 2, // Version 2 is more widely supported as of now
248                        ..Default::default()
249                    },
250                )
251                .expect("SpzGaussians from gaussians"),
252            )
253        }
254        _ => panic!("Unsupported output file extension, expected .ply or .spz"),
255    };
256
257    log::debug!("Writing modified Gaussians to output file");
258    selected_gaussians
259        .write_to_file(&args.output)
260        .expect("write modified Gaussians to output file");
261
262    log::info!("Filtered Gaussians written to {}", args.output);
263}