wgpu_3dgs_core/
compute_bundle.rs

1use crate::{ComputeBundleBuildError, ComputeBundleCreateError};
2
3macro_rules! label_for_components {
4    ($label:expr, $component:expr) => {
5        format!(
6            "{} {}",
7            $label.as_deref().unwrap_or("Compute Bundle"),
8            $component,
9        )
10    };
11}
12
13/// A bundle of [`wgpu::ComputePipeline`], its [`wgpu::BindGroupLayout`]
14/// and optionally [`wgpu::BindGroup`].
15///
16/// ## Overview
17///
18/// This is an abstraction of a compute pipeline with its associated resources, so that any
19/// compute operations can be easily setup and dispatched.
20///
21/// It is recommended to use [`ComputeBundleBuilder`] to create a compute bundle
22///
23/// ## Shader Format
24///
25/// The compute shader is suggested to be in the following form:
26///
27/// ```wgsl
28/// override workgroup_size: u32;
29///
30/// @compute @workgroup_size(workgroup_size)
31/// fn main(@builtin(global_invocation_id) id: vec3<u32>) {
32///     let index = id.x;
33///
34///     if index >= arrayLength(&data) {
35///         return;
36///     }
37///
38///     // Do something with `data[index]`
39/// }
40/// ```
41///
42/// - `workgroup_size` is an overridable variable of type `u32`.
43/// - The entry point function (here `main`) must have the `@compute` attribute and a
44///   `@workgroup_size(workgroup_size)` attribute.
45/// - The entry point function is suggested to have a parameter with
46///   [`@builtin(global_invocation_id)`](https://www.w3.org/TR/WGSL/#global-invocation-id-builtin-value)
47///   attribute to get the global invocation ID for indexing into the data.
48#[derive(Debug, Clone)]
49pub struct ComputeBundle<B = wgpu::BindGroup> {
50    /// The label of the compute bundle.
51    label: Option<String>,
52    /// The workgroup size.
53    workgroup_size: u32,
54    /// The bind group layouts.
55    bind_group_layouts: Vec<wgpu::BindGroupLayout>,
56    /// The bind groups.
57    bind_groups: Vec<B>,
58    /// The compute pipeline.
59    pipeline: wgpu::ComputePipeline,
60}
61
62impl<B> ComputeBundle<B> {
63    /// Create the bind group at the given index.
64    ///
65    /// `index` refers to the index in [`ComputeBundle::bind_group_layouts`].
66    ///
67    /// Returns [`None`] if the `index` is out of bounds.
68    ///
69    /// As a good practice, if you are designing API for others to use, do not let the user
70    /// create bind groups manually as they will have to make sure the binding resources match
71    /// the layout.
72    pub fn create_bind_group<'a>(
73        &self,
74        device: &wgpu::Device,
75        index: usize,
76        resources: impl IntoIterator<Item = wgpu::BindingResource<'a>>,
77    ) -> Option<wgpu::BindGroup> {
78        Some(ComputeBundle::create_bind_group_static(
79            self.label.as_deref(),
80            device,
81            index,
82            self.bind_group_layouts().get(index)?,
83            resources,
84        ))
85    }
86
87    /// Get the number of invocations in one workgroup.
88    pub fn workgroup_size(&self) -> u32 {
89        self.workgroup_size
90    }
91
92    /// Get the label.
93    pub fn label(&self) -> Option<&str> {
94        self.label.as_deref()
95    }
96
97    /// Get the bind group layouts.
98    ///
99    /// This corresponds to the `bind_group_layout_descriptors` provided
100    /// when creating the compute bundle.
101    pub fn bind_group_layouts(&self) -> &[wgpu::BindGroupLayout] {
102        &self.bind_group_layouts
103    }
104
105    /// Get the compute pipeline.
106    pub fn pipeline(&self) -> &wgpu::ComputePipeline {
107        &self.pipeline
108    }
109
110    /// Dispatch the compute bundle for `count` instances with provided bind group.
111    ///
112    /// Each bind group in `bind_groups` corresponds to the bind group layout
113    /// at the same index in [`ComputeBundle::bind_group_layouts`].
114    pub fn dispatch_with_bind_groups<'a>(
115        &self,
116        encoder: &mut wgpu::CommandEncoder,
117        bind_groups: impl IntoIterator<Item = &'a wgpu::BindGroup>,
118        count: u32,
119    ) {
120        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
121            label: Some(label_for_components!(self.label, "Compute Pass").as_str()),
122            timestamp_writes: None,
123        });
124
125        pass.set_pipeline(&self.pipeline);
126
127        for (i, group) in bind_groups.into_iter().enumerate() {
128            pass.set_bind_group(i as u32, group, &[]);
129        }
130
131        pass.dispatch_workgroups(count.div_ceil(self.workgroup_size()), 1, 1);
132    }
133}
134
135impl ComputeBundle {
136    /// Create a new compute bundle.
137    ///
138    /// `shader_source` requires an overridable variable `workgroup_size` of `u32`, see docs of
139    /// [`ComputeBundle`].
140    #[allow(clippy::too_many_arguments)]
141    pub fn new<'a, 'b>(
142        label: Option<&str>,
143        device: &wgpu::Device,
144        bind_group_layout_descriptors: impl IntoIterator<Item = &'a wgpu::BindGroupLayoutDescriptor<'a>>,
145        resources: impl IntoIterator<Item = impl IntoIterator<Item = wgpu::BindingResource<'a>>>,
146        compilation_options: wgpu::PipelineCompilationOptions,
147        shader_source: wgpu::ShaderSource,
148        entry_point: &str,
149        workgroup_size: Option<u32>,
150    ) -> Result<Self, ComputeBundleCreateError> {
151        let this = ComputeBundle::new_without_bind_groups(
152            label,
153            device,
154            bind_group_layout_descriptors,
155            compilation_options,
156            shader_source,
157            entry_point,
158            workgroup_size,
159        )?;
160
161        let resources = resources.into_iter().collect::<Vec<_>>();
162
163        if resources.len() != this.bind_group_layouts.len() {
164            return Err(ComputeBundleCreateError::ResourceCountMismatch {
165                resource_count: resources.len(),
166                bind_group_layout_count: this.bind_group_layouts.len(),
167            });
168        }
169
170        log::debug!("Creating {} bind groups", label.unwrap_or("compute bundle"));
171        let bind_groups = this
172            .bind_group_layouts
173            .iter()
174            .zip(resources)
175            .enumerate()
176            .map(|(i, (layout, resources))| {
177                ComputeBundle::create_bind_group_static(this.label(), device, i, layout, resources)
178            })
179            .collect::<Vec<_>>();
180
181        Ok(Self {
182            label: label.map(String::from),
183            workgroup_size: this.workgroup_size,
184            bind_group_layouts: this.bind_group_layouts,
185            bind_groups,
186            pipeline: this.pipeline,
187        })
188    }
189
190    /// Get the bind groups.
191    pub fn bind_groups(&self) -> &[wgpu::BindGroup] {
192        &self.bind_groups
193    }
194
195    /// Dispatch the compute bundle for `count` instances.
196    pub fn dispatch(&self, encoder: &mut wgpu::CommandEncoder, count: u32) {
197        self.dispatch_with_bind_groups(encoder, self.bind_groups(), count);
198    }
199
200    /// Update the bind group at `index`.
201    ///
202    /// Returns [`Some`] of the previous bind group if it was updated,
203    /// or [`None`] if the index is out of bounds.
204    pub fn update_bind_group(
205        &mut self,
206        index: usize,
207        bind_group: wgpu::BindGroup,
208    ) -> Option<wgpu::BindGroup> {
209        if index >= self.bind_groups.len() {
210            return None;
211        }
212
213        Some(std::mem::replace(&mut self.bind_groups[index], bind_group))
214    }
215
216    /// Update the bind group at `index` with the provided binding resources.
217    ///
218    /// Returns [`Some`] of the previous bind group if it was updated,
219    /// or [`None`] if the index is out of bounds.
220    pub fn update_bind_group_with_binding_resources<'a>(
221        &mut self,
222        device: &wgpu::Device,
223        index: usize,
224        resources: impl IntoIterator<Item = wgpu::BindingResource<'a>>,
225    ) -> Option<wgpu::BindGroup> {
226        let bind_group = self.create_bind_group(device, index, resources)?;
227        self.update_bind_group(index, bind_group)
228    }
229
230    /// Create a bind group statically.
231    ///
232    /// `index` is only for labeling.
233    fn create_bind_group_static<'a>(
234        label: Option<&str>,
235        device: &wgpu::Device,
236        index: usize,
237        bind_group_layout: &wgpu::BindGroupLayout,
238        resources: impl IntoIterator<Item = wgpu::BindingResource<'a>>,
239    ) -> wgpu::BindGroup {
240        device.create_bind_group(&wgpu::BindGroupDescriptor {
241            label: Some(label_for_components!(label, format!("Bind Group {index}")).as_str()),
242            layout: bind_group_layout,
243            entries: &resources
244                .into_iter()
245                .enumerate()
246                .map(|(i, resource)| wgpu::BindGroupEntry {
247                    binding: i as u32,
248                    resource,
249                })
250                .collect::<Vec<_>>(),
251        })
252    }
253}
254
255impl ComputeBundle<()> {
256    /// Create a new compute bundle without internally managed bind group.
257    ///
258    /// To create a bind group with layout matched to one of the layout in this compute bundle,
259    /// use the [`ComputeBundle::create_bind_group`] method.
260    pub fn new_without_bind_groups<'a>(
261        label: Option<&str>,
262        device: &wgpu::Device,
263        bind_group_layout_descriptors: impl IntoIterator<Item = &'a wgpu::BindGroupLayoutDescriptor<'a>>,
264        compilation_options: wgpu::PipelineCompilationOptions,
265        shader_source: wgpu::ShaderSource,
266        entry_point: &str,
267        workgroup_size: Option<u32>,
268    ) -> Result<Self, ComputeBundleCreateError> {
269        let workgroup_size_limit = device
270            .limits()
271            .max_compute_workgroup_size_x
272            .min(device.limits().max_compute_invocations_per_workgroup);
273
274        let workgroup_size = workgroup_size.unwrap_or(workgroup_size_limit);
275
276        if workgroup_size > workgroup_size_limit {
277            return Err(ComputeBundleCreateError::WorkgroupSizeExceedsDeviceLimit {
278                workgroup_size,
279                device_limit: workgroup_size_limit,
280            });
281        }
282
283        log::debug!(
284            "Creating {} bind group layouts",
285            label.unwrap_or("compute bundle")
286        );
287        let bind_group_layouts = bind_group_layout_descriptors
288            .into_iter()
289            .map(|desc| device.create_bind_group_layout(desc))
290            .collect::<Vec<_>>();
291
292        log::debug!(
293            "Creating {} pipeline layout",
294            label.unwrap_or("compute bundle"),
295        );
296        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
297            label: Some(label_for_components!(label, "Pipeline Layout").as_str()),
298            bind_group_layouts: &bind_group_layouts.iter().collect::<Vec<_>>(),
299            push_constant_ranges: &[],
300        });
301
302        log::debug!(
303            "Creating {} shader module",
304            label.unwrap_or("compute bundle"),
305        );
306        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
307            label: Some(label_for_components!(label, "Shader").as_str()),
308            source: shader_source,
309        });
310
311        let constants = [
312            &[("workgroup_size", workgroup_size as f64)],
313            compilation_options.constants,
314        ]
315        .concat();
316
317        let compilation_options = wgpu::PipelineCompilationOptions {
318            constants: &constants,
319            zero_initialize_workgroup_memory: compilation_options.zero_initialize_workgroup_memory,
320        };
321
322        log::debug!("Creating {} pipeline", label.unwrap_or("compute bundle"),);
323        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
324            label: Some(label_for_components!(label, "Pipeline").as_str()),
325            layout: Some(&pipeline_layout),
326            module: &shader,
327            entry_point: Some(entry_point),
328            compilation_options: compilation_options.clone(),
329            cache: None,
330        });
331
332        log::info!("{} created", label.unwrap_or("Compute Bundle"));
333
334        Ok(Self {
335            label: label.map(String::from),
336            workgroup_size,
337            bind_group_layouts,
338            bind_groups: Vec::new(),
339            pipeline,
340        })
341    }
342
343    /// Dispatch the compute bundle for `count` instances.
344    pub fn dispatch<'a>(
345        &self,
346        encoder: &mut wgpu::CommandEncoder,
347        count: u32,
348        bind_groups: impl IntoIterator<Item = &'a wgpu::BindGroup>,
349    ) {
350        self.dispatch_with_bind_groups(encoder, bind_groups, count);
351    }
352}
353
354/// A builder for [`ComputeBundle`].
355///
356/// The shader is compiled using the WESL compiler,
357///
358/// The following fields should be set before calling [`ComputeBundleBuilder::build`] or
359/// [`ComputeBundleBuilder::build_without_bind_groups`]:
360/// - [`ComputeBundleBuilder::bind_group_layouts`]
361/// - [`ComputeBundleBuilder::resolver`]
362/// - [`ComputeBundleBuilder::entry_point`]
363/// - [`ComputeBundleBuilder::main_shader`]
364pub struct ComputeBundleBuilder<'a, R: wesl::Resolver = wesl::StandardResolver> {
365    pub label: Option<&'a str>,
366    pub bind_group_layouts: Vec<&'a wgpu::BindGroupLayoutDescriptor<'a>>,
367    pub pipeline_compile_options: wgpu::PipelineCompilationOptions<'a>,
368    pub entry_point: Option<&'a str>,
369    pub main_shader: Option<wesl::ModulePath>,
370    pub wesl_compile_options: wesl::CompileOptions,
371    pub resolver: Option<R>,
372    pub mangler: Box<dyn wesl::Mangler + Send + Sync + 'static>,
373    pub workgroup_size: Option<u32>,
374}
375
376impl ComputeBundleBuilder<'_> {
377    /// Create a new compute bundle builder.
378    pub fn new() -> Self {
379        Self {
380            label: None,
381            bind_group_layouts: Vec::new(),
382            pipeline_compile_options: wgpu::PipelineCompilationOptions::default(),
383            entry_point: None,
384            main_shader: None,
385            wesl_compile_options: wesl::CompileOptions::default(),
386            resolver: None,
387            mangler: Box::new(wesl::NoMangler),
388            workgroup_size: None,
389        }
390    }
391}
392
393impl<'a, R: wesl::Resolver> ComputeBundleBuilder<'a, R> {
394    /// Set the label of the compute bundle.
395    pub fn label(mut self, label: impl Into<&'a str>) -> Self {
396        self.label = Some(label.into());
397        self
398    }
399
400    /// Add a [`wgpu::BindGroupLayoutDescriptor`].
401    pub fn bind_group_layout(
402        mut self,
403        bind_group_layout: &'a wgpu::BindGroupLayoutDescriptor<'a>,
404    ) -> Self {
405        self.bind_group_layouts.push(bind_group_layout);
406        self
407    }
408
409    /// Add [`wgpu::BindGroupLayoutDescriptor`]s.
410    pub fn bind_group_layouts(
411        mut self,
412        bind_group_layouts: impl IntoIterator<Item = &'a wgpu::BindGroupLayoutDescriptor<'a>>,
413    ) -> Self {
414        self.bind_group_layouts.extend(bind_group_layouts);
415        self
416    }
417
418    /// Set the [`wgpu::PipelineCompilationOptions`].
419    pub fn pipeline_compile_options(
420        mut self,
421        compilation_options: wgpu::PipelineCompilationOptions<'a>,
422    ) -> Self {
423        self.pipeline_compile_options = compilation_options;
424        self
425    }
426
427    /// Set the entry point of the compute shader.
428    ///
429    /// This should be the function name of the entry point in the compute shader, which is
430    /// passed into [`wgpu::ComputePipelineDescriptor::entry_point`].
431    pub fn entry_point(mut self, main: &'a str) -> Self {
432        self.entry_point = Some(main);
433        self
434    }
435
436    /// Set the main shader of the compute bundle.
437    ///
438    /// The shader is required to have an overridable variable `workgroup_size` of `u32`, which is
439    /// set to the workgroup size of at the entry point of the compute pipeline.
440    pub fn main_shader(self, main: wesl::ModulePath) -> ComputeBundleBuilder<'a, R> {
441        ComputeBundleBuilder {
442            label: self.label,
443            bind_group_layouts: self.bind_group_layouts,
444            pipeline_compile_options: self.pipeline_compile_options,
445            entry_point: self.entry_point,
446            main_shader: Some(main),
447            wesl_compile_options: self.wesl_compile_options,
448            resolver: self.resolver,
449            mangler: self.mangler,
450            workgroup_size: self.workgroup_size,
451        }
452    }
453
454    /// Set the [`wesl::CompileOptions`].
455    pub fn wesl_compile_options(mut self, options: wesl::CompileOptions) -> Self {
456        self.wesl_compile_options = options;
457        self
458    }
459
460    /// Set the [`wesl::Resolver`].
461    pub fn resolver<S: wesl::Resolver>(self, resolver: S) -> ComputeBundleBuilder<'a, S> {
462        ComputeBundleBuilder {
463            label: self.label,
464            bind_group_layouts: self.bind_group_layouts,
465            pipeline_compile_options: self.pipeline_compile_options,
466            entry_point: self.entry_point,
467            main_shader: self.main_shader,
468            wesl_compile_options: self.wesl_compile_options,
469            resolver: Some(resolver),
470            mangler: self.mangler,
471            workgroup_size: self.workgroup_size,
472        }
473    }
474
475    /// Set the [`wesl::Mangler`].
476    pub fn mangler(
477        self,
478        mangler: impl wesl::Mangler + Send + Sync + 'static,
479    ) -> ComputeBundleBuilder<'a, R> {
480        ComputeBundleBuilder {
481            label: self.label,
482            bind_group_layouts: self.bind_group_layouts,
483            pipeline_compile_options: self.pipeline_compile_options,
484            entry_point: self.entry_point,
485            main_shader: self.main_shader,
486            wesl_compile_options: self.wesl_compile_options,
487            resolver: self.resolver,
488            mangler: Box::new(mangler),
489            workgroup_size: self.workgroup_size,
490        }
491    }
492
493    /// Set the workgroup size.
494    pub fn workgroup_size(mut self, workgroup_size: u32) -> Self {
495        self.workgroup_size = Some(workgroup_size);
496        self
497    }
498
499    /// Build the compute bundle with bindings.
500    pub fn build<'b>(
501        self,
502        device: &wgpu::Device,
503        resources: impl IntoIterator<Item = impl IntoIterator<Item = wgpu::BindingResource<'a>>>,
504    ) -> Result<ComputeBundle<wgpu::BindGroup>, ComputeBundleBuildError> {
505        if self.bind_group_layouts.is_empty() {
506            return Err(ComputeBundleBuildError::MissingBindGroupLayout);
507        }
508
509        let Some(resolver) = self.resolver else {
510            return Err(ComputeBundleBuildError::MissingResolver);
511        };
512
513        let Some(entry_point) = self.entry_point else {
514            return Err(ComputeBundleBuildError::MissingEntryPoint);
515        };
516
517        let Some(main_shader) = self.main_shader else {
518            return Err(ComputeBundleBuildError::MissingMainShader);
519        };
520
521        let shader_source = wgpu::ShaderSource::Wgsl(
522            wesl::compile_sourcemap(
523                &main_shader,
524                &resolver,
525                &self.mangler,
526                &self.wesl_compile_options,
527            )?
528            .to_string()
529            .into(),
530        );
531
532        ComputeBundle::new(
533            self.label,
534            device,
535            self.bind_group_layouts.into_iter().collect::<Vec<_>>(),
536            resources,
537            self.pipeline_compile_options,
538            shader_source,
539            entry_point,
540            self.workgroup_size,
541        )
542        .map_err(Into::into)
543    }
544
545    /// Build the compute bundle without bind groups.
546    pub fn build_without_bind_groups(
547        self,
548        device: &wgpu::Device,
549    ) -> Result<ComputeBundle<()>, ComputeBundleBuildError> {
550        if self.bind_group_layouts.is_empty() {
551            return Err(ComputeBundleBuildError::MissingBindGroupLayout);
552        }
553
554        let Some(resolver) = self.resolver else {
555            return Err(ComputeBundleBuildError::MissingResolver);
556        };
557
558        let Some(entry_point) = self.entry_point else {
559            return Err(ComputeBundleBuildError::MissingEntryPoint);
560        };
561
562        let Some(main_shader) = self.main_shader else {
563            return Err(ComputeBundleBuildError::MissingMainShader);
564        };
565
566        let shader_source = wgpu::ShaderSource::Wgsl(
567            wesl::compile_sourcemap(
568                &main_shader,
569                &resolver,
570                &self.mangler,
571                &self.wesl_compile_options,
572            )?
573            .to_string()
574            .into(),
575        );
576
577        Ok(ComputeBundle::new_without_bind_groups(
578            self.label,
579            device,
580            self.bind_group_layouts.into_iter().collect::<Vec<_>>(),
581            self.pipeline_compile_options,
582            shader_source,
583            entry_point,
584            self.workgroup_size,
585        )?)
586    }
587}
588
589impl Default for ComputeBundleBuilder<'_> {
590    fn default() -> Self {
591        Self::new()
592    }
593}