screen_13/driver/
compute.rs1use {
4 super::{
5 DriverError,
6 device::Device,
7 shader::{DescriptorBindingMap, PipelineDescriptorInfo, Shader, align_spriv},
8 },
9 ash::vk,
10 derive_builder::{Builder, UninitializedFieldError},
11 log::{trace, warn},
12 std::{ffi::CString, ops::Deref, sync::Arc, thread::panicking},
13};
14
15#[derive(Debug)]
27pub struct ComputePipeline {
28 pub(crate) descriptor_bindings: DescriptorBindingMap,
29 pub(crate) descriptor_info: PipelineDescriptorInfo,
30 device: Arc<Device>,
31 pub(crate) layout: vk::PipelineLayout,
32
33 pub info: ComputePipelineInfo,
35
36 pub name: Option<String>,
38
39 pipeline: vk::Pipeline,
40 pub(crate) push_constants: Option<vk::PushConstantRange>,
41}
42
43impl ComputePipeline {
44 #[profiling::function]
72 pub fn create(
73 device: &Arc<Device>,
74 info: impl Into<ComputePipelineInfo>,
75 shader: impl Into<Shader>,
76 ) -> Result<Self, DriverError> {
77 use std::slice::from_ref;
78
79 trace!("create");
80
81 let device = Arc::clone(device);
82 let info: ComputePipelineInfo = info.into();
83 let shader = shader.into();
84
85 let mut descriptor_bindings = shader.descriptor_bindings();
87 for (descriptor_info, _) in descriptor_bindings.values_mut() {
88 if descriptor_info.binding_count() == 0 {
89 descriptor_info.set_binding_count(info.bindless_descriptor_count);
90 }
91 }
92
93 let descriptor_info = PipelineDescriptorInfo::create(&device, &descriptor_bindings)?;
94 let descriptor_set_layouts = descriptor_info
95 .layouts
96 .values()
97 .map(|descriptor_set_layout| **descriptor_set_layout)
98 .collect::<Box<[_]>>();
99
100 unsafe {
101 let shader_module = device
102 .create_shader_module(
103 &vk::ShaderModuleCreateInfo::default().code(align_spriv(&shader.spirv)?),
104 None,
105 )
106 .map_err(|err| {
107 warn!("{err}");
108
109 DriverError::Unsupported
110 })?;
111 let entry_name = CString::new(shader.entry_name.as_bytes()).unwrap();
112 let mut stage_create_info = vk::PipelineShaderStageCreateInfo::default()
113 .module(shader_module)
114 .stage(shader.stage)
115 .name(&entry_name);
116 let specialization_info = shader.specialization_info.as_ref().map(|info| {
117 vk::SpecializationInfo::default()
118 .map_entries(&info.map_entries)
119 .data(&info.data)
120 });
121
122 if let Some(specialization_info) = &specialization_info {
123 stage_create_info = stage_create_info.specialization_info(specialization_info);
124 }
125
126 let mut layout_info =
127 vk::PipelineLayoutCreateInfo::default().set_layouts(&descriptor_set_layouts);
128
129 let push_constants = shader.push_constant_range();
130 if let Some(push_constants) = &push_constants {
131 layout_info = layout_info.push_constant_ranges(from_ref(push_constants));
132 }
133
134 let layout = device
135 .create_pipeline_layout(&layout_info, None)
136 .map_err(|err| {
137 warn!("{err}");
138
139 DriverError::Unsupported
140 })?;
141 let pipeline_info = vk::ComputePipelineCreateInfo::default()
142 .stage(stage_create_info)
143 .layout(layout);
144 let pipeline = device
145 .create_compute_pipelines(
146 Device::pipeline_cache(&device),
147 from_ref(&pipeline_info),
148 None,
149 )
150 .map_err(|(_, err)| {
151 warn!("{err}");
152
153 DriverError::Unsupported
154 })?[0];
155
156 device.destroy_shader_module(shader_module, None);
157
158 Ok(ComputePipeline {
159 descriptor_bindings,
160 descriptor_info,
161 device,
162 info,
163 layout,
164 name: None,
165 pipeline,
166 push_constants,
167 })
168 }
169 }
170
171 pub fn with_name(mut this: Self, name: impl Into<String>) -> Self {
173 this.name = Some(name.into());
174 this
175 }
176}
177
178impl Deref for ComputePipeline {
179 type Target = vk::Pipeline;
180
181 fn deref(&self) -> &Self::Target {
182 &self.pipeline
183 }
184}
185
186impl Drop for ComputePipeline {
187 #[profiling::function]
188 fn drop(&mut self) {
189 if panicking() {
190 return;
191 }
192
193 unsafe {
194 self.device.destroy_pipeline(self.pipeline, None);
195 self.device.destroy_pipeline_layout(self.layout, None);
196 }
197 }
198}
199
200#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
202#[builder(
203 build_fn(
204 private,
205 name = "fallible_build",
206 error = "ComputePipelineInfoBuilderError"
207 ),
208 derive(Clone, Copy, Debug),
209 pattern = "owned"
210)]
211#[non_exhaustive]
212pub struct ComputePipelineInfo {
213 #[builder(default = "8192")]
236 pub bindless_descriptor_count: u32,
237}
238
239impl ComputePipelineInfo {
240 #[inline(always)]
242 pub fn to_builder(self) -> ComputePipelineInfoBuilder {
243 ComputePipelineInfoBuilder {
244 bindless_descriptor_count: Some(self.bindless_descriptor_count),
245 }
246 }
247}
248
249impl Default for ComputePipelineInfo {
250 fn default() -> Self {
251 Self {
252 bindless_descriptor_count: 8192,
253 }
254 }
255}
256
257impl From<ComputePipelineInfoBuilder> for ComputePipelineInfo {
258 fn from(info: ComputePipelineInfoBuilder) -> Self {
259 info.build()
260 }
261}
262
263impl ComputePipelineInfoBuilder {
264 #[inline(always)]
266 pub fn build(self) -> ComputePipelineInfo {
267 let res = self.fallible_build();
268
269 #[cfg(test)]
270 let res = res.unwrap();
271
272 #[cfg(not(test))]
273 let res = unsafe { res.unwrap_unchecked() };
274
275 res
276 }
277}
278
279#[derive(Debug)]
280struct ComputePipelineInfoBuilderError;
281
282impl From<UninitializedFieldError> for ComputePipelineInfoBuilderError {
283 fn from(_: UninitializedFieldError) -> Self {
284 Self
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291
292 type Info = ComputePipelineInfo;
293 type Builder = ComputePipelineInfoBuilder;
294
295 #[test]
296 pub fn compute_pipeline_info() {
297 let info = Info::default();
298 let builder = info.to_builder().build();
299
300 assert_eq!(info, builder);
301 }
302
303 #[test]
304 pub fn compute_pipeline_info_builder() {
305 let info = Info::default();
306 let builder = Builder::default().build();
307
308 assert_eq!(info, builder);
309 }
310}