screen_13/driver/
shader.rs

1//! Shader resource types
2
3use {
4    super::{DescriptorSetLayout, DriverError, VertexInputState, device::Device},
5    ash::vk,
6    derive_builder::{Builder, UninitializedFieldError},
7    log::{debug, error, trace, warn},
8    ordered_float::OrderedFloat,
9    spirq::{
10        ReflectConfig,
11        entry_point::EntryPoint,
12        ty::{DescriptorType, ScalarType, Type, VectorType},
13        var::Variable,
14    },
15    std::{
16        collections::{BTreeMap, HashMap},
17        fmt::{Debug, Formatter},
18        iter::repeat_n,
19        mem::size_of_val,
20        ops::Deref,
21        sync::Arc,
22        thread::panicking,
23    },
24};
25
26pub(crate) type DescriptorBindingMap = HashMap<Descriptor, (DescriptorInfo, vk::ShaderStageFlags)>;
27
28pub(crate) fn align_spriv(code: &[u8]) -> Result<&[u32], DriverError> {
29    let (prefix, code, suffix) = unsafe { code.align_to() };
30
31    if prefix.len() + suffix.len() == 0 {
32        Ok(code)
33    } else {
34        warn!("Invalid SPIR-V code");
35
36        Err(DriverError::InvalidData)
37    }
38}
39
40#[profiling::function]
41fn guess_immutable_sampler(binding_name: &str) -> SamplerInfo {
42    const INVALID_ERR: &str = "Invalid sampler specification";
43
44    let (texel_filter, mipmap_mode, address_modes) = if binding_name.contains("_sampler_") {
45        let spec = &binding_name[binding_name.len() - 3..];
46        let texel_filter = match &spec[0..1] {
47            "n" => vk::Filter::NEAREST,
48            "l" => vk::Filter::LINEAR,
49            _ => panic!("{INVALID_ERR}: {}", &spec[0..1]),
50        };
51
52        let mipmap_mode = match &spec[1..2] {
53            "n" => vk::SamplerMipmapMode::NEAREST,
54            "l" => vk::SamplerMipmapMode::LINEAR,
55            _ => panic!("{INVALID_ERR}: {}", &spec[1..2]),
56        };
57
58        let address_modes = match &spec[2..3] {
59            "b" => vk::SamplerAddressMode::CLAMP_TO_BORDER,
60            "e" => vk::SamplerAddressMode::CLAMP_TO_EDGE,
61            "m" => vk::SamplerAddressMode::MIRRORED_REPEAT,
62            "r" => vk::SamplerAddressMode::REPEAT,
63            _ => panic!("{INVALID_ERR}: {}", &spec[2..3]),
64        };
65
66        (texel_filter, mipmap_mode, address_modes)
67    } else {
68        debug!("image binding {binding_name} using default sampler");
69
70        (
71            vk::Filter::LINEAR,
72            vk::SamplerMipmapMode::LINEAR,
73            vk::SamplerAddressMode::REPEAT,
74        )
75    };
76    let anisotropy_enable = texel_filter == vk::Filter::LINEAR;
77    let mut info = SamplerInfoBuilder::default()
78        .mag_filter(texel_filter)
79        .min_filter(texel_filter)
80        .mipmap_mode(mipmap_mode)
81        .address_mode_u(address_modes)
82        .address_mode_v(address_modes)
83        .address_mode_w(address_modes)
84        .max_lod(vk::LOD_CLAMP_NONE)
85        .anisotropy_enable(anisotropy_enable);
86
87    if anisotropy_enable {
88        info = info.max_anisotropy(16.0);
89    }
90
91    info.build()
92}
93
94/// Tuple of descriptor set index and binding index.
95///
96/// This is a generic representation of the descriptor binding point within the shader and not a
97/// bound descriptor reference.
98#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
99pub struct Descriptor {
100    /// Descriptor set index
101    pub set: u32,
102
103    /// Descriptor binding index
104    pub binding: u32,
105}
106
107impl From<u32> for Descriptor {
108    fn from(binding: u32) -> Self {
109        Self { set: 0, binding }
110    }
111}
112
113impl From<(u32, u32)> for Descriptor {
114    fn from((set, binding): (u32, u32)) -> Self {
115        Self { set, binding }
116    }
117}
118
119#[derive(Clone, Copy, Debug)]
120pub(crate) enum DescriptorInfo {
121    AccelerationStructure(u32),
122    CombinedImageSampler(u32, SamplerInfo, bool), //count, sampler, is-manually-defined?
123    InputAttachment(u32, u32),                    //count, input index,
124    SampledImage(u32),
125    Sampler(u32, SamplerInfo, bool), //count, sampler, is-manually-defined?
126    StorageBuffer(u32),
127    StorageImage(u32),
128    StorageTexelBuffer(u32),
129    UniformBuffer(u32),
130    UniformTexelBuffer(u32),
131}
132
133impl DescriptorInfo {
134    pub fn binding_count(self) -> u32 {
135        match self {
136            Self::AccelerationStructure(binding_count) => binding_count,
137            Self::CombinedImageSampler(binding_count, ..) => binding_count,
138            Self::InputAttachment(binding_count, _) => binding_count,
139            Self::SampledImage(binding_count) => binding_count,
140            Self::Sampler(binding_count, ..) => binding_count,
141            Self::StorageBuffer(binding_count) => binding_count,
142            Self::StorageImage(binding_count) => binding_count,
143            Self::StorageTexelBuffer(binding_count) => binding_count,
144            Self::UniformBuffer(binding_count) => binding_count,
145            Self::UniformTexelBuffer(binding_count) => binding_count,
146        }
147    }
148
149    pub fn descriptor_type(self) -> vk::DescriptorType {
150        match self {
151            Self::AccelerationStructure(_) => vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
152            Self::CombinedImageSampler(..) => vk::DescriptorType::COMBINED_IMAGE_SAMPLER,
153            Self::InputAttachment(..) => vk::DescriptorType::INPUT_ATTACHMENT,
154            Self::SampledImage(_) => vk::DescriptorType::SAMPLED_IMAGE,
155            Self::Sampler(..) => vk::DescriptorType::SAMPLER,
156            Self::StorageBuffer(_) => vk::DescriptorType::STORAGE_BUFFER,
157            Self::StorageImage(_) => vk::DescriptorType::STORAGE_IMAGE,
158            Self::StorageTexelBuffer(_) => vk::DescriptorType::STORAGE_TEXEL_BUFFER,
159            Self::UniformBuffer(_) => vk::DescriptorType::UNIFORM_BUFFER,
160            Self::UniformTexelBuffer(_) => vk::DescriptorType::UNIFORM_TEXEL_BUFFER,
161        }
162    }
163
164    fn sampler_info(self) -> Option<SamplerInfo> {
165        match self {
166            Self::CombinedImageSampler(_, sampler_info, _) | Self::Sampler(_, sampler_info, _) => {
167                Some(sampler_info)
168            }
169            _ => None,
170        }
171    }
172
173    pub fn set_binding_count(&mut self, binding_count: u32) {
174        *match self {
175            Self::AccelerationStructure(binding_count) => binding_count,
176            Self::CombinedImageSampler(binding_count, ..) => binding_count,
177            Self::InputAttachment(binding_count, _) => binding_count,
178            Self::SampledImage(binding_count) => binding_count,
179            Self::Sampler(binding_count, ..) => binding_count,
180            Self::StorageBuffer(binding_count) => binding_count,
181            Self::StorageImage(binding_count) => binding_count,
182            Self::StorageTexelBuffer(binding_count) => binding_count,
183            Self::UniformBuffer(binding_count) => binding_count,
184            Self::UniformTexelBuffer(binding_count) => binding_count,
185        } = binding_count;
186    }
187}
188
189#[derive(Debug)]
190pub(crate) struct PipelineDescriptorInfo {
191    pub layouts: BTreeMap<u32, DescriptorSetLayout>,
192    pub pool_sizes: HashMap<u32, HashMap<vk::DescriptorType, u32>>,
193
194    #[allow(dead_code)]
195    samplers: Box<[Sampler]>,
196}
197
198impl PipelineDescriptorInfo {
199    #[profiling::function]
200    pub fn create(
201        device: &Arc<Device>,
202        descriptor_bindings: &DescriptorBindingMap,
203    ) -> Result<Self, DriverError> {
204        let descriptor_set_count = descriptor_bindings
205            .keys()
206            .map(|descriptor| descriptor.set)
207            .max()
208            .map(|set| set + 1)
209            .unwrap_or_default();
210        let mut layouts = BTreeMap::new();
211        let mut pool_sizes = HashMap::new();
212
213        //trace!("descriptor_bindings: {:#?}", &descriptor_bindings);
214
215        let mut sampler_info_binding_count = HashMap::<_, u32>::with_capacity(
216            descriptor_bindings
217                .values()
218                .filter(|(descriptor_info, _)| descriptor_info.sampler_info().is_some())
219                .count(),
220        );
221
222        for (sampler_info, binding_count) in
223            descriptor_bindings
224                .values()
225                .filter_map(|(descriptor_info, _)| {
226                    descriptor_info
227                        .sampler_info()
228                        .map(|sampler_info| (sampler_info, descriptor_info.binding_count()))
229                })
230        {
231            sampler_info_binding_count
232                .entry(sampler_info)
233                .and_modify(|sampler_info_binding_count| {
234                    *sampler_info_binding_count = binding_count.max(*sampler_info_binding_count);
235                })
236                .or_insert(binding_count);
237        }
238
239        let mut samplers = sampler_info_binding_count
240            .keys()
241            .copied()
242            .map(|sampler_info| {
243                Sampler::create(device, sampler_info).map(|sampler| (sampler_info, sampler))
244            })
245            .collect::<Result<HashMap<_, _>, _>>()?;
246        let immutable_samplers = sampler_info_binding_count
247            .iter()
248            .map(|(sampler_info, &binding_count)| {
249                (
250                    *sampler_info,
251                    repeat_n(*samplers[sampler_info], binding_count as _).collect::<Box<_>>(),
252                )
253            })
254            .collect::<HashMap<_, _>>();
255
256        for descriptor_set_idx in 0..descriptor_set_count {
257            let mut binding_counts = HashMap::<vk::DescriptorType, u32>::new();
258            let mut bindings = vec![];
259
260            for (descriptor, (descriptor_info, stage_flags)) in descriptor_bindings
261                .iter()
262                .filter(|(descriptor, _)| descriptor.set == descriptor_set_idx)
263            {
264                let descriptor_ty = descriptor_info.descriptor_type();
265                *binding_counts.entry(descriptor_ty).or_default() +=
266                    descriptor_info.binding_count();
267                let mut binding = vk::DescriptorSetLayoutBinding::default()
268                    .binding(descriptor.binding)
269                    .descriptor_count(descriptor_info.binding_count())
270                    .descriptor_type(descriptor_ty)
271                    .stage_flags(*stage_flags);
272
273                if let Some(immutable_samplers) =
274                    descriptor_info.sampler_info().map(|sampler_info| {
275                        &immutable_samplers[&sampler_info]
276                            [0..descriptor_info.binding_count() as usize]
277                    })
278                {
279                    binding = binding.immutable_samplers(immutable_samplers);
280                }
281
282                bindings.push(binding);
283            }
284
285            let pool_size = pool_sizes
286                .entry(descriptor_set_idx)
287                .or_insert_with(HashMap::new);
288
289            for (descriptor_ty, binding_count) in binding_counts.into_iter() {
290                *pool_size.entry(descriptor_ty).or_default() += binding_count;
291            }
292
293            //trace!("bindings: {:#?}", &bindings);
294
295            let mut create_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
296
297            // The bindless flags have to be created for every descriptor set layout binding.
298            // [vulkan spec](https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/VkDescriptorSetLayoutBindingFlagsCreateInfo.html)
299            // Maybe using one vector and updating it would be more efficient.
300            let bindless_flags = vec![vk::DescriptorBindingFlags::PARTIALLY_BOUND; bindings.len()];
301            let mut bindless_flags = if device
302                .physical_device
303                .features_v1_2
304                .descriptor_binding_partially_bound
305            {
306                let bindless_flags = vk::DescriptorSetLayoutBindingFlagsCreateInfo::default()
307                    .binding_flags(&bindless_flags);
308                Some(bindless_flags)
309            } else {
310                None
311            };
312
313            if let Some(bindless_flags) = bindless_flags.as_mut() {
314                create_info = create_info.push_next(bindless_flags);
315            }
316
317            layouts.insert(
318                descriptor_set_idx,
319                DescriptorSetLayout::create(device, &create_info)?,
320            );
321        }
322
323        let samplers = samplers
324            .drain()
325            .map(|(_, sampler)| sampler)
326            .collect::<Box<_>>();
327
328        //trace!("layouts {:#?}", &layouts);
329        // trace!("pool_sizes {:#?}", &pool_sizes);
330
331        Ok(Self {
332            layouts,
333            pool_sizes,
334            samplers,
335        })
336    }
337}
338
339pub(crate) struct Sampler {
340    device: Arc<Device>,
341    sampler: vk::Sampler,
342}
343
344impl Sampler {
345    #[profiling::function]
346    pub fn create(device: &Arc<Device>, info: impl Into<SamplerInfo>) -> Result<Self, DriverError> {
347        let device = Arc::clone(device);
348        let info = info.into();
349
350        let sampler = unsafe {
351            device
352                .create_sampler(
353                    &vk::SamplerCreateInfo::default()
354                        .flags(info.flags)
355                        .mag_filter(info.mag_filter)
356                        .min_filter(info.min_filter)
357                        .mipmap_mode(info.mipmap_mode)
358                        .address_mode_u(info.address_mode_u)
359                        .address_mode_v(info.address_mode_v)
360                        .address_mode_w(info.address_mode_w)
361                        .mip_lod_bias(info.mip_lod_bias.0)
362                        .anisotropy_enable(info.anisotropy_enable)
363                        .max_anisotropy(info.max_anisotropy.0)
364                        .compare_enable(info.compare_enable)
365                        .compare_op(info.compare_op)
366                        .min_lod(info.min_lod.0)
367                        .max_lod(info.max_lod.0)
368                        .border_color(info.border_color)
369                        .unnormalized_coordinates(info.unnormalized_coordinates)
370                        .push_next(
371                            &mut vk::SamplerReductionModeCreateInfo::default()
372                                .reduction_mode(info.reduction_mode),
373                        ),
374                    None,
375                )
376                .map_err(|err| {
377                    warn!("{err}");
378
379                    match err {
380                        vk::Result::ERROR_OUT_OF_HOST_MEMORY
381                        | vk::Result::ERROR_OUT_OF_DEVICE_MEMORY => DriverError::OutOfMemory,
382                        _ => DriverError::Unsupported,
383                    }
384                })?
385        };
386
387        Ok(Self { device, sampler })
388    }
389}
390
391impl Debug for Sampler {
392    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
393        write!(f, "{:?}", self.sampler)
394    }
395}
396
397impl Deref for Sampler {
398    type Target = vk::Sampler;
399
400    fn deref(&self) -> &Self::Target {
401        &self.sampler
402    }
403}
404
405impl Drop for Sampler {
406    #[profiling::function]
407    fn drop(&mut self) {
408        if panicking() {
409            return;
410        }
411
412        unsafe {
413            self.device.destroy_sampler(self.sampler, None);
414        }
415    }
416}
417
418/// Information used to create a [`vk::Sampler`] instance.
419#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
420#[builder(
421    build_fn(private, name = "fallible_build", error = "SamplerInfoBuilderError"),
422    derive(Clone, Copy, Debug),
423    pattern = "owned"
424)]
425#[non_exhaustive]
426pub struct SamplerInfo {
427    /// Bitmask specifying additional parameters of a sampler.
428    #[builder(default)]
429    pub flags: vk::SamplerCreateFlags,
430
431    /// Specify the magnification filter to apply to texture lookups.
432    ///
433    /// The default value is [`vk::Filter::NEAREST`]
434    #[builder(default)]
435    pub mag_filter: vk::Filter,
436
437    /// Specify the minification filter to apply to texture lookups.
438    ///
439    /// The default value is [`vk::Filter::NEAREST`]
440    #[builder(default)]
441    pub min_filter: vk::Filter,
442
443    /// A value specifying the mipmap filter to apply to lookups.
444    ///
445    /// The default value is [`vk::SamplerMipmapMode::NEAREST`]
446    #[builder(default)]
447    pub mipmap_mode: vk::SamplerMipmapMode,
448
449    /// A value specifying the addressing mode for U coordinates outside `[0, 1)`.
450    ///
451    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
452    #[builder(default)]
453    pub address_mode_u: vk::SamplerAddressMode,
454
455    /// A value specifying the addressing mode for V coordinates outside `[0, 1)`.
456    ///
457    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
458    #[builder(default)]
459    pub address_mode_v: vk::SamplerAddressMode,
460
461    /// A value specifying the addressing mode for W coordinates outside `[0, 1)`.
462    ///
463    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
464    #[builder(default)]
465    pub address_mode_w: vk::SamplerAddressMode,
466
467    /// The bias to be added to mipmap LOD calculation and bias provided by image sampling functions
468    /// in SPIR-V, as described in the
469    /// [LOD Operation](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-level-of-detail-operation)
470    /// section.
471    #[builder(default, setter(into))]
472    pub mip_lod_bias: OrderedFloat<f32>,
473
474    /// Enables anisotropic filtering, as described in the
475    /// [Texel Anisotropic Filtering](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-texel-anisotropic-filtering)
476    /// section
477    #[builder(default)]
478    pub anisotropy_enable: bool,
479
480    /// The anisotropy value clamp used by the sampler when `anisotropy_enable` is `true`.
481    ///
482    /// If `anisotropy_enable` is `false`, max_anisotropy is ignored.
483    #[builder(default, setter(into))]
484    pub max_anisotropy: OrderedFloat<f32>,
485
486    /// Enables comparison against a reference value during lookups.
487    #[builder(default)]
488    pub compare_enable: bool,
489
490    /// Specifies the comparison operator to apply to fetched data before filtering as described in
491    /// the
492    /// [Depth Compare Operation](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-depth-compare-operation)
493    /// section.
494    #[builder(default)]
495    pub compare_op: vk::CompareOp,
496
497    /// Used to clamp the
498    /// [minimum of the computed LOD value](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-level-of-detail-operation).
499    #[builder(default, setter(into))]
500    pub min_lod: OrderedFloat<f32>,
501
502    /// Used to clamp the
503    /// [maximum of the computed LOD value](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-level-of-detail-operation).
504    ///
505    /// To avoid clamping the maximum value, set maxLod to the constant `vk::LOD_CLAMP_NONE`.
506    #[builder(default, setter(into))]
507    pub max_lod: OrderedFloat<f32>,
508
509    /// Secifies the predefined border color to use.
510    ///
511    /// The default value is [`vk::BorderColor::FLOAT_TRANSPARENT_BLACK`]
512    #[builder(default)]
513    pub border_color: vk::BorderColor,
514
515    /// Controls whether to use unnormalized or normalized texel coordinates to address texels of
516    /// the image.
517    ///
518    /// When set to `true`, the range of the image coordinates used to lookup the texel is in the
519    /// range of zero to the image size in each dimension.
520    ///
521    /// When set to `false` the range of image coordinates is zero to one.
522    ///
523    /// See
524    /// [requirements](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkSamplerCreateInfo.html).
525    #[builder(default)]
526    pub unnormalized_coordinates: bool,
527
528    /// Specifies sampler reduction mode.
529    ///
530    /// Setting magnification filter ([`mag_filter`](Self::mag_filter)) to [`vk::Filter::NEAREST`]
531    /// disables sampler reduction mode.
532    ///
533    /// The default value is [`vk::SamplerReductionMode::WEIGHTED_AVERAGE`]
534    ///
535    /// See
536    /// [requirements](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkSamplerCreateInfo.html).
537    #[builder(default)]
538    pub reduction_mode: vk::SamplerReductionMode,
539}
540
541impl SamplerInfo {
542    /// Default sampler information with `mag_filter`, `min_filter` and `mipmap_mode` set to linear.
543    pub const LINEAR: SamplerInfoBuilder = SamplerInfoBuilder {
544        flags: None,
545        mag_filter: Some(vk::Filter::LINEAR),
546        min_filter: Some(vk::Filter::LINEAR),
547        mipmap_mode: Some(vk::SamplerMipmapMode::LINEAR),
548        address_mode_u: None,
549        address_mode_v: None,
550        address_mode_w: None,
551        mip_lod_bias: None,
552        anisotropy_enable: None,
553        max_anisotropy: None,
554        compare_enable: None,
555        compare_op: None,
556        min_lod: None,
557        max_lod: None,
558        border_color: None,
559        unnormalized_coordinates: None,
560        reduction_mode: None,
561    };
562
563    /// Default sampler information with `mag_filter`, `min_filter` and `mipmap_mode` set to
564    /// nearest.
565    pub const NEAREST: SamplerInfoBuilder = SamplerInfoBuilder {
566        flags: None,
567        mag_filter: Some(vk::Filter::NEAREST),
568        min_filter: Some(vk::Filter::NEAREST),
569        mipmap_mode: Some(vk::SamplerMipmapMode::NEAREST),
570        address_mode_u: None,
571        address_mode_v: None,
572        address_mode_w: None,
573        mip_lod_bias: None,
574        anisotropy_enable: None,
575        max_anisotropy: None,
576        compare_enable: None,
577        compare_op: None,
578        min_lod: None,
579        max_lod: None,
580        border_color: None,
581        unnormalized_coordinates: None,
582        reduction_mode: None,
583    };
584
585    /// Creates a default `SamplerInfoBuilder`.
586    #[allow(clippy::new_ret_no_self)]
587    #[deprecated = "Use SamplerInfo::default()"]
588    #[doc(hidden)]
589    pub fn new() -> SamplerInfoBuilder {
590        Self::default().to_builder()
591    }
592
593    /// Converts a `SamplerInfo` into a `SamplerInfoBuilder`.
594    #[inline(always)]
595    pub fn to_builder(self) -> SamplerInfoBuilder {
596        SamplerInfoBuilder {
597            flags: Some(self.flags),
598            mag_filter: Some(self.mag_filter),
599            min_filter: Some(self.min_filter),
600            mipmap_mode: Some(self.mipmap_mode),
601            address_mode_u: Some(self.address_mode_u),
602            address_mode_v: Some(self.address_mode_v),
603            address_mode_w: Some(self.address_mode_w),
604            mip_lod_bias: Some(self.mip_lod_bias),
605            anisotropy_enable: Some(self.anisotropy_enable),
606            max_anisotropy: Some(self.max_anisotropy),
607            compare_enable: Some(self.compare_enable),
608            compare_op: Some(self.compare_op),
609            min_lod: Some(self.min_lod),
610            max_lod: Some(self.max_lod),
611            border_color: Some(self.border_color),
612            unnormalized_coordinates: Some(self.unnormalized_coordinates),
613            reduction_mode: Some(self.reduction_mode),
614        }
615    }
616}
617
618impl Default for SamplerInfo {
619    fn default() -> Self {
620        Self {
621            flags: vk::SamplerCreateFlags::empty(),
622            mag_filter: vk::Filter::NEAREST,
623            min_filter: vk::Filter::NEAREST,
624            mipmap_mode: vk::SamplerMipmapMode::NEAREST,
625            address_mode_u: vk::SamplerAddressMode::REPEAT,
626            address_mode_v: vk::SamplerAddressMode::REPEAT,
627            address_mode_w: vk::SamplerAddressMode::REPEAT,
628            mip_lod_bias: OrderedFloat(0.0),
629            anisotropy_enable: false,
630            max_anisotropy: OrderedFloat(0.0),
631            compare_enable: false,
632            compare_op: vk::CompareOp::NEVER,
633            min_lod: OrderedFloat(0.0),
634            max_lod: OrderedFloat(0.0),
635            border_color: vk::BorderColor::FLOAT_TRANSPARENT_BLACK,
636            unnormalized_coordinates: false,
637            reduction_mode: vk::SamplerReductionMode::WEIGHTED_AVERAGE,
638        }
639    }
640}
641
642impl SamplerInfoBuilder {
643    /// Builds a new `SamplerInfo`.
644    #[inline(always)]
645    pub fn build(self) -> SamplerInfo {
646        let res = self.fallible_build();
647
648        #[cfg(test)]
649        let res = res.unwrap();
650
651        #[cfg(not(test))]
652        let res = unsafe { res.unwrap_unchecked() };
653
654        res
655    }
656}
657
658impl From<SamplerInfoBuilder> for SamplerInfo {
659    fn from(info: SamplerInfoBuilder) -> Self {
660        info.build()
661    }
662}
663
664#[derive(Debug)]
665struct SamplerInfoBuilderError;
666
667impl From<UninitializedFieldError> for SamplerInfoBuilderError {
668    fn from(_: UninitializedFieldError) -> Self {
669        Self
670    }
671}
672
673/// Describes a shader program which runs on some pipeline stage.
674#[allow(missing_docs)]
675#[derive(Builder, Clone)]
676#[builder(
677    build_fn(private, name = "fallible_build", error = "ShaderBuilderError"),
678    derive(Clone, Debug),
679    pattern = "owned"
680)]
681pub struct Shader {
682    /// The name of the entry point which will be executed by this shader.
683    ///
684    /// The default value is `main`.
685    #[builder(default = "\"main\".to_owned()")]
686    pub entry_name: String,
687
688    /// Data about Vulkan specialization constants.
689    ///
690    /// # Examples
691    ///
692    /// Basic usage (GLSL):
693    ///
694    /// ```
695    /// # inline_spirv::inline_spirv!(r#"
696    /// #version 460 core
697    ///
698    /// // Defaults to 6 if not set using Shader specialization_info!
699    /// layout(constant_id = 0) const uint MY_COUNT = 6;
700    ///
701    /// layout(set = 0, binding = 0) uniform sampler2D my_samplers[MY_COUNT];
702    ///
703    /// void main()
704    /// {
705    ///     // Code uses MY_COUNT number of my_samplers here
706    /// }
707    /// # "#, comp);
708    /// ```
709    ///
710    /// ```no_run
711    /// # use std::sync::Arc;
712    /// # use ash::vk;
713    /// # use screen_13::driver::DriverError;
714    /// # use screen_13::driver::device::{Device, DeviceInfo};
715    /// # use screen_13::driver::shader::{Shader, SpecializationInfo};
716    /// # fn main() -> Result<(), DriverError> {
717    /// # let device = Arc::new(Device::create_headless(DeviceInfo::default())?);
718    /// # let my_shader_code = [0u8; 1];
719    /// // We instead specify 42 for MY_COUNT:
720    /// let shader = Shader::new_fragment(my_shader_code.as_slice())
721    ///     .specialization_info(SpecializationInfo::new(
722    ///         [vk::SpecializationMapEntry {
723    ///             constant_id: 0,
724    ///             offset: 0,
725    ///             size: 4,
726    ///         }],
727    ///         42u32.to_ne_bytes()
728    ///     ));
729    /// # Ok(()) }
730    /// ```
731    #[builder(default, setter(strip_option))]
732    pub specialization_info: Option<SpecializationInfo>,
733
734    /// Shader code.
735    ///
736    /// Although SPIR-V code is specified as `u32` values, this field uses `u8` in order to make
737    /// loading from file simpler. You should always have a SPIR-V code length which is a multiple
738    /// of four bytes, or an error will be returned during pipeline creation.
739    pub spirv: Vec<u8>,
740
741    /// The shader stage this structure applies to.
742    pub stage: vk::ShaderStageFlags,
743
744    #[builder(private)]
745    entry_point: EntryPoint,
746
747    #[builder(default, private)]
748    image_samplers: HashMap<Descriptor, SamplerInfo>,
749
750    #[builder(default, private, setter(strip_option))]
751    vertex_input_state: Option<VertexInputState>,
752}
753
754impl Shader {
755    /// Specifies a shader with the given `stage` and shader code values.
756    #[allow(clippy::new_ret_no_self)]
757    pub fn new(stage: vk::ShaderStageFlags, spirv: impl ShaderCode) -> ShaderBuilder {
758        ShaderBuilder::default()
759            .spirv(spirv.into_vec())
760            .stage(stage)
761    }
762
763    /// Creates a new ray trace shader.
764    ///
765    /// # Panics
766    ///
767    /// If the shader code is invalid or not a multiple of four bytes in length.
768    pub fn new_any_hit(spirv: impl ShaderCode) -> ShaderBuilder {
769        Self::new(vk::ShaderStageFlags::ANY_HIT_KHR, spirv)
770    }
771
772    /// Creates a new ray trace shader.
773    ///
774    /// # Panics
775    ///
776    /// If the shader code is invalid or not a multiple of four bytes in length.
777    pub fn new_callable(spirv: impl ShaderCode) -> ShaderBuilder {
778        Self::new(vk::ShaderStageFlags::CALLABLE_KHR, spirv)
779    }
780
781    /// Creates a new ray trace shader.
782    ///
783    /// # Panics
784    ///
785    /// If the shader code is invalid or not a multiple of four bytes in length.
786    pub fn new_closest_hit(spirv: impl ShaderCode) -> ShaderBuilder {
787        Self::new(vk::ShaderStageFlags::CLOSEST_HIT_KHR, spirv)
788    }
789
790    /// Creates a new compute shader.
791    ///
792    /// # Panics
793    ///
794    /// If the shader code is invalid or not a multiple of four bytes in length.
795    pub fn new_compute(spirv: impl ShaderCode) -> ShaderBuilder {
796        Self::new(vk::ShaderStageFlags::COMPUTE, spirv)
797    }
798
799    /// Creates a new fragment shader.
800    ///
801    /// # Panics
802    ///
803    /// If the shader code is invalid or not a multiple of four bytes in length.
804    pub fn new_fragment(spirv: impl ShaderCode) -> ShaderBuilder {
805        Self::new(vk::ShaderStageFlags::FRAGMENT, spirv)
806    }
807
808    /// Creates a new geometry shader.
809    ///
810    /// # Panics
811    ///
812    /// If the shader code is invalid or not a multiple of four bytes in length.
813    pub fn new_geometry(spirv: impl ShaderCode) -> ShaderBuilder {
814        Self::new(vk::ShaderStageFlags::GEOMETRY, spirv)
815    }
816
817    /// Creates a new ray trace shader.
818    ///
819    /// # Panics
820    ///
821    /// If the shader code is invalid or not a multiple of four bytes in length.
822    pub fn new_intersection(spirv: impl ShaderCode) -> ShaderBuilder {
823        Self::new(vk::ShaderStageFlags::INTERSECTION_KHR, spirv)
824    }
825
826    /// Creates a new mesh shader.
827    ///
828    /// # Panics
829    ///
830    /// If the shader code is invalid.
831    pub fn new_mesh(spirv: impl ShaderCode) -> ShaderBuilder {
832        Self::new(vk::ShaderStageFlags::MESH_EXT, spirv)
833    }
834
835    /// Creates a new ray trace shader.
836    ///
837    /// # Panics
838    ///
839    /// If the shader code is invalid or not a multiple of four bytes in length.
840    pub fn new_miss(spirv: impl ShaderCode) -> ShaderBuilder {
841        Self::new(vk::ShaderStageFlags::MISS_KHR, spirv)
842    }
843
844    /// Creates a new ray trace shader.
845    ///
846    /// # Panics
847    ///
848    /// If the shader code is invalid or not a multiple of four bytes in length.
849    pub fn new_ray_gen(spirv: impl ShaderCode) -> ShaderBuilder {
850        Self::new(vk::ShaderStageFlags::RAYGEN_KHR, spirv)
851    }
852
853    /// Creates a new mesh task shader.
854    ///
855    /// # Panics
856    ///
857    /// If the shader code is invalid.
858    pub fn new_task(spirv: impl ShaderCode) -> ShaderBuilder {
859        Self::new(vk::ShaderStageFlags::TASK_EXT, spirv)
860    }
861
862    /// Creates a new tesselation control shader.
863    ///
864    /// # Panics
865    ///
866    /// If the shader code is invalid or not a multiple of four bytes in length.
867    pub fn new_tesselation_ctrl(spirv: impl ShaderCode) -> ShaderBuilder {
868        Self::new(vk::ShaderStageFlags::TESSELLATION_CONTROL, spirv)
869    }
870
871    /// Creates a new tesselation evaluation shader.
872    ///
873    /// # Panics
874    ///
875    /// If the shader code is invalid or not a multiple of four bytes in length.
876    pub fn new_tesselation_eval(spirv: impl ShaderCode) -> ShaderBuilder {
877        Self::new(vk::ShaderStageFlags::TESSELLATION_EVALUATION, spirv)
878    }
879
880    /// Creates a new vertex shader.
881    ///
882    /// # Panics
883    ///
884    /// If the shader code is invalid or not a multiple of four bytes in length.
885    pub fn new_vertex(spirv: impl ShaderCode) -> ShaderBuilder {
886        Self::new(vk::ShaderStageFlags::VERTEX, spirv)
887    }
888
889    /// Returns the input and write attachments of a shader.
890    #[profiling::function]
891    pub(super) fn attachments(
892        &self,
893    ) -> (
894        impl Iterator<Item = u32> + '_,
895        impl Iterator<Item = u32> + '_,
896    ) {
897        (
898            self.entry_point.vars.iter().filter_map(|var| match var {
899                Variable::Descriptor {
900                    desc_ty: DescriptorType::InputAttachment(attachment),
901                    ..
902                } => Some(*attachment),
903                _ => None,
904            }),
905            self.entry_point.vars.iter().filter_map(|var| match var {
906                Variable::Output { location, .. } => Some(location.loc()),
907                _ => None,
908            }),
909        )
910    }
911
912    #[profiling::function]
913    pub(super) fn descriptor_bindings(&self) -> DescriptorBindingMap {
914        let mut res = DescriptorBindingMap::default();
915
916        for (name, descriptor, desc_ty, binding_count) in
917            self.entry_point.vars.iter().filter_map(|var| match var {
918                Variable::Descriptor {
919                    name,
920                    desc_bind,
921                    desc_ty,
922                    nbind,
923                    ..
924                } => Some((
925                    name,
926                    Descriptor {
927                        set: desc_bind.set(),
928                        binding: desc_bind.bind(),
929                    },
930                    desc_ty,
931                    *nbind,
932                )),
933                _ => None,
934            })
935        {
936            trace!(
937                "descriptor {}: {}.{} = {:?}[{}]",
938                name.as_deref().unwrap_or_default(),
939                descriptor.set,
940                descriptor.binding,
941                *desc_ty,
942                binding_count
943            );
944
945            let descriptor_info = match desc_ty {
946                DescriptorType::AccelStruct() => {
947                    DescriptorInfo::AccelerationStructure(binding_count)
948                }
949                DescriptorType::CombinedImageSampler() => {
950                    let (sampler_info, is_manually_defined) =
951                        self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
952
953                    DescriptorInfo::CombinedImageSampler(
954                        binding_count,
955                        sampler_info,
956                        is_manually_defined,
957                    )
958                }
959                DescriptorType::InputAttachment(attachment) => {
960                    DescriptorInfo::InputAttachment(binding_count, *attachment)
961                }
962                DescriptorType::SampledImage() => DescriptorInfo::SampledImage(binding_count),
963                DescriptorType::Sampler() => {
964                    let (sampler_info, is_manually_defined) =
965                        self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
966
967                    DescriptorInfo::Sampler(binding_count, sampler_info, is_manually_defined)
968                }
969                DescriptorType::StorageBuffer(_access_ty) => {
970                    DescriptorInfo::StorageBuffer(binding_count)
971                }
972                DescriptorType::StorageImage(_access_ty) => {
973                    DescriptorInfo::StorageImage(binding_count)
974                }
975                DescriptorType::StorageTexelBuffer(_access_ty) => {
976                    DescriptorInfo::StorageTexelBuffer(binding_count)
977                }
978                DescriptorType::UniformBuffer() => DescriptorInfo::UniformBuffer(binding_count),
979                DescriptorType::UniformTexelBuffer() => {
980                    DescriptorInfo::UniformTexelBuffer(binding_count)
981                }
982            };
983            res.insert(descriptor, (descriptor_info, self.stage));
984        }
985
986        res
987    }
988
989    fn image_sampler(&self, descriptor: Descriptor, name: &str) -> (SamplerInfo, bool) {
990        self.image_samplers
991            .get(&descriptor)
992            .copied()
993            .map(|sampler_info| (sampler_info, true))
994            .unwrap_or_else(|| (guess_immutable_sampler(name), false))
995    }
996
997    #[profiling::function]
998    pub(super) fn merge_descriptor_bindings(
999        descriptor_bindings: impl IntoIterator<Item = DescriptorBindingMap>,
1000    ) -> DescriptorBindingMap {
1001        fn merge_info(lhs: &mut DescriptorInfo, rhs: DescriptorInfo) -> bool {
1002            let (lhs_count, rhs_count) = match lhs {
1003                DescriptorInfo::AccelerationStructure(lhs) => {
1004                    if let DescriptorInfo::AccelerationStructure(rhs) = rhs {
1005                        (lhs, rhs)
1006                    } else {
1007                        return false;
1008                    }
1009                }
1010                DescriptorInfo::CombinedImageSampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1011                    if let DescriptorInfo::CombinedImageSampler(
1012                        rhs,
1013                        rhs_sampler,
1014                        rhs_is_manually_defined,
1015                    ) = rhs
1016                    {
1017                        // Allow one of the samplers to be manually defined (only one!)
1018                        if *lhs_is_manually_defined && rhs_is_manually_defined {
1019                            return false;
1020                        } else if rhs_is_manually_defined {
1021                            *lhs_sampler = rhs_sampler;
1022                        }
1023
1024                        (lhs, rhs)
1025                    } else {
1026                        return false;
1027                    }
1028                }
1029                DescriptorInfo::InputAttachment(lhs, lhs_idx) => {
1030                    if let DescriptorInfo::InputAttachment(rhs, rhs_idx) = rhs {
1031                        if *lhs_idx != rhs_idx {
1032                            return false;
1033                        }
1034
1035                        (lhs, rhs)
1036                    } else {
1037                        return false;
1038                    }
1039                }
1040                DescriptorInfo::SampledImage(lhs) => {
1041                    if let DescriptorInfo::SampledImage(rhs) = rhs {
1042                        (lhs, rhs)
1043                    } else {
1044                        return false;
1045                    }
1046                }
1047                DescriptorInfo::Sampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1048                    if let DescriptorInfo::Sampler(rhs, rhs_sampler, rhs_is_manually_defined) = rhs
1049                    {
1050                        // Allow one of the samplers to be manually defined (only one!)
1051                        if *lhs_is_manually_defined && rhs_is_manually_defined {
1052                            return false;
1053                        } else if rhs_is_manually_defined {
1054                            *lhs_sampler = rhs_sampler;
1055                        }
1056
1057                        (lhs, rhs)
1058                    } else {
1059                        return false;
1060                    }
1061                }
1062                DescriptorInfo::StorageBuffer(lhs) => {
1063                    if let DescriptorInfo::StorageBuffer(rhs) = rhs {
1064                        (lhs, rhs)
1065                    } else {
1066                        return false;
1067                    }
1068                }
1069                DescriptorInfo::StorageImage(lhs) => {
1070                    if let DescriptorInfo::StorageImage(rhs) = rhs {
1071                        (lhs, rhs)
1072                    } else {
1073                        return false;
1074                    }
1075                }
1076                DescriptorInfo::StorageTexelBuffer(lhs) => {
1077                    if let DescriptorInfo::StorageTexelBuffer(rhs) = rhs {
1078                        (lhs, rhs)
1079                    } else {
1080                        return false;
1081                    }
1082                }
1083                DescriptorInfo::UniformBuffer(lhs) => {
1084                    if let DescriptorInfo::UniformBuffer(rhs) = rhs {
1085                        (lhs, rhs)
1086                    } else {
1087                        return false;
1088                    }
1089                }
1090                DescriptorInfo::UniformTexelBuffer(lhs) => {
1091                    if let DescriptorInfo::UniformTexelBuffer(rhs) = rhs {
1092                        (lhs, rhs)
1093                    } else {
1094                        return false;
1095                    }
1096                }
1097            };
1098
1099            *lhs_count = rhs_count.max(*lhs_count);
1100
1101            true
1102        }
1103
1104        #[profiling::function]
1105        fn merge_pair(src: DescriptorBindingMap, dst: &mut DescriptorBindingMap) {
1106            for (descriptor_binding, (descriptor_info, descriptor_flags)) in src.into_iter() {
1107                if let Some((existing_info, existing_flags)) = dst.get_mut(&descriptor_binding) {
1108                    if !merge_info(existing_info, descriptor_info) {
1109                        panic!("Inconsistent shader descriptors ({descriptor_binding:?})");
1110                    }
1111
1112                    *existing_flags |= descriptor_flags;
1113                } else {
1114                    dst.insert(descriptor_binding, (descriptor_info, descriptor_flags));
1115                }
1116            }
1117        }
1118
1119        let mut descriptor_bindings = descriptor_bindings.into_iter();
1120        let mut res = descriptor_bindings.next().unwrap_or_default();
1121        for descriptor_binding in descriptor_bindings {
1122            merge_pair(descriptor_binding, &mut res);
1123        }
1124
1125        res
1126    }
1127
1128    #[profiling::function]
1129    pub(super) fn push_constant_range(&self) -> Option<vk::PushConstantRange> {
1130        self.entry_point
1131            .vars
1132            .iter()
1133            .filter_map(|var| match var {
1134                Variable::PushConstant {
1135                    ty: Type::Struct(ty),
1136                    ..
1137                } => Some(ty.members.clone()),
1138                _ => None,
1139            })
1140            .flatten()
1141            .map(|push_const| {
1142                let offset = push_const.offset.unwrap_or_default();
1143                let size = push_const
1144                    .ty
1145                    .nbyte()
1146                    .unwrap_or_default()
1147                    .next_multiple_of(4);
1148                offset..offset + size
1149            })
1150            .reduce(|a, b| a.start.min(b.start)..a.end.max(b.end))
1151            .map(|push_const| vk::PushConstantRange {
1152                stage_flags: self.stage,
1153                size: (push_const.end - push_const.start) as _,
1154                offset: push_const.start as _,
1155            })
1156    }
1157
1158    #[profiling::function]
1159    fn reflect_entry_point(
1160        entry_name: &str,
1161        spirv: &[u8],
1162        specialization_info: Option<&SpecializationInfo>,
1163    ) -> Result<EntryPoint, DriverError> {
1164        let mut config = ReflectConfig::new();
1165        config.ref_all_rscs(true).spv(spirv);
1166
1167        if let Some(spec_info) = specialization_info {
1168            for spec in &spec_info.map_entries {
1169                config.specialize(
1170                    spec.constant_id,
1171                    spec_info.data[spec.offset as usize..spec.offset as usize + spec.size].into(),
1172                );
1173            }
1174        }
1175
1176        let entry_points = config.reflect().map_err(|err| {
1177            error!("Unable to reflect spirv: {err}");
1178
1179            DriverError::InvalidData
1180        })?;
1181        let entry_point = entry_points
1182            .into_iter()
1183            .find(|entry_point| entry_point.name == entry_name)
1184            .ok_or_else(|| {
1185                error!("Entry point not found");
1186
1187                DriverError::InvalidData
1188            })?;
1189
1190        Ok(entry_point)
1191    }
1192
1193    #[profiling::function]
1194    pub(super) fn vertex_input(&self) -> VertexInputState {
1195        // Check for manually-specified vertex layout descriptions
1196        if let Some(vertex_input) = &self.vertex_input_state {
1197            return vertex_input.clone();
1198        }
1199
1200        fn scalar_format(ty: &ScalarType) -> vk::Format {
1201            match *ty {
1202                ScalarType::Float { bits } => match bits {
1203                    u8::BITS => vk::Format::R8_SNORM,
1204                    u16::BITS => vk::Format::R16_SFLOAT,
1205                    u32::BITS => vk::Format::R32_SFLOAT,
1206                    u64::BITS => vk::Format::R64_SFLOAT,
1207                    _ => unimplemented!("{bits}-bit float"),
1208                },
1209                ScalarType::Integer {
1210                    bits,
1211                    is_signed: false,
1212                } => match bits {
1213                    u8::BITS => vk::Format::R8_UINT,
1214                    u16::BITS => vk::Format::R16_UINT,
1215                    u32::BITS => vk::Format::R32_UINT,
1216                    u64::BITS => vk::Format::R64_UINT,
1217                    _ => unimplemented!("{bits}-bit unsigned integer"),
1218                },
1219                ScalarType::Integer {
1220                    bits,
1221                    is_signed: true,
1222                } => match bits {
1223                    u8::BITS => vk::Format::R8_SINT,
1224                    u16::BITS => vk::Format::R16_SINT,
1225                    u32::BITS => vk::Format::R32_SINT,
1226                    u64::BITS => vk::Format::R64_SINT,
1227                    _ => unimplemented!("{bits}-bit signed integer"),
1228                },
1229                _ => unimplemented!("{ty:?}"),
1230            }
1231        }
1232
1233        fn vector_format(ty: &VectorType) -> vk::Format {
1234            match *ty {
1235                VectorType {
1236                    scalar_ty: ScalarType::Float { bits },
1237                    nscalar,
1238                } => match (bits, nscalar) {
1239                    (u8::BITS, 2) => vk::Format::R8G8_SNORM,
1240                    (u8::BITS, 3) => vk::Format::R8G8B8_SNORM,
1241                    (u8::BITS, 4) => vk::Format::R8G8B8A8_SNORM,
1242                    (u16::BITS, 2) => vk::Format::R16G16_SFLOAT,
1243                    (u16::BITS, 3) => vk::Format::R16G16B16_SFLOAT,
1244                    (u16::BITS, 4) => vk::Format::R16G16B16A16_SFLOAT,
1245                    (u32::BITS, 2) => vk::Format::R32G32_SFLOAT,
1246                    (u32::BITS, 3) => vk::Format::R32G32B32_SFLOAT,
1247                    (u32::BITS, 4) => vk::Format::R32G32B32A32_SFLOAT,
1248                    (u64::BITS, 2) => vk::Format::R64G64_SFLOAT,
1249                    (u64::BITS, 3) => vk::Format::R64G64B64_SFLOAT,
1250                    (u64::BITS, 4) => vk::Format::R64G64B64A64_SFLOAT,
1251                    _ => unimplemented!("{bits}-bit vec{nscalar} float"),
1252                },
1253                VectorType {
1254                    scalar_ty:
1255                        ScalarType::Integer {
1256                            bits,
1257                            is_signed: false,
1258                        },
1259                    nscalar,
1260                } => match (bits, nscalar) {
1261                    (u8::BITS, 2) => vk::Format::R8G8_UINT,
1262                    (u8::BITS, 3) => vk::Format::R8G8B8_UINT,
1263                    (u8::BITS, 4) => vk::Format::R8G8B8A8_UINT,
1264                    (u16::BITS, 2) => vk::Format::R16G16_UINT,
1265                    (u16::BITS, 3) => vk::Format::R16G16B16_UINT,
1266                    (u16::BITS, 4) => vk::Format::R16G16B16A16_UINT,
1267                    (u32::BITS, 2) => vk::Format::R32G32_UINT,
1268                    (u32::BITS, 3) => vk::Format::R32G32B32_UINT,
1269                    (u32::BITS, 4) => vk::Format::R32G32B32A32_UINT,
1270                    (u64::BITS, 2) => vk::Format::R64G64_UINT,
1271                    (u64::BITS, 3) => vk::Format::R64G64B64_UINT,
1272                    (u64::BITS, 4) => vk::Format::R64G64B64A64_UINT,
1273                    _ => unimplemented!("{bits}-bit vec{nscalar} unsigned integer"),
1274                },
1275                VectorType {
1276                    scalar_ty:
1277                        ScalarType::Integer {
1278                            bits,
1279                            is_signed: true,
1280                        },
1281                    nscalar,
1282                } => match (bits, nscalar) {
1283                    (u8::BITS, 2) => vk::Format::R8G8_SINT,
1284                    (u8::BITS, 3) => vk::Format::R8G8B8_SINT,
1285                    (u8::BITS, 4) => vk::Format::R8G8B8A8_SINT,
1286                    (u16::BITS, 2) => vk::Format::R16G16_SINT,
1287                    (u16::BITS, 3) => vk::Format::R16G16B16_SINT,
1288                    (u16::BITS, 4) => vk::Format::R16G16B16A16_SINT,
1289                    (u32::BITS, 2) => vk::Format::R32G32_SINT,
1290                    (u32::BITS, 3) => vk::Format::R32G32B32_SINT,
1291                    (u32::BITS, 4) => vk::Format::R32G32B32A32_SINT,
1292                    (u64::BITS, 2) => vk::Format::R64G64_SINT,
1293                    (u64::BITS, 3) => vk::Format::R64G64B64_SINT,
1294                    (u64::BITS, 4) => vk::Format::R64G64B64A64_SINT,
1295                    _ => unimplemented!("{bits}-bit vec{nscalar} signed integer"),
1296                },
1297                _ => unimplemented!("{ty:?}"),
1298            }
1299        }
1300
1301        let mut input_rates_strides = HashMap::new();
1302        let mut vertex_attribute_descriptions = vec![];
1303
1304        for (name, location, ty) in self.entry_point.vars.iter().filter_map(|var| match var {
1305            Variable::Input { name, location, ty } => Some((name, location, ty)),
1306            _ => None,
1307        }) {
1308            let (binding, guessed_rate) = name
1309                .as_ref()
1310                .filter(|name| name.contains("_ibind") || name.contains("_vbind"))
1311                .map(|name| {
1312                    let binding = name[name.rfind("bind").unwrap()..]
1313                        .parse()
1314                        .unwrap_or_default();
1315                    let rate = if name.contains("_ibind") {
1316                        vk::VertexInputRate::INSTANCE
1317                    } else {
1318                        vk::VertexInputRate::VERTEX
1319                    };
1320
1321                    (binding, rate)
1322                })
1323                .unwrap_or_default();
1324            let (location, _) = location.into_inner();
1325            if let Some((input_rate, _)) = input_rates_strides.get(&binding) {
1326                assert_eq!(*input_rate, guessed_rate);
1327            }
1328
1329            let byte_stride = ty.nbyte().unwrap_or_default() as u32;
1330            let (input_rate, stride) = input_rates_strides.entry(binding).or_default();
1331            *input_rate = guessed_rate;
1332            *stride += byte_stride;
1333
1334            //trace!("{location} {:?} is {byte_stride} bytes", name);
1335
1336            vertex_attribute_descriptions.push(vk::VertexInputAttributeDescription {
1337                location,
1338                binding,
1339                format: match ty {
1340                    Type::Scalar(ty) => scalar_format(ty),
1341                    Type::Vector(ty) => vector_format(ty),
1342                    _ => unimplemented!("{:?}", ty),
1343                },
1344                offset: byte_stride, // Figured out below - this data is iter'd in an unknown order
1345            });
1346        }
1347
1348        vertex_attribute_descriptions.sort_unstable_by(|lhs, rhs| {
1349            let binding = lhs.binding.cmp(&rhs.binding);
1350            if binding.is_lt() {
1351                return binding;
1352            }
1353
1354            lhs.location.cmp(&rhs.location)
1355        });
1356
1357        let mut offset = 0;
1358        let mut offset_binding = 0;
1359
1360        for vertex_attribute_description in &mut vertex_attribute_descriptions {
1361            if vertex_attribute_description.binding != offset_binding {
1362                offset_binding = vertex_attribute_description.binding;
1363                offset = 0;
1364            }
1365
1366            let stride = vertex_attribute_description.offset;
1367            vertex_attribute_description.offset = offset;
1368            offset += stride;
1369
1370            debug!(
1371                "vertex attribute {}.{}: {:?} (offset={})",
1372                vertex_attribute_description.binding,
1373                vertex_attribute_description.location,
1374                vertex_attribute_description.format,
1375                vertex_attribute_description.offset,
1376            );
1377        }
1378
1379        let mut vertex_binding_descriptions = vec![];
1380        for (binding, (input_rate, stride)) in input_rates_strides.into_iter() {
1381            vertex_binding_descriptions.push(vk::VertexInputBindingDescription {
1382                binding,
1383                input_rate,
1384                stride,
1385            });
1386        }
1387
1388        VertexInputState {
1389            vertex_attribute_descriptions,
1390            vertex_binding_descriptions,
1391        }
1392    }
1393}
1394
1395impl Debug for Shader {
1396    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1397        // We don't want the default formatter bc vec u8
1398        // TODO: Better output message
1399        f.write_str("Shader")
1400    }
1401}
1402
1403impl From<ShaderBuilder> for Shader {
1404    fn from(shader: ShaderBuilder) -> Self {
1405        shader.build()
1406    }
1407}
1408
1409// HACK: https://github.com/colin-kiegel/rust-derive-builder/issues/56
1410impl ShaderBuilder {
1411    /// Specifies a shader with the given `stage` and shader code values.
1412    pub fn new(stage: vk::ShaderStageFlags, spirv: Vec<u8>) -> Self {
1413        Self::default().stage(stage).spirv(spirv)
1414    }
1415
1416    /// Builds a new `Shader`.
1417    pub fn build(mut self) -> Shader {
1418        let entry_name = self.entry_name.as_deref().unwrap_or("main");
1419        self.entry_point = Some(
1420            Shader::reflect_entry_point(
1421                entry_name,
1422                self.spirv.as_deref().unwrap(),
1423                self.specialization_info
1424                    .as_ref()
1425                    .map(|opt| opt.as_ref())
1426                    .unwrap_or_default(),
1427            )
1428            .unwrap_or_else(|_| panic!("invalid shader code for entry name \'{entry_name}\'")),
1429        );
1430
1431        self.fallible_build()
1432            .expect("All required fields set at initialization")
1433    }
1434
1435    /// Specifies a manually-defined image sampler.
1436    ///
1437    /// Sampled images, by default, use reflection to automatically assign image samplers. Each
1438    /// sampled image may use a suffix such as `_llr` or `_nne` for common linear/linear repeat or
1439    /// nearest/nearest clamp-to-edge samplers, respectively.
1440    ///
1441    /// See the [main documentation] for more information about automatic image samplers.
1442    ///
1443    /// Descriptor bindings may be specified as `(1, 2)` for descriptor set index `1` and binding
1444    /// index `2`, or if the descriptor set index is `0` simply specify `2` for the same case.
1445    ///
1446    /// _NOTE:_ When defining image samplers which are used in multiple stages of a single pipeline
1447    /// you must only call this function on one of the shader stages, it does not matter which one.
1448    ///
1449    /// # Panics
1450    ///
1451    /// Panics if two shader stages of the same pipeline define individual calls to `image_sampler`.
1452    ///
1453    /// [main documentation]: crate
1454    #[profiling::function]
1455    pub fn image_sampler(
1456        mut self,
1457        descriptor: impl Into<Descriptor>,
1458        info: impl Into<SamplerInfo>,
1459    ) -> Self {
1460        let descriptor = descriptor.into();
1461        let info = info.into();
1462
1463        if self.image_samplers.is_none() {
1464            self.image_samplers = Some(Default::default());
1465        }
1466
1467        self.image_samplers
1468            .as_mut()
1469            .unwrap()
1470            .insert(descriptor, info);
1471
1472        self
1473    }
1474
1475    /// Specifies a manually-defined vertex input layout.
1476    ///
1477    /// The vertex input layout, by default, uses reflection to automatically define vertex binding
1478    /// and attribute descriptions. Each vertex location is inferred to have 32-bit channels and be
1479    /// tightly packed in the vertex buffer. In this mode, a location with `_ibind_0` or `_vbind3`
1480    /// suffixes is inferred to use instance-rate on vertex buffer binding `0` or vertex rate on
1481    /// binding `3`, respectively.
1482    ///
1483    /// See the [main documentation] for more information about automatic vertex input layout.
1484    ///
1485    /// [main documentation]: crate
1486    #[profiling::function]
1487    pub fn vertex_input(
1488        mut self,
1489        bindings: impl Into<Vec<vk::VertexInputBindingDescription>>,
1490        attributes: impl Into<Vec<vk::VertexInputAttributeDescription>>,
1491    ) -> Self {
1492        self.vertex_input_state = Some(Some(VertexInputState {
1493            vertex_binding_descriptions: bindings.into(),
1494            vertex_attribute_descriptions: attributes.into(),
1495        }));
1496        self
1497    }
1498}
1499
1500#[derive(Debug)]
1501struct ShaderBuilderError;
1502
1503impl From<UninitializedFieldError> for ShaderBuilderError {
1504    fn from(_: UninitializedFieldError) -> Self {
1505        Self
1506    }
1507}
1508
1509/// Trait for types which can be converted into shader code.
1510pub trait ShaderCode {
1511    /// Converts the instance into SPIR-V shader code specified as a byte array.
1512    fn into_vec(self) -> Vec<u8>;
1513}
1514
1515impl ShaderCode for &[u8] {
1516    fn into_vec(self) -> Vec<u8> {
1517        debug_assert_eq!(self.len() % 4, 0, "invalid spir-v code");
1518
1519        self.to_vec()
1520    }
1521}
1522
1523impl ShaderCode for &[u32] {
1524    fn into_vec(self) -> Vec<u8> {
1525        pub fn into_u8_slice<T>(t: &[T]) -> &[u8]
1526        where
1527            T: Sized,
1528        {
1529            use std::slice::from_raw_parts;
1530
1531            unsafe { from_raw_parts(t.as_ptr() as *const _, size_of_val(t)) }
1532        }
1533
1534        into_u8_slice(self).into_vec()
1535    }
1536}
1537
1538impl ShaderCode for Vec<u8> {
1539    fn into_vec(self) -> Vec<u8> {
1540        debug_assert_eq!(self.len() % 4, 0, "invalid spir-v code");
1541
1542        self
1543    }
1544}
1545
1546impl ShaderCode for Vec<u32> {
1547    fn into_vec(self) -> Vec<u8> {
1548        self.as_slice().into_vec()
1549    }
1550}
1551
1552/// Describes specialized constant values.
1553#[derive(Clone, Debug)]
1554pub struct SpecializationInfo {
1555    /// A buffer of data which holds the constant values.
1556    pub data: Vec<u8>,
1557
1558    /// Mapping of locations within the constant value data which describe each individual constant.
1559    pub map_entries: Vec<vk::SpecializationMapEntry>,
1560}
1561
1562impl SpecializationInfo {
1563    /// Constructs a new `SpecializationInfo`.
1564    pub fn new(
1565        map_entries: impl Into<Vec<vk::SpecializationMapEntry>>,
1566        data: impl Into<Vec<u8>>,
1567    ) -> Self {
1568        Self {
1569            data: data.into(),
1570            map_entries: map_entries.into(),
1571        }
1572    }
1573}
1574
1575#[cfg(test)]
1576mod tests {
1577    use super::*;
1578
1579    type Info = SamplerInfo;
1580    type Builder = SamplerInfoBuilder;
1581
1582    #[test]
1583    pub fn sampler_info() {
1584        let info = Info::default();
1585        let builder = info.to_builder().build();
1586
1587        assert_eq!(info, builder);
1588    }
1589
1590    #[test]
1591    pub fn sampler_info_builder() {
1592        let info = Info::default();
1593        let builder = Builder::default().build();
1594
1595        assert_eq!(info, builder);
1596    }
1597}