Skip to main content

vk_graph/driver/
shader.rs

1//! Shader resource types
2
3use {
4    super::{DescriptorSetLayout, DriverError, VertexInputState, device::Device},
5    ash::vk,
6    derive_builder::{Builder, UninitializedFieldError},
7    log::{debug, error, trace, warn},
8    ordered_float::OrderedFloat,
9    spirq::{
10        ReflectConfig,
11        entry_point::EntryPoint,
12        parse::SpirvBinary,
13        spirv::ExecutionModel,
14        ty::{DescriptorType, ScalarType, Type, VectorType},
15        var::Variable,
16    },
17    std::{
18        collections::{BTreeMap, HashMap},
19        fmt::{Debug, Formatter},
20        iter::repeat_n,
21        ops::Deref,
22        thread::panicking,
23    },
24};
25
26#[allow(deprecated)]
27#[deprecated = "use SpecializationMap struct"]
28#[doc(hidden)]
29pub type SpecializationInfo = self::deprecated::SpecializationInfo;
30
31pub(crate) type DescriptorBindingMap = HashMap<Descriptor, (DescriptorInfo, vk::ShaderStageFlags)>;
32
33#[profiling::function]
34fn guess_immutable_sampler(binding_name: &str) -> SamplerInfo {
35    const INVALID_ERR: &str = "Invalid sampler specification";
36
37    let (texel_filter, mipmap_mode, address_modes) = if binding_name.contains("_sampler_") {
38        let spec = &binding_name[binding_name.len() - 3..];
39        let texel_filter = match &spec[0..1] {
40            "n" => vk::Filter::NEAREST,
41            "l" => vk::Filter::LINEAR,
42            _ => panic!("{INVALID_ERR}: {}", &spec[0..1]),
43        };
44
45        let mipmap_mode = match &spec[1..2] {
46            "n" => vk::SamplerMipmapMode::NEAREST,
47            "l" => vk::SamplerMipmapMode::LINEAR,
48            _ => panic!("{INVALID_ERR}: {}", &spec[1..2]),
49        };
50
51        let address_modes = match &spec[2..3] {
52            "b" => vk::SamplerAddressMode::CLAMP_TO_BORDER,
53            "e" => vk::SamplerAddressMode::CLAMP_TO_EDGE,
54            "m" => vk::SamplerAddressMode::MIRRORED_REPEAT,
55            "r" => vk::SamplerAddressMode::REPEAT,
56            _ => panic!("{INVALID_ERR}: {}", &spec[2..3]),
57        };
58
59        (texel_filter, mipmap_mode, address_modes)
60    } else {
61        debug!("image binding {binding_name} using default sampler");
62
63        (
64            vk::Filter::LINEAR,
65            vk::SamplerMipmapMode::LINEAR,
66            vk::SamplerAddressMode::REPEAT,
67        )
68    };
69    let anisotropy_enable = texel_filter == vk::Filter::LINEAR;
70    let mut info = SamplerInfoBuilder::default()
71        .mag_filter(texel_filter)
72        .min_filter(texel_filter)
73        .mipmap_mode(mipmap_mode)
74        .address_mode_u(address_modes)
75        .address_mode_v(address_modes)
76        .address_mode_w(address_modes)
77        .max_lod(vk::LOD_CLAMP_NONE)
78        .anisotropy_enable(anisotropy_enable);
79
80    if anisotropy_enable {
81        info = info.max_anisotropy(16.0);
82    }
83
84    info.build()
85}
86
87/// Tuple of descriptor set index and binding index.
88///
89/// This is a generic representation of the descriptor binding point within the shader and not a
90/// bound descriptor reference.
91#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
92pub struct Descriptor {
93    /// Descriptor set index
94    pub set: u32,
95
96    /// Descriptor binding index
97    pub binding: u32,
98}
99
100impl From<u32> for Descriptor {
101    fn from(binding: u32) -> Self {
102        Self { set: 0, binding }
103    }
104}
105
106impl From<(u32, u32)> for Descriptor {
107    fn from((set, binding): (u32, u32)) -> Self {
108        Self { set, binding }
109    }
110}
111
112#[derive(Clone, Copy, Debug)]
113pub(crate) enum DescriptorInfo {
114    AccelerationStructure(u32),
115    CombinedImageSampler(u32, SamplerInfo, bool), //count, sampler, is-manually-defined?
116    InputAttachment(u32, u32),                    //count, input index,
117    SampledImage(u32),
118    Sampler(u32, SamplerInfo, bool), //count, sampler, is-manually-defined?
119    StorageBuffer(u32),
120    StorageImage(u32),
121    StorageTexelBuffer(u32),
122    UniformBuffer(u32),
123    UniformTexelBuffer(u32),
124}
125
126impl DescriptorInfo {
127    pub fn binding_count(self) -> u32 {
128        match self {
129            Self::AccelerationStructure(binding_count) => binding_count,
130            Self::CombinedImageSampler(binding_count, ..) => binding_count,
131            Self::InputAttachment(binding_count, _) => binding_count,
132            Self::SampledImage(binding_count) => binding_count,
133            Self::Sampler(binding_count, ..) => binding_count,
134            Self::StorageBuffer(binding_count) => binding_count,
135            Self::StorageImage(binding_count) => binding_count,
136            Self::StorageTexelBuffer(binding_count) => binding_count,
137            Self::UniformBuffer(binding_count) => binding_count,
138            Self::UniformTexelBuffer(binding_count) => binding_count,
139        }
140    }
141
142    pub fn descriptor_type(self) -> vk::DescriptorType {
143        match self {
144            Self::AccelerationStructure(_) => vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
145            Self::CombinedImageSampler(..) => vk::DescriptorType::COMBINED_IMAGE_SAMPLER,
146            Self::InputAttachment(..) => vk::DescriptorType::INPUT_ATTACHMENT,
147            Self::SampledImage(_) => vk::DescriptorType::SAMPLED_IMAGE,
148            Self::Sampler(..) => vk::DescriptorType::SAMPLER,
149            Self::StorageBuffer(_) => vk::DescriptorType::STORAGE_BUFFER,
150            Self::StorageImage(_) => vk::DescriptorType::STORAGE_IMAGE,
151            Self::StorageTexelBuffer(_) => vk::DescriptorType::STORAGE_TEXEL_BUFFER,
152            Self::UniformBuffer(_) => vk::DescriptorType::UNIFORM_BUFFER,
153            Self::UniformTexelBuffer(_) => vk::DescriptorType::UNIFORM_TEXEL_BUFFER,
154        }
155    }
156
157    fn sampler_info(self) -> Option<SamplerInfo> {
158        match self {
159            Self::CombinedImageSampler(_, sampler_info, _) | Self::Sampler(_, sampler_info, _) => {
160                Some(sampler_info)
161            }
162            _ => None,
163        }
164    }
165
166    pub fn set_binding_count(&mut self, binding_count: u32) {
167        *match self {
168            Self::AccelerationStructure(binding_count) => binding_count,
169            Self::CombinedImageSampler(binding_count, ..) => binding_count,
170            Self::InputAttachment(binding_count, _) => binding_count,
171            Self::SampledImage(binding_count) => binding_count,
172            Self::Sampler(binding_count, ..) => binding_count,
173            Self::StorageBuffer(binding_count) => binding_count,
174            Self::StorageImage(binding_count) => binding_count,
175            Self::StorageTexelBuffer(binding_count) => binding_count,
176            Self::UniformBuffer(binding_count) => binding_count,
177            Self::UniformTexelBuffer(binding_count) => binding_count,
178        } = binding_count;
179    }
180}
181
182#[derive(Debug)]
183pub(crate) struct PipelineDescriptorInfo {
184    pub layouts: BTreeMap<u32, DescriptorSetLayout>,
185    pub pool_sizes: HashMap<u32, HashMap<vk::DescriptorType, u32>>,
186
187    #[allow(dead_code)]
188    samplers: Box<[Sampler]>,
189}
190
191impl PipelineDescriptorInfo {
192    #[profiling::function]
193    pub fn create(
194        device: &Device,
195        descriptor_bindings: &DescriptorBindingMap,
196    ) -> Result<Self, DriverError> {
197        let descriptor_set_count = descriptor_bindings
198            .keys()
199            .map(|descriptor| descriptor.set)
200            .max()
201            .map(|set| set + 1)
202            .unwrap_or_default();
203        let mut layouts = BTreeMap::new();
204        let mut pool_sizes = HashMap::new();
205
206        //trace!("descriptor_bindings: {:#?}", &descriptor_bindings);
207
208        let mut sampler_info_binding_count = HashMap::<_, u32>::with_capacity(
209            descriptor_bindings
210                .values()
211                .filter(|(descriptor_info, _)| descriptor_info.sampler_info().is_some())
212                .count(),
213        );
214
215        for (sampler_info, binding_count) in
216            descriptor_bindings
217                .values()
218                .filter_map(|(descriptor_info, _)| {
219                    descriptor_info
220                        .sampler_info()
221                        .map(|sampler_info| (sampler_info, descriptor_info.binding_count()))
222                })
223        {
224            sampler_info_binding_count
225                .entry(sampler_info)
226                .and_modify(|sampler_info_binding_count| {
227                    *sampler_info_binding_count = binding_count.max(*sampler_info_binding_count);
228                })
229                .or_insert(binding_count);
230        }
231
232        let mut samplers = sampler_info_binding_count
233            .keys()
234            .copied()
235            .map(|sampler_info| {
236                Sampler::create(device, sampler_info).map(|sampler| (sampler_info, sampler))
237            })
238            .collect::<Result<HashMap<_, _>, _>>()?;
239        let immutable_samplers = sampler_info_binding_count
240            .iter()
241            .map(|(sampler_info, &binding_count)| {
242                (
243                    *sampler_info,
244                    repeat_n(*samplers[sampler_info], binding_count as _).collect::<Box<_>>(),
245                )
246            })
247            .collect::<HashMap<_, _>>();
248
249        for descriptor_set_idx in 0..descriptor_set_count {
250            let mut binding_counts = HashMap::<vk::DescriptorType, u32>::new();
251            let mut bindings = vec![];
252
253            for (descriptor, (descriptor_info, stage_flags)) in descriptor_bindings
254                .iter()
255                .filter(|(descriptor, _)| descriptor.set == descriptor_set_idx)
256            {
257                let descriptor_ty = descriptor_info.descriptor_type();
258                *binding_counts.entry(descriptor_ty).or_default() +=
259                    descriptor_info.binding_count();
260                let mut binding = vk::DescriptorSetLayoutBinding::default()
261                    .binding(descriptor.binding)
262                    .descriptor_count(descriptor_info.binding_count())
263                    .descriptor_type(descriptor_ty)
264                    .stage_flags(*stage_flags);
265
266                if let Some(immutable_samplers) =
267                    descriptor_info.sampler_info().map(|sampler_info| {
268                        &immutable_samplers[&sampler_info]
269                            [0..descriptor_info.binding_count() as usize]
270                    })
271                {
272                    binding = binding.immutable_samplers(immutable_samplers);
273                }
274
275                bindings.push(binding);
276            }
277
278            let pool_size = pool_sizes
279                .entry(descriptor_set_idx)
280                .or_insert_with(HashMap::new);
281
282            for (descriptor_ty, binding_count) in binding_counts.into_iter() {
283                *pool_size.entry(descriptor_ty).or_default() += binding_count;
284            }
285
286            //trace!("bindings: {:#?}", &bindings);
287
288            let mut create_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
289
290            // The bindless flags have to be created for every descriptor set layout binding.
291            // See [`VkDescriptorSetLayoutBindingFlagsCreateInfo`](https://registry.khronos.org/vulkan/specs/latest/man/html/VkDescriptorSetLayoutBindingFlagsCreateInfo.html).
292            // Maybe using one vector and updating it would be more efficient.
293            let bindless_flags = vec![vk::DescriptorBindingFlags::PARTIALLY_BOUND; bindings.len()];
294            let mut bindless_flags = if device
295                .physical_device
296                .features_v1_2
297                .descriptor_binding_partially_bound
298            {
299                let bindless_flags = vk::DescriptorSetLayoutBindingFlagsCreateInfo::default()
300                    .binding_flags(&bindless_flags);
301                Some(bindless_flags)
302            } else {
303                None
304            };
305
306            if let Some(bindless_flags) = bindless_flags.as_mut() {
307                create_info = create_info.push_next(bindless_flags);
308            }
309
310            layouts.insert(
311                descriptor_set_idx,
312                DescriptorSetLayout::create(device, &create_info)?,
313            );
314        }
315
316        let samplers = samplers
317            .drain()
318            .map(|(_, sampler)| sampler)
319            .collect::<Box<_>>();
320
321        //trace!("layouts {:#?}", &layouts);
322        // trace!("pool_sizes {:#?}", &pool_sizes);
323
324        Ok(Self {
325            layouts,
326            pool_sizes,
327            samplers,
328        })
329    }
330}
331
332pub(crate) struct Sampler {
333    device: Device,
334    sampler: vk::Sampler,
335}
336
337impl Sampler {
338    #[profiling::function]
339    pub fn create(device: &Device, info: impl Into<SamplerInfo>) -> Result<Self, DriverError> {
340        let device = device.clone();
341        let info = info.into();
342
343        let sampler = unsafe {
344            device
345                .create_sampler(
346                    &vk::SamplerCreateInfo::default()
347                        .flags(info.flags)
348                        .mag_filter(info.mag_filter)
349                        .min_filter(info.min_filter)
350                        .mipmap_mode(info.mipmap_mode)
351                        .address_mode_u(info.address_mode_u)
352                        .address_mode_v(info.address_mode_v)
353                        .address_mode_w(info.address_mode_w)
354                        .mip_lod_bias(info.mip_lod_bias.0)
355                        .anisotropy_enable(info.anisotropy_enable)
356                        .max_anisotropy(info.max_anisotropy.0)
357                        .compare_enable(info.compare_enable)
358                        .compare_op(info.compare_op)
359                        .min_lod(info.min_lod.0)
360                        .max_lod(info.max_lod.0)
361                        .border_color(info.border_color)
362                        .unnormalized_coordinates(info.unnormalized_coordinates)
363                        .push_next(
364                            &mut vk::SamplerReductionModeCreateInfo::default()
365                                .reduction_mode(info.reduction_mode),
366                        ),
367                    None,
368                )
369                .map_err(|err| match err {
370                    vk::Result::ERROR_OUT_OF_HOST_MEMORY
371                    | vk::Result::ERROR_OUT_OF_DEVICE_MEMORY => {
372                        warn!("unable to create sampler: {err}");
373                        DriverError::OutOfMemory
374                    }
375                    _ => {
376                        warn!("unsupported sampler creation: {err}");
377                        DriverError::Unsupported
378                    }
379                })?
380        };
381
382        Ok(Self { device, sampler })
383    }
384}
385
386impl Debug for Sampler {
387    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
388        write!(f, "{:?}", self.sampler)
389    }
390}
391
392impl Deref for Sampler {
393    type Target = vk::Sampler;
394
395    fn deref(&self) -> &Self::Target {
396        &self.sampler
397    }
398}
399
400impl Drop for Sampler {
401    #[profiling::function]
402    fn drop(&mut self) {
403        if panicking() {
404            return;
405        }
406
407        unsafe {
408            self.device.destroy_sampler(self.sampler, None);
409        }
410    }
411}
412
413/// Information used to create a [`vk::Sampler`] instance.
414#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
415#[builder(
416    build_fn(private, name = "fallible_build", error = "SamplerInfoBuilderError"),
417    derive(Clone, Copy, Debug),
418    pattern = "owned"
419)]
420pub struct SamplerInfo {
421    /// Bitmask specifying additional parameters of a sampler.
422    #[builder(default)]
423    pub flags: vk::SamplerCreateFlags,
424
425    /// Specify the magnification filter to apply to texture lookups.
426    ///
427    /// The default value is [`vk::Filter::NEAREST`]
428    #[builder(default)]
429    pub mag_filter: vk::Filter,
430
431    /// Specify the minification filter to apply to texture lookups.
432    ///
433    /// The default value is [`vk::Filter::NEAREST`]
434    #[builder(default)]
435    pub min_filter: vk::Filter,
436
437    /// A value specifying the mipmap filter to apply to lookups.
438    ///
439    /// The default value is [`vk::SamplerMipmapMode::NEAREST`]
440    #[builder(default)]
441    pub mipmap_mode: vk::SamplerMipmapMode,
442
443    /// A value specifying the addressing mode for U coordinates outside `[0, 1)`.
444    ///
445    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
446    #[builder(default)]
447    pub address_mode_u: vk::SamplerAddressMode,
448
449    /// A value specifying the addressing mode for V coordinates outside `[0, 1)`.
450    ///
451    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
452    #[builder(default)]
453    pub address_mode_v: vk::SamplerAddressMode,
454
455    /// A value specifying the addressing mode for W coordinates outside `[0, 1)`.
456    ///
457    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
458    #[builder(default)]
459    pub address_mode_w: vk::SamplerAddressMode,
460
461    /// The bias to be added to mipmap LOD calculation and bias provided by image sampling
462    /// functions in SPIR-V, as described in the
463    /// See [`VkSamplerCreateInfo`](https://registry.khronos.org/vulkan/specs/latest/man/html/VkSamplerCreateInfo.html).
464    /// section.
465    #[builder(default, setter(into))]
466    pub mip_lod_bias: OrderedFloat<f32>,
467
468    /// Enables anisotropic filtering, as described in the
469    /// See [`VkSamplerCreateInfo`](https://registry.khronos.org/vulkan/specs/latest/man/html/VkSamplerCreateInfo.html).
470    /// section
471    #[builder(default)]
472    pub anisotropy_enable: bool,
473
474    /// The anisotropy value clamp used by the sampler when `anisotropy_enable` is `true`.
475    ///
476    /// If `anisotropy_enable` is `false`, max_anisotropy is ignored.
477    #[builder(default, setter(into))]
478    pub max_anisotropy: OrderedFloat<f32>,
479
480    /// Enables comparison against a reference value during lookups.
481    #[builder(default)]
482    pub compare_enable: bool,
483
484    /// Specifies the comparison operator to apply to fetched data before filtering as described in
485    /// the
486    /// See [`VkSamplerCreateInfo`](https://registry.khronos.org/vulkan/specs/latest/man/html/VkSamplerCreateInfo.html).
487    /// section.
488    #[builder(default)]
489    pub compare_op: vk::CompareOp,
490
491    /// Used to clamp the
492    /// See [`VkSamplerCreateInfo`](https://registry.khronos.org/vulkan/specs/latest/man/html/VkSamplerCreateInfo.html).
493    #[builder(default, setter(into))]
494    pub min_lod: OrderedFloat<f32>,
495
496    /// Used to clamp the
497    /// See [`VkSamplerCreateInfo`](https://registry.khronos.org/vulkan/specs/latest/man/html/VkSamplerCreateInfo.html).
498    ///
499    /// To avoid clamping the maximum value, set maxLod to the constant `vk::LOD_CLAMP_NONE`.
500    #[builder(default, setter(into))]
501    pub max_lod: OrderedFloat<f32>,
502
503    /// Secifies the predefined border color to use.
504    ///
505    /// The default value is [`vk::BorderColor::FLOAT_TRANSPARENT_BLACK`]
506    #[builder(default)]
507    pub border_color: vk::BorderColor,
508
509    /// Controls whether to use unnormalized or normalized texel coordinates to address texels of
510    /// the image.
511    ///
512    /// When set to `true`, the range of the image coordinates used to lookup the texel is in the
513    /// range of zero to the image size in each dimension.
514    ///
515    /// When set to `false` the range of image coordinates is zero to one.
516    ///
517    /// See
518    /// See [`VkSamplerCreateInfo`](https://registry.khronos.org/vulkan/specs/latest/man/html/VkSamplerCreateInfo.html).
519    #[builder(default)]
520    pub unnormalized_coordinates: bool,
521
522    /// Specifies sampler reduction mode.
523    ///
524    /// Setting magnification filter ([`mag_filter`](Self::mag_filter)) to [`vk::Filter::NEAREST`]
525    /// disables sampler reduction mode.
526    ///
527    /// The default value is [`vk::SamplerReductionMode::WEIGHTED_AVERAGE`]
528    ///
529    /// See
530    /// See [`VkSamplerCreateInfo`](https://registry.khronos.org/vulkan/specs/latest/man/html/VkSamplerCreateInfo.html).
531    #[builder(default)]
532    pub reduction_mode: vk::SamplerReductionMode,
533}
534
535impl SamplerInfo {
536    /// Default sampler information with `mag_filter`, `min_filter` and `mipmap_mode` set to linear.
537    pub const LINEAR: SamplerInfoBuilder = SamplerInfoBuilder {
538        flags: None,
539        mag_filter: Some(vk::Filter::LINEAR),
540        min_filter: Some(vk::Filter::LINEAR),
541        mipmap_mode: Some(vk::SamplerMipmapMode::LINEAR),
542        address_mode_u: None,
543        address_mode_v: None,
544        address_mode_w: None,
545        mip_lod_bias: None,
546        anisotropy_enable: None,
547        max_anisotropy: None,
548        compare_enable: None,
549        compare_op: None,
550        min_lod: None,
551        max_lod: None,
552        border_color: None,
553        unnormalized_coordinates: None,
554        reduction_mode: None,
555    };
556
557    /// Default sampler information with `mag_filter`, `min_filter` and `mipmap_mode` set to
558    /// nearest.
559    pub const NEAREST: SamplerInfoBuilder = SamplerInfoBuilder {
560        flags: None,
561        mag_filter: Some(vk::Filter::NEAREST),
562        min_filter: Some(vk::Filter::NEAREST),
563        mipmap_mode: Some(vk::SamplerMipmapMode::NEAREST),
564        address_mode_u: None,
565        address_mode_v: None,
566        address_mode_w: None,
567        mip_lod_bias: None,
568        anisotropy_enable: None,
569        max_anisotropy: None,
570        compare_enable: None,
571        compare_op: None,
572        min_lod: None,
573        max_lod: None,
574        border_color: None,
575        unnormalized_coordinates: None,
576        reduction_mode: None,
577    };
578
579    /// Creates a default `SamplerInfoBuilder`.
580    #[allow(clippy::new_ret_no_self)]
581    #[deprecated = "Use SamplerInfo::default()"]
582    #[doc(hidden)]
583    pub fn new() -> SamplerInfoBuilder {
584        Self::default().into_builder()
585    }
586
587    /// Creates a default `SamplerInfoBuilder`.
588    pub fn builder() -> SamplerInfoBuilder {
589        Default::default()
590    }
591
592    /// Converts a `SamplerInfo` into a `SamplerInfoBuilder`.
593    pub fn into_builder(self) -> SamplerInfoBuilder {
594        SamplerInfoBuilder {
595            flags: Some(self.flags),
596            mag_filter: Some(self.mag_filter),
597            min_filter: Some(self.min_filter),
598            mipmap_mode: Some(self.mipmap_mode),
599            address_mode_u: Some(self.address_mode_u),
600            address_mode_v: Some(self.address_mode_v),
601            address_mode_w: Some(self.address_mode_w),
602            mip_lod_bias: Some(self.mip_lod_bias),
603            anisotropy_enable: Some(self.anisotropy_enable),
604            max_anisotropy: Some(self.max_anisotropy),
605            compare_enable: Some(self.compare_enable),
606            compare_op: Some(self.compare_op),
607            min_lod: Some(self.min_lod),
608            max_lod: Some(self.max_lod),
609            border_color: Some(self.border_color),
610            unnormalized_coordinates: Some(self.unnormalized_coordinates),
611            reduction_mode: Some(self.reduction_mode),
612        }
613    }
614
615    #[deprecated = "use into_builder function"]
616    #[doc(hidden)]
617    pub fn to_builder(self) -> SamplerInfoBuilder {
618        self.into_builder()
619    }
620}
621
622impl Default for SamplerInfo {
623    fn default() -> Self {
624        Self {
625            flags: vk::SamplerCreateFlags::empty(),
626            mag_filter: vk::Filter::NEAREST,
627            min_filter: vk::Filter::NEAREST,
628            mipmap_mode: vk::SamplerMipmapMode::NEAREST,
629            address_mode_u: vk::SamplerAddressMode::REPEAT,
630            address_mode_v: vk::SamplerAddressMode::REPEAT,
631            address_mode_w: vk::SamplerAddressMode::REPEAT,
632            mip_lod_bias: OrderedFloat(0.0),
633            anisotropy_enable: false,
634            max_anisotropy: OrderedFloat(0.0),
635            compare_enable: false,
636            compare_op: vk::CompareOp::NEVER,
637            min_lod: OrderedFloat(0.0),
638            max_lod: OrderedFloat(0.0),
639            border_color: vk::BorderColor::FLOAT_TRANSPARENT_BLACK,
640            unnormalized_coordinates: false,
641            reduction_mode: vk::SamplerReductionMode::WEIGHTED_AVERAGE,
642        }
643    }
644}
645
646impl SamplerInfoBuilder {
647    /// Builds a new `SamplerInfo`.
648    #[inline(always)]
649    pub fn build(self) -> SamplerInfo {
650        self.fallible_build().expect("invalid sampler info")
651    }
652}
653
654impl From<SamplerInfoBuilder> for SamplerInfo {
655    fn from(info: SamplerInfoBuilder) -> Self {
656        info.build()
657    }
658}
659
660#[derive(Debug)]
661struct SamplerInfoBuilderError;
662
663impl From<UninitializedFieldError> for SamplerInfoBuilderError {
664    fn from(_: UninitializedFieldError) -> Self {
665        Self
666    }
667}
668
669/// Describes a shader program which runs on some pipeline stage.
670#[allow(missing_docs)]
671#[derive(Builder, Clone)]
672#[builder(
673    build_fn(private, name = "fallible_build", error = "ShaderBuilderError"),
674    derive(Clone, Debug),
675    pattern = "owned"
676)]
677pub struct Shader {
678    /// The name of the entry point which will be executed by this shader.
679    ///
680    /// The default value is `main`.
681    #[builder(default = "\"main\".to_owned()", setter(into))]
682    pub entry_name: String,
683
684    /// Data about Vulkan specialization constants.
685    ///
686    /// # Examples
687    ///
688    /// Basic usage (GLSL):
689    ///
690    /// ```
691    /// # vk_shader_macros::glsl!(r#"
692    /// #version 460 core
693    /// #pragma shader_stage(compute)
694    ///
695    /// // Defaults to 6 if not set using Shader specialization!
696    /// layout(constant_id = 0) const uint MY_COUNT = 6;
697    ///
698    /// layout(set = 0, binding = 0) uniform sampler2D my_samplers[MY_COUNT];
699    ///
700    /// void main()
701    /// {
702    ///     // Code uses MY_COUNT number of my_samplers here
703    /// }
704    /// # "#);
705    /// ```
706    ///
707    /// ```no_run
708    /// # use std::sync::Arc;
709    /// # use ash::vk;
710    /// # use vk_graph::driver::DriverError;
711    /// # use vk_graph::driver::device::{Device, DeviceInfo};
712    /// # use vk_graph::driver::shader::{Shader, SpecializationMap};
713    /// # fn main() -> Result<(), DriverError> {
714    /// # let device = Device::new(DeviceInfo::default())?;
715    /// # let my_shader_code = [0u8; 1];
716    /// // We instead specify 42 for MY_COUNT:
717    /// let shader = Shader::new_fragment(my_shader_code.as_slice())
718    ///     .specialization(
719    ///         SpecializationMap::new(42u32.to_ne_bytes())
720    ///             .constant(0, 0, 4)
721    ///     );
722    /// # Ok(()) }
723    /// ```
724    #[builder(default, setter(strip_option))]
725    pub specialization: Option<SpecializationMap>,
726
727    /// Shader code.
728    ///
729    /// Although SPIR-V code is specified as `u32` values, this field uses `u8` in order to make
730    /// loading from file simpler. You should always have a SPIR-V code length which is a multiple
731    /// of four bytes, or an error will be returned during pipeline creation.
732    #[builder(setter(into))]
733    pub spirv: SpirvBinary,
734
735    /// The shader stage this structure applies to.
736    pub stage: vk::ShaderStageFlags,
737
738    #[builder(private)]
739    entry_point: EntryPoint,
740
741    #[builder(default, private)]
742    image_samplers: HashMap<Descriptor, SamplerInfo>,
743
744    #[builder(default, private, setter(strip_option))]
745    vertex_input_state: Option<VertexInputState>,
746}
747
748impl Shader {
749    /// Specifies a shader with the given `stage` and shader code.
750    #[allow(clippy::new_ret_no_self)]
751    pub fn new(stage: vk::ShaderStageFlags, spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
752        ShaderBuilder::default().spirv(spirv).stage(stage)
753    }
754
755    /// Creates a new ray trace shader.
756    ///
757    /// # Panics
758    ///
759    /// If the shader code is invalid or not a multiple of four bytes in length.
760    pub fn new_any_hit(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
761        Self::new(vk::ShaderStageFlags::ANY_HIT_KHR, spirv)
762    }
763
764    /// Creates a new ray trace shader.
765    ///
766    /// # Panics
767    ///
768    /// If the shader code is invalid or not a multiple of four bytes in length.
769    pub fn new_callable(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
770        Self::new(vk::ShaderStageFlags::CALLABLE_KHR, spirv)
771    }
772
773    /// Creates a new ray trace shader.
774    ///
775    /// # Panics
776    ///
777    /// If the shader code is invalid or not a multiple of four bytes in length.
778    pub fn new_closest_hit(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
779        Self::new(vk::ShaderStageFlags::CLOSEST_HIT_KHR, spirv)
780    }
781
782    /// Creates a new compute shader.
783    ///
784    /// # Panics
785    ///
786    /// If the shader code is invalid or not a multiple of four bytes in length.
787    pub fn new_compute(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
788        Self::new(vk::ShaderStageFlags::COMPUTE, spirv)
789    }
790
791    /// Creates a new fragment shader.
792    ///
793    /// # Panics
794    ///
795    /// If the shader code is invalid or not a multiple of four bytes in length.
796    pub fn new_fragment(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
797        Self::new(vk::ShaderStageFlags::FRAGMENT, spirv)
798    }
799
800    /// Creates a new geometry shader.
801    ///
802    /// # Panics
803    ///
804    /// If the shader code is invalid or not a multiple of four bytes in length.
805    pub fn new_geometry(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
806        Self::new(vk::ShaderStageFlags::GEOMETRY, spirv)
807    }
808
809    /// Creates a new ray trace shader.
810    ///
811    /// # Panics
812    ///
813    /// If the shader code is invalid or not a multiple of four bytes in length.
814    pub fn new_intersection(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
815        Self::new(vk::ShaderStageFlags::INTERSECTION_KHR, spirv)
816    }
817
818    /// Creates a new mesh shader.
819    ///
820    /// # Panics
821    ///
822    /// If the shader code is invalid.
823    pub fn new_mesh(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
824        Self::new(vk::ShaderStageFlags::MESH_EXT, spirv)
825    }
826
827    /// Creates a new ray trace shader.
828    ///
829    /// # Panics
830    ///
831    /// If the shader code is invalid or not a multiple of four bytes in length.
832    pub fn new_miss(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
833        Self::new(vk::ShaderStageFlags::MISS_KHR, spirv)
834    }
835
836    /// Creates a new ray trace shader.
837    ///
838    /// # Panics
839    ///
840    /// If the shader code is invalid or not a multiple of four bytes in length.
841    pub fn new_ray_gen(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
842        Self::new(vk::ShaderStageFlags::RAYGEN_KHR, spirv)
843    }
844
845    /// Creates a new mesh task shader.
846    ///
847    /// # Panics
848    ///
849    /// If the shader code is invalid.
850    pub fn new_task(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
851        Self::new(vk::ShaderStageFlags::TASK_EXT, spirv)
852    }
853
854    /// Creates a new tessellation control shader.
855    ///
856    /// # Panics
857    ///
858    /// If the shader code is invalid or not a multiple of four bytes in length.
859    pub fn new_tessellation_ctrl(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
860        Self::new(vk::ShaderStageFlags::TESSELLATION_CONTROL, spirv)
861    }
862
863    #[deprecated = "use new_tessellation_ctrl function"]
864    #[doc(hidden)]
865    pub fn new_tesselation_ctrl(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
866        Self::new_tessellation_ctrl(spirv)
867    }
868
869    /// Creates a new tessellation evaluation shader.
870    ///
871    /// # Panics
872    ///
873    /// If the shader code is invalid or not a multiple of four bytes in length.
874    pub fn new_tessellation_eval(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
875        Self::new(vk::ShaderStageFlags::TESSELLATION_EVALUATION, spirv)
876    }
877
878    #[deprecated = "use new_tessellation_eval function"]
879    #[doc(hidden)]
880    pub fn new_tesselation_eval(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
881        Self::new_tessellation_eval(spirv)
882    }
883
884    /// Creates a new vertex shader.
885    ///
886    /// # Panics
887    ///
888    /// If the shader code is invalid or not a multiple of four bytes in length.
889    pub fn new_vertex(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
890        Self::new(vk::ShaderStageFlags::VERTEX, spirv)
891    }
892
893    /// Returns the input and write attachments of a shader.
894    #[profiling::function]
895    pub(super) fn attachments(
896        &self,
897    ) -> (
898        impl Iterator<Item = u32> + '_,
899        impl Iterator<Item = u32> + '_,
900    ) {
901        (
902            self.entry_point.vars.iter().filter_map(|var| match var {
903                Variable::Descriptor {
904                    desc_ty: DescriptorType::InputAttachment(attachment),
905                    ..
906                } => Some(*attachment),
907                _ => None,
908            }),
909            self.entry_point.vars.iter().filter_map(|var| match var {
910                Variable::Output { location, .. } => Some(location.loc()),
911                _ => None,
912            }),
913        )
914    }
915
916    /// Creates a default `ShaderBuilder`.
917    pub fn builder() -> ShaderBuilder {
918        Default::default()
919    }
920
921    #[profiling::function]
922    pub(super) fn descriptor_bindings(&self) -> DescriptorBindingMap {
923        let mut res = DescriptorBindingMap::default();
924
925        for (name, descriptor, desc_ty, binding_count) in
926            self.entry_point.vars.iter().filter_map(|var| match var {
927                Variable::Descriptor {
928                    name,
929                    desc_bind,
930                    desc_ty,
931                    nbind,
932                    ..
933                } => Some((
934                    name,
935                    Descriptor {
936                        set: desc_bind.set(),
937                        binding: desc_bind.bind(),
938                    },
939                    desc_ty,
940                    *nbind,
941                )),
942                _ => None,
943            })
944        {
945            trace!(
946                "descriptor {}: {}.{} = {:?}[{}]",
947                name.as_deref().unwrap_or_default(),
948                descriptor.set,
949                descriptor.binding,
950                *desc_ty,
951                binding_count
952            );
953
954            let descriptor_info = match desc_ty {
955                DescriptorType::AccelStruct() => {
956                    DescriptorInfo::AccelerationStructure(binding_count)
957                }
958                DescriptorType::CombinedImageSampler() => {
959                    let (sampler_info, is_manually_defined) =
960                        self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
961
962                    DescriptorInfo::CombinedImageSampler(
963                        binding_count,
964                        sampler_info,
965                        is_manually_defined,
966                    )
967                }
968                DescriptorType::InputAttachment(attachment) => {
969                    DescriptorInfo::InputAttachment(binding_count, *attachment)
970                }
971                DescriptorType::SampledImage() => DescriptorInfo::SampledImage(binding_count),
972                DescriptorType::Sampler() => {
973                    let (sampler_info, is_manually_defined) =
974                        self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
975
976                    DescriptorInfo::Sampler(binding_count, sampler_info, is_manually_defined)
977                }
978                DescriptorType::StorageBuffer(_access_ty) => {
979                    DescriptorInfo::StorageBuffer(binding_count)
980                }
981                DescriptorType::StorageImage(_access_ty) => {
982                    DescriptorInfo::StorageImage(binding_count)
983                }
984                DescriptorType::StorageTexelBuffer(_access_ty) => {
985                    DescriptorInfo::StorageTexelBuffer(binding_count)
986                }
987                DescriptorType::UniformBuffer() => DescriptorInfo::UniformBuffer(binding_count),
988                DescriptorType::UniformTexelBuffer() => {
989                    DescriptorInfo::UniformTexelBuffer(binding_count)
990                }
991            };
992            res.insert(descriptor, (descriptor_info, self.stage));
993        }
994
995        res
996    }
997
998    /// Specifies a shader with the given shader code.
999    pub fn from_spirv(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
1000        ShaderBuilder::default().spirv(spirv)
1001    }
1002
1003    fn image_sampler(&self, descriptor: Descriptor, name: &str) -> (SamplerInfo, bool) {
1004        self.image_samplers
1005            .get(&descriptor)
1006            .copied()
1007            .map(|sampler_info| (sampler_info, true))
1008            .unwrap_or_else(|| (guess_immutable_sampler(name), false))
1009    }
1010
1011    #[profiling::function]
1012    pub(super) fn merge_descriptor_bindings(
1013        descriptor_bindings: impl IntoIterator<Item = DescriptorBindingMap>,
1014    ) -> DescriptorBindingMap {
1015        fn merge_info(lhs: &mut DescriptorInfo, rhs: DescriptorInfo) -> bool {
1016            let (lhs_count, rhs_count) = match lhs {
1017                DescriptorInfo::AccelerationStructure(lhs) => {
1018                    if let DescriptorInfo::AccelerationStructure(rhs) = rhs {
1019                        (lhs, rhs)
1020                    } else {
1021                        return false;
1022                    }
1023                }
1024                DescriptorInfo::CombinedImageSampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1025                    if let DescriptorInfo::CombinedImageSampler(
1026                        rhs,
1027                        rhs_sampler,
1028                        rhs_is_manually_defined,
1029                    ) = rhs
1030                    {
1031                        // Allow one of the samplers to be manually defined (only one!)
1032                        if *lhs_is_manually_defined && rhs_is_manually_defined {
1033                            return false;
1034                        } else if rhs_is_manually_defined {
1035                            *lhs_sampler = rhs_sampler;
1036                        }
1037
1038                        (lhs, rhs)
1039                    } else {
1040                        return false;
1041                    }
1042                }
1043                DescriptorInfo::InputAttachment(lhs, lhs_idx) => {
1044                    if let DescriptorInfo::InputAttachment(rhs, rhs_idx) = rhs {
1045                        if *lhs_idx != rhs_idx {
1046                            return false;
1047                        }
1048
1049                        (lhs, rhs)
1050                    } else {
1051                        return false;
1052                    }
1053                }
1054                DescriptorInfo::SampledImage(lhs) => {
1055                    if let DescriptorInfo::SampledImage(rhs) = rhs {
1056                        (lhs, rhs)
1057                    } else {
1058                        return false;
1059                    }
1060                }
1061                DescriptorInfo::Sampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1062                    if let DescriptorInfo::Sampler(rhs, rhs_sampler, rhs_is_manually_defined) = rhs
1063                    {
1064                        // Allow one of the samplers to be manually defined (only one!)
1065                        if *lhs_is_manually_defined && rhs_is_manually_defined {
1066                            return false;
1067                        } else if rhs_is_manually_defined {
1068                            *lhs_sampler = rhs_sampler;
1069                        }
1070
1071                        (lhs, rhs)
1072                    } else {
1073                        return false;
1074                    }
1075                }
1076                DescriptorInfo::StorageBuffer(lhs) => {
1077                    if let DescriptorInfo::StorageBuffer(rhs) = rhs {
1078                        (lhs, rhs)
1079                    } else {
1080                        return false;
1081                    }
1082                }
1083                DescriptorInfo::StorageImage(lhs) => {
1084                    if let DescriptorInfo::StorageImage(rhs) = rhs {
1085                        (lhs, rhs)
1086                    } else {
1087                        return false;
1088                    }
1089                }
1090                DescriptorInfo::StorageTexelBuffer(lhs) => {
1091                    if let DescriptorInfo::StorageTexelBuffer(rhs) = rhs {
1092                        (lhs, rhs)
1093                    } else {
1094                        return false;
1095                    }
1096                }
1097                DescriptorInfo::UniformBuffer(lhs) => {
1098                    if let DescriptorInfo::UniformBuffer(rhs) = rhs {
1099                        (lhs, rhs)
1100                    } else {
1101                        return false;
1102                    }
1103                }
1104                DescriptorInfo::UniformTexelBuffer(lhs) => {
1105                    if let DescriptorInfo::UniformTexelBuffer(rhs) = rhs {
1106                        (lhs, rhs)
1107                    } else {
1108                        return false;
1109                    }
1110                }
1111            };
1112
1113            *lhs_count = rhs_count.max(*lhs_count);
1114
1115            true
1116        }
1117
1118        #[profiling::function]
1119        fn merge_pair(src: DescriptorBindingMap, dst: &mut DescriptorBindingMap) {
1120            for (descriptor_binding, (descriptor_info, descriptor_flags)) in src.into_iter() {
1121                if let Some((existing_info, existing_flags)) = dst.get_mut(&descriptor_binding) {
1122                    if !merge_info(existing_info, descriptor_info) {
1123                        panic!("Inconsistent shader descriptors ({descriptor_binding:?})");
1124                    }
1125
1126                    *existing_flags |= descriptor_flags;
1127                } else {
1128                    dst.insert(descriptor_binding, (descriptor_info, descriptor_flags));
1129                }
1130            }
1131        }
1132
1133        let mut descriptor_bindings = descriptor_bindings.into_iter();
1134        let mut res = descriptor_bindings.next().unwrap_or_default();
1135        for descriptor_binding in descriptor_bindings {
1136            merge_pair(descriptor_binding, &mut res);
1137        }
1138
1139        res
1140    }
1141
1142    #[profiling::function]
1143    pub(super) fn push_constant_range(&self) -> Option<vk::PushConstantRange> {
1144        self.entry_point
1145            .vars
1146            .iter()
1147            .filter_map(|var| match var {
1148                Variable::PushConstant {
1149                    ty: Type::Struct(ty),
1150                    ..
1151                } => Some(ty.members.clone()),
1152                _ => None,
1153            })
1154            .flatten()
1155            .map(|push_const| {
1156                let offset = push_const.offset.unwrap_or_default();
1157                let size = push_const
1158                    .ty
1159                    .nbyte()
1160                    .unwrap_or_default()
1161                    .next_multiple_of(4);
1162                offset..offset + size
1163            })
1164            .reduce(|a, b| a.start.min(b.start)..a.end.max(b.end))
1165            .map(|push_const| vk::PushConstantRange {
1166                stage_flags: self.stage,
1167                size: (push_const.end - push_const.start) as _,
1168                offset: push_const.start as _,
1169            })
1170    }
1171
1172    #[profiling::function]
1173    fn reflect_entry_point(
1174        entry_name: &str,
1175        spirv: impl Into<SpirvBinary>,
1176        specialization: Option<&SpecializationMap>,
1177    ) -> Result<EntryPoint, DriverError> {
1178        let mut config = ReflectConfig::new();
1179        config.ref_all_rscs(true).spv(spirv);
1180
1181        if let Some(specialization) = specialization {
1182            for &vk::SpecializationMapEntry {
1183                constant_id,
1184                offset,
1185                size,
1186            } in &specialization.entries
1187            {
1188                config.specialize(
1189                    constant_id,
1190                    specialization.data[offset as usize..offset as usize + size].into(),
1191                );
1192            }
1193        }
1194
1195        let entry_points = config.reflect().map_err(|err| {
1196            error!("invalid spirv reflection data: {err}");
1197
1198            DriverError::InvalidData
1199        })?;
1200        let entry_point = entry_points
1201            .into_iter()
1202            .find(|entry_point| entry_point.name == entry_name)
1203            .ok_or_else(|| {
1204                error!("invalid shader entry point: not found");
1205
1206                DriverError::InvalidData
1207            })?;
1208
1209        Ok(entry_point)
1210    }
1211
1212    #[profiling::function]
1213    pub(super) fn try_vertex_input(&self) -> Result<VertexInputState, DriverError> {
1214        // Check for manually-specified vertex layout descriptions
1215        if let Some(vertex_input) = &self.vertex_input_state {
1216            return Ok(vertex_input.clone());
1217        }
1218
1219        fn scalar_format(ty: &ScalarType) -> Option<vk::Format> {
1220            match *ty {
1221                ScalarType::Float { bits } => match bits {
1222                    u8::BITS => Some(vk::Format::R8_SNORM),
1223                    u16::BITS => Some(vk::Format::R16_SFLOAT),
1224                    u32::BITS => Some(vk::Format::R32_SFLOAT),
1225                    u64::BITS => Some(vk::Format::R64_SFLOAT),
1226                    _ => None,
1227                },
1228                ScalarType::Integer {
1229                    bits,
1230                    is_signed: false,
1231                } => match bits {
1232                    u8::BITS => Some(vk::Format::R8_UINT),
1233                    u16::BITS => Some(vk::Format::R16_UINT),
1234                    u32::BITS => Some(vk::Format::R32_UINT),
1235                    u64::BITS => Some(vk::Format::R64_UINT),
1236                    _ => None,
1237                },
1238                ScalarType::Integer {
1239                    bits,
1240                    is_signed: true,
1241                } => match bits {
1242                    u8::BITS => Some(vk::Format::R8_SINT),
1243                    u16::BITS => Some(vk::Format::R16_SINT),
1244                    u32::BITS => Some(vk::Format::R32_SINT),
1245                    u64::BITS => Some(vk::Format::R64_SINT),
1246                    _ => None,
1247                },
1248                _ => None,
1249            }
1250        }
1251
1252        fn vector_format(ty: &VectorType) -> Option<vk::Format> {
1253            match *ty {
1254                VectorType {
1255                    scalar_ty: ScalarType::Float { bits },
1256                    nscalar,
1257                } => match (bits, nscalar) {
1258                    (u8::BITS, 2) => Some(vk::Format::R8G8_SNORM),
1259                    (u8::BITS, 3) => Some(vk::Format::R8G8B8_SNORM),
1260                    (u8::BITS, 4) => Some(vk::Format::R8G8B8A8_SNORM),
1261                    (u16::BITS, 2) => Some(vk::Format::R16G16_SFLOAT),
1262                    (u16::BITS, 3) => Some(vk::Format::R16G16B16_SFLOAT),
1263                    (u16::BITS, 4) => Some(vk::Format::R16G16B16A16_SFLOAT),
1264                    (u32::BITS, 2) => Some(vk::Format::R32G32_SFLOAT),
1265                    (u32::BITS, 3) => Some(vk::Format::R32G32B32_SFLOAT),
1266                    (u32::BITS, 4) => Some(vk::Format::R32G32B32A32_SFLOAT),
1267                    (u64::BITS, 2) => Some(vk::Format::R64G64_SFLOAT),
1268                    (u64::BITS, 3) => Some(vk::Format::R64G64B64_SFLOAT),
1269                    (u64::BITS, 4) => Some(vk::Format::R64G64B64A64_SFLOAT),
1270                    _ => None,
1271                },
1272                VectorType {
1273                    scalar_ty:
1274                        ScalarType::Integer {
1275                            bits,
1276                            is_signed: false,
1277                        },
1278                    nscalar,
1279                } => match (bits, nscalar) {
1280                    (u8::BITS, 2) => Some(vk::Format::R8G8_UINT),
1281                    (u8::BITS, 3) => Some(vk::Format::R8G8B8_UINT),
1282                    (u8::BITS, 4) => Some(vk::Format::R8G8B8A8_UINT),
1283                    (u16::BITS, 2) => Some(vk::Format::R16G16_UINT),
1284                    (u16::BITS, 3) => Some(vk::Format::R16G16B16_UINT),
1285                    (u16::BITS, 4) => Some(vk::Format::R16G16B16A16_UINT),
1286                    (u32::BITS, 2) => Some(vk::Format::R32G32_UINT),
1287                    (u32::BITS, 3) => Some(vk::Format::R32G32B32_UINT),
1288                    (u32::BITS, 4) => Some(vk::Format::R32G32B32A32_UINT),
1289                    (u64::BITS, 2) => Some(vk::Format::R64G64_UINT),
1290                    (u64::BITS, 3) => Some(vk::Format::R64G64B64_UINT),
1291                    (u64::BITS, 4) => Some(vk::Format::R64G64B64A64_UINT),
1292                    _ => None,
1293                },
1294                VectorType {
1295                    scalar_ty:
1296                        ScalarType::Integer {
1297                            bits,
1298                            is_signed: true,
1299                        },
1300                    nscalar,
1301                } => match (bits, nscalar) {
1302                    (u8::BITS, 2) => Some(vk::Format::R8G8_SINT),
1303                    (u8::BITS, 3) => Some(vk::Format::R8G8B8_SINT),
1304                    (u8::BITS, 4) => Some(vk::Format::R8G8B8A8_SINT),
1305                    (u16::BITS, 2) => Some(vk::Format::R16G16_SINT),
1306                    (u16::BITS, 3) => Some(vk::Format::R16G16B16_SINT),
1307                    (u16::BITS, 4) => Some(vk::Format::R16G16B16A16_SINT),
1308                    (u32::BITS, 2) => Some(vk::Format::R32G32_SINT),
1309                    (u32::BITS, 3) => Some(vk::Format::R32G32B32_SINT),
1310                    (u32::BITS, 4) => Some(vk::Format::R32G32B32A32_SINT),
1311                    (u64::BITS, 2) => Some(vk::Format::R64G64_SINT),
1312                    (u64::BITS, 3) => Some(vk::Format::R64G64B64_SINT),
1313                    (u64::BITS, 4) => Some(vk::Format::R64G64B64A64_SINT),
1314                    _ => None,
1315                },
1316                _ => None,
1317            }
1318        }
1319
1320        let mut input_rates_strides = HashMap::new();
1321        let mut vertex_attribute_descriptions = vec![];
1322
1323        for (name, location, ty) in self.entry_point.vars.iter().filter_map(|var| match var {
1324            Variable::Input { name, location, ty } => Some((name, location, ty)),
1325            _ => None,
1326        }) {
1327            let (binding, guessed_rate) = name
1328                .as_ref()
1329                .filter(|name| name.contains("_ibind") || name.contains("_vbind"))
1330                .map(|name| {
1331                    let binding = name[name.rfind("bind").expect("missing bind suffix")..]
1332                        .parse()
1333                        .unwrap_or_default();
1334                    let rate = if name.contains("_ibind") {
1335                        vk::VertexInputRate::INSTANCE
1336                    } else {
1337                        vk::VertexInputRate::VERTEX
1338                    };
1339
1340                    (binding, rate)
1341                })
1342                .unwrap_or_default();
1343            let (location, _) = location.into_inner();
1344            if let Some((input_rate, _)) = input_rates_strides.get(&binding) {
1345                assert_eq!(*input_rate, guessed_rate);
1346            }
1347
1348            let byte_stride = ty.nbyte().unwrap_or_default() as u32;
1349            let (input_rate, stride) = input_rates_strides.entry(binding).or_default();
1350            *input_rate = guessed_rate;
1351            *stride += byte_stride;
1352
1353            //trace!("{location} {:?} is {byte_stride} bytes", name);
1354
1355            let format = match ty {
1356                Type::Scalar(ty) => scalar_format(ty),
1357                Type::Vector(ty) => vector_format(ty),
1358                _ => None,
1359            }
1360            .ok_or_else(|| {
1361                warn!("unsupported reflected vertex input type: {ty:?}");
1362
1363                DriverError::Unsupported
1364            })?;
1365
1366            vertex_attribute_descriptions.push(vk::VertexInputAttributeDescription {
1367                location,
1368                binding,
1369                format,
1370                offset: byte_stride, // Figured out below - this data is iter'd in an unknown order
1371            });
1372        }
1373
1374        vertex_attribute_descriptions.sort_unstable_by(|lhs, rhs| {
1375            let binding = lhs.binding.cmp(&rhs.binding);
1376            if binding.is_lt() {
1377                return binding;
1378            }
1379
1380            lhs.location.cmp(&rhs.location)
1381        });
1382
1383        let mut offset = 0;
1384        let mut offset_binding = 0;
1385
1386        for vertex_attribute_description in &mut vertex_attribute_descriptions {
1387            if vertex_attribute_description.binding != offset_binding {
1388                offset_binding = vertex_attribute_description.binding;
1389                offset = 0;
1390            }
1391
1392            let stride = vertex_attribute_description.offset;
1393            vertex_attribute_description.offset = offset;
1394            offset += stride;
1395
1396            debug!(
1397                "vertex attribute {}.{}: {:?} (offset={})",
1398                vertex_attribute_description.binding,
1399                vertex_attribute_description.location,
1400                vertex_attribute_description.format,
1401                vertex_attribute_description.offset,
1402            );
1403        }
1404
1405        let mut vertex_binding_descriptions = vec![];
1406        for (binding, (input_rate, stride)) in input_rates_strides.into_iter() {
1407            vertex_binding_descriptions.push(vk::VertexInputBindingDescription {
1408                binding,
1409                input_rate,
1410                stride,
1411            });
1412        }
1413
1414        Ok(VertexInputState {
1415            vertex_attribute_descriptions,
1416            vertex_binding_descriptions,
1417        })
1418    }
1419
1420    #[profiling::function]
1421    pub(super) fn vertex_input(&self) -> VertexInputState {
1422        self.try_vertex_input()
1423            .expect("unsupported reflected vertex input layout")
1424    }
1425}
1426
1427impl Debug for Shader {
1428    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1429        // We don't want the default formatter bc vec u8
1430        // TODO: Better output message
1431        f.write_str("Shader")
1432    }
1433}
1434
1435impl From<ShaderBuilder> for Shader {
1436    fn from(shader: ShaderBuilder) -> Self {
1437        shader.build()
1438    }
1439}
1440
1441impl<T> From<T> for Shader
1442where
1443    T: Into<SpirvBinary>,
1444{
1445    fn from(spirv: T) -> Self {
1446        Shader::from_spirv(spirv).build()
1447    }
1448}
1449
1450// HACK: https://github.com/colin-kiegel/rust-derive-builder/issues/56
1451impl ShaderBuilder {
1452    /// Specifies a shader with the given `stage` and shader code values.
1453    pub fn new(stage: vk::ShaderStageFlags, spirv: Vec<u8>) -> Self {
1454        Self::default().stage(stage).spirv(spirv)
1455    }
1456
1457    /// Builds a new `Shader`.
1458    pub fn build(self) -> Shader {
1459        let entry_name = self.entry_name.clone().unwrap_or_else(|| "main".to_owned());
1460
1461        self.try_build().unwrap_or_else(|_| {
1462            panic!("invalid or unsupported shader code for entry name '{entry_name}'")
1463        })
1464    }
1465
1466    /// Specifies a manually-defined image sampler.
1467    ///
1468    /// Sampled images, by default, use reflection to automatically assign image samplers. Each
1469    /// sampled image may use a suffix such as `_llr` or `_nne` for common linear/linear repeat or
1470    /// nearest/nearest clamp-to-edge samplers, respectively.
1471    ///
1472    /// See the [main documentation] for more information about automatic image samplers.
1473    ///
1474    /// Descriptor bindings may be specified as `(1, 2)` for descriptor set index `1` and binding
1475    /// index `2`, or if the descriptor set index is `0` simply specify `2` for the same case.
1476    ///
1477    /// _NOTE:_ When defining image samplers which are used in multiple stages of a single pipeline
1478    /// you must only call this function on one of the shader stages, it does not matter which one.
1479    ///
1480    /// # Panics
1481    ///
1482    /// Panics if two shader stages of the same pipeline define individual calls to `image_sampler`.
1483    ///
1484    /// [main documentation]: crate
1485    #[profiling::function]
1486    pub fn image_sampler(
1487        mut self,
1488        descriptor: impl Into<Descriptor>,
1489        info: impl Into<SamplerInfo>,
1490    ) -> Self {
1491        let descriptor = descriptor.into();
1492        let info = info.into();
1493
1494        if self.image_samplers.is_none() {
1495            self.image_samplers = Some(Default::default());
1496        }
1497
1498        self.image_samplers
1499            .as_mut()
1500            .expect("missing image samplers")
1501            .insert(descriptor, info);
1502
1503        self
1504    }
1505
1506    /// Attempts to build a new `Shader`.
1507    pub fn try_build(mut self) -> Result<Shader, DriverError> {
1508        let entry_name = self.entry_name.as_deref().unwrap_or("main");
1509        let entry_point = Shader::reflect_entry_point(
1510            entry_name,
1511            self.spirv
1512                .as_ref()
1513                .map(|spirv| spirv.words())
1514                .expect("missing spirv code"),
1515            self.specialization
1516                .as_ref()
1517                .map(|opt| opt.as_ref())
1518                .unwrap_or_default(),
1519        )
1520        .map_err(|err| {
1521            warn!("invalid shader reflection entry point: {err}");
1522
1523            DriverError::InvalidData
1524        })?;
1525
1526        if self.stage.unwrap_or_default().is_empty() {
1527            self.stage = Some(match entry_point.exec_model {
1528                ExecutionModel::Vertex => vk::ShaderStageFlags::VERTEX,
1529                ExecutionModel::TessellationControl => vk::ShaderStageFlags::TESSELLATION_CONTROL,
1530                ExecutionModel::TessellationEvaluation => {
1531                    vk::ShaderStageFlags::TESSELLATION_EVALUATION
1532                }
1533                ExecutionModel::Geometry => vk::ShaderStageFlags::GEOMETRY,
1534                ExecutionModel::Fragment => vk::ShaderStageFlags::FRAGMENT,
1535                ExecutionModel::GLCompute => vk::ShaderStageFlags::COMPUTE,
1536                ExecutionModel::Kernel => {
1537                    warn!("unsupported shader execution model: kernel");
1538
1539                    return Err(DriverError::Unsupported);
1540                }
1541                ExecutionModel::TaskNV => vk::ShaderStageFlags::TASK_EXT,
1542                ExecutionModel::MeshNV => vk::ShaderStageFlags::MESH_EXT,
1543                ExecutionModel::RayGenerationNV => vk::ShaderStageFlags::RAYGEN_KHR,
1544                ExecutionModel::IntersectionNV => vk::ShaderStageFlags::INTERSECTION_KHR,
1545                ExecutionModel::AnyHitNV => vk::ShaderStageFlags::ANY_HIT_KHR,
1546                ExecutionModel::ClosestHitNV => vk::ShaderStageFlags::CLOSEST_HIT_KHR,
1547                ExecutionModel::MissNV => vk::ShaderStageFlags::MISS_KHR,
1548                ExecutionModel::CallableNV => vk::ShaderStageFlags::CALLABLE_KHR,
1549                ExecutionModel::TaskEXT => vk::ShaderStageFlags::TASK_EXT,
1550                ExecutionModel::MeshEXT => vk::ShaderStageFlags::MESH_EXT,
1551            })
1552        }
1553
1554        self.entry_point = Some(entry_point);
1555
1556        self.fallible_build().map_err(|err| {
1557            warn!("invalid shader builder state: {err:?}");
1558
1559            DriverError::InvalidData
1560        })
1561    }
1562
1563    /// Specifies a manually-defined vertex input layout.
1564    ///
1565    /// The vertex input layout, by default, uses reflection to automatically define vertex binding
1566    /// and attribute descriptions. Each vertex location is inferred to have 32-bit channels and be
1567    /// tightly packed in the vertex buffer. In this mode, a location with `_ibind_0` or `_vbind3`
1568    /// suffixes is inferred to use instance-rate on vertex buffer binding `0` or vertex rate on
1569    /// binding `3`, respectively.
1570    ///
1571    /// See the [main documentation] for more information about automatic vertex input layout.
1572    ///
1573    /// [main documentation]: crate
1574    #[profiling::function]
1575    pub fn vertex_input(
1576        mut self,
1577        bindings: impl Into<Vec<vk::VertexInputBindingDescription>>,
1578        attributes: impl Into<Vec<vk::VertexInputAttributeDescription>>,
1579    ) -> Self {
1580        self.vertex_input_state = Some(Some(VertexInputState {
1581            vertex_binding_descriptions: bindings.into(),
1582            vertex_attribute_descriptions: attributes.into(),
1583        }));
1584        self
1585    }
1586}
1587
1588#[derive(Debug)]
1589struct ShaderBuilderError;
1590
1591impl From<UninitializedFieldError> for ShaderBuilderError {
1592    fn from(_: UninitializedFieldError) -> Self {
1593        Self
1594    }
1595}
1596
1597/// Describes specialized constant values.
1598#[derive(Clone, Debug, Default)]
1599pub struct SpecializationMap {
1600    /// A buffer of data which holds the constant values.
1601    pub data: Vec<u8>,
1602
1603    /// Mapping of locations within the constant value data which describe each individual
1604    /// constant.
1605    pub entries: Vec<vk::SpecializationMapEntry>,
1606}
1607
1608impl SpecializationMap {
1609    /// Constructs a new `SpecializationMap`.
1610    pub fn new(data: impl Into<Vec<u8>>) -> Self {
1611        Self {
1612            data: data.into(),
1613            entries: Default::default(),
1614        }
1615    }
1616
1617    /// Adds a single constant offset and size to the map and returns the map for further building.
1618    pub fn constant(mut self, constant_id: u32, offset: u32, size: usize) -> Self {
1619        self.set_constant(constant_id, offset, size);
1620        self
1621    }
1622
1623    /// Adds a single constant offset and size to the map and returns the map for further building.
1624    pub fn set_constant(&mut self, constant_id: u32, offset: u32, size: usize) {
1625        self.entries.push(vk::SpecializationMapEntry {
1626            constant_id,
1627            offset,
1628            size,
1629        });
1630    }
1631}
1632
1633impl<'a> From<&'a SpecializationMap> for vk::SpecializationInfo<'a> {
1634    fn from(value: &'a SpecializationMap) -> Self {
1635        vk::SpecializationInfo::default()
1636            .map_entries(&value.entries)
1637            .data(&value.data)
1638    }
1639}
1640
1641mod deprecated {
1642    use {
1643        crate::driver::shader::{ShaderBuilder, SpecializationMap},
1644        ash::vk,
1645    };
1646
1647    #[derive(Clone, Debug)]
1648    pub struct SpecializationInfo {
1649        pub data: Vec<u8>,
1650        pub map_entries: Vec<vk::SpecializationMapEntry>,
1651    }
1652
1653    impl SpecializationInfo {
1654        pub fn new(
1655            map_entries: impl Into<Vec<vk::SpecializationMapEntry>>,
1656            data: impl Into<Vec<u8>>,
1657        ) -> Self {
1658            Self {
1659                data: data.into(),
1660                map_entries: map_entries.into(),
1661            }
1662        }
1663    }
1664
1665    impl ShaderBuilder {
1666        #[deprecated = "use specialization function"]
1667        #[doc(hidden)]
1668        pub fn specialization_info(self, info: SpecializationInfo) -> Self {
1669            let mut specialization = SpecializationMap::new(info.data);
1670
1671            for entry in &info.map_entries {
1672                specialization.set_constant(entry.constant_id, entry.offset, entry.size);
1673            }
1674
1675            self.specialization(specialization)
1676        }
1677    }
1678}
1679
1680#[cfg(test)]
1681mod test {
1682    use super::*;
1683
1684    type Info = SamplerInfo;
1685    type Builder = SamplerInfoBuilder;
1686
1687    #[test]
1688    pub fn sampler_info() {
1689        let info = Info::default();
1690        let builder = info.into_builder().build();
1691
1692        assert_eq!(info, builder);
1693    }
1694
1695    #[test]
1696    pub fn sampler_info_builder() {
1697        let info = Info::default();
1698        let builder = Builder::default().build();
1699
1700        assert_eq!(info, builder);
1701    }
1702}