Skip to main content

simple/
simple.rs

1//! This example renders a simple 3D Gaussian splatting model.
2//!
3//! For example, to render the model with a size of 1.2 and max standard deviation of 2.0, run:
4//!
5//! ```sh
6//! cargo run --example simple -- -m "path/to/model.ply" --size 1.2 --std-dev 2.0
7//! ```
8//!
9//! To view more options and the controls, run with `--help`:
10//!
11//! ```sh
12//! cargo run --example simple -- --help
13//! ```
14
15use std::sync::Arc;
16
17use clap::Parser;
18use colored::Colorize;
19use glam::*;
20use winit::{error::EventLoopError, event_loop::EventLoop, keyboard::KeyCode, window::Window};
21
22use wgpu_3dgs_viewer as gs;
23use wgpu_3dgs_viewer::core::{GaussianMaxStdDev, GaussiansSource};
24
25mod utils;
26use utils::core;
27
28/// The command line arguments.
29#[derive(Parser, Debug)]
30#[command(
31    version,
32    about,
33    long_about = "\
34    A 3D Gaussian splatting viewer written in Rust using wgpu.\n\
35    \n\
36    Use W, A, S, D, Space, Shift to move, use mouse to rotate.\n\
37    "
38)]
39struct Args {
40    /// Path to the .ply file.
41    #[arg(short, long)]
42    model: String,
43
44    /// The size of the Gaussians.
45    #[arg(long, default_value_t = 1.0)]
46    size: f32,
47
48    /// The display mode of the Gaussians.
49    #[arg(long, value_enum, default_value_t = DisplayMode::Splat, ignore_case = true)]
50    mode: DisplayMode,
51
52    /// The SH degree of the Gaussians.
53    #[arg(long, default_value_t = 3, value_parser = clap::value_parser!(u8).range(0..=3))]
54    sh_degree: u8,
55
56    /// Whether to hide SH0.
57    #[arg(long, default_value_t)]
58    no_sh0: bool,
59
60    /// The max standard deviation for the Gaussians.
61    #[arg(long, default_value_t = 3.0)]
62    std_dev: f32,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
66enum DisplayMode {
67    Splat,
68    Ellipse,
69    Point,
70}
71
72fn main() -> Result<(), EventLoopError> {
73    env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
74
75    let args = Args::parse();
76
77    if !(0.0..=3.0).contains(&args.std_dev) {
78        eprintln!(
79            "{} invalid value '{}' for '{}': {} is not in 0.0..=3.0\n\nFor more information, try '{}'.",
80            "error:".red().bold(),
81            args.std_dev.to_string().yellow(),
82            "--std-dev <STD_DEV>".bold(),
83            args.std_dev,
84            "--help".bold()
85        );
86        std::process::exit(1);
87    }
88
89    let event_loop = EventLoop::new()?;
90    event_loop.run_app(&mut core::App::<System>::new(args))?;
91    Ok(())
92}
93
94/// The application system.
95#[allow(dead_code)]
96struct System {
97    surface: wgpu::Surface<'static>,
98    queue: wgpu::Queue,
99    device: wgpu::Device,
100    config: wgpu::SurfaceConfiguration,
101
102    camera: gs::Camera,
103    gaussians: gs::core::Gaussians,
104    viewer: gs::Viewer,
105}
106
107impl core::System for System {
108    type Args = Args;
109
110    async fn init(window: Arc<Window>, args: &Args) -> Self {
111        let model_path = &args.model;
112        let size = window.inner_size();
113
114        log::debug!("Creating wgpu instance");
115        let instance =
116            wgpu::Instance::new(wgpu::InstanceDescriptor::new_without_display_handle_from_env());
117
118        log::debug!("Creating window surface");
119        let surface = instance.create_surface(window.clone()).expect("surface");
120
121        log::debug!("Requesting adapter");
122        let adapter = instance
123            .request_adapter(&wgpu::RequestAdapterOptions {
124                power_preference: wgpu::PowerPreference::HighPerformance,
125                compatible_surface: Some(&surface),
126                force_fallback_adapter: false,
127            })
128            .await
129            .expect("adapter");
130
131        log::debug!("Requesting device");
132        let (device, queue) = adapter
133            .request_device(&wgpu::DeviceDescriptor {
134                label: Some("Device"),
135                required_limits: adapter.limits(),
136                ..Default::default()
137            })
138            .await
139            .expect("device");
140
141        let surface_caps = surface.get_capabilities(&adapter);
142        let surface_format = surface_caps.formats[0];
143        let config = wgpu::SurfaceConfiguration {
144            usage: wgpu::TextureUsages::RENDER_ATTACHMENT,
145            format: surface_format,
146            width: size.width.max(1),
147            height: size.height.max(1),
148            present_mode: surface_caps.present_modes[0],
149            alpha_mode: surface_caps.alpha_modes[0],
150            view_formats: vec![surface_format.remove_srgb_suffix()],
151            desired_maximum_frame_latency: 2,
152        };
153
154        log::debug!("Configuring surface");
155        surface.configure(&device, &config);
156
157        log::debug!("Creating gaussians");
158        let gaussians = [GaussiansSource::Ply, GaussiansSource::Spz]
159            .into_iter()
160            .find_map(|source| gs::core::Gaussians::read_from_file(model_path, source).ok())
161            .expect("gaussians");
162
163        log::debug!("Creating camera");
164        let camera = gs::Camera::new(0.1..1e4, 60f32.to_radians());
165
166        log::debug!("Creating viewer");
167        let mut viewer =
168            gs::Viewer::new(&device, config.view_formats[0], &gaussians).expect("viewer");
169        viewer.update_model_transform(
170            &queue,
171            Vec3::ZERO,
172            Quat::from_axis_angle(Vec3::Z, 180f32.to_radians()),
173            Vec3::ONE,
174        );
175
176        viewer.update_gaussian_transform(
177            &queue,
178            args.size,
179            match args.mode {
180                DisplayMode::Splat => gs::core::GaussianDisplayMode::Splat,
181                DisplayMode::Ellipse => gs::core::GaussianDisplayMode::Ellipse,
182                DisplayMode::Point => gs::core::GaussianDisplayMode::Point,
183            },
184            gs::core::GaussianShDegree::new(args.sh_degree).expect("sh degree"),
185            args.no_sh0,
186            GaussianMaxStdDev::new(args.std_dev).expect("max std dev"),
187        );
188
189        log::info!("System initialized");
190
191        Self {
192            surface,
193            device,
194            queue,
195            config,
196
197            camera,
198            gaussians,
199            viewer,
200        }
201    }
202
203    fn update(&mut self, input: &core::Input, delta_time: f32) {
204        // Camera movement
205        const SPEED: f32 = 1.0;
206
207        let mut forward = 0.0;
208        if input.held_keys.contains(&KeyCode::KeyW) {
209            forward += SPEED * delta_time;
210        }
211        if input.held_keys.contains(&KeyCode::KeyS) {
212            forward -= SPEED * delta_time;
213        }
214
215        let mut right = 0.0;
216        if input.held_keys.contains(&KeyCode::KeyD) {
217            right += SPEED * delta_time;
218        }
219        if input.held_keys.contains(&KeyCode::KeyA) {
220            right -= SPEED * delta_time;
221        }
222
223        self.camera.move_by(forward, right);
224
225        let mut up = 0.0;
226        if input.held_keys.contains(&KeyCode::Space) {
227            up += SPEED * delta_time;
228        }
229        if input.held_keys.contains(&KeyCode::ShiftLeft) {
230            up -= SPEED * delta_time;
231        }
232
233        self.camera.move_up(up);
234
235        // Camera rotation
236        const SENSITIVITY: f32 = 0.15;
237
238        let yaw = input.mouse_diff.x * SENSITIVITY * delta_time;
239        let pitch = input.mouse_diff.y * SENSITIVITY * delta_time;
240
241        self.camera.pitch_by(-pitch);
242        self.camera.yaw_by(-yaw);
243
244        // Update the viewer
245        self.viewer.update_camera(
246            &self.queue,
247            &self.camera,
248            uvec2(self.config.width, self.config.height),
249        );
250    }
251
252    fn render(&mut self) {
253        let texture = match self.surface.get_current_texture() {
254            wgpu::CurrentSurfaceTexture::Success(texture)
255            | wgpu::CurrentSurfaceTexture::Suboptimal(texture) => texture,
256            e => {
257                log::error!("Failed to get current texture: {e:?}");
258                return;
259            }
260        };
261        let texture_view = texture.texture.create_view(&wgpu::TextureViewDescriptor {
262            label: Some("Texture View"),
263            format: Some(self.config.view_formats[0]),
264            ..Default::default()
265        });
266
267        let mut encoder = self
268            .device
269            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
270                label: Some("Command Encoder"),
271            });
272
273        self.viewer.render(&mut encoder, &texture_view);
274
275        self.queue.submit(std::iter::once(encoder.finish()));
276        if let Err(e) = self.device.poll(wgpu::PollType::wait_indefinitely()) {
277            log::error!("Failed to poll device: {e:?}");
278        }
279        texture.present();
280    }
281
282    fn resize(&mut self, size: winit::dpi::PhysicalSize<u32>) {
283        if size.width > 0 && size.height > 0 {
284            self.config.width = size.width;
285            self.config.height = size.height;
286            self.surface.configure(&self.device, &self.config);
287        }
288    }
289}