1use crate::{ComputeBundleBuildError, ComputeBundleCreateError};
2
3macro_rules! label_for_components {
4 ($label:expr, $component:expr) => {
5 format!(
6 "{} {}",
7 $label.as_deref().unwrap_or("Compute Bundle"),
8 $component,
9 )
10 };
11}
12
13#[derive(Debug, Clone)]
49pub struct ComputeBundle<B = wgpu::BindGroup> {
50 label: Option<String>,
52 workgroup_size: u32,
54 bind_group_layouts: Vec<wgpu::BindGroupLayout>,
56 bind_groups: Vec<B>,
58 pipeline: wgpu::ComputePipeline,
60}
61
62impl<B> ComputeBundle<B> {
63 pub fn create_bind_group<'a>(
73 &self,
74 device: &wgpu::Device,
75 index: usize,
76 resources: impl IntoIterator<Item = wgpu::BindingResource<'a>>,
77 ) -> Option<wgpu::BindGroup> {
78 Some(ComputeBundle::create_bind_group_static(
79 self.label.as_deref(),
80 device,
81 index,
82 self.bind_group_layouts().get(index)?,
83 resources,
84 ))
85 }
86
87 pub fn workgroup_size(&self) -> u32 {
89 self.workgroup_size
90 }
91
92 pub fn label(&self) -> Option<&str> {
94 self.label.as_deref()
95 }
96
97 pub fn bind_group_layouts(&self) -> &[wgpu::BindGroupLayout] {
102 &self.bind_group_layouts
103 }
104
105 pub fn pipeline(&self) -> &wgpu::ComputePipeline {
107 &self.pipeline
108 }
109
110 pub fn dispatch_with_bind_groups<'a>(
115 &self,
116 encoder: &mut wgpu::CommandEncoder,
117 bind_groups: impl IntoIterator<Item = &'a wgpu::BindGroup>,
118 count: u32,
119 ) {
120 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
121 label: Some(label_for_components!(self.label, "Compute Pass").as_str()),
122 timestamp_writes: None,
123 });
124
125 pass.set_pipeline(&self.pipeline);
126
127 for (i, group) in bind_groups.into_iter().enumerate() {
128 pass.set_bind_group(i as u32, group, &[]);
129 }
130
131 pass.dispatch_workgroups(count.div_ceil(self.workgroup_size()), 1, 1);
132 }
133}
134
135impl ComputeBundle {
136 #[allow(clippy::too_many_arguments)]
141 pub fn new<'a, 'b>(
142 label: Option<&str>,
143 device: &wgpu::Device,
144 bind_group_layout_descriptors: impl IntoIterator<Item = &'a wgpu::BindGroupLayoutDescriptor<'a>>,
145 resources: impl IntoIterator<Item = impl IntoIterator<Item = wgpu::BindingResource<'a>>>,
146 compilation_options: wgpu::PipelineCompilationOptions,
147 shader_source: wgpu::ShaderSource,
148 entry_point: &str,
149 workgroup_size: Option<u32>,
150 ) -> Result<Self, ComputeBundleCreateError> {
151 let this = ComputeBundle::new_without_bind_groups(
152 label,
153 device,
154 bind_group_layout_descriptors,
155 compilation_options,
156 shader_source,
157 entry_point,
158 workgroup_size,
159 )?;
160
161 let resources = resources.into_iter().collect::<Vec<_>>();
162
163 if resources.len() != this.bind_group_layouts.len() {
164 return Err(ComputeBundleCreateError::ResourceCountMismatch {
165 resource_count: resources.len(),
166 bind_group_layout_count: this.bind_group_layouts.len(),
167 });
168 }
169
170 log::debug!("Creating {} bind groups", label.unwrap_or("compute bundle"));
171 let bind_groups = this
172 .bind_group_layouts
173 .iter()
174 .zip(resources)
175 .enumerate()
176 .map(|(i, (layout, resources))| {
177 ComputeBundle::create_bind_group_static(this.label(), device, i, layout, resources)
178 })
179 .collect::<Vec<_>>();
180
181 Ok(Self {
182 label: label.map(String::from),
183 workgroup_size: this.workgroup_size,
184 bind_group_layouts: this.bind_group_layouts,
185 bind_groups,
186 pipeline: this.pipeline,
187 })
188 }
189
190 pub fn bind_groups(&self) -> &[wgpu::BindGroup] {
192 &self.bind_groups
193 }
194
195 pub fn dispatch(&self, encoder: &mut wgpu::CommandEncoder, count: u32) {
197 self.dispatch_with_bind_groups(encoder, self.bind_groups(), count);
198 }
199
200 pub fn update_bind_group(
205 &mut self,
206 index: usize,
207 bind_group: wgpu::BindGroup,
208 ) -> Option<wgpu::BindGroup> {
209 if index >= self.bind_groups.len() {
210 return None;
211 }
212
213 Some(std::mem::replace(&mut self.bind_groups[index], bind_group))
214 }
215
216 pub fn update_bind_group_with_binding_resources<'a>(
221 &mut self,
222 device: &wgpu::Device,
223 index: usize,
224 resources: impl IntoIterator<Item = wgpu::BindingResource<'a>>,
225 ) -> Option<wgpu::BindGroup> {
226 let bind_group = self.create_bind_group(device, index, resources)?;
227 self.update_bind_group(index, bind_group)
228 }
229
230 fn create_bind_group_static<'a>(
234 label: Option<&str>,
235 device: &wgpu::Device,
236 index: usize,
237 bind_group_layout: &wgpu::BindGroupLayout,
238 resources: impl IntoIterator<Item = wgpu::BindingResource<'a>>,
239 ) -> wgpu::BindGroup {
240 device.create_bind_group(&wgpu::BindGroupDescriptor {
241 label: Some(label_for_components!(label, format!("Bind Group {index}")).as_str()),
242 layout: bind_group_layout,
243 entries: &resources
244 .into_iter()
245 .enumerate()
246 .map(|(i, resource)| wgpu::BindGroupEntry {
247 binding: i as u32,
248 resource,
249 })
250 .collect::<Vec<_>>(),
251 })
252 }
253}
254
255impl ComputeBundle<()> {
256 pub fn new_without_bind_groups<'a>(
261 label: Option<&str>,
262 device: &wgpu::Device,
263 bind_group_layout_descriptors: impl IntoIterator<Item = &'a wgpu::BindGroupLayoutDescriptor<'a>>,
264 compilation_options: wgpu::PipelineCompilationOptions,
265 shader_source: wgpu::ShaderSource,
266 entry_point: &str,
267 workgroup_size: Option<u32>,
268 ) -> Result<Self, ComputeBundleCreateError> {
269 let workgroup_size_limit = device
270 .limits()
271 .max_compute_workgroup_size_x
272 .min(device.limits().max_compute_invocations_per_workgroup);
273
274 let workgroup_size = workgroup_size.unwrap_or(workgroup_size_limit);
275
276 if workgroup_size > workgroup_size_limit {
277 return Err(ComputeBundleCreateError::WorkgroupSizeExceedsDeviceLimit {
278 workgroup_size,
279 device_limit: workgroup_size_limit,
280 });
281 }
282
283 log::debug!(
284 "Creating {} bind group layouts",
285 label.unwrap_or("compute bundle")
286 );
287 let bind_group_layouts = bind_group_layout_descriptors
288 .into_iter()
289 .map(|desc| device.create_bind_group_layout(desc))
290 .collect::<Vec<_>>();
291
292 log::debug!(
293 "Creating {} pipeline layout",
294 label.unwrap_or("compute bundle"),
295 );
296 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
297 label: Some(label_for_components!(label, "Pipeline Layout").as_str()),
298 bind_group_layouts: &bind_group_layouts.iter().collect::<Vec<_>>(),
299 push_constant_ranges: &[],
300 });
301
302 log::debug!(
303 "Creating {} shader module",
304 label.unwrap_or("compute bundle"),
305 );
306 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
307 label: Some(label_for_components!(label, "Shader").as_str()),
308 source: shader_source,
309 });
310
311 let constants = [
312 &[("workgroup_size", workgroup_size as f64)],
313 compilation_options.constants,
314 ]
315 .concat();
316
317 let compilation_options = wgpu::PipelineCompilationOptions {
318 constants: &constants,
319 zero_initialize_workgroup_memory: compilation_options.zero_initialize_workgroup_memory,
320 };
321
322 log::debug!("Creating {} pipeline", label.unwrap_or("compute bundle"),);
323 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
324 label: Some(label_for_components!(label, "Pipeline").as_str()),
325 layout: Some(&pipeline_layout),
326 module: &shader,
327 entry_point: Some(entry_point),
328 compilation_options: compilation_options.clone(),
329 cache: None,
330 });
331
332 log::info!("{} created", label.unwrap_or("Compute Bundle"));
333
334 Ok(Self {
335 label: label.map(String::from),
336 workgroup_size,
337 bind_group_layouts,
338 bind_groups: Vec::new(),
339 pipeline,
340 })
341 }
342
343 pub fn dispatch<'a>(
345 &self,
346 encoder: &mut wgpu::CommandEncoder,
347 count: u32,
348 bind_groups: impl IntoIterator<Item = &'a wgpu::BindGroup>,
349 ) {
350 self.dispatch_with_bind_groups(encoder, bind_groups, count);
351 }
352}
353
354pub struct ComputeBundleBuilder<'a, R: wesl::Resolver = wesl::StandardResolver> {
365 pub label: Option<&'a str>,
366 pub bind_group_layouts: Vec<&'a wgpu::BindGroupLayoutDescriptor<'a>>,
367 pub pipeline_compile_options: wgpu::PipelineCompilationOptions<'a>,
368 pub entry_point: Option<&'a str>,
369 pub main_shader: Option<wesl::ModulePath>,
370 pub wesl_compile_options: wesl::CompileOptions,
371 pub resolver: Option<R>,
372 pub mangler: Box<dyn wesl::Mangler + Send + Sync + 'static>,
373 pub workgroup_size: Option<u32>,
374}
375
376impl ComputeBundleBuilder<'_> {
377 pub fn new() -> Self {
379 Self {
380 label: None,
381 bind_group_layouts: Vec::new(),
382 pipeline_compile_options: wgpu::PipelineCompilationOptions::default(),
383 entry_point: None,
384 main_shader: None,
385 wesl_compile_options: wesl::CompileOptions::default(),
386 resolver: None,
387 mangler: Box::new(wesl::NoMangler),
388 workgroup_size: None,
389 }
390 }
391}
392
393impl<'a, R: wesl::Resolver> ComputeBundleBuilder<'a, R> {
394 pub fn label(mut self, label: impl Into<&'a str>) -> Self {
396 self.label = Some(label.into());
397 self
398 }
399
400 pub fn bind_group_layout(
402 mut self,
403 bind_group_layout: &'a wgpu::BindGroupLayoutDescriptor<'a>,
404 ) -> Self {
405 self.bind_group_layouts.push(bind_group_layout);
406 self
407 }
408
409 pub fn bind_group_layouts(
411 mut self,
412 bind_group_layouts: impl IntoIterator<Item = &'a wgpu::BindGroupLayoutDescriptor<'a>>,
413 ) -> Self {
414 self.bind_group_layouts.extend(bind_group_layouts);
415 self
416 }
417
418 pub fn pipeline_compile_options(
420 mut self,
421 compilation_options: wgpu::PipelineCompilationOptions<'a>,
422 ) -> Self {
423 self.pipeline_compile_options = compilation_options;
424 self
425 }
426
427 pub fn entry_point(mut self, main: &'a str) -> Self {
432 self.entry_point = Some(main);
433 self
434 }
435
436 pub fn main_shader(self, main: wesl::ModulePath) -> ComputeBundleBuilder<'a, R> {
441 ComputeBundleBuilder {
442 label: self.label,
443 bind_group_layouts: self.bind_group_layouts,
444 pipeline_compile_options: self.pipeline_compile_options,
445 entry_point: self.entry_point,
446 main_shader: Some(main),
447 wesl_compile_options: self.wesl_compile_options,
448 resolver: self.resolver,
449 mangler: self.mangler,
450 workgroup_size: self.workgroup_size,
451 }
452 }
453
454 pub fn wesl_compile_options(mut self, options: wesl::CompileOptions) -> Self {
456 self.wesl_compile_options = options;
457 self
458 }
459
460 pub fn resolver<S: wesl::Resolver>(self, resolver: S) -> ComputeBundleBuilder<'a, S> {
462 ComputeBundleBuilder {
463 label: self.label,
464 bind_group_layouts: self.bind_group_layouts,
465 pipeline_compile_options: self.pipeline_compile_options,
466 entry_point: self.entry_point,
467 main_shader: self.main_shader,
468 wesl_compile_options: self.wesl_compile_options,
469 resolver: Some(resolver),
470 mangler: self.mangler,
471 workgroup_size: self.workgroup_size,
472 }
473 }
474
475 pub fn mangler(
477 self,
478 mangler: impl wesl::Mangler + Send + Sync + 'static,
479 ) -> ComputeBundleBuilder<'a, R> {
480 ComputeBundleBuilder {
481 label: self.label,
482 bind_group_layouts: self.bind_group_layouts,
483 pipeline_compile_options: self.pipeline_compile_options,
484 entry_point: self.entry_point,
485 main_shader: self.main_shader,
486 wesl_compile_options: self.wesl_compile_options,
487 resolver: self.resolver,
488 mangler: Box::new(mangler),
489 workgroup_size: self.workgroup_size,
490 }
491 }
492
493 pub fn workgroup_size(mut self, workgroup_size: u32) -> Self {
495 self.workgroup_size = Some(workgroup_size);
496 self
497 }
498
499 pub fn build<'b>(
501 self,
502 device: &wgpu::Device,
503 resources: impl IntoIterator<Item = impl IntoIterator<Item = wgpu::BindingResource<'a>>>,
504 ) -> Result<ComputeBundle<wgpu::BindGroup>, ComputeBundleBuildError> {
505 if self.bind_group_layouts.is_empty() {
506 return Err(ComputeBundleBuildError::MissingBindGroupLayout);
507 }
508
509 let Some(resolver) = self.resolver else {
510 return Err(ComputeBundleBuildError::MissingResolver);
511 };
512
513 let Some(entry_point) = self.entry_point else {
514 return Err(ComputeBundleBuildError::MissingEntryPoint);
515 };
516
517 let Some(main_shader) = self.main_shader else {
518 return Err(ComputeBundleBuildError::MissingMainShader);
519 };
520
521 let shader_source = wgpu::ShaderSource::Wgsl(
522 wesl::compile_sourcemap(
523 &main_shader,
524 &resolver,
525 &self.mangler,
526 &self.wesl_compile_options,
527 )?
528 .to_string()
529 .into(),
530 );
531
532 ComputeBundle::new(
533 self.label,
534 device,
535 self.bind_group_layouts.into_iter().collect::<Vec<_>>(),
536 resources,
537 self.pipeline_compile_options,
538 shader_source,
539 entry_point,
540 self.workgroup_size,
541 )
542 .map_err(Into::into)
543 }
544
545 pub fn build_without_bind_groups(
547 self,
548 device: &wgpu::Device,
549 ) -> Result<ComputeBundle<()>, ComputeBundleBuildError> {
550 if self.bind_group_layouts.is_empty() {
551 return Err(ComputeBundleBuildError::MissingBindGroupLayout);
552 }
553
554 let Some(resolver) = self.resolver else {
555 return Err(ComputeBundleBuildError::MissingResolver);
556 };
557
558 let Some(entry_point) = self.entry_point else {
559 return Err(ComputeBundleBuildError::MissingEntryPoint);
560 };
561
562 let Some(main_shader) = self.main_shader else {
563 return Err(ComputeBundleBuildError::MissingMainShader);
564 };
565
566 let shader_source = wgpu::ShaderSource::Wgsl(
567 wesl::compile_sourcemap(
568 &main_shader,
569 &resolver,
570 &self.mangler,
571 &self.wesl_compile_options,
572 )?
573 .to_string()
574 .into(),
575 );
576
577 Ok(ComputeBundle::new_without_bind_groups(
578 self.label,
579 device,
580 self.bind_group_layouts.into_iter().collect::<Vec<_>>(),
581 self.pipeline_compile_options,
582 shader_source,
583 entry_point,
584 self.workgroup_size,
585 )?)
586 }
587}
588
589impl Default for ComputeBundleBuilder<'_> {
590 fn default() -> Self {
591 Self::new()
592 }
593}