screen_13/driver/
ray_trace.rs

1//! Ray tracing pipeline types
2
3use {
4    super::{
5        DriverError,
6        device::Device,
7        merge_push_constant_ranges,
8        physical_device::RayTraceProperties,
9        shader::{DescriptorBindingMap, PipelineDescriptorInfo, Shader, align_spriv},
10    },
11    ash::vk,
12    derive_builder::{Builder, UninitializedFieldError},
13    log::warn,
14    std::{ffi::CString, ops::Deref, sync::Arc, thread::panicking},
15};
16
17/// Smart pointer handle to a [pipeline] object.
18///
19/// Also contains information about the object.
20///
21/// ## `Deref` behavior
22///
23/// `RayTracePipeline` automatically dereferences to [`vk::Pipeline`] (via the [`Deref`]
24/// trait), so you can call `vk::Pipeline`'s methods on a value of type `RayTracePipeline`. To avoid
25/// name clashes with `vk::Pipeline`'s methods, the methods of `RayTracePipeline` itself are
26/// associated functions, called using [fully qualified syntax]:
27///
28/// [pipeline]: https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkPipeline.html
29/// [deref]: core::ops::Deref
30/// [fully qualified syntax]: https://doc.rust-lang.org/book/ch19-03-advanced-traits.html#fully-qualified-syntax-for-disambiguation-calling-methods-with-the-same-name
31#[derive(Debug)]
32pub struct RayTracePipeline {
33    pub(crate) descriptor_bindings: DescriptorBindingMap,
34    pub(crate) descriptor_info: PipelineDescriptorInfo,
35    device: Arc<Device>,
36
37    /// Information used to create this object.
38    pub info: RayTracePipelineInfo,
39
40    pub(crate) layout: vk::PipelineLayout,
41
42    /// A descriptive name used in debugging messages.
43    pub name: Option<String>,
44
45    pub(crate) push_constants: Vec<vk::PushConstantRange>,
46    pipeline: vk::Pipeline,
47    shader_modules: Vec<vk::ShaderModule>,
48    shader_group_handles: Vec<u8>,
49}
50
51impl RayTracePipeline {
52    /// Creates a new ray trace pipeline on the given device.
53    ///
54    /// The correct pipeline stages will be enabled based on the provided shaders. See [Shader] for
55    /// details on all available stages.
56    ///
57    /// The number and composition of the `shader_groups` parameter must match the actual shaders
58    /// provided.
59    ///
60    /// # Panics
61    ///
62    /// If shader code is not a multiple of four bytes.
63    ///
64    /// # Examples
65    ///
66    /// Basic usage:
67    ///
68    /// ```no_run
69    /// # use std::sync::Arc;
70    /// # use ash::vk;
71    /// # use screen_13::driver::DriverError;
72    /// # use screen_13::driver::device::{Device, DeviceInfo};
73    /// # use screen_13::driver::ray_trace::{RayTracePipeline, RayTracePipelineInfo, RayTraceShaderGroup};
74    /// # use screen_13::driver::shader::Shader;
75    /// # fn main() -> Result<(), DriverError> {
76    /// # let device = Arc::new(Device::create_headless(DeviceInfo::default())?);
77    /// # let my_rgen_code = [0u8; 1];
78    /// # let my_chit_code = [0u8; 1];
79    /// # let my_miss_code = [0u8; 1];
80    /// # let my_shadow_code = [0u8; 1];
81    /// // shader code is raw SPIR-V code as bytes
82    /// let info = RayTracePipelineInfo::default().to_builder().max_ray_recursion_depth(1);
83    /// let pipeline = RayTracePipeline::create(
84    ///     &device,
85    ///     info,
86    ///     [
87    ///         Shader::new_ray_gen(my_rgen_code.as_slice()),
88    ///         Shader::new_closest_hit(my_chit_code.as_slice()),
89    ///         Shader::new_miss(my_miss_code.as_slice()),
90    ///         Shader::new_miss(my_shadow_code.as_slice()),
91    ///     ],
92    ///     [
93    ///         RayTraceShaderGroup::new_general(0),
94    ///         RayTraceShaderGroup::new_triangles(1, None),
95    ///         RayTraceShaderGroup::new_general(2),
96    ///         RayTraceShaderGroup::new_general(3),
97    ///     ],
98    /// )?;
99    ///
100    /// assert_ne!(*pipeline, vk::Pipeline::null());
101    /// assert_eq!(pipeline.info.max_ray_recursion_depth, 1);
102    /// # Ok(()) }
103    /// ```
104    #[profiling::function]
105    pub fn create<S>(
106        device: &Arc<Device>,
107        info: impl Into<RayTracePipelineInfo>,
108        shaders: impl IntoIterator<Item = S>,
109        shader_groups: impl IntoIterator<Item = RayTraceShaderGroup>,
110    ) -> Result<Self, DriverError>
111    where
112        S: Into<Shader>,
113    {
114        let info = info.into();
115        let shader_groups = shader_groups
116            .into_iter()
117            .map(|shader_group| shader_group.into())
118            .collect::<Vec<_>>();
119        let group_count = shader_groups.len();
120
121        let shaders = shaders
122            .into_iter()
123            .map(|shader| shader.into())
124            .collect::<Vec<Shader>>();
125        let push_constants = shaders
126            .iter()
127            .map(|shader| shader.push_constant_range())
128            .filter_map(|mut push_const| push_const.take())
129            .collect::<Vec<_>>();
130
131        // Use SPIR-V reflection to get the types and counts of all descriptors
132        let mut descriptor_bindings = Shader::merge_descriptor_bindings(
133            shaders.iter().map(|shader| shader.descriptor_bindings()),
134        );
135        for (descriptor_info, _) in descriptor_bindings.values_mut() {
136            if descriptor_info.binding_count() == 0 {
137                descriptor_info.set_binding_count(info.bindless_descriptor_count);
138            }
139        }
140
141        let descriptor_info = PipelineDescriptorInfo::create(device, &descriptor_bindings)?;
142        let descriptor_set_layout_handles = descriptor_info
143            .layouts
144            .values()
145            .map(|descriptor_set_layout| **descriptor_set_layout)
146            .collect::<Box<[_]>>();
147
148        unsafe {
149            let layout = device
150                .create_pipeline_layout(
151                    &vk::PipelineLayoutCreateInfo::default()
152                        .set_layouts(&descriptor_set_layout_handles)
153                        .push_constant_ranges(&push_constants),
154                    None,
155                )
156                .map_err(|err| {
157                    warn!("{err}");
158
159                    DriverError::Unsupported
160                })?;
161            let entry_points: Box<[CString]> = shaders
162                .iter()
163                .map(|shader| CString::new(shader.entry_name.as_str()))
164                .collect::<Result<_, _>>()
165                .map_err(|err| {
166                    warn!("{err}");
167
168                    DriverError::InvalidData
169                })?;
170            let specialization_infos: Box<[Option<vk::SpecializationInfo>]> = shaders
171                .iter()
172                .map(|shader| {
173                    shader.specialization_info.as_ref().map(|info| {
174                        vk::SpecializationInfo::default()
175                            .data(&info.data)
176                            .map_entries(&info.map_entries)
177                    })
178                })
179                .collect();
180            let mut shader_stages: Vec<vk::PipelineShaderStageCreateInfo> =
181                Vec::with_capacity(shaders.len());
182            let mut shader_modules = Vec::with_capacity(shaders.len());
183            for (idx, shader) in shaders.iter().enumerate() {
184                let module = device
185                    .create_shader_module(
186                        &vk::ShaderModuleCreateInfo::default().code(align_spriv(&shader.spirv)?),
187                        None,
188                    )
189                    .map_err(|err| {
190                        warn!("{err}");
191
192                        device.destroy_pipeline_layout(layout, None);
193
194                        for module in shader_modules.drain(..) {
195                            device.destroy_shader_module(module, None);
196                        }
197
198                        DriverError::Unsupported
199                    })?;
200
201                shader_modules.push(module);
202
203                let mut stage = vk::PipelineShaderStageCreateInfo::default()
204                    .module(module)
205                    .name(entry_points[idx].as_ref())
206                    .stage(shader.stage);
207
208                if let Some(specialization_info) = &specialization_infos[idx] {
209                    stage = stage.specialization_info(specialization_info);
210                }
211
212                shader_stages.push(stage);
213            }
214
215            let mut dynamic_states = Vec::with_capacity(1);
216
217            if info.dynamic_stack_size {
218                dynamic_states.push(vk::DynamicState::RAY_TRACING_PIPELINE_STACK_SIZE_KHR);
219            }
220
221            let ray_trace_ext = device
222                .ray_trace_ext
223                .as_ref()
224                .ok_or(DriverError::Unsupported)?;
225            let pipeline = ray_trace_ext
226                .create_ray_tracing_pipelines(
227                    vk::DeferredOperationKHR::null(),
228                    Device::pipeline_cache(device),
229                    &[vk::RayTracingPipelineCreateInfoKHR::default()
230                        .stages(&shader_stages)
231                        .groups(&shader_groups)
232                        .max_pipeline_ray_recursion_depth(
233                            info.max_ray_recursion_depth.min(
234                                device
235                                    .physical_device
236                                    .ray_trace_properties
237                                    .as_ref()
238                                    .unwrap()
239                                    .max_ray_recursion_depth,
240                            ),
241                        )
242                        .layout(layout)
243                        .dynamic_state(
244                            &vk::PipelineDynamicStateCreateInfo::default()
245                                .dynamic_states(&dynamic_states),
246                        )],
247                    None,
248                )
249                .map_err(|(pipelines, err)| {
250                    warn!("{err}");
251
252                    for pipeline in pipelines {
253                        device.destroy_pipeline(pipeline, None);
254                    }
255
256                    device.destroy_pipeline_layout(layout, None);
257
258                    for shader_module in shader_modules.iter().copied() {
259                        device.destroy_shader_module(shader_module, None);
260                    }
261
262                    DriverError::Unsupported
263                })?[0];
264            let device = Arc::clone(device);
265            let &RayTraceProperties {
266                shader_group_handle_size,
267                ..
268            } = device
269                .physical_device
270                .ray_trace_properties
271                .as_ref()
272                .unwrap();
273
274            let push_constants = merge_push_constant_ranges(&push_constants);
275
276            // SAFETY:
277            // According to [vulkan spec](https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/vkGetRayTracingShaderGroupHandlesKHR.html)
278            // Valid usage of this function requires:
279            // 1. pipeline must be raytracing pipeline.
280            // 2. first_group must be less than the number of shader groups in the pipeline.
281            // 3. the sum of first group and group_count must be less or equal to the number of shader
282            //    modules in the pipeline.
283            // 4. data_size must be at least shader_group_handle_size * group_count.
284            // 5. pipeline must not have been created with VK_PIPELINE_CREATE_LIBRARY_BIT_KHR.
285            //
286            let shader_group_handles = {
287                ray_trace_ext.get_ray_tracing_shader_group_handles(
288                    pipeline,
289                    0,
290                    group_count as u32,
291                    group_count * shader_group_handle_size as usize,
292                )
293            }
294            .map_err(|_| DriverError::InvalidData)?;
295
296            Ok(Self {
297                descriptor_bindings,
298                descriptor_info,
299                device,
300                info,
301                layout,
302                name: None,
303                push_constants,
304                pipeline,
305                shader_modules,
306                shader_group_handles,
307            })
308        }
309    }
310
311    /// Function returning a handle to a shader group of this pipeline.
312    /// This can be used to construct a sbt.
313    ///
314    /// # Examples
315    ///
316    /// See
317    /// [ray_trace.rs](https://github.com/attackgoat/screen-13/blob/master/examples/ray_trace.rs)
318    /// for a detail example which constructs a shader binding table buffer using this function.
319    pub fn group_handle(this: &Self, idx: usize) -> Result<&[u8], DriverError> {
320        let &RayTraceProperties {
321            shader_group_handle_size,
322            ..
323        } = this
324            .device
325            .physical_device
326            .ray_trace_properties
327            .as_ref()
328            .ok_or(DriverError::Unsupported)?;
329        let start = idx * shader_group_handle_size as usize;
330        let end = start + shader_group_handle_size as usize;
331
332        Ok(&this.shader_group_handles[start..end])
333    }
334
335    /// Query ray trace pipeline shader group shader stack size.
336    ///
337    /// The return value is the ray tracing pipeline stack size in bytes for the specified shader as
338    /// called from the specified shader group.
339    #[profiling::function]
340    pub fn group_stack_size(
341        this: &Self,
342        group: u32,
343        group_shader: vk::ShaderGroupShaderKHR,
344    ) -> vk::DeviceSize {
345        unsafe {
346            // Safely use unchecked because ray_trace_ext is checked during pipeline creation
347            this.device
348                .ray_trace_ext
349                .as_ref()
350                .unwrap_unchecked()
351                .get_ray_tracing_shader_group_stack_size(this.pipeline, group, group_shader)
352        }
353    }
354
355    /// Sets the debugging name assigned to this pipeline.
356    pub fn with_name(mut this: Self, name: impl Into<String>) -> Self {
357        this.name = Some(name.into());
358        this
359    }
360}
361
362impl Deref for RayTracePipeline {
363    type Target = vk::Pipeline;
364
365    fn deref(&self) -> &Self::Target {
366        &self.pipeline
367    }
368}
369
370impl Drop for RayTracePipeline {
371    #[profiling::function]
372    fn drop(&mut self) {
373        if panicking() {
374            return;
375        }
376
377        unsafe {
378            self.device.destroy_pipeline(self.pipeline, None);
379            self.device.destroy_pipeline_layout(self.layout, None);
380        }
381
382        for shader_module in self.shader_modules.drain(..) {
383            unsafe {
384                self.device.destroy_shader_module(shader_module, None);
385            }
386        }
387    }
388}
389
390/// Information used to create a [`RayTracePipeline`] instance.
391#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
392#[builder(
393    build_fn(
394        private,
395        name = "fallible_build",
396        error = "RayTracePipelineInfoBuilderError"
397    ),
398    derive(Clone, Copy, Debug),
399    pattern = "owned"
400)]
401#[non_exhaustive]
402pub struct RayTracePipelineInfo {
403    /// The number of descriptors to allocate for a given binding when using bindless (unbounded)
404    /// syntax.
405    ///
406    /// The default is `8192`.
407    ///
408    /// # Examples
409    ///
410    /// Basic usage (GLSL):
411    ///
412    /// ```
413    /// # inline_spirv::inline_spirv!(r#"
414    /// #version 460 core
415    /// #extension GL_EXT_nonuniform_qualifier : require
416    ///
417    /// layout(set = 0, binding = 0, rgba8) readonly uniform image2D my_binding[];
418    ///
419    /// void main()
420    /// {
421    ///     // my_binding will have space for 8,192 images by default
422    /// }
423    /// # "#, rchit, vulkan1_2);
424    /// ```
425    #[builder(default = "8192")]
426    pub bindless_descriptor_count: u32,
427
428    /// Allow [setting the stack size dynamically] for a ray trace pipeline.
429    ///
430    /// When set, you must manually set the stack size during ray trace passes using
431    /// [`RayTrace::set_stack_size`](crate::graph::pass_ref::RayTrace::set_stack_size).
432    ///
433    /// [setting the stack size dynamically]: https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/vkCmdSetRayTracingPipelineStackSizeKHR.html
434    #[builder(default)]
435    pub dynamic_stack_size: bool,
436
437    /// The [maximum recursion depth] of shaders executed by this pipeline.
438    ///
439    /// The default is `16`.
440    ///
441    /// [maximum recursion depth]: https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#ray-tracing-recursion-depth
442    #[builder(default = "16")]
443    pub max_ray_recursion_depth: u32,
444}
445
446impl RayTracePipelineInfo {
447    /// Creates a default `RayTracePipelineInfoBuilder`.
448    pub fn builder() -> RayTracePipelineInfoBuilder {
449        Default::default()
450    }
451
452    /// Converts a `RayTracePipelineInfo` into a `RayTracePipelineInfoBuilder`.
453    #[inline(always)]
454    pub fn to_builder(self) -> RayTracePipelineInfoBuilder {
455        RayTracePipelineInfoBuilder {
456            bindless_descriptor_count: Some(self.bindless_descriptor_count),
457            dynamic_stack_size: Some(self.dynamic_stack_size),
458            max_ray_recursion_depth: Some(self.max_ray_recursion_depth),
459        }
460    }
461}
462
463impl Default for RayTracePipelineInfo {
464    fn default() -> Self {
465        Self {
466            bindless_descriptor_count: 8192,
467            dynamic_stack_size: false,
468            max_ray_recursion_depth: 16,
469        }
470    }
471}
472
473impl From<RayTracePipelineInfoBuilder> for RayTracePipelineInfo {
474    fn from(info: RayTracePipelineInfoBuilder) -> Self {
475        info.build()
476    }
477}
478
479impl RayTracePipelineInfoBuilder {
480    /// Builds a new `RayTracePipelineInfo`.
481    #[inline(always)]
482    pub fn build(self) -> RayTracePipelineInfo {
483        let res = self.fallible_build();
484
485        #[cfg(test)]
486        let res = res.unwrap();
487
488        #[cfg(not(test))]
489        let res = unsafe { res.unwrap_unchecked() };
490
491        res
492    }
493}
494
495#[derive(Debug)]
496struct RayTracePipelineInfoBuilderError;
497
498impl From<UninitializedFieldError> for RayTracePipelineInfoBuilderError {
499    fn from(_: UninitializedFieldError) -> Self {
500        Self
501    }
502}
503
504/// Describes the set of the shader stages to be included in each shader group in the ray trace
505/// pipeline.
506///
507/// See
508/// [VkRayTracingShaderGroupCreateInfoKHR](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#VkRayTracingShaderGroupCreateInfoKHR).
509#[derive(Clone, Copy, Debug)]
510pub struct RayTraceShaderGroup {
511    /// The optional index of the any-hit shader in the group if the shader group has type of
512    /// [RayTraceShaderGroupType::TrianglesHitGroup] or
513    /// [RayTraceShaderGroupType::ProceduralHitGroup].
514    pub any_hit_shader: Option<u32>,
515
516    /// The optional index of the closest hit shader in the group if the shader group has type of
517    /// [RayTraceShaderGroupType::TrianglesHitGroup] or
518    /// [RayTraceShaderGroupType::ProceduralHitGroup].
519    pub closest_hit_shader: Option<u32>,
520
521    /// The index of the ray generation, miss, or callable shader in the group if the shader group
522    /// has type of [RayTraceShaderGroupType::General].
523    pub general_shader: Option<u32>,
524
525    /// The index of the intersection shader in the group if the shader group has type of
526    /// [RayTraceShaderGroupType::ProceduralHitGroup].
527    pub intersection_shader: Option<u32>,
528
529    /// The type of hit group specified in this structure.
530    pub ty: RayTraceShaderGroupType,
531}
532
533impl RayTraceShaderGroup {
534    fn new(
535        ty: RayTraceShaderGroupType,
536        general_shader: impl Into<Option<u32>>,
537        intersection_shader: impl Into<Option<u32>>,
538        closest_hit_shader: impl Into<Option<u32>>,
539        any_hit_shader: impl Into<Option<u32>>,
540    ) -> Self {
541        let any_hit_shader = any_hit_shader.into();
542        let closest_hit_shader = closest_hit_shader.into();
543        let general_shader = general_shader.into();
544        let intersection_shader = intersection_shader.into();
545
546        Self {
547            any_hit_shader,
548            closest_hit_shader,
549            general_shader,
550            intersection_shader,
551            ty,
552        }
553    }
554
555    /// Creates a new general-type shader group with the given general shader.
556    pub fn new_general(general_shader: impl Into<Option<u32>>) -> Self {
557        Self::new(
558            RayTraceShaderGroupType::General,
559            general_shader,
560            None,
561            None,
562            None,
563        )
564    }
565
566    /// Creates a new procedural-type shader group with the given intersection shader, and optional
567    /// closest-hit and any-hit shaders.
568    pub fn new_procedural(
569        intersection_shader: u32,
570        closest_hit_shader: impl Into<Option<u32>>,
571        any_hit_shader: impl Into<Option<u32>>,
572    ) -> Self {
573        Self::new(
574            RayTraceShaderGroupType::ProceduralHitGroup,
575            None,
576            intersection_shader,
577            closest_hit_shader,
578            any_hit_shader,
579        )
580    }
581
582    /// Creates a new triangles-type shader group with the given closest-hit shader and optional any-hit
583    /// shader.
584    pub fn new_triangles(closest_hit_shader: u32, any_hit_shader: impl Into<Option<u32>>) -> Self {
585        Self::new(
586            RayTraceShaderGroupType::TrianglesHitGroup,
587            None,
588            None,
589            closest_hit_shader,
590            any_hit_shader,
591        )
592    }
593}
594
595impl From<RayTraceShaderGroup> for vk::RayTracingShaderGroupCreateInfoKHR<'static> {
596    fn from(shader_group: RayTraceShaderGroup) -> Self {
597        vk::RayTracingShaderGroupCreateInfoKHR::default()
598            .ty(shader_group.ty.into())
599            .any_hit_shader(shader_group.any_hit_shader.unwrap_or(vk::SHADER_UNUSED_KHR))
600            .closest_hit_shader(
601                shader_group
602                    .closest_hit_shader
603                    .unwrap_or(vk::SHADER_UNUSED_KHR),
604            )
605            .general_shader(shader_group.general_shader.unwrap_or(vk::SHADER_UNUSED_KHR))
606            .intersection_shader(
607                shader_group
608                    .intersection_shader
609                    .unwrap_or(vk::SHADER_UNUSED_KHR),
610            )
611    }
612}
613
614/// Describes a type of ray tracing shader group, which is a collection of shaders which run in the
615/// specified mode.
616#[derive(Clone, Copy, Debug)]
617pub enum RayTraceShaderGroupType {
618    /// A shader group with a general shader.
619    General,
620
621    /// A shader group with an intersection shader, and optional closest-hit and any-hit shaders.
622    ProceduralHitGroup,
623
624    /// A shader group with a closest-hit shader and optional any-hit shader.
625    TrianglesHitGroup,
626}
627
628impl From<RayTraceShaderGroupType> for vk::RayTracingShaderGroupTypeKHR {
629    fn from(ty: RayTraceShaderGroupType) -> Self {
630        match ty {
631            RayTraceShaderGroupType::General => vk::RayTracingShaderGroupTypeKHR::GENERAL,
632            RayTraceShaderGroupType::ProceduralHitGroup => {
633                vk::RayTracingShaderGroupTypeKHR::PROCEDURAL_HIT_GROUP
634            }
635            RayTraceShaderGroupType::TrianglesHitGroup => {
636                vk::RayTracingShaderGroupTypeKHR::TRIANGLES_HIT_GROUP
637            }
638        }
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645
646    type Info = RayTracePipelineInfo;
647    type Builder = RayTracePipelineInfoBuilder;
648
649    #[test]
650    pub fn ray_trace_pipeline_info() {
651        let info = Info::default();
652        let builder = info.to_builder().build();
653
654        assert_eq!(info, builder);
655    }
656
657    #[test]
658    pub fn ray_trace_pipeline_info_builder() {
659        let info = Info::default();
660        let builder = Builder::default().build();
661
662        assert_eq!(info, builder);
663    }
664}