screen_13/driver/
compute.rs

1//! Computing pipeline types
2
3use {
4    super::{
5        DriverError,
6        device::Device,
7        shader::{DescriptorBindingMap, PipelineDescriptorInfo, Shader, align_spriv},
8    },
9    ash::vk,
10    derive_builder::{Builder, UninitializedFieldError},
11    log::{trace, warn},
12    std::{ffi::CString, ops::Deref, sync::Arc, thread::panicking},
13};
14
15/// Smart pointer handle to a [pipeline] object.
16///
17/// Also contains information about the object.
18///
19/// ## `Deref` behavior
20///
21/// `ComputePipeline` automatically dereferences to [`vk::Pipeline`] (via the [`Deref`]
22/// trait), so you can call `vk::Pipeline`'s methods on a value of type `ComputePipeline`.
23///
24/// [pipeline]: https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkPipeline.html
25/// [deref]: core::ops::Deref
26#[derive(Debug)]
27pub struct ComputePipeline {
28    pub(crate) descriptor_bindings: DescriptorBindingMap,
29    pub(crate) descriptor_info: PipelineDescriptorInfo,
30    device: Arc<Device>,
31    pub(crate) layout: vk::PipelineLayout,
32
33    /// Information used to create this object.
34    pub info: ComputePipelineInfo,
35
36    /// A descriptive name used in debugging messages.
37    pub name: Option<String>,
38
39    pipeline: vk::Pipeline,
40    pub(crate) push_constants: Option<vk::PushConstantRange>,
41}
42
43impl ComputePipeline {
44    /// Creates a new compute pipeline on the given device.
45    ///
46    /// # Panics
47    ///
48    /// If shader code is not a multiple of four bytes.
49    ///
50    /// # Examples
51    ///
52    /// Basic usage:
53    ///
54    /// ```no_run
55    /// # use std::sync::Arc;
56    /// # use ash::vk;
57    /// # use screen_13::driver::DriverError;
58    /// # use screen_13::driver::device::{Device, DeviceInfo};
59    /// # use screen_13::driver::compute::{ComputePipeline, ComputePipelineInfo};
60    /// # use screen_13::driver::shader::{Shader};
61    /// # fn main() -> Result<(), DriverError> {
62    /// # let device = Arc::new(Device::create_headless(DeviceInfo::default())?);
63    /// # let my_shader_code = [0u8; 1];
64    /// // my_shader_code is raw SPIR-V code as bytes
65    /// let shader = Shader::new_compute(my_shader_code.as_slice());
66    /// let pipeline = ComputePipeline::create(&device, ComputePipelineInfo::default(), shader)?;
67    ///
68    /// assert_ne!(*pipeline, vk::Pipeline::null());
69    /// # Ok(()) }
70    /// ```
71    #[profiling::function]
72    pub fn create(
73        device: &Arc<Device>,
74        info: impl Into<ComputePipelineInfo>,
75        shader: impl Into<Shader>,
76    ) -> Result<Self, DriverError> {
77        use std::slice::from_ref;
78
79        trace!("create");
80
81        let device = Arc::clone(device);
82        let info: ComputePipelineInfo = info.into();
83        let shader = shader.into();
84
85        // Use SPIR-V reflection to get the types and counts of all descriptors
86        let mut descriptor_bindings = shader.descriptor_bindings();
87        for (descriptor_info, _) in descriptor_bindings.values_mut() {
88            if descriptor_info.binding_count() == 0 {
89                descriptor_info.set_binding_count(info.bindless_descriptor_count);
90            }
91        }
92
93        let descriptor_info = PipelineDescriptorInfo::create(&device, &descriptor_bindings)?;
94        let descriptor_set_layouts = descriptor_info
95            .layouts
96            .values()
97            .map(|descriptor_set_layout| **descriptor_set_layout)
98            .collect::<Box<[_]>>();
99
100        unsafe {
101            let shader_module = device
102                .create_shader_module(
103                    &vk::ShaderModuleCreateInfo::default().code(align_spriv(&shader.spirv)?),
104                    None,
105                )
106                .map_err(|err| {
107                    warn!("{err}");
108
109                    DriverError::Unsupported
110                })?;
111            let entry_name = CString::new(shader.entry_name.as_bytes()).unwrap();
112            let mut stage_create_info = vk::PipelineShaderStageCreateInfo::default()
113                .module(shader_module)
114                .stage(shader.stage)
115                .name(&entry_name);
116            let specialization_info = shader.specialization_info.as_ref().map(|info| {
117                vk::SpecializationInfo::default()
118                    .map_entries(&info.map_entries)
119                    .data(&info.data)
120            });
121
122            if let Some(specialization_info) = &specialization_info {
123                stage_create_info = stage_create_info.specialization_info(specialization_info);
124            }
125
126            let mut layout_info =
127                vk::PipelineLayoutCreateInfo::default().set_layouts(&descriptor_set_layouts);
128
129            let push_constants = shader.push_constant_range();
130            if let Some(push_constants) = &push_constants {
131                layout_info = layout_info.push_constant_ranges(from_ref(push_constants));
132            }
133
134            let layout = device
135                .create_pipeline_layout(&layout_info, None)
136                .map_err(|err| {
137                    warn!("{err}");
138
139                    DriverError::Unsupported
140                })?;
141            let pipeline_info = vk::ComputePipelineCreateInfo::default()
142                .stage(stage_create_info)
143                .layout(layout);
144            let pipeline = device
145                .create_compute_pipelines(
146                    Device::pipeline_cache(&device),
147                    from_ref(&pipeline_info),
148                    None,
149                )
150                .map_err(|(_, err)| {
151                    warn!("{err}");
152
153                    DriverError::Unsupported
154                })?[0];
155
156            device.destroy_shader_module(shader_module, None);
157
158            Ok(ComputePipeline {
159                descriptor_bindings,
160                descriptor_info,
161                device,
162                info,
163                layout,
164                name: None,
165                pipeline,
166                push_constants,
167            })
168        }
169    }
170
171    /// Sets the debugging name assigned to this pipeline.
172    pub fn with_name(mut this: Self, name: impl Into<String>) -> Self {
173        this.name = Some(name.into());
174        this
175    }
176}
177
178impl Deref for ComputePipeline {
179    type Target = vk::Pipeline;
180
181    fn deref(&self) -> &Self::Target {
182        &self.pipeline
183    }
184}
185
186impl Drop for ComputePipeline {
187    #[profiling::function]
188    fn drop(&mut self) {
189        if panicking() {
190            return;
191        }
192
193        unsafe {
194            self.device.destroy_pipeline(self.pipeline, None);
195            self.device.destroy_pipeline_layout(self.layout, None);
196        }
197    }
198}
199
200/// Information used to create a [`ComputePipeline`] instance.
201#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
202#[builder(
203    build_fn(
204        private,
205        name = "fallible_build",
206        error = "ComputePipelineInfoBuilderError"
207    ),
208    derive(Clone, Copy, Debug),
209    pattern = "owned"
210)]
211#[non_exhaustive]
212pub struct ComputePipelineInfo {
213    /// The number of descriptors to allocate for a given binding when using bindless (unbounded)
214    /// syntax.
215    ///
216    /// The default is `8192`.
217    ///
218    /// # Examples
219    ///
220    /// Basic usage (GLSL):
221    ///
222    /// ```
223    /// # inline_spirv::inline_spirv!(r#"
224    /// #version 460 core
225    /// #extension GL_EXT_nonuniform_qualifier : require
226    ///
227    /// layout(set = 0, binding = 0, rgba8) writeonly uniform image2D my_binding[];
228    ///
229    /// void main()
230    /// {
231    ///     // my_binding will have space for 8,192 images by default
232    /// }
233    /// # "#, comp);
234    /// ```
235    #[builder(default = "8192")]
236    pub bindless_descriptor_count: u32,
237}
238
239impl ComputePipelineInfo {
240    /// Converts a `ComputePipelineInfo` into a `ComputePipelineInfoBuilder`.
241    #[inline(always)]
242    pub fn to_builder(self) -> ComputePipelineInfoBuilder {
243        ComputePipelineInfoBuilder {
244            bindless_descriptor_count: Some(self.bindless_descriptor_count),
245        }
246    }
247}
248
249impl Default for ComputePipelineInfo {
250    fn default() -> Self {
251        Self {
252            bindless_descriptor_count: 8192,
253        }
254    }
255}
256
257impl From<ComputePipelineInfoBuilder> for ComputePipelineInfo {
258    fn from(info: ComputePipelineInfoBuilder) -> Self {
259        info.build()
260    }
261}
262
263impl ComputePipelineInfoBuilder {
264    /// Builds a new `ComputePipelineInfo`.
265    #[inline(always)]
266    pub fn build(self) -> ComputePipelineInfo {
267        let res = self.fallible_build();
268
269        #[cfg(test)]
270        let res = res.unwrap();
271
272        #[cfg(not(test))]
273        let res = unsafe { res.unwrap_unchecked() };
274
275        res
276    }
277}
278
279#[derive(Debug)]
280struct ComputePipelineInfoBuilderError;
281
282impl From<UninitializedFieldError> for ComputePipelineInfoBuilderError {
283    fn from(_: UninitializedFieldError) -> Self {
284        Self
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    type Info = ComputePipelineInfo;
293    type Builder = ComputePipelineInfoBuilder;
294
295    #[test]
296    pub fn compute_pipeline_info() {
297        let info = Info::default();
298        let builder = info.to_builder().build();
299
300        assert_eq!(info, builder);
301    }
302
303    #[test]
304    pub fn compute_pipeline_info_builder() {
305        let info = Info::default();
306        let builder = Builder::default().build();
307
308        assert_eq!(info, builder);
309    }
310}