1use crate::{
2 CameraBuffer, GaussiansDepthBuffer, IndirectArgsBuffer, IndirectIndicesBuffer,
3 PreprocessorCreateError, RadixSortIndirectArgsBuffer,
4 core::{
5 BufferWrapper, ComputeBundle, ComputeBundleBuilder, GaussianPod, GaussianTransformBuffer,
6 GaussiansBuffer, ModelTransformBuffer,
7 },
8 wesl_utils,
9};
10
11#[cfg(feature = "viewer-selection")]
12use crate::{editor::SelectionBuffer, selection};
13
14#[derive(Debug)]
18pub struct Preprocessor<G: GaussianPod, B = wgpu::BindGroup> {
19 #[allow(dead_code)]
21 bind_group_layout: wgpu::BindGroupLayout,
22 bind_group: B,
24 pre_bundle: ComputeBundle<()>,
26 bundle: ComputeBundle<()>,
28 post_bundle: ComputeBundle<()>,
30 gaussian_pod_marker: std::marker::PhantomData<G>,
32}
33
34impl<G: GaussianPod, B> Preprocessor<G, B> {
35 #[allow(clippy::too_many_arguments)]
37 pub fn create_bind_group(
38 &self,
39 device: &wgpu::Device,
40 camera: &CameraBuffer,
41 model_transform: &ModelTransformBuffer,
42 gaussian_transform: &GaussianTransformBuffer,
43 gaussians: &GaussiansBuffer<G>,
44 indirect_args: &IndirectArgsBuffer,
45 radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
46 indirect_indices: &IndirectIndicesBuffer,
47 gaussians_depth: &GaussiansDepthBuffer,
48 #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
49 #[cfg(feature = "viewer-selection")]
50 invert_selection: &selection::PreprocessorInvertSelectionBuffer,
51 ) -> wgpu::BindGroup {
52 Preprocessor::create_bind_group_static(
53 device,
54 &self.bind_group_layout,
55 camera,
56 model_transform,
57 gaussian_transform,
58 gaussians,
59 indirect_args,
60 radix_sort_indirect_args,
61 indirect_indices,
62 gaussians_depth,
63 #[cfg(feature = "viewer-selection")]
64 selection,
65 #[cfg(feature = "viewer-selection")]
66 invert_selection,
67 )
68 }
69
70 pub fn workgroup_size(&self) -> u32 {
72 self.bundle.workgroup_size()
73 }
74
75 pub fn bind_group_layout(&self) -> &wgpu::BindGroupLayout {
77 &self.bind_group_layout
78 }
79
80 pub fn pre_bundle(&self) -> &ComputeBundle<()> {
82 &self.pre_bundle
83 }
84
85 pub fn bundle(&self) -> &ComputeBundle<()> {
87 &self.bundle
88 }
89
90 pub fn post_bundle(&self) -> &ComputeBundle<()> {
92 &self.post_bundle
93 }
94}
95
96impl<G: GaussianPod> Preprocessor<G> {
97 const LABEL: &str = "Preprocessor";
99
100 const MAIN_SHADER: &str = "wgpu_3dgs_viewer::preprocess";
102
103 pub const BIND_GROUP_LAYOUT_DESCRIPTOR: wgpu::BindGroupLayoutDescriptor<'static> =
105 wgpu::BindGroupLayoutDescriptor {
106 label: Some("Preprocessor Bind Group Layout"),
107 entries: &[
108 wgpu::BindGroupLayoutEntry {
110 binding: 0,
111 visibility: wgpu::ShaderStages::COMPUTE,
112 ty: wgpu::BindingType::Buffer {
113 ty: wgpu::BufferBindingType::Uniform,
114 has_dynamic_offset: false,
115 min_binding_size: None,
116 },
117 count: None,
118 },
119 wgpu::BindGroupLayoutEntry {
121 binding: 1,
122 visibility: wgpu::ShaderStages::COMPUTE,
123 ty: wgpu::BindingType::Buffer {
124 ty: wgpu::BufferBindingType::Uniform,
125 has_dynamic_offset: false,
126 min_binding_size: None,
127 },
128 count: None,
129 },
130 wgpu::BindGroupLayoutEntry {
132 binding: 2,
133 visibility: wgpu::ShaderStages::COMPUTE,
134 ty: wgpu::BindingType::Buffer {
135 ty: wgpu::BufferBindingType::Uniform,
136 has_dynamic_offset: false,
137 min_binding_size: None,
138 },
139 count: None,
140 },
141 wgpu::BindGroupLayoutEntry {
143 binding: 3,
144 visibility: wgpu::ShaderStages::COMPUTE,
145 ty: wgpu::BindingType::Buffer {
146 ty: wgpu::BufferBindingType::Storage { read_only: true },
147 has_dynamic_offset: false,
148 min_binding_size: None,
149 },
150 count: None,
151 },
152 wgpu::BindGroupLayoutEntry {
154 binding: 4,
155 visibility: wgpu::ShaderStages::COMPUTE,
156 ty: wgpu::BindingType::Buffer {
157 ty: wgpu::BufferBindingType::Storage { read_only: false },
158 has_dynamic_offset: false,
159 min_binding_size: None,
160 },
161 count: None,
162 },
163 wgpu::BindGroupLayoutEntry {
165 binding: 5,
166 visibility: wgpu::ShaderStages::COMPUTE,
167 ty: wgpu::BindingType::Buffer {
168 ty: wgpu::BufferBindingType::Storage { read_only: false },
169 has_dynamic_offset: false,
170 min_binding_size: None,
171 },
172 count: None,
173 },
174 wgpu::BindGroupLayoutEntry {
176 binding: 6,
177 visibility: wgpu::ShaderStages::COMPUTE,
178 ty: wgpu::BindingType::Buffer {
179 ty: wgpu::BufferBindingType::Storage { read_only: false },
180 has_dynamic_offset: false,
181 min_binding_size: None,
182 },
183 count: None,
184 },
185 wgpu::BindGroupLayoutEntry {
187 binding: 7,
188 visibility: wgpu::ShaderStages::COMPUTE,
189 ty: wgpu::BindingType::Buffer {
190 ty: wgpu::BufferBindingType::Storage { read_only: false },
191 has_dynamic_offset: false,
192 min_binding_size: None,
193 },
194 count: None,
195 },
196 #[cfg(feature = "viewer-selection")]
198 wgpu::BindGroupLayoutEntry {
199 binding: 8,
200 visibility: wgpu::ShaderStages::COMPUTE,
201 ty: wgpu::BindingType::Buffer {
202 ty: wgpu::BufferBindingType::Storage { read_only: true },
203 has_dynamic_offset: false,
204 min_binding_size: None,
205 },
206 count: None,
207 },
208 #[cfg(feature = "viewer-selection")]
210 wgpu::BindGroupLayoutEntry {
211 binding: 9,
212 visibility: wgpu::ShaderStages::COMPUTE,
213 ty: wgpu::BindingType::Buffer {
214 ty: wgpu::BufferBindingType::Uniform,
215 has_dynamic_offset: false,
216 min_binding_size: None,
217 },
218 count: None,
219 },
220 ],
221 };
222
223 #[allow(clippy::too_many_arguments)]
225 pub fn new(
226 device: &wgpu::Device,
227 camera: &CameraBuffer,
228 model_transform: &ModelTransformBuffer,
229 gaussian_transform: &GaussianTransformBuffer,
230 gaussians: &GaussiansBuffer<G>,
231 indirect_args: &IndirectArgsBuffer,
232 radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
233 indirect_indices: &IndirectIndicesBuffer,
234 gaussians_depth: &GaussiansDepthBuffer,
235 #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
236 #[cfg(feature = "viewer-selection")]
237 invert_selection: &selection::PreprocessorInvertSelectionBuffer,
238 ) -> Result<Self, PreprocessorCreateError> {
239 if (device.limits().max_storage_buffer_binding_size as wgpu::BufferAddress)
240 < gaussians.buffer().size()
241 {
242 return Err(PreprocessorCreateError::ModelSizeExceedsDeviceLimit {
243 model_size: gaussians.buffer().size(),
244 device_limit: device.limits().max_storage_buffer_binding_size,
245 });
246 }
247
248 let this = Preprocessor::new_without_bind_group(device)?;
249
250 log::debug!("Creating preprocessor bind group");
251 let bind_group = this.create_bind_group(
252 device,
253 camera,
254 model_transform,
255 gaussian_transform,
256 gaussians,
257 indirect_args,
258 radix_sort_indirect_args,
259 indirect_indices,
260 gaussians_depth,
261 #[cfg(feature = "viewer-selection")]
262 selection,
263 #[cfg(feature = "viewer-selection")]
264 invert_selection,
265 );
266
267 Ok(Self {
268 bind_group_layout: this.bind_group_layout,
269 bind_group,
270 pre_bundle: this.pre_bundle,
271 bundle: this.bundle,
272 post_bundle: this.post_bundle,
273 gaussian_pod_marker: std::marker::PhantomData,
274 })
275 }
276
277 pub fn bind_group(&self) -> &wgpu::BindGroup {
279 &self.bind_group
280 }
281
282 pub fn preprocess(&self, encoder: &mut wgpu::CommandEncoder, gaussian_count: u32) {
284 self.pre_bundle.dispatch(encoder, 1, [&self.bind_group]);
285
286 self.bundle
287 .dispatch(encoder, gaussian_count, [&self.bind_group]);
288
289 self.post_bundle.dispatch(encoder, 1, [&self.bind_group]);
290 }
291
292 #[allow(clippy::too_many_arguments)]
294 fn create_bind_group_static(
295 device: &wgpu::Device,
296 bind_group_layout: &wgpu::BindGroupLayout,
297 camera: &CameraBuffer,
298 model_transform: &ModelTransformBuffer,
299 gaussian_transform: &GaussianTransformBuffer,
300 gaussians: &GaussiansBuffer<G>,
301 indirect_args: &IndirectArgsBuffer,
302 radix_sort_indirect_args: &RadixSortIndirectArgsBuffer,
303 indirect_indices: &IndirectIndicesBuffer,
304 gaussians_depth: &GaussiansDepthBuffer,
305 #[cfg(feature = "viewer-selection")] selection: &SelectionBuffer,
306 #[cfg(feature = "viewer-selection")]
307 invert_selection: &selection::PreprocessorInvertSelectionBuffer,
308 ) -> wgpu::BindGroup {
309 device.create_bind_group(&wgpu::BindGroupDescriptor {
310 label: Some("Preprocessor Bind Group"),
311 layout: bind_group_layout,
312 entries: &[
313 wgpu::BindGroupEntry {
315 binding: 0,
316 resource: camera.buffer().as_entire_binding(),
317 },
318 wgpu::BindGroupEntry {
320 binding: 1,
321 resource: model_transform.buffer().as_entire_binding(),
322 },
323 wgpu::BindGroupEntry {
325 binding: 2,
326 resource: gaussian_transform.buffer().as_entire_binding(),
327 },
328 wgpu::BindGroupEntry {
330 binding: 3,
331 resource: gaussians.buffer().as_entire_binding(),
332 },
333 wgpu::BindGroupEntry {
335 binding: 4,
336 resource: indirect_args.buffer().as_entire_binding(),
337 },
338 wgpu::BindGroupEntry {
340 binding: 5,
341 resource: radix_sort_indirect_args.buffer().as_entire_binding(),
342 },
343 wgpu::BindGroupEntry {
345 binding: 6,
346 resource: indirect_indices.buffer().as_entire_binding(),
347 },
348 wgpu::BindGroupEntry {
350 binding: 7,
351 resource: gaussians_depth.buffer().as_entire_binding(),
352 },
353 #[cfg(feature = "viewer-selection")]
355 wgpu::BindGroupEntry {
356 binding: 8,
357 resource: selection.buffer().as_entire_binding(),
358 },
359 #[cfg(feature = "viewer-selection")]
361 wgpu::BindGroupEntry {
362 binding: 9,
363 resource: invert_selection.buffer().as_entire_binding(),
364 },
365 ],
366 })
367 }
368}
369
370impl<G: GaussianPod> Preprocessor<G, ()> {
371 pub fn new_without_bind_group(device: &wgpu::Device) -> Result<Self, PreprocessorCreateError> {
376 let main_shader: wesl::ModulePath = Preprocessor::<G>::MAIN_SHADER
377 .parse()
378 .expect("preprocess module path");
379
380 let wesl_compile_options = wesl::CompileOptions {
381 features: wesl::Features {
382 flags: G::features()
383 .into_iter()
384 .chain(std::iter::once((
385 "selection_buffer",
386 cfg!(feature = "viewer-selection"),
387 )))
388 .map(|(k, v)| (k.to_string(), v.into()))
389 .collect(),
390 ..Default::default()
391 },
392 ..Default::default()
393 };
394
395 let bind_group_layout =
396 device.create_bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR);
397
398 let pre_bundle = ComputeBundleBuilder::new()
399 .label(format!("Pre {}", Preprocessor::<G>::LABEL).as_str())
400 .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
401 .entry_point("pre")
402 .main_shader(main_shader.clone())
403 .wesl_compile_options(wesl_compile_options.clone())
404 .resolver(wesl_utils::resolver())
405 .build_without_bind_groups(device)?;
406
407 let bundle = ComputeBundleBuilder::new()
408 .label(Preprocessor::<G>::LABEL)
409 .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
410 .entry_point("main")
411 .main_shader(main_shader.clone())
412 .wesl_compile_options(wesl_compile_options.clone())
413 .resolver(wesl_utils::resolver())
414 .build_without_bind_groups(device)?;
415
416 let post_bundle = ComputeBundleBuilder::new()
417 .label(format!("Post {}", Preprocessor::<G>::LABEL).as_str())
418 .bind_group_layout(&Preprocessor::<G>::BIND_GROUP_LAYOUT_DESCRIPTOR)
419 .entry_point("post")
420 .main_shader(main_shader)
421 .wesl_compile_options(wesl_compile_options)
422 .resolver(wesl_utils::resolver())
423 .build_without_bind_groups(device)?;
424
425 log::info!("Preprocessor created");
426
427 Ok(Self {
428 bind_group_layout,
429 bind_group: (),
430 pre_bundle,
431 bundle,
432 post_bundle,
433 gaussian_pod_marker: std::marker::PhantomData,
434 })
435 }
436
437 pub fn preprocess(
439 &self,
440 encoder: &mut wgpu::CommandEncoder,
441 bind_group: &wgpu::BindGroup,
442 gaussian_count: u32,
443 ) {
444 self.pre_bundle.dispatch(encoder, 1, [bind_group]);
445
446 self.bundle.dispatch(encoder, gaussian_count, [bind_group]);
447
448 self.post_bundle.dispatch(encoder, 1, [bind_group]);
449 }
450}