1use {
4 super::{DescriptorSetLayout, DriverError, VertexInputState, device::Device},
5 ash::vk,
6 derive_builder::{Builder, UninitializedFieldError},
7 log::{debug, error, trace, warn},
8 ordered_float::OrderedFloat,
9 spirq::{
10 ReflectConfig,
11 entry_point::EntryPoint,
12 ty::{DescriptorType, ScalarType, Type, VectorType},
13 var::Variable,
14 },
15 std::{
16 collections::{BTreeMap, HashMap},
17 fmt::{Debug, Formatter},
18 iter::repeat_n,
19 mem::size_of_val,
20 ops::Deref,
21 sync::Arc,
22 thread::panicking,
23 },
24};
25
26pub(crate) type DescriptorBindingMap = HashMap<Descriptor, (DescriptorInfo, vk::ShaderStageFlags)>;
27
28pub(crate) fn align_spriv(code: &[u8]) -> Result<&[u32], DriverError> {
29 let (prefix, code, suffix) = unsafe { code.align_to() };
30
31 if prefix.len() + suffix.len() == 0 {
32 Ok(code)
33 } else {
34 warn!("Invalid SPIR-V code");
35
36 Err(DriverError::InvalidData)
37 }
38}
39
40#[profiling::function]
41fn guess_immutable_sampler(binding_name: &str) -> SamplerInfo {
42 const INVALID_ERR: &str = "Invalid sampler specification";
43
44 let (texel_filter, mipmap_mode, address_modes) = if binding_name.contains("_sampler_") {
45 let spec = &binding_name[binding_name.len() - 3..];
46 let texel_filter = match &spec[0..1] {
47 "n" => vk::Filter::NEAREST,
48 "l" => vk::Filter::LINEAR,
49 _ => panic!("{INVALID_ERR}: {}", &spec[0..1]),
50 };
51
52 let mipmap_mode = match &spec[1..2] {
53 "n" => vk::SamplerMipmapMode::NEAREST,
54 "l" => vk::SamplerMipmapMode::LINEAR,
55 _ => panic!("{INVALID_ERR}: {}", &spec[1..2]),
56 };
57
58 let address_modes = match &spec[2..3] {
59 "b" => vk::SamplerAddressMode::CLAMP_TO_BORDER,
60 "e" => vk::SamplerAddressMode::CLAMP_TO_EDGE,
61 "m" => vk::SamplerAddressMode::MIRRORED_REPEAT,
62 "r" => vk::SamplerAddressMode::REPEAT,
63 _ => panic!("{INVALID_ERR}: {}", &spec[2..3]),
64 };
65
66 (texel_filter, mipmap_mode, address_modes)
67 } else {
68 debug!("image binding {binding_name} using default sampler");
69
70 (
71 vk::Filter::LINEAR,
72 vk::SamplerMipmapMode::LINEAR,
73 vk::SamplerAddressMode::REPEAT,
74 )
75 };
76 let anisotropy_enable = texel_filter == vk::Filter::LINEAR;
77 let mut info = SamplerInfoBuilder::default()
78 .mag_filter(texel_filter)
79 .min_filter(texel_filter)
80 .mipmap_mode(mipmap_mode)
81 .address_mode_u(address_modes)
82 .address_mode_v(address_modes)
83 .address_mode_w(address_modes)
84 .max_lod(vk::LOD_CLAMP_NONE)
85 .anisotropy_enable(anisotropy_enable);
86
87 if anisotropy_enable {
88 info = info.max_anisotropy(16.0);
89 }
90
91 info.build()
92}
93
94#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
99pub struct Descriptor {
100 pub set: u32,
102
103 pub binding: u32,
105}
106
107impl From<u32> for Descriptor {
108 fn from(binding: u32) -> Self {
109 Self { set: 0, binding }
110 }
111}
112
113impl From<(u32, u32)> for Descriptor {
114 fn from((set, binding): (u32, u32)) -> Self {
115 Self { set, binding }
116 }
117}
118
119#[derive(Clone, Copy, Debug)]
120pub(crate) enum DescriptorInfo {
121 AccelerationStructure(u32),
122 CombinedImageSampler(u32, SamplerInfo, bool), InputAttachment(u32, u32), SampledImage(u32),
125 Sampler(u32, SamplerInfo, bool), StorageBuffer(u32),
127 StorageImage(u32),
128 StorageTexelBuffer(u32),
129 UniformBuffer(u32),
130 UniformTexelBuffer(u32),
131}
132
133impl DescriptorInfo {
134 pub fn binding_count(self) -> u32 {
135 match self {
136 Self::AccelerationStructure(binding_count) => binding_count,
137 Self::CombinedImageSampler(binding_count, ..) => binding_count,
138 Self::InputAttachment(binding_count, _) => binding_count,
139 Self::SampledImage(binding_count) => binding_count,
140 Self::Sampler(binding_count, ..) => binding_count,
141 Self::StorageBuffer(binding_count) => binding_count,
142 Self::StorageImage(binding_count) => binding_count,
143 Self::StorageTexelBuffer(binding_count) => binding_count,
144 Self::UniformBuffer(binding_count) => binding_count,
145 Self::UniformTexelBuffer(binding_count) => binding_count,
146 }
147 }
148
149 pub fn descriptor_type(self) -> vk::DescriptorType {
150 match self {
151 Self::AccelerationStructure(_) => vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
152 Self::CombinedImageSampler(..) => vk::DescriptorType::COMBINED_IMAGE_SAMPLER,
153 Self::InputAttachment(..) => vk::DescriptorType::INPUT_ATTACHMENT,
154 Self::SampledImage(_) => vk::DescriptorType::SAMPLED_IMAGE,
155 Self::Sampler(..) => vk::DescriptorType::SAMPLER,
156 Self::StorageBuffer(_) => vk::DescriptorType::STORAGE_BUFFER,
157 Self::StorageImage(_) => vk::DescriptorType::STORAGE_IMAGE,
158 Self::StorageTexelBuffer(_) => vk::DescriptorType::STORAGE_TEXEL_BUFFER,
159 Self::UniformBuffer(_) => vk::DescriptorType::UNIFORM_BUFFER,
160 Self::UniformTexelBuffer(_) => vk::DescriptorType::UNIFORM_TEXEL_BUFFER,
161 }
162 }
163
164 fn sampler_info(self) -> Option<SamplerInfo> {
165 match self {
166 Self::CombinedImageSampler(_, sampler_info, _) | Self::Sampler(_, sampler_info, _) => {
167 Some(sampler_info)
168 }
169 _ => None,
170 }
171 }
172
173 pub fn set_binding_count(&mut self, binding_count: u32) {
174 *match self {
175 Self::AccelerationStructure(binding_count) => binding_count,
176 Self::CombinedImageSampler(binding_count, ..) => binding_count,
177 Self::InputAttachment(binding_count, _) => binding_count,
178 Self::SampledImage(binding_count) => binding_count,
179 Self::Sampler(binding_count, ..) => binding_count,
180 Self::StorageBuffer(binding_count) => binding_count,
181 Self::StorageImage(binding_count) => binding_count,
182 Self::StorageTexelBuffer(binding_count) => binding_count,
183 Self::UniformBuffer(binding_count) => binding_count,
184 Self::UniformTexelBuffer(binding_count) => binding_count,
185 } = binding_count;
186 }
187}
188
189#[derive(Debug)]
190pub(crate) struct PipelineDescriptorInfo {
191 pub layouts: BTreeMap<u32, DescriptorSetLayout>,
192 pub pool_sizes: HashMap<u32, HashMap<vk::DescriptorType, u32>>,
193
194 #[allow(dead_code)]
195 samplers: Box<[Sampler]>,
196}
197
198impl PipelineDescriptorInfo {
199 #[profiling::function]
200 pub fn create(
201 device: &Arc<Device>,
202 descriptor_bindings: &DescriptorBindingMap,
203 ) -> Result<Self, DriverError> {
204 let descriptor_set_count = descriptor_bindings
205 .keys()
206 .map(|descriptor| descriptor.set)
207 .max()
208 .map(|set| set + 1)
209 .unwrap_or_default();
210 let mut layouts = BTreeMap::new();
211 let mut pool_sizes = HashMap::new();
212
213 let mut sampler_info_binding_count = HashMap::<_, u32>::with_capacity(
216 descriptor_bindings
217 .values()
218 .filter(|(descriptor_info, _)| descriptor_info.sampler_info().is_some())
219 .count(),
220 );
221
222 for (sampler_info, binding_count) in
223 descriptor_bindings
224 .values()
225 .filter_map(|(descriptor_info, _)| {
226 descriptor_info
227 .sampler_info()
228 .map(|sampler_info| (sampler_info, descriptor_info.binding_count()))
229 })
230 {
231 sampler_info_binding_count
232 .entry(sampler_info)
233 .and_modify(|sampler_info_binding_count| {
234 *sampler_info_binding_count = binding_count.max(*sampler_info_binding_count);
235 })
236 .or_insert(binding_count);
237 }
238
239 let mut samplers = sampler_info_binding_count
240 .keys()
241 .copied()
242 .map(|sampler_info| {
243 Sampler::create(device, sampler_info).map(|sampler| (sampler_info, sampler))
244 })
245 .collect::<Result<HashMap<_, _>, _>>()?;
246 let immutable_samplers = sampler_info_binding_count
247 .iter()
248 .map(|(sampler_info, &binding_count)| {
249 (
250 *sampler_info,
251 repeat_n(*samplers[sampler_info], binding_count as _).collect::<Box<_>>(),
252 )
253 })
254 .collect::<HashMap<_, _>>();
255
256 for descriptor_set_idx in 0..descriptor_set_count {
257 let mut binding_counts = HashMap::<vk::DescriptorType, u32>::new();
258 let mut bindings = vec![];
259
260 for (descriptor, (descriptor_info, stage_flags)) in descriptor_bindings
261 .iter()
262 .filter(|(descriptor, _)| descriptor.set == descriptor_set_idx)
263 {
264 let descriptor_ty = descriptor_info.descriptor_type();
265 *binding_counts.entry(descriptor_ty).or_default() +=
266 descriptor_info.binding_count();
267 let mut binding = vk::DescriptorSetLayoutBinding::default()
268 .binding(descriptor.binding)
269 .descriptor_count(descriptor_info.binding_count())
270 .descriptor_type(descriptor_ty)
271 .stage_flags(*stage_flags);
272
273 if let Some(immutable_samplers) =
274 descriptor_info.sampler_info().map(|sampler_info| {
275 &immutable_samplers[&sampler_info]
276 [0..descriptor_info.binding_count() as usize]
277 })
278 {
279 binding = binding.immutable_samplers(immutable_samplers);
280 }
281
282 bindings.push(binding);
283 }
284
285 let pool_size = pool_sizes
286 .entry(descriptor_set_idx)
287 .or_insert_with(HashMap::new);
288
289 for (descriptor_ty, binding_count) in binding_counts.into_iter() {
290 *pool_size.entry(descriptor_ty).or_default() += binding_count;
291 }
292
293 let mut create_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
296
297 let bindless_flags = vec![vk::DescriptorBindingFlags::PARTIALLY_BOUND; bindings.len()];
301 let mut bindless_flags = if device
302 .physical_device
303 .features_v1_2
304 .descriptor_binding_partially_bound
305 {
306 let bindless_flags = vk::DescriptorSetLayoutBindingFlagsCreateInfo::default()
307 .binding_flags(&bindless_flags);
308 Some(bindless_flags)
309 } else {
310 None
311 };
312
313 if let Some(bindless_flags) = bindless_flags.as_mut() {
314 create_info = create_info.push_next(bindless_flags);
315 }
316
317 layouts.insert(
318 descriptor_set_idx,
319 DescriptorSetLayout::create(device, &create_info)?,
320 );
321 }
322
323 let samplers = samplers
324 .drain()
325 .map(|(_, sampler)| sampler)
326 .collect::<Box<_>>();
327
328 Ok(Self {
332 layouts,
333 pool_sizes,
334 samplers,
335 })
336 }
337}
338
339pub(crate) struct Sampler {
340 device: Arc<Device>,
341 sampler: vk::Sampler,
342}
343
344impl Sampler {
345 #[profiling::function]
346 pub fn create(device: &Arc<Device>, info: impl Into<SamplerInfo>) -> Result<Self, DriverError> {
347 let device = Arc::clone(device);
348 let info = info.into();
349
350 let sampler = unsafe {
351 device
352 .create_sampler(
353 &vk::SamplerCreateInfo::default()
354 .flags(info.flags)
355 .mag_filter(info.mag_filter)
356 .min_filter(info.min_filter)
357 .mipmap_mode(info.mipmap_mode)
358 .address_mode_u(info.address_mode_u)
359 .address_mode_v(info.address_mode_v)
360 .address_mode_w(info.address_mode_w)
361 .mip_lod_bias(info.mip_lod_bias.0)
362 .anisotropy_enable(info.anisotropy_enable)
363 .max_anisotropy(info.max_anisotropy.0)
364 .compare_enable(info.compare_enable)
365 .compare_op(info.compare_op)
366 .min_lod(info.min_lod.0)
367 .max_lod(info.max_lod.0)
368 .border_color(info.border_color)
369 .unnormalized_coordinates(info.unnormalized_coordinates)
370 .push_next(
371 &mut vk::SamplerReductionModeCreateInfo::default()
372 .reduction_mode(info.reduction_mode),
373 ),
374 None,
375 )
376 .map_err(|err| {
377 warn!("{err}");
378
379 match err {
380 vk::Result::ERROR_OUT_OF_HOST_MEMORY
381 | vk::Result::ERROR_OUT_OF_DEVICE_MEMORY => DriverError::OutOfMemory,
382 _ => DriverError::Unsupported,
383 }
384 })?
385 };
386
387 Ok(Self { device, sampler })
388 }
389}
390
391impl Debug for Sampler {
392 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
393 write!(f, "{:?}", self.sampler)
394 }
395}
396
397impl Deref for Sampler {
398 type Target = vk::Sampler;
399
400 fn deref(&self) -> &Self::Target {
401 &self.sampler
402 }
403}
404
405impl Drop for Sampler {
406 #[profiling::function]
407 fn drop(&mut self) {
408 if panicking() {
409 return;
410 }
411
412 unsafe {
413 self.device.destroy_sampler(self.sampler, None);
414 }
415 }
416}
417
418#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
420#[builder(
421 build_fn(private, name = "fallible_build", error = "SamplerInfoBuilderError"),
422 derive(Clone, Copy, Debug),
423 pattern = "owned"
424)]
425#[non_exhaustive]
426pub struct SamplerInfo {
427 #[builder(default)]
429 pub flags: vk::SamplerCreateFlags,
430
431 #[builder(default)]
435 pub mag_filter: vk::Filter,
436
437 #[builder(default)]
441 pub min_filter: vk::Filter,
442
443 #[builder(default)]
447 pub mipmap_mode: vk::SamplerMipmapMode,
448
449 #[builder(default)]
453 pub address_mode_u: vk::SamplerAddressMode,
454
455 #[builder(default)]
459 pub address_mode_v: vk::SamplerAddressMode,
460
461 #[builder(default)]
465 pub address_mode_w: vk::SamplerAddressMode,
466
467 #[builder(default, setter(into))]
472 pub mip_lod_bias: OrderedFloat<f32>,
473
474 #[builder(default)]
478 pub anisotropy_enable: bool,
479
480 #[builder(default, setter(into))]
484 pub max_anisotropy: OrderedFloat<f32>,
485
486 #[builder(default)]
488 pub compare_enable: bool,
489
490 #[builder(default)]
495 pub compare_op: vk::CompareOp,
496
497 #[builder(default, setter(into))]
500 pub min_lod: OrderedFloat<f32>,
501
502 #[builder(default, setter(into))]
507 pub max_lod: OrderedFloat<f32>,
508
509 #[builder(default)]
513 pub border_color: vk::BorderColor,
514
515 #[builder(default)]
526 pub unnormalized_coordinates: bool,
527
528 #[builder(default)]
538 pub reduction_mode: vk::SamplerReductionMode,
539}
540
541impl SamplerInfo {
542 pub const LINEAR: SamplerInfoBuilder = SamplerInfoBuilder {
544 flags: None,
545 mag_filter: Some(vk::Filter::LINEAR),
546 min_filter: Some(vk::Filter::LINEAR),
547 mipmap_mode: Some(vk::SamplerMipmapMode::LINEAR),
548 address_mode_u: None,
549 address_mode_v: None,
550 address_mode_w: None,
551 mip_lod_bias: None,
552 anisotropy_enable: None,
553 max_anisotropy: None,
554 compare_enable: None,
555 compare_op: None,
556 min_lod: None,
557 max_lod: None,
558 border_color: None,
559 unnormalized_coordinates: None,
560 reduction_mode: None,
561 };
562
563 pub const NEAREST: SamplerInfoBuilder = SamplerInfoBuilder {
566 flags: None,
567 mag_filter: Some(vk::Filter::NEAREST),
568 min_filter: Some(vk::Filter::NEAREST),
569 mipmap_mode: Some(vk::SamplerMipmapMode::NEAREST),
570 address_mode_u: None,
571 address_mode_v: None,
572 address_mode_w: None,
573 mip_lod_bias: None,
574 anisotropy_enable: None,
575 max_anisotropy: None,
576 compare_enable: None,
577 compare_op: None,
578 min_lod: None,
579 max_lod: None,
580 border_color: None,
581 unnormalized_coordinates: None,
582 reduction_mode: None,
583 };
584
585 #[allow(clippy::new_ret_no_self)]
587 #[deprecated = "Use SamplerInfo::default()"]
588 #[doc(hidden)]
589 pub fn new() -> SamplerInfoBuilder {
590 Self::default().to_builder()
591 }
592
593 #[inline(always)]
595 pub fn to_builder(self) -> SamplerInfoBuilder {
596 SamplerInfoBuilder {
597 flags: Some(self.flags),
598 mag_filter: Some(self.mag_filter),
599 min_filter: Some(self.min_filter),
600 mipmap_mode: Some(self.mipmap_mode),
601 address_mode_u: Some(self.address_mode_u),
602 address_mode_v: Some(self.address_mode_v),
603 address_mode_w: Some(self.address_mode_w),
604 mip_lod_bias: Some(self.mip_lod_bias),
605 anisotropy_enable: Some(self.anisotropy_enable),
606 max_anisotropy: Some(self.max_anisotropy),
607 compare_enable: Some(self.compare_enable),
608 compare_op: Some(self.compare_op),
609 min_lod: Some(self.min_lod),
610 max_lod: Some(self.max_lod),
611 border_color: Some(self.border_color),
612 unnormalized_coordinates: Some(self.unnormalized_coordinates),
613 reduction_mode: Some(self.reduction_mode),
614 }
615 }
616}
617
618impl Default for SamplerInfo {
619 fn default() -> Self {
620 Self {
621 flags: vk::SamplerCreateFlags::empty(),
622 mag_filter: vk::Filter::NEAREST,
623 min_filter: vk::Filter::NEAREST,
624 mipmap_mode: vk::SamplerMipmapMode::NEAREST,
625 address_mode_u: vk::SamplerAddressMode::REPEAT,
626 address_mode_v: vk::SamplerAddressMode::REPEAT,
627 address_mode_w: vk::SamplerAddressMode::REPEAT,
628 mip_lod_bias: OrderedFloat(0.0),
629 anisotropy_enable: false,
630 max_anisotropy: OrderedFloat(0.0),
631 compare_enable: false,
632 compare_op: vk::CompareOp::NEVER,
633 min_lod: OrderedFloat(0.0),
634 max_lod: OrderedFloat(0.0),
635 border_color: vk::BorderColor::FLOAT_TRANSPARENT_BLACK,
636 unnormalized_coordinates: false,
637 reduction_mode: vk::SamplerReductionMode::WEIGHTED_AVERAGE,
638 }
639 }
640}
641
642impl SamplerInfoBuilder {
643 #[inline(always)]
645 pub fn build(self) -> SamplerInfo {
646 let res = self.fallible_build();
647
648 #[cfg(test)]
649 let res = res.unwrap();
650
651 #[cfg(not(test))]
652 let res = unsafe { res.unwrap_unchecked() };
653
654 res
655 }
656}
657
658impl From<SamplerInfoBuilder> for SamplerInfo {
659 fn from(info: SamplerInfoBuilder) -> Self {
660 info.build()
661 }
662}
663
664#[derive(Debug)]
665struct SamplerInfoBuilderError;
666
667impl From<UninitializedFieldError> for SamplerInfoBuilderError {
668 fn from(_: UninitializedFieldError) -> Self {
669 Self
670 }
671}
672
673#[allow(missing_docs)]
675#[derive(Builder, Clone)]
676#[builder(
677 build_fn(private, name = "fallible_build", error = "ShaderBuilderError"),
678 derive(Clone, Debug),
679 pattern = "owned"
680)]
681pub struct Shader {
682 #[builder(default = "\"main\".to_owned()")]
686 pub entry_name: String,
687
688 #[builder(default, setter(strip_option))]
732 pub specialization_info: Option<SpecializationInfo>,
733
734 pub spirv: Vec<u8>,
740
741 pub stage: vk::ShaderStageFlags,
743
744 #[builder(private)]
745 entry_point: EntryPoint,
746
747 #[builder(default, private)]
748 image_samplers: HashMap<Descriptor, SamplerInfo>,
749
750 #[builder(default, private, setter(strip_option))]
751 vertex_input_state: Option<VertexInputState>,
752}
753
754impl Shader {
755 #[allow(clippy::new_ret_no_self)]
757 pub fn new(stage: vk::ShaderStageFlags, spirv: impl ShaderCode) -> ShaderBuilder {
758 ShaderBuilder::default()
759 .spirv(spirv.into_vec())
760 .stage(stage)
761 }
762
763 pub fn new_any_hit(spirv: impl ShaderCode) -> ShaderBuilder {
769 Self::new(vk::ShaderStageFlags::ANY_HIT_KHR, spirv)
770 }
771
772 pub fn new_callable(spirv: impl ShaderCode) -> ShaderBuilder {
778 Self::new(vk::ShaderStageFlags::CALLABLE_KHR, spirv)
779 }
780
781 pub fn new_closest_hit(spirv: impl ShaderCode) -> ShaderBuilder {
787 Self::new(vk::ShaderStageFlags::CLOSEST_HIT_KHR, spirv)
788 }
789
790 pub fn new_compute(spirv: impl ShaderCode) -> ShaderBuilder {
796 Self::new(vk::ShaderStageFlags::COMPUTE, spirv)
797 }
798
799 pub fn new_fragment(spirv: impl ShaderCode) -> ShaderBuilder {
805 Self::new(vk::ShaderStageFlags::FRAGMENT, spirv)
806 }
807
808 pub fn new_geometry(spirv: impl ShaderCode) -> ShaderBuilder {
814 Self::new(vk::ShaderStageFlags::GEOMETRY, spirv)
815 }
816
817 pub fn new_intersection(spirv: impl ShaderCode) -> ShaderBuilder {
823 Self::new(vk::ShaderStageFlags::INTERSECTION_KHR, spirv)
824 }
825
826 pub fn new_mesh(spirv: impl ShaderCode) -> ShaderBuilder {
832 Self::new(vk::ShaderStageFlags::MESH_EXT, spirv)
833 }
834
835 pub fn new_miss(spirv: impl ShaderCode) -> ShaderBuilder {
841 Self::new(vk::ShaderStageFlags::MISS_KHR, spirv)
842 }
843
844 pub fn new_ray_gen(spirv: impl ShaderCode) -> ShaderBuilder {
850 Self::new(vk::ShaderStageFlags::RAYGEN_KHR, spirv)
851 }
852
853 pub fn new_task(spirv: impl ShaderCode) -> ShaderBuilder {
859 Self::new(vk::ShaderStageFlags::TASK_EXT, spirv)
860 }
861
862 pub fn new_tesselation_ctrl(spirv: impl ShaderCode) -> ShaderBuilder {
868 Self::new(vk::ShaderStageFlags::TESSELLATION_CONTROL, spirv)
869 }
870
871 pub fn new_tesselation_eval(spirv: impl ShaderCode) -> ShaderBuilder {
877 Self::new(vk::ShaderStageFlags::TESSELLATION_EVALUATION, spirv)
878 }
879
880 pub fn new_vertex(spirv: impl ShaderCode) -> ShaderBuilder {
886 Self::new(vk::ShaderStageFlags::VERTEX, spirv)
887 }
888
889 #[profiling::function]
891 pub(super) fn attachments(
892 &self,
893 ) -> (
894 impl Iterator<Item = u32> + '_,
895 impl Iterator<Item = u32> + '_,
896 ) {
897 (
898 self.entry_point.vars.iter().filter_map(|var| match var {
899 Variable::Descriptor {
900 desc_ty: DescriptorType::InputAttachment(attachment),
901 ..
902 } => Some(*attachment),
903 _ => None,
904 }),
905 self.entry_point.vars.iter().filter_map(|var| match var {
906 Variable::Output { location, .. } => Some(location.loc()),
907 _ => None,
908 }),
909 )
910 }
911
912 #[profiling::function]
913 pub(super) fn descriptor_bindings(&self) -> DescriptorBindingMap {
914 let mut res = DescriptorBindingMap::default();
915
916 for (name, descriptor, desc_ty, binding_count) in
917 self.entry_point.vars.iter().filter_map(|var| match var {
918 Variable::Descriptor {
919 name,
920 desc_bind,
921 desc_ty,
922 nbind,
923 ..
924 } => Some((
925 name,
926 Descriptor {
927 set: desc_bind.set(),
928 binding: desc_bind.bind(),
929 },
930 desc_ty,
931 *nbind,
932 )),
933 _ => None,
934 })
935 {
936 trace!(
937 "descriptor {}: {}.{} = {:?}[{}]",
938 name.as_deref().unwrap_or_default(),
939 descriptor.set,
940 descriptor.binding,
941 *desc_ty,
942 binding_count
943 );
944
945 let descriptor_info = match desc_ty {
946 DescriptorType::AccelStruct() => {
947 DescriptorInfo::AccelerationStructure(binding_count)
948 }
949 DescriptorType::CombinedImageSampler() => {
950 let (sampler_info, is_manually_defined) =
951 self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
952
953 DescriptorInfo::CombinedImageSampler(
954 binding_count,
955 sampler_info,
956 is_manually_defined,
957 )
958 }
959 DescriptorType::InputAttachment(attachment) => {
960 DescriptorInfo::InputAttachment(binding_count, *attachment)
961 }
962 DescriptorType::SampledImage() => DescriptorInfo::SampledImage(binding_count),
963 DescriptorType::Sampler() => {
964 let (sampler_info, is_manually_defined) =
965 self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
966
967 DescriptorInfo::Sampler(binding_count, sampler_info, is_manually_defined)
968 }
969 DescriptorType::StorageBuffer(_access_ty) => {
970 DescriptorInfo::StorageBuffer(binding_count)
971 }
972 DescriptorType::StorageImage(_access_ty) => {
973 DescriptorInfo::StorageImage(binding_count)
974 }
975 DescriptorType::StorageTexelBuffer(_access_ty) => {
976 DescriptorInfo::StorageTexelBuffer(binding_count)
977 }
978 DescriptorType::UniformBuffer() => DescriptorInfo::UniformBuffer(binding_count),
979 DescriptorType::UniformTexelBuffer() => {
980 DescriptorInfo::UniformTexelBuffer(binding_count)
981 }
982 };
983 res.insert(descriptor, (descriptor_info, self.stage));
984 }
985
986 res
987 }
988
989 fn image_sampler(&self, descriptor: Descriptor, name: &str) -> (SamplerInfo, bool) {
990 self.image_samplers
991 .get(&descriptor)
992 .copied()
993 .map(|sampler_info| (sampler_info, true))
994 .unwrap_or_else(|| (guess_immutable_sampler(name), false))
995 }
996
997 #[profiling::function]
998 pub(super) fn merge_descriptor_bindings(
999 descriptor_bindings: impl IntoIterator<Item = DescriptorBindingMap>,
1000 ) -> DescriptorBindingMap {
1001 fn merge_info(lhs: &mut DescriptorInfo, rhs: DescriptorInfo) -> bool {
1002 let (lhs_count, rhs_count) = match lhs {
1003 DescriptorInfo::AccelerationStructure(lhs) => {
1004 if let DescriptorInfo::AccelerationStructure(rhs) = rhs {
1005 (lhs, rhs)
1006 } else {
1007 return false;
1008 }
1009 }
1010 DescriptorInfo::CombinedImageSampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1011 if let DescriptorInfo::CombinedImageSampler(
1012 rhs,
1013 rhs_sampler,
1014 rhs_is_manually_defined,
1015 ) = rhs
1016 {
1017 if *lhs_is_manually_defined && rhs_is_manually_defined {
1019 return false;
1020 } else if rhs_is_manually_defined {
1021 *lhs_sampler = rhs_sampler;
1022 }
1023
1024 (lhs, rhs)
1025 } else {
1026 return false;
1027 }
1028 }
1029 DescriptorInfo::InputAttachment(lhs, lhs_idx) => {
1030 if let DescriptorInfo::InputAttachment(rhs, rhs_idx) = rhs {
1031 if *lhs_idx != rhs_idx {
1032 return false;
1033 }
1034
1035 (lhs, rhs)
1036 } else {
1037 return false;
1038 }
1039 }
1040 DescriptorInfo::SampledImage(lhs) => {
1041 if let DescriptorInfo::SampledImage(rhs) = rhs {
1042 (lhs, rhs)
1043 } else {
1044 return false;
1045 }
1046 }
1047 DescriptorInfo::Sampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1048 if let DescriptorInfo::Sampler(rhs, rhs_sampler, rhs_is_manually_defined) = rhs
1049 {
1050 if *lhs_is_manually_defined && rhs_is_manually_defined {
1052 return false;
1053 } else if rhs_is_manually_defined {
1054 *lhs_sampler = rhs_sampler;
1055 }
1056
1057 (lhs, rhs)
1058 } else {
1059 return false;
1060 }
1061 }
1062 DescriptorInfo::StorageBuffer(lhs) => {
1063 if let DescriptorInfo::StorageBuffer(rhs) = rhs {
1064 (lhs, rhs)
1065 } else {
1066 return false;
1067 }
1068 }
1069 DescriptorInfo::StorageImage(lhs) => {
1070 if let DescriptorInfo::StorageImage(rhs) = rhs {
1071 (lhs, rhs)
1072 } else {
1073 return false;
1074 }
1075 }
1076 DescriptorInfo::StorageTexelBuffer(lhs) => {
1077 if let DescriptorInfo::StorageTexelBuffer(rhs) = rhs {
1078 (lhs, rhs)
1079 } else {
1080 return false;
1081 }
1082 }
1083 DescriptorInfo::UniformBuffer(lhs) => {
1084 if let DescriptorInfo::UniformBuffer(rhs) = rhs {
1085 (lhs, rhs)
1086 } else {
1087 return false;
1088 }
1089 }
1090 DescriptorInfo::UniformTexelBuffer(lhs) => {
1091 if let DescriptorInfo::UniformTexelBuffer(rhs) = rhs {
1092 (lhs, rhs)
1093 } else {
1094 return false;
1095 }
1096 }
1097 };
1098
1099 *lhs_count = rhs_count.max(*lhs_count);
1100
1101 true
1102 }
1103
1104 #[profiling::function]
1105 fn merge_pair(src: DescriptorBindingMap, dst: &mut DescriptorBindingMap) {
1106 for (descriptor_binding, (descriptor_info, descriptor_flags)) in src.into_iter() {
1107 if let Some((existing_info, existing_flags)) = dst.get_mut(&descriptor_binding) {
1108 if !merge_info(existing_info, descriptor_info) {
1109 panic!("Inconsistent shader descriptors ({descriptor_binding:?})");
1110 }
1111
1112 *existing_flags |= descriptor_flags;
1113 } else {
1114 dst.insert(descriptor_binding, (descriptor_info, descriptor_flags));
1115 }
1116 }
1117 }
1118
1119 let mut descriptor_bindings = descriptor_bindings.into_iter();
1120 let mut res = descriptor_bindings.next().unwrap_or_default();
1121 for descriptor_binding in descriptor_bindings {
1122 merge_pair(descriptor_binding, &mut res);
1123 }
1124
1125 res
1126 }
1127
1128 #[profiling::function]
1129 pub(super) fn push_constant_range(&self) -> Option<vk::PushConstantRange> {
1130 self.entry_point
1131 .vars
1132 .iter()
1133 .filter_map(|var| match var {
1134 Variable::PushConstant {
1135 ty: Type::Struct(ty),
1136 ..
1137 } => Some(ty.members.clone()),
1138 _ => None,
1139 })
1140 .flatten()
1141 .map(|push_const| {
1142 let offset = push_const.offset.unwrap_or_default();
1143 let size = push_const
1144 .ty
1145 .nbyte()
1146 .unwrap_or_default()
1147 .next_multiple_of(4);
1148 offset..offset + size
1149 })
1150 .reduce(|a, b| a.start.min(b.start)..a.end.max(b.end))
1151 .map(|push_const| vk::PushConstantRange {
1152 stage_flags: self.stage,
1153 size: (push_const.end - push_const.start) as _,
1154 offset: push_const.start as _,
1155 })
1156 }
1157
1158 #[profiling::function]
1159 fn reflect_entry_point(
1160 entry_name: &str,
1161 spirv: &[u8],
1162 specialization_info: Option<&SpecializationInfo>,
1163 ) -> Result<EntryPoint, DriverError> {
1164 let mut config = ReflectConfig::new();
1165 config.ref_all_rscs(true).spv(spirv);
1166
1167 if let Some(spec_info) = specialization_info {
1168 for spec in &spec_info.map_entries {
1169 config.specialize(
1170 spec.constant_id,
1171 spec_info.data[spec.offset as usize..spec.offset as usize + spec.size].into(),
1172 );
1173 }
1174 }
1175
1176 let entry_points = config.reflect().map_err(|err| {
1177 error!("Unable to reflect spirv: {err}");
1178
1179 DriverError::InvalidData
1180 })?;
1181 let entry_point = entry_points
1182 .into_iter()
1183 .find(|entry_point| entry_point.name == entry_name)
1184 .ok_or_else(|| {
1185 error!("Entry point not found");
1186
1187 DriverError::InvalidData
1188 })?;
1189
1190 Ok(entry_point)
1191 }
1192
1193 #[profiling::function]
1194 pub(super) fn vertex_input(&self) -> VertexInputState {
1195 if let Some(vertex_input) = &self.vertex_input_state {
1197 return vertex_input.clone();
1198 }
1199
1200 fn scalar_format(ty: &ScalarType) -> vk::Format {
1201 match *ty {
1202 ScalarType::Float { bits } => match bits {
1203 u8::BITS => vk::Format::R8_SNORM,
1204 u16::BITS => vk::Format::R16_SFLOAT,
1205 u32::BITS => vk::Format::R32_SFLOAT,
1206 u64::BITS => vk::Format::R64_SFLOAT,
1207 _ => unimplemented!("{bits}-bit float"),
1208 },
1209 ScalarType::Integer {
1210 bits,
1211 is_signed: false,
1212 } => match bits {
1213 u8::BITS => vk::Format::R8_UINT,
1214 u16::BITS => vk::Format::R16_UINT,
1215 u32::BITS => vk::Format::R32_UINT,
1216 u64::BITS => vk::Format::R64_UINT,
1217 _ => unimplemented!("{bits}-bit unsigned integer"),
1218 },
1219 ScalarType::Integer {
1220 bits,
1221 is_signed: true,
1222 } => match bits {
1223 u8::BITS => vk::Format::R8_SINT,
1224 u16::BITS => vk::Format::R16_SINT,
1225 u32::BITS => vk::Format::R32_SINT,
1226 u64::BITS => vk::Format::R64_SINT,
1227 _ => unimplemented!("{bits}-bit signed integer"),
1228 },
1229 _ => unimplemented!("{ty:?}"),
1230 }
1231 }
1232
1233 fn vector_format(ty: &VectorType) -> vk::Format {
1234 match *ty {
1235 VectorType {
1236 scalar_ty: ScalarType::Float { bits },
1237 nscalar,
1238 } => match (bits, nscalar) {
1239 (u8::BITS, 2) => vk::Format::R8G8_SNORM,
1240 (u8::BITS, 3) => vk::Format::R8G8B8_SNORM,
1241 (u8::BITS, 4) => vk::Format::R8G8B8A8_SNORM,
1242 (u16::BITS, 2) => vk::Format::R16G16_SFLOAT,
1243 (u16::BITS, 3) => vk::Format::R16G16B16_SFLOAT,
1244 (u16::BITS, 4) => vk::Format::R16G16B16A16_SFLOAT,
1245 (u32::BITS, 2) => vk::Format::R32G32_SFLOAT,
1246 (u32::BITS, 3) => vk::Format::R32G32B32_SFLOAT,
1247 (u32::BITS, 4) => vk::Format::R32G32B32A32_SFLOAT,
1248 (u64::BITS, 2) => vk::Format::R64G64_SFLOAT,
1249 (u64::BITS, 3) => vk::Format::R64G64B64_SFLOAT,
1250 (u64::BITS, 4) => vk::Format::R64G64B64A64_SFLOAT,
1251 _ => unimplemented!("{bits}-bit vec{nscalar} float"),
1252 },
1253 VectorType {
1254 scalar_ty:
1255 ScalarType::Integer {
1256 bits,
1257 is_signed: false,
1258 },
1259 nscalar,
1260 } => match (bits, nscalar) {
1261 (u8::BITS, 2) => vk::Format::R8G8_UINT,
1262 (u8::BITS, 3) => vk::Format::R8G8B8_UINT,
1263 (u8::BITS, 4) => vk::Format::R8G8B8A8_UINT,
1264 (u16::BITS, 2) => vk::Format::R16G16_UINT,
1265 (u16::BITS, 3) => vk::Format::R16G16B16_UINT,
1266 (u16::BITS, 4) => vk::Format::R16G16B16A16_UINT,
1267 (u32::BITS, 2) => vk::Format::R32G32_UINT,
1268 (u32::BITS, 3) => vk::Format::R32G32B32_UINT,
1269 (u32::BITS, 4) => vk::Format::R32G32B32A32_UINT,
1270 (u64::BITS, 2) => vk::Format::R64G64_UINT,
1271 (u64::BITS, 3) => vk::Format::R64G64B64_UINT,
1272 (u64::BITS, 4) => vk::Format::R64G64B64A64_UINT,
1273 _ => unimplemented!("{bits}-bit vec{nscalar} unsigned integer"),
1274 },
1275 VectorType {
1276 scalar_ty:
1277 ScalarType::Integer {
1278 bits,
1279 is_signed: true,
1280 },
1281 nscalar,
1282 } => match (bits, nscalar) {
1283 (u8::BITS, 2) => vk::Format::R8G8_SINT,
1284 (u8::BITS, 3) => vk::Format::R8G8B8_SINT,
1285 (u8::BITS, 4) => vk::Format::R8G8B8A8_SINT,
1286 (u16::BITS, 2) => vk::Format::R16G16_SINT,
1287 (u16::BITS, 3) => vk::Format::R16G16B16_SINT,
1288 (u16::BITS, 4) => vk::Format::R16G16B16A16_SINT,
1289 (u32::BITS, 2) => vk::Format::R32G32_SINT,
1290 (u32::BITS, 3) => vk::Format::R32G32B32_SINT,
1291 (u32::BITS, 4) => vk::Format::R32G32B32A32_SINT,
1292 (u64::BITS, 2) => vk::Format::R64G64_SINT,
1293 (u64::BITS, 3) => vk::Format::R64G64B64_SINT,
1294 (u64::BITS, 4) => vk::Format::R64G64B64A64_SINT,
1295 _ => unimplemented!("{bits}-bit vec{nscalar} signed integer"),
1296 },
1297 _ => unimplemented!("{ty:?}"),
1298 }
1299 }
1300
1301 let mut input_rates_strides = HashMap::new();
1302 let mut vertex_attribute_descriptions = vec![];
1303
1304 for (name, location, ty) in self.entry_point.vars.iter().filter_map(|var| match var {
1305 Variable::Input { name, location, ty } => Some((name, location, ty)),
1306 _ => None,
1307 }) {
1308 let (binding, guessed_rate) = name
1309 .as_ref()
1310 .filter(|name| name.contains("_ibind") || name.contains("_vbind"))
1311 .map(|name| {
1312 let binding = name[name.rfind("bind").unwrap()..]
1313 .parse()
1314 .unwrap_or_default();
1315 let rate = if name.contains("_ibind") {
1316 vk::VertexInputRate::INSTANCE
1317 } else {
1318 vk::VertexInputRate::VERTEX
1319 };
1320
1321 (binding, rate)
1322 })
1323 .unwrap_or_default();
1324 let (location, _) = location.into_inner();
1325 if let Some((input_rate, _)) = input_rates_strides.get(&binding) {
1326 assert_eq!(*input_rate, guessed_rate);
1327 }
1328
1329 let byte_stride = ty.nbyte().unwrap_or_default() as u32;
1330 let (input_rate, stride) = input_rates_strides.entry(binding).or_default();
1331 *input_rate = guessed_rate;
1332 *stride += byte_stride;
1333
1334 vertex_attribute_descriptions.push(vk::VertexInputAttributeDescription {
1337 location,
1338 binding,
1339 format: match ty {
1340 Type::Scalar(ty) => scalar_format(ty),
1341 Type::Vector(ty) => vector_format(ty),
1342 _ => unimplemented!("{:?}", ty),
1343 },
1344 offset: byte_stride, });
1346 }
1347
1348 vertex_attribute_descriptions.sort_unstable_by(|lhs, rhs| {
1349 let binding = lhs.binding.cmp(&rhs.binding);
1350 if binding.is_lt() {
1351 return binding;
1352 }
1353
1354 lhs.location.cmp(&rhs.location)
1355 });
1356
1357 let mut offset = 0;
1358 let mut offset_binding = 0;
1359
1360 for vertex_attribute_description in &mut vertex_attribute_descriptions {
1361 if vertex_attribute_description.binding != offset_binding {
1362 offset_binding = vertex_attribute_description.binding;
1363 offset = 0;
1364 }
1365
1366 let stride = vertex_attribute_description.offset;
1367 vertex_attribute_description.offset = offset;
1368 offset += stride;
1369
1370 debug!(
1371 "vertex attribute {}.{}: {:?} (offset={})",
1372 vertex_attribute_description.binding,
1373 vertex_attribute_description.location,
1374 vertex_attribute_description.format,
1375 vertex_attribute_description.offset,
1376 );
1377 }
1378
1379 let mut vertex_binding_descriptions = vec![];
1380 for (binding, (input_rate, stride)) in input_rates_strides.into_iter() {
1381 vertex_binding_descriptions.push(vk::VertexInputBindingDescription {
1382 binding,
1383 input_rate,
1384 stride,
1385 });
1386 }
1387
1388 VertexInputState {
1389 vertex_attribute_descriptions,
1390 vertex_binding_descriptions,
1391 }
1392 }
1393}
1394
1395impl Debug for Shader {
1396 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1397 f.write_str("Shader")
1400 }
1401}
1402
1403impl From<ShaderBuilder> for Shader {
1404 fn from(shader: ShaderBuilder) -> Self {
1405 shader.build()
1406 }
1407}
1408
1409impl ShaderBuilder {
1411 pub fn new(stage: vk::ShaderStageFlags, spirv: Vec<u8>) -> Self {
1413 Self::default().stage(stage).spirv(spirv)
1414 }
1415
1416 pub fn build(mut self) -> Shader {
1418 let entry_name = self.entry_name.as_deref().unwrap_or("main");
1419 self.entry_point = Some(
1420 Shader::reflect_entry_point(
1421 entry_name,
1422 self.spirv.as_deref().unwrap(),
1423 self.specialization_info
1424 .as_ref()
1425 .map(|opt| opt.as_ref())
1426 .unwrap_or_default(),
1427 )
1428 .unwrap_or_else(|_| panic!("invalid shader code for entry name \'{entry_name}\'")),
1429 );
1430
1431 self.fallible_build()
1432 .expect("All required fields set at initialization")
1433 }
1434
1435 #[profiling::function]
1455 pub fn image_sampler(
1456 mut self,
1457 descriptor: impl Into<Descriptor>,
1458 info: impl Into<SamplerInfo>,
1459 ) -> Self {
1460 let descriptor = descriptor.into();
1461 let info = info.into();
1462
1463 if self.image_samplers.is_none() {
1464 self.image_samplers = Some(Default::default());
1465 }
1466
1467 self.image_samplers
1468 .as_mut()
1469 .unwrap()
1470 .insert(descriptor, info);
1471
1472 self
1473 }
1474
1475 #[profiling::function]
1487 pub fn vertex_input(
1488 mut self,
1489 bindings: impl Into<Vec<vk::VertexInputBindingDescription>>,
1490 attributes: impl Into<Vec<vk::VertexInputAttributeDescription>>,
1491 ) -> Self {
1492 self.vertex_input_state = Some(Some(VertexInputState {
1493 vertex_binding_descriptions: bindings.into(),
1494 vertex_attribute_descriptions: attributes.into(),
1495 }));
1496 self
1497 }
1498}
1499
1500#[derive(Debug)]
1501struct ShaderBuilderError;
1502
1503impl From<UninitializedFieldError> for ShaderBuilderError {
1504 fn from(_: UninitializedFieldError) -> Self {
1505 Self
1506 }
1507}
1508
1509pub trait ShaderCode {
1511 fn into_vec(self) -> Vec<u8>;
1513}
1514
1515impl ShaderCode for &[u8] {
1516 fn into_vec(self) -> Vec<u8> {
1517 debug_assert_eq!(self.len() % 4, 0, "invalid spir-v code");
1518
1519 self.to_vec()
1520 }
1521}
1522
1523impl ShaderCode for &[u32] {
1524 fn into_vec(self) -> Vec<u8> {
1525 pub fn into_u8_slice<T>(t: &[T]) -> &[u8]
1526 where
1527 T: Sized,
1528 {
1529 use std::slice::from_raw_parts;
1530
1531 unsafe { from_raw_parts(t.as_ptr() as *const _, size_of_val(t)) }
1532 }
1533
1534 into_u8_slice(self).into_vec()
1535 }
1536}
1537
1538impl ShaderCode for Vec<u8> {
1539 fn into_vec(self) -> Vec<u8> {
1540 debug_assert_eq!(self.len() % 4, 0, "invalid spir-v code");
1541
1542 self
1543 }
1544}
1545
1546impl ShaderCode for Vec<u32> {
1547 fn into_vec(self) -> Vec<u8> {
1548 self.as_slice().into_vec()
1549 }
1550}
1551
1552#[derive(Clone, Debug)]
1554pub struct SpecializationInfo {
1555 pub data: Vec<u8>,
1557
1558 pub map_entries: Vec<vk::SpecializationMapEntry>,
1560}
1561
1562impl SpecializationInfo {
1563 pub fn new(
1565 map_entries: impl Into<Vec<vk::SpecializationMapEntry>>,
1566 data: impl Into<Vec<u8>>,
1567 ) -> Self {
1568 Self {
1569 data: data.into(),
1570 map_entries: map_entries.into(),
1571 }
1572 }
1573}
1574
1575#[cfg(test)]
1576mod tests {
1577 use super::*;
1578
1579 type Info = SamplerInfo;
1580 type Builder = SamplerInfoBuilder;
1581
1582 #[test]
1583 pub fn sampler_info() {
1584 let info = Info::default();
1585 let builder = info.to_builder().build();
1586
1587 assert_eq!(info, builder);
1588 }
1589
1590 #[test]
1591 pub fn sampler_info_builder() {
1592 let info = Info::default();
1593 let builder = Builder::default().build();
1594
1595 assert_eq!(info, builder);
1596 }
1597}