1use {
4 super::{
5 DriverError,
6 device::Device,
7 merge_push_constant_ranges,
8 physical_device::RayTraceProperties,
9 shader::{DescriptorBindingMap, PipelineDescriptorInfo, Shader, align_spriv},
10 },
11 ash::vk,
12 derive_builder::{Builder, UninitializedFieldError},
13 log::warn,
14 std::{ffi::CString, ops::Deref, sync::Arc, thread::panicking},
15};
16
17#[derive(Debug)]
32pub struct RayTracePipeline {
33 pub(crate) descriptor_bindings: DescriptorBindingMap,
34 pub(crate) descriptor_info: PipelineDescriptorInfo,
35 device: Arc<Device>,
36
37 pub info: RayTracePipelineInfo,
39
40 pub(crate) layout: vk::PipelineLayout,
41
42 pub name: Option<String>,
44
45 pub(crate) push_constants: Vec<vk::PushConstantRange>,
46 pipeline: vk::Pipeline,
47 shader_modules: Vec<vk::ShaderModule>,
48 shader_group_handles: Vec<u8>,
49}
50
51impl RayTracePipeline {
52 #[profiling::function]
105 pub fn create<S>(
106 device: &Arc<Device>,
107 info: impl Into<RayTracePipelineInfo>,
108 shaders: impl IntoIterator<Item = S>,
109 shader_groups: impl IntoIterator<Item = RayTraceShaderGroup>,
110 ) -> Result<Self, DriverError>
111 where
112 S: Into<Shader>,
113 {
114 let info = info.into();
115 let shader_groups = shader_groups
116 .into_iter()
117 .map(|shader_group| shader_group.into())
118 .collect::<Vec<_>>();
119 let group_count = shader_groups.len();
120
121 let shaders = shaders
122 .into_iter()
123 .map(|shader| shader.into())
124 .collect::<Vec<Shader>>();
125 let push_constants = shaders
126 .iter()
127 .map(|shader| shader.push_constant_range())
128 .filter_map(|mut push_const| push_const.take())
129 .collect::<Vec<_>>();
130
131 let mut descriptor_bindings = Shader::merge_descriptor_bindings(
133 shaders.iter().map(|shader| shader.descriptor_bindings()),
134 );
135 for (descriptor_info, _) in descriptor_bindings.values_mut() {
136 if descriptor_info.binding_count() == 0 {
137 descriptor_info.set_binding_count(info.bindless_descriptor_count);
138 }
139 }
140
141 let descriptor_info = PipelineDescriptorInfo::create(device, &descriptor_bindings)?;
142 let descriptor_set_layout_handles = descriptor_info
143 .layouts
144 .values()
145 .map(|descriptor_set_layout| **descriptor_set_layout)
146 .collect::<Box<[_]>>();
147
148 unsafe {
149 let layout = device
150 .create_pipeline_layout(
151 &vk::PipelineLayoutCreateInfo::default()
152 .set_layouts(&descriptor_set_layout_handles)
153 .push_constant_ranges(&push_constants),
154 None,
155 )
156 .map_err(|err| {
157 warn!("{err}");
158
159 DriverError::Unsupported
160 })?;
161 let entry_points: Box<[CString]> = shaders
162 .iter()
163 .map(|shader| CString::new(shader.entry_name.as_str()))
164 .collect::<Result<_, _>>()
165 .map_err(|err| {
166 warn!("{err}");
167
168 DriverError::InvalidData
169 })?;
170 let specialization_infos: Box<[Option<vk::SpecializationInfo>]> = shaders
171 .iter()
172 .map(|shader| {
173 shader.specialization_info.as_ref().map(|info| {
174 vk::SpecializationInfo::default()
175 .data(&info.data)
176 .map_entries(&info.map_entries)
177 })
178 })
179 .collect();
180 let mut shader_stages: Vec<vk::PipelineShaderStageCreateInfo> =
181 Vec::with_capacity(shaders.len());
182 let mut shader_modules = Vec::with_capacity(shaders.len());
183 for (idx, shader) in shaders.iter().enumerate() {
184 let module = device
185 .create_shader_module(
186 &vk::ShaderModuleCreateInfo::default().code(align_spriv(&shader.spirv)?),
187 None,
188 )
189 .map_err(|err| {
190 warn!("{err}");
191
192 device.destroy_pipeline_layout(layout, None);
193
194 for module in shader_modules.drain(..) {
195 device.destroy_shader_module(module, None);
196 }
197
198 DriverError::Unsupported
199 })?;
200
201 shader_modules.push(module);
202
203 let mut stage = vk::PipelineShaderStageCreateInfo::default()
204 .module(module)
205 .name(entry_points[idx].as_ref())
206 .stage(shader.stage);
207
208 if let Some(specialization_info) = &specialization_infos[idx] {
209 stage = stage.specialization_info(specialization_info);
210 }
211
212 shader_stages.push(stage);
213 }
214
215 let mut dynamic_states = Vec::with_capacity(1);
216
217 if info.dynamic_stack_size {
218 dynamic_states.push(vk::DynamicState::RAY_TRACING_PIPELINE_STACK_SIZE_KHR);
219 }
220
221 let ray_trace_ext = device
222 .ray_trace_ext
223 .as_ref()
224 .ok_or(DriverError::Unsupported)?;
225 let pipeline = ray_trace_ext
226 .create_ray_tracing_pipelines(
227 vk::DeferredOperationKHR::null(),
228 Device::pipeline_cache(device),
229 &[vk::RayTracingPipelineCreateInfoKHR::default()
230 .stages(&shader_stages)
231 .groups(&shader_groups)
232 .max_pipeline_ray_recursion_depth(
233 info.max_ray_recursion_depth.min(
234 device
235 .physical_device
236 .ray_trace_properties
237 .as_ref()
238 .unwrap()
239 .max_ray_recursion_depth,
240 ),
241 )
242 .layout(layout)
243 .dynamic_state(
244 &vk::PipelineDynamicStateCreateInfo::default()
245 .dynamic_states(&dynamic_states),
246 )],
247 None,
248 )
249 .map_err(|(pipelines, err)| {
250 warn!("{err}");
251
252 for pipeline in pipelines {
253 device.destroy_pipeline(pipeline, None);
254 }
255
256 device.destroy_pipeline_layout(layout, None);
257
258 for shader_module in shader_modules.iter().copied() {
259 device.destroy_shader_module(shader_module, None);
260 }
261
262 DriverError::Unsupported
263 })?[0];
264 let device = Arc::clone(device);
265 let &RayTraceProperties {
266 shader_group_handle_size,
267 ..
268 } = device
269 .physical_device
270 .ray_trace_properties
271 .as_ref()
272 .unwrap();
273
274 let push_constants = merge_push_constant_ranges(&push_constants);
275
276 let shader_group_handles = {
287 ray_trace_ext.get_ray_tracing_shader_group_handles(
288 pipeline,
289 0,
290 group_count as u32,
291 group_count * shader_group_handle_size as usize,
292 )
293 }
294 .map_err(|_| DriverError::InvalidData)?;
295
296 Ok(Self {
297 descriptor_bindings,
298 descriptor_info,
299 device,
300 info,
301 layout,
302 name: None,
303 push_constants,
304 pipeline,
305 shader_modules,
306 shader_group_handles,
307 })
308 }
309 }
310
311 pub fn group_handle(this: &Self, idx: usize) -> Result<&[u8], DriverError> {
320 let &RayTraceProperties {
321 shader_group_handle_size,
322 ..
323 } = this
324 .device
325 .physical_device
326 .ray_trace_properties
327 .as_ref()
328 .ok_or(DriverError::Unsupported)?;
329 let start = idx * shader_group_handle_size as usize;
330 let end = start + shader_group_handle_size as usize;
331
332 Ok(&this.shader_group_handles[start..end])
333 }
334
335 #[profiling::function]
340 pub fn group_stack_size(
341 this: &Self,
342 group: u32,
343 group_shader: vk::ShaderGroupShaderKHR,
344 ) -> vk::DeviceSize {
345 unsafe {
346 this.device
348 .ray_trace_ext
349 .as_ref()
350 .unwrap_unchecked()
351 .get_ray_tracing_shader_group_stack_size(this.pipeline, group, group_shader)
352 }
353 }
354
355 pub fn with_name(mut this: Self, name: impl Into<String>) -> Self {
357 this.name = Some(name.into());
358 this
359 }
360}
361
362impl Deref for RayTracePipeline {
363 type Target = vk::Pipeline;
364
365 fn deref(&self) -> &Self::Target {
366 &self.pipeline
367 }
368}
369
370impl Drop for RayTracePipeline {
371 #[profiling::function]
372 fn drop(&mut self) {
373 if panicking() {
374 return;
375 }
376
377 unsafe {
378 self.device.destroy_pipeline(self.pipeline, None);
379 self.device.destroy_pipeline_layout(self.layout, None);
380 }
381
382 for shader_module in self.shader_modules.drain(..) {
383 unsafe {
384 self.device.destroy_shader_module(shader_module, None);
385 }
386 }
387 }
388}
389
390#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
392#[builder(
393 build_fn(
394 private,
395 name = "fallible_build",
396 error = "RayTracePipelineInfoBuilderError"
397 ),
398 derive(Clone, Copy, Debug),
399 pattern = "owned"
400)]
401#[non_exhaustive]
402pub struct RayTracePipelineInfo {
403 #[builder(default = "8192")]
426 pub bindless_descriptor_count: u32,
427
428 #[builder(default)]
435 pub dynamic_stack_size: bool,
436
437 #[builder(default = "16")]
443 pub max_ray_recursion_depth: u32,
444}
445
446impl RayTracePipelineInfo {
447 pub fn builder() -> RayTracePipelineInfoBuilder {
449 Default::default()
450 }
451
452 #[inline(always)]
454 pub fn to_builder(self) -> RayTracePipelineInfoBuilder {
455 RayTracePipelineInfoBuilder {
456 bindless_descriptor_count: Some(self.bindless_descriptor_count),
457 dynamic_stack_size: Some(self.dynamic_stack_size),
458 max_ray_recursion_depth: Some(self.max_ray_recursion_depth),
459 }
460 }
461}
462
463impl Default for RayTracePipelineInfo {
464 fn default() -> Self {
465 Self {
466 bindless_descriptor_count: 8192,
467 dynamic_stack_size: false,
468 max_ray_recursion_depth: 16,
469 }
470 }
471}
472
473impl From<RayTracePipelineInfoBuilder> for RayTracePipelineInfo {
474 fn from(info: RayTracePipelineInfoBuilder) -> Self {
475 info.build()
476 }
477}
478
479impl RayTracePipelineInfoBuilder {
480 #[inline(always)]
482 pub fn build(self) -> RayTracePipelineInfo {
483 let res = self.fallible_build();
484
485 #[cfg(test)]
486 let res = res.unwrap();
487
488 #[cfg(not(test))]
489 let res = unsafe { res.unwrap_unchecked() };
490
491 res
492 }
493}
494
495#[derive(Debug)]
496struct RayTracePipelineInfoBuilderError;
497
498impl From<UninitializedFieldError> for RayTracePipelineInfoBuilderError {
499 fn from(_: UninitializedFieldError) -> Self {
500 Self
501 }
502}
503
504#[derive(Clone, Copy, Debug)]
510pub struct RayTraceShaderGroup {
511 pub any_hit_shader: Option<u32>,
515
516 pub closest_hit_shader: Option<u32>,
520
521 pub general_shader: Option<u32>,
524
525 pub intersection_shader: Option<u32>,
528
529 pub ty: RayTraceShaderGroupType,
531}
532
533impl RayTraceShaderGroup {
534 fn new(
535 ty: RayTraceShaderGroupType,
536 general_shader: impl Into<Option<u32>>,
537 intersection_shader: impl Into<Option<u32>>,
538 closest_hit_shader: impl Into<Option<u32>>,
539 any_hit_shader: impl Into<Option<u32>>,
540 ) -> Self {
541 let any_hit_shader = any_hit_shader.into();
542 let closest_hit_shader = closest_hit_shader.into();
543 let general_shader = general_shader.into();
544 let intersection_shader = intersection_shader.into();
545
546 Self {
547 any_hit_shader,
548 closest_hit_shader,
549 general_shader,
550 intersection_shader,
551 ty,
552 }
553 }
554
555 pub fn new_general(general_shader: impl Into<Option<u32>>) -> Self {
557 Self::new(
558 RayTraceShaderGroupType::General,
559 general_shader,
560 None,
561 None,
562 None,
563 )
564 }
565
566 pub fn new_procedural(
569 intersection_shader: u32,
570 closest_hit_shader: impl Into<Option<u32>>,
571 any_hit_shader: impl Into<Option<u32>>,
572 ) -> Self {
573 Self::new(
574 RayTraceShaderGroupType::ProceduralHitGroup,
575 None,
576 intersection_shader,
577 closest_hit_shader,
578 any_hit_shader,
579 )
580 }
581
582 pub fn new_triangles(closest_hit_shader: u32, any_hit_shader: impl Into<Option<u32>>) -> Self {
585 Self::new(
586 RayTraceShaderGroupType::TrianglesHitGroup,
587 None,
588 None,
589 closest_hit_shader,
590 any_hit_shader,
591 )
592 }
593}
594
595impl From<RayTraceShaderGroup> for vk::RayTracingShaderGroupCreateInfoKHR<'static> {
596 fn from(shader_group: RayTraceShaderGroup) -> Self {
597 vk::RayTracingShaderGroupCreateInfoKHR::default()
598 .ty(shader_group.ty.into())
599 .any_hit_shader(shader_group.any_hit_shader.unwrap_or(vk::SHADER_UNUSED_KHR))
600 .closest_hit_shader(
601 shader_group
602 .closest_hit_shader
603 .unwrap_or(vk::SHADER_UNUSED_KHR),
604 )
605 .general_shader(shader_group.general_shader.unwrap_or(vk::SHADER_UNUSED_KHR))
606 .intersection_shader(
607 shader_group
608 .intersection_shader
609 .unwrap_or(vk::SHADER_UNUSED_KHR),
610 )
611 }
612}
613
614#[derive(Clone, Copy, Debug)]
617pub enum RayTraceShaderGroupType {
618 General,
620
621 ProceduralHitGroup,
623
624 TrianglesHitGroup,
626}
627
628impl From<RayTraceShaderGroupType> for vk::RayTracingShaderGroupTypeKHR {
629 fn from(ty: RayTraceShaderGroupType) -> Self {
630 match ty {
631 RayTraceShaderGroupType::General => vk::RayTracingShaderGroupTypeKHR::GENERAL,
632 RayTraceShaderGroupType::ProceduralHitGroup => {
633 vk::RayTracingShaderGroupTypeKHR::PROCEDURAL_HIT_GROUP
634 }
635 RayTraceShaderGroupType::TrianglesHitGroup => {
636 vk::RayTracingShaderGroupTypeKHR::TRIANGLES_HIT_GROUP
637 }
638 }
639 }
640}
641
642#[cfg(test)]
643mod tests {
644 use super::*;
645
646 type Info = RayTracePipelineInfo;
647 type Builder = RayTracePipelineInfoBuilder;
648
649 #[test]
650 pub fn ray_trace_pipeline_info() {
651 let info = Info::default();
652 let builder = info.to_builder().build();
653
654 assert_eq!(info, builder);
655 }
656
657 #[test]
658 pub fn ray_trace_pipeline_info_builder() {
659 let info = Info::default();
660 let builder = Builder::default().build();
661
662 assert_eq!(info, builder);
663 }
664}