simple_wgpu/
compute_pipeline.rs

1use wgpu::PipelineCompilationOptions;
2
3use crate::{
4    bind_group::BindGroup, context::Context, pipeline_layout::PipelineLayout, shader::EntryPoint,
5};
6
7/// A compute pipeline
8///
9/// Loosely equivalent to [wgpu::ComputePipeline]
10#[derive(Clone, Debug)]
11pub struct ComputePipeline {
12    entry_point: EntryPoint,
13    label: Option<String>,
14}
15
16#[derive(Clone, Hash, PartialEq, Eq)]
17pub(crate) struct ComputePipelineCacheKey {
18    layout: PipelineLayout,
19    entry_point: EntryPoint,
20}
21
22impl ComputePipeline {
23    pub(crate) fn get_or_build(
24        &self,
25        context: &Context,
26        bind_groups: &[BindGroup],
27    ) -> wgpu::ComputePipeline {
28        let layout = PipelineLayout {
29            bind_group_layouts: bind_groups.iter().map(|b| b.build_layout()).collect(),
30        };
31
32        let key = ComputePipelineCacheKey {
33            layout: layout.clone(),
34            entry_point: self.entry_point.clone(),
35        };
36
37        let mut pipeline_cache = context.caches.compute_pipeline_cache.borrow_mut();
38
39        pipeline_cache
40            .get_or_insert_with(key, || {
41                let layout = layout.get_or_build(context);
42
43                context
44                    .device()
45                    .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
46                        layout: Some(&layout),
47                        module: &self.entry_point.shader,
48                        entry_point: Some(&self.entry_point.entry_point),
49                        label: self.label.as_deref(),
50                        cache: None,
51                        compilation_options: PipelineCompilationOptions::default(),
52                    })
53            })
54            .clone()
55    }
56}
57
58/// Builds a [ComputePipeline]
59#[derive(Clone)]
60pub struct ComputePipelineBuilder {
61    entry_point: EntryPoint,
62    label: Option<String>,
63}
64
65impl ComputePipelineBuilder {
66    pub fn with_entry_point(entry_point: &EntryPoint) -> Self {
67        Self {
68            entry_point: entry_point.clone(),
69            label: None,
70        }
71    }
72
73    /// Set the optional debug name. This may appear in error messages and GPU profiler traces
74    pub fn label(mut self, label: &str) -> Self {
75        self.label = Some(label.into());
76        self
77    }
78
79    pub fn build(self) -> ComputePipeline {
80        ComputePipeline {
81            entry_point: self.entry_point,
82            label: self.label,
83        }
84    }
85}