1use crate::{
2 CameraBuffer, GaussiansDepthBuffer, IndirectArgsBuffer, IndirectIndicesBuffer,
3 PreprocessorCreateError, RadixSortIndirectArgsBuffer,
4 core::{
5 BufferWrapper, ComputeBundle, ComputeBundleBuilder, GaussianPod, GaussianTransformBuffer,
6 GaussiansBuffer, ModelTransformBuffer,
7 },
8 wesl_utils,
9};
10
11#[cfg(feature = "viewer-selection")]
12use crate::{editor::SelectionBuffer, selection};
13
14#[derive(Debug)]
18pub struct Preprocessor<G: GaussianPod, B = wgpu::BindGroup> {
19 #[allow(dead_code)]
21 bind_group_layout: wgpu::BindGroupLayout,
22 bind_group: B,
24 pre_bundle: ComputeBundle<()>,
26 bundle: ComputeBundle<()>,
28 post_bundle: ComputeBundle<()>,
30 gaussian_pod_marker: std::marker::PhantomData<G>,
32}
33
34impl<G: GaussianPod, B> Preprocessor<G, B> {
35 #[allow(clippy::too_many_arguments)]
37 pub fn create_bind_group(
38 &self,
39 device: &wgpu::Device,
40 camera: &CameraBuffer,
41 model_transform: &ModelTransformBuffer,
42 gaussian_transform: &GaussianTransformBuffer,
43 gaussians: &GaussiansBuffer<G>,
44 indirect_args: &IndirectArgsBuffer,
45 radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
46 indirect_indices: &IndirectIndicesBuffer,
47 gaussians_depth: &GaussiansDepthBuffer,
48 #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
49 #[cfg(feature = "viewer-selection")]
50 invert_selection: &selection::PreprocessorInvertSelectionBuffer,
51 ) -> wgpu::BindGroup {
52 Preprocessor::create_bind_group_static(
53 device,
54 &self.bind_group_layout,
55 camera,
56 model_transform,
57 gaussian_transform,
58 gaussians,
59 indirect_args,
60 radix_sort_indirect_args,
61 indirect_indices,
62 gaussians_depth,
63 #[cfg(feature = "viewer-selection")]
64 selection,
65 #[cfg(feature = "viewer-selection")]
66 invert_selection,
67 )
68 }
69
70 pub fn workgroup_size(&self) -> u32 {
72 self.bundle.workgroup_size()
73 }
74
75 pub fn bind_group_layout(&self) -> &wgpu::BindGroupLayout {
77 &self.bind_group_layout
78 }
79}
80
81impl<G: GaussianPod> Preprocessor<G> {
82 const LABEL: &str = "Preprocessor";
84
85 const MAIN_SHADER: &str = "wgpu_3dgs_viewer::preprocess";
87
88 pub const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
90 wgpu::BindGroupLayoutDescriptor {
91 label: Some("Preprocessor Bind Group Layout"),
92 entries: &[
93 wgpu::BindGroupLayoutEntry {
95 binding: 0,
96 visibility: wgpu::ShaderStages::COMPUTE,
97 ty: wgpu::BindingType::Buffer {
98 ty: wgpu::BufferBindingType::Uniform,
99 has_dynamic_offset: false,
100 min_binding_size: None,
101 },
102 count: None,
103 },
104 wgpu::BindGroupLayoutEntry {
106 binding: 1,
107 visibility: wgpu::ShaderStages::COMPUTE,
108 ty: wgpu::BindingType::Buffer {
109 ty: wgpu::BufferBindingType::Uniform,
110 has_dynamic_offset: false,
111 min_binding_size: None,
112 },
113 count: None,
114 },
115 wgpu::BindGroupLayoutEntry {
117 binding: 2,
118 visibility: wgpu::ShaderStages::COMPUTE,
119 ty: wgpu::BindingType::Buffer {
120 ty: wgpu::BufferBindingType::Uniform,
121 has_dynamic_offset: false,
122 min_binding_size: None,
123 },
124 count: None,
125 },
126 wgpu::BindGroupLayoutEntry {
128 binding: 3,
129 visibility: wgpu::ShaderStages::COMPUTE,
130 ty: wgpu::BindingType::Buffer {
131 ty: wgpu::BufferBindingType::Storage { read_only: true },
132 has_dynamic_offset: false,
133 min_binding_size: None,
134 },
135 count: None,
136 },
137 wgpu::BindGroupLayoutEntry {
139 binding: 4,
140 visibility: wgpu::ShaderStages::COMPUTE,
141 ty: wgpu::BindingType::Buffer {
142 ty: wgpu::BufferBindingType::Storage { read_only: false },
143 has_dynamic_offset: false,
144 min_binding_size: None,
145 },
146 count: None,
147 },
148 wgpu::BindGroupLayoutEntry {
150 binding: 5,
151 visibility: wgpu::ShaderStages::COMPUTE,
152 ty: wgpu::BindingType::Buffer {
153 ty: wgpu::BufferBindingType::Storage { read_only: false },
154 has_dynamic_offset: false,
155 min_binding_size: None,
156 },
157 count: None,
158 },
159 wgpu::BindGroupLayoutEntry {
161 binding: 6,
162 visibility: wgpu::ShaderStages::COMPUTE,
163 ty: wgpu::BindingType::Buffer {
164 ty: wgpu::BufferBindingType::Storage { read_only: false },
165 has_dynamic_offset: false,
166 min_binding_size: None,
167 },
168 count: None,
169 },
170 wgpu::BindGroupLayoutEntry {
172 binding: 7,
173 visibility: wgpu::ShaderStages::COMPUTE,
174 ty: wgpu::BindingType::Buffer {
175 ty: wgpu::BufferBindingType::Storage { read_only: false },
176 has_dynamic_offset: false,
177 min_binding_size: None,
178 },
179 count: None,
180 },
181 #[cfg(feature = "viewer-selection")]
183 wgpu::BindGroupLayoutEntry {
184 binding: 8,
185 visibility: wgpu::ShaderStages::COMPUTE,
186 ty: wgpu::BindingType::Buffer {
187 ty: wgpu::BufferBindingType::Storage { read_only: true },
188 has_dynamic_offset: false,
189 min_binding_size: None,
190 },
191 count: None,
192 },
193 #[cfg(feature = "viewer-selection")]
195 wgpu::BindGroupLayoutEntry {
196 binding: 9,
197 visibility: wgpu::ShaderStages::COMPUTE,
198 ty: wgpu::BindingType::Buffer {
199 ty: wgpu::BufferBindingType::Uniform,
200 has_dynamic_offset: false,
201 min_binding_size: None,
202 },
203 count: None,
204 },
205 ],
206 };
207
208 #[allow(clippy::too_many_arguments)]
210 pub fn new(
211 device: &wgpu::Device,
212 camera: &CameraBuffer,
213 model_transform: &ModelTransformBuffer,
214 gaussian_transform: &GaussianTransformBuffer,
215 gaussians: &GaussiansBuffer<G>,
216 indirect_args: &IndirectArgsBuffer,
217 radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
218 indirect_indices: &IndirectIndicesBuffer,
219 gaussians_depth: &GaussiansDepthBuffer,
220 #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
221 #[cfg(feature = "viewer-selection")]
222 invert_selection: &selection::PreprocessorInvertSelectionBuffer,
223 ) -> Result<Self, PreprocessorCreateError> {
224 if (device.limits().max_storage_buffer_binding_size as wgpu::BufferAddress)
225 < gaussians.buffer().size()
226 {
227 return Err(PreprocessorCreateError::ModelSizeExceedsDeviceLimit {
228 model_size: gaussians.buffer().size(),
229 device_limit: device.limits().max_storage_buffer_binding_size,
230 });
231 }
232
233 let this = Preprocessor::new_without_bind_group(device)?;
234
235 log::debug!("Creating preprocessor bind group");
236 let bind_group = this.create_bind_group(
237 device,
238 camera,
239 model_transform,
240 gaussian_transform,
241 gaussians,
242 indirect_args,
243 radix_sort_indirect_args,
244 indirect_indices,
245 gaussians_depth,
246 #[cfg(feature = "viewer-selection")]
247 selection,
248 #[cfg(feature = "viewer-selection")]
249 invert_selection,
250 );
251
252 Ok(Self {
253 bind_group_layout: this.bind_group_layout,
254 bind_group,
255 pre_bundle: this.pre_bundle,
256 bundle: this.bundle,
257 post_bundle: this.post_bundle,
258 gaussian_pod_marker: std::marker::PhantomData,
259 })
260 }
261
262 pub fn preprocess(&self, encoder: &mut wgpu::CommandEncoder, gaussian_count: u32) {
264 self.pre_bundle.dispatch(encoder, 1, [&self.bind_group]);
265
266 self.bundle
267 .dispatch(encoder, gaussian_count, [&self.bind_group]);
268
269 self.post_bundle.dispatch(encoder, 1, [&self.bind_group]);
270 }
271
272 #[allow(clippy::too_many_arguments)]
274 fn create_bind_group_static(
275 device: &wgpu::Device,
276 bind_group_layout: &wgpu::BindGroupLayout,
277 camera: &CameraBuffer,
278 model_transform: &ModelTransformBuffer,
279 gaussian_transform: &GaussianTransformBuffer,
280 gaussians: &GaussiansBuffer<G>,
281 indirect_args: &IndirectArgsBuffer,
282 radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
283 indirect_indices: &IndirectIndicesBuffer,
284 gaussians_depth: &GaussiansDepthBuffer,
285 #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
286 #[cfg(feature = "viewer-selection")]
287 invert_selection: &selection::PreprocessorInvertSelectionBuffer,
288 ) -> wgpu::BindGroup {
289 device.create_bind_group(&wgpu::BindGroupDescriptor {
290 label: Some("Preprocessor Bind Group"),
291 layout: bind_group_layout,
292 entries: &[
293 wgpu::BindGroupEntry {
295 binding: 0,
296 resource: camera.buffer().as_entire_binding(),
297 },
298 wgpu::BindGroupEntry {
300 binding: 1,
301 resource: model_transform.buffer().as_entire_binding(),
302 },
303 wgpu::BindGroupEntry {
305 binding: 2,
306 resource: gaussian_transform.buffer().as_entire_binding(),
307 },
308 wgpu::BindGroupEntry {
310 binding: 3,
311 resource: gaussians.buffer().as_entire_binding(),
312 },
313 wgpu::BindGroupEntry {
315 binding: 4,
316 resource: indirect_args.buffer().as_entire_binding(),
317 },
318 wgpu::BindGroupEntry {
320 binding: 5,
321 resource: radix_sort_indirect_args.buffer().as_entire_binding(),
322 },
323 wgpu::BindGroupEntry {
325 binding: 6,
326 resource: indirect_indices.buffer().as_entire_binding(),
327 },
328 wgpu::BindGroupEntry {
330 binding: 7,
331 resource: gaussians_depth.buffer().as_entire_binding(),
332 },
333 #[cfg(feature = "viewer-selection")]
335 wgpu::BindGroupEntry {
336 binding: 8,
337 resource: selection.buffer().as_entire_binding(),
338 },
339 #[cfg(feature = "viewer-selection")]
341 wgpu::BindGroupEntry {
342 binding: 9,
343 resource: invert_selection.buffer().as_entire_binding(),
344 },
345 ],
346 })
347 }
348}
349
350impl<G: GaussianPod> Preprocessor<G, ()> {
351 pub fn new_without_bind_group(device: &wgpu::Device) -> Result<Self, PreprocessorCreateError> {
356 let main_shader: wesl::ModulePath = Preprocessor::<G>::MAIN_SHADER
357 .parse()
358 .expect("preprocess module path");
359
360 let wesl_compile_options = wesl::CompileOptions {
361 features: wesl::Features {
362 flags: G::features()
363 .into_iter()
364 .chain(std::iter::once((
365 "selection_buffer",
366 cfg!(feature = "viewer-selection"),
367 )))
368 .map(|(k, v)| (k.to_string(), v.into()))
369 .collect(),
370 ..Default::default()
371 },
372 ..Default::default()
373 };
374
375 let bind_group_layout =
376 device.create_bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR);
377
378 let pre_bundle = ComputeBundleBuilder::new()
379 .label(format!("Pre {}", Preprocessor::<G>::LABEL).as_str())
380 .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
381 .entry_point("pre")
382 .main_shader(main_shader.clone())
383 .wesl_compile_options(wesl_compile_options.clone())
384 .resolver(wesl_utils::resolver())
385 .build_without_bind_groups(device)?;
386
387 let bundle = ComputeBundleBuilder::new()
388 .label(Preprocessor::<G>::LABEL)
389 .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
390 .entry_point("main")
391 .main_shader(main_shader.clone())
392 .wesl_compile_options(wesl_compile_options.clone())
393 .resolver(wesl_utils::resolver())
394 .build_without_bind_groups(device)?;
395
396 let post_bundle = ComputeBundleBuilder::new()
397 .label(format!("Post {}", Preprocessor::<G>::LABEL).as_str())
398 .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
399 .entry_point("post")
400 .main_shader(main_shader)
401 .wesl_compile_options(wesl_compile_options)
402 .resolver(wesl_utils::resolver())
403 .build_without_bind_groups(device)?;
404
405 log::info!("Preprocessor created");
406
407 Ok(Self {
408 bind_group_layout,
409 bind_group: (),
410 pre_bundle,
411 bundle,
412 post_bundle,
413 gaussian_pod_marker: std::marker::PhantomData,
414 })
415 }
416
417 pub fn preprocess(
419 &self,
420 encoder: &mut wgpu::CommandEncoder,
421 bind_group: &wgpu::BindGroup,
422 gaussian_count: u32,
423 ) {
424 self.pre_bundle.dispatch(encoder, 1, [bind_group]);
425
426 self.bundle.dispatch(encoder, gaussian_count, [bind_group]);
427
428 self.post_bundle.dispatch(encoder, 1, [bind_group]);
429 }
430}