Skip to main content

multi_model/
multi_model.rs

1//! This example renders multiple models in the order of their centroids' distance to the camera using the `multi-model` feature.
2//!
3//! For example, to use an offset of (10, 0, 0) between each model, run:
4//!
5//! ```sh
6//! cargo run --example multi-model --features="multi-model" -- -m "path/to/model1.ply" -m "path/to/model2.ply" --offset 10.0,0.0,0.0
7//! ```
8//!
9//! To view more options and the controls, run with `--help`:
10//!
11//! ```sh
12//! cargo run --example multi-model --features="multi-model" -- --help
13//! ```
14
15use std::sync::Arc;
16
17use clap::Parser;
18use glam::*;
19use winit::{error::EventLoopError, event_loop::EventLoop, keyboard::KeyCode, window::Window};
20
21use wgpu_3dgs_viewer as gs;
22use wgpu_3dgs_viewer::core::{GaussiansSource, IterGaussian};
23
24mod utils;
25use utils::core;
26
27/// The command line arguments.
28#[derive(Parser, Debug)]
29#[command(
30    version,
31    about,
32    long_about = "\
33    A 3D Gaussian splatting viewer written in Rust using wgpu.\n\
34    \n\
35    Use W, A, S, D, Space, Shift to move, use mouse to rotate.\n\
36    "
37)]
38struct Args {
39    /// Path to the .ply file.
40    #[arg(short, long, num_args = 1..)]
41    models: Vec<String>,
42
43    /// The offset of each model.
44    #[arg(
45        short,
46        long,
47        num_args = 3,
48        value_delimiter = ',',
49        default_value = "10.0,0.0,0.0"
50    )]
51    offset: Vec<f32>,
52}
53
54fn main() -> Result<(), EventLoopError> {
55    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
56
57    let event_loop = EventLoop::new()?;
58    event_loop.run_app(&mut core::App::<System>::new(Args::parse()))?;
59    Ok(())
60}
61
62/// The application system.
63#[allow(dead_code)]
64struct System {
65    surface: wgpu::Surface<'static>,
66    queue: wgpu::Queue,
67    device: wgpu::Device,
68    config: wgpu::SurfaceConfiguration,
69
70    camera: gs::Camera,
71    gaussians: Vec<gs::core::Gaussians>,
72    gaussian_centroids: Vec<Vec3>,
73    viewer: gs::MultiModelViewer<gs::DefaultGaussianPod, usize>,
74}
75
76impl core::System for System {
77    type Args = Args;
78
79    async fn init(window: Arc<Window>, args: &Args) -> Self {
80        let model_paths = &args.models;
81        let model_offset = Vec3::from_slice(&args.offset);
82        let size = window.inner_size();
83
84        log::debug!("Creating wgpu instance");
85        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor::default());
86
87        log::debug!("Creating window surface");
88        let surface = instance.create_surface(window.clone()).expect("surface");
89
90        log::debug!("Requesting adapter");
91        let adapter = instance
92            .request_adapter(&wgpu::RequestAdapterOptions {
93                power_preference: wgpu::PowerPreference::HighPerformance,
94                compatible_surface: Some(&surface),
95                force_fallback_adapter: false,
96            })
97            .await
98            .expect("adapter");
99
100        log::debug!("Requesting device");
101        let (device, queue) = adapter
102            .request_device(&wgpu::DeviceDescriptor {
103                label: Some("Device"),
104                required_limits: adapter.limits(),
105                ..Default::default()
106            })
107            .await
108            .expect("device");
109
110        let surface_caps = surface.get_capabilities(&adapter);
111        let surface_format = surface_caps.formats[0];
112        let config = wgpu::SurfaceConfiguration {
113            usage: wgpu::TextureUsages::RENDER_ATTACHMENT,
114            format: surface_format,
115            width: size.width.max(1),
116            height: size.height.max(1),
117            present_mode: surface_caps.present_modes[0],
118            alpha_mode: surface_caps.alpha_modes[0],
119            view_formats: vec![surface_format.remove_srgb_suffix()],
120            desired_maximum_frame_latency: 2,
121        };
122
123        log::debug!("Configuring surface");
124        surface.configure(&device, &config);
125
126        log::debug!("Creating gaussians");
127        let gaussians = model_paths
128            .iter()
129            .map(|model_path| {
130                log::debug!("Reading model from {model_path}");
131                [GaussiansSource::Ply, GaussiansSource::Spz]
132                    .into_iter()
133                    .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
134                    .expect("gaussians")
135            })
136            .collect::<Vec<_>>();
137
138        log::debug!("Computing gaussian centroids");
139        let mut gaussian_centroids = gaussians
140            .iter()
141            .map(|g| {
142                let mut centroid = Vec3::ZERO;
143                for gaussian in g.iter_gaussian() {
144                    centroid += gaussian.pos;
145                }
146                centroid / g.len() as f32
147            })
148            .collect::<Vec<_>>();
149
150        log::debug!("Creating camera");
151        let camera = gs::Camera::new(0.1..1e4, 60f32.to_radians());
152
153        log::debug!("Creating viewer");
154        let mut viewer =
155            gs::MultiModelViewer::new(&device, config.view_formats[0]).expect("viewer");
156
157        let quat = Quat::from_axis_angle(Vec3::Z, 180f32.to_radians());
158        for (i, gaussians) in gaussians.iter().enumerate() {
159            let offset = model_offset * i as f32;
160
161            log::debug!("Pushing model {i}");
162
163            viewer.insert_model(&device, i, gaussians);
164            viewer
165                .update_model_transform(&queue, &i, offset, quat, Vec3::ONE)
166                .expect("update model");
167
168            gaussian_centroids[i] = quat.mul_vec3(gaussian_centroids[i]) + offset;
169        }
170
171        log::info!("System initialized");
172
173        Self {
174            surface,
175            device,
176            queue,
177            config,
178
179            camera,
180            gaussians,
181            gaussian_centroids,
182            viewer,
183        }
184    }
185
186    fn update(&mut self, input: &core::Input, delta_time: f32) {
187        const SPEED: f32 = 1.0;
188
189        let mut forward = 0.0;
190        if input.held_keys.contains(&KeyCode::KeyW) {
191            forward += SPEED * delta_time;
192        }
193        if input.held_keys.contains(&KeyCode::KeyS) {
194            forward -= SPEED * delta_time;
195        }
196
197        let mut right = 0.0;
198        if input.held_keys.contains(&KeyCode::KeyD) {
199            right += SPEED * delta_time;
200        }
201        if input.held_keys.contains(&KeyCode::KeyA) {
202            right -= SPEED * delta_time;
203        }
204
205        self.camera.move_by(forward, right);
206
207        let mut up = 0.0;
208        if input.held_keys.contains(&KeyCode::Space) {
209            up += SPEED * delta_time;
210        }
211        if input.held_keys.contains(&KeyCode::ShiftLeft) {
212            up -= SPEED * delta_time;
213        }
214
215        self.camera.move_up(up);
216
217        // Camera rotation
218        const SENSITIVITY: f32 = 0.15;
219
220        let yaw = input.mouse_diff.x * SENSITIVITY * delta_time;
221        let pitch = input.mouse_diff.y * SENSITIVITY * delta_time;
222
223        self.camera.pitch_by(-pitch);
224        self.camera.yaw_by(-yaw);
225
226        // Update the viewer
227        self.viewer.update_camera(
228            &self.queue,
229            &self.camera,
230            uvec2(self.config.width, self.config.height),
231        );
232    }
233
234    fn render(&mut self) {
235        let texture = match self.surface.get_current_texture() {
236            Ok(texture) => texture,
237            Err(e) => {
238                log::error!("Failed to get current texture: {e:?}");
239                return;
240            }
241        };
242        let texture_view = texture.texture.create_view(&wgpu::TextureViewDescriptor {
243            label: Some("Texture View"),
244            format: Some(self.config.view_formats[0]),
245            ..Default::default()
246        });
247
248        let mut encoder = self
249            .device
250            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
251                label: Some("Command Encoder"),
252            });
253
254        let mut render_keys = self
255            .gaussian_centroids
256            .iter()
257            .enumerate()
258            .map(|(i, centroid)| (i, centroid - self.camera.pos))
259            .collect::<Vec<_>>();
260
261        render_keys.sort_by(|(_, a), (_, b)| {
262            a.length()
263                .partial_cmp(&b.length())
264                .unwrap_or(std::cmp::Ordering::Equal)
265        });
266
267        let render_keys = render_keys
268            .into_iter()
269            .rev()
270            .map(|(i, _)| i)
271            .collect::<Vec<_>>();
272
273        self.viewer
274            .render(
275                &mut encoder,
276                &texture_view,
277                render_keys.iter().collect::<Vec<_>>().as_slice(),
278            )
279            .expect("render");
280
281        self.queue.submit(std::iter::once(encoder.finish()));
282        if let Err(e) = self.device.poll(wgpu::PollType::wait_indefinitely()) {
283            log::error!("Failed to poll device: {e:?}");
284        }
285        texture.present();
286    }
287
288    fn resize(&mut self, size: winit::dpi::PhysicalSize<u32>) {
289        if size.width > 0 && size.height > 0 {
290            self.config.width = size.width;
291            self.config.height = size.height;
292            self.surface.configure(&self.device, &self.config);
293        }
294    }
295}