screen_13/driver/
shader.rs

1//! Shader resource types
2
3use {
4    super::{DescriptorSetLayout, DriverError, VertexInputState, device::Device},
5    ash::vk,
6    derive_builder::{Builder, UninitializedFieldError},
7    log::{debug, error, trace, warn},
8    ordered_float::OrderedFloat,
9    spirq::{
10        ReflectConfig,
11        entry_point::EntryPoint,
12        ty::{DescriptorType, ScalarType, SpirvType, Type},
13        var::Variable,
14    },
15    std::{
16        collections::{BTreeMap, HashMap},
17        fmt::{Debug, Formatter},
18        iter::repeat,
19        mem::size_of_val,
20        ops::Deref,
21        sync::Arc,
22        thread::panicking,
23    },
24};
25
26pub(crate) type DescriptorBindingMap = HashMap<Descriptor, (DescriptorInfo, vk::ShaderStageFlags)>;
27
28pub(crate) fn align_spriv(code: &[u8]) -> Result<&[u32], DriverError> {
29    let (prefix, code, suffix) = unsafe { code.align_to() };
30
31    if prefix.len() + suffix.len() == 0 {
32        Ok(code)
33    } else {
34        warn!("Invalid SPIR-V code");
35
36        Err(DriverError::InvalidData)
37    }
38}
39
40#[profiling::function]
41fn guess_immutable_sampler(binding_name: &str) -> SamplerInfo {
42    const INVALID_ERR: &str = "Invalid sampler specification";
43
44    let (texel_filter, mipmap_mode, address_modes) = if binding_name.contains("_sampler_") {
45        let spec = &binding_name[binding_name.len() - 3..];
46        let texel_filter = match &spec[0..1] {
47            "n" => vk::Filter::NEAREST,
48            "l" => vk::Filter::LINEAR,
49            _ => panic!("{INVALID_ERR}: {}", &spec[0..1]),
50        };
51
52        let mipmap_mode = match &spec[1..2] {
53            "n" => vk::SamplerMipmapMode::NEAREST,
54            "l" => vk::SamplerMipmapMode::LINEAR,
55            _ => panic!("{INVALID_ERR}: {}", &spec[1..2]),
56        };
57
58        let address_modes = match &spec[2..3] {
59            "b" => vk::SamplerAddressMode::CLAMP_TO_BORDER,
60            "e" => vk::SamplerAddressMode::CLAMP_TO_EDGE,
61            "m" => vk::SamplerAddressMode::MIRRORED_REPEAT,
62            "r" => vk::SamplerAddressMode::REPEAT,
63            _ => panic!("{INVALID_ERR}: {}", &spec[2..3]),
64        };
65
66        (texel_filter, mipmap_mode, address_modes)
67    } else {
68        debug!("image binding {binding_name} using default sampler");
69
70        (
71            vk::Filter::LINEAR,
72            vk::SamplerMipmapMode::LINEAR,
73            vk::SamplerAddressMode::REPEAT,
74        )
75    };
76    let anisotropy_enable = texel_filter == vk::Filter::LINEAR;
77    let mut info = SamplerInfoBuilder::default()
78        .mag_filter(texel_filter)
79        .min_filter(texel_filter)
80        .mipmap_mode(mipmap_mode)
81        .address_mode_u(address_modes)
82        .address_mode_v(address_modes)
83        .address_mode_w(address_modes)
84        .max_lod(vk::LOD_CLAMP_NONE)
85        .anisotropy_enable(anisotropy_enable);
86
87    if anisotropy_enable {
88        info = info.max_anisotropy(16.0);
89    }
90
91    info.build()
92}
93
94/// Tuple of descriptor set index and binding index.
95///
96/// This is a generic representation of the descriptor binding point within the shader and not a
97/// bound descriptor reference.
98#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
99pub struct Descriptor {
100    /// Descriptor set index
101    pub set: u32,
102
103    /// Descriptor binding index
104    pub binding: u32,
105}
106
107impl From<u32> for Descriptor {
108    fn from(binding: u32) -> Self {
109        Self { set: 0, binding }
110    }
111}
112
113impl From<(u32, u32)> for Descriptor {
114    fn from((set, binding): (u32, u32)) -> Self {
115        Self { set, binding }
116    }
117}
118
119#[derive(Clone, Copy, Debug)]
120pub(crate) enum DescriptorInfo {
121    AccelerationStructure(u32),
122    CombinedImageSampler(u32, SamplerInfo, bool), //count, sampler, is-manually-defined?
123    InputAttachment(u32, u32),                    //count, input index,
124    SampledImage(u32),
125    Sampler(u32, SamplerInfo, bool), //count, sampler, is-manually-defined?
126    StorageBuffer(u32),
127    StorageImage(u32),
128    StorageTexelBuffer(u32),
129    UniformBuffer(u32),
130    UniformTexelBuffer(u32),
131}
132
133impl DescriptorInfo {
134    pub fn binding_count(self) -> u32 {
135        match self {
136            Self::AccelerationStructure(binding_count) => binding_count,
137            Self::CombinedImageSampler(binding_count, ..) => binding_count,
138            Self::InputAttachment(binding_count, _) => binding_count,
139            Self::SampledImage(binding_count) => binding_count,
140            Self::Sampler(binding_count, ..) => binding_count,
141            Self::StorageBuffer(binding_count) => binding_count,
142            Self::StorageImage(binding_count) => binding_count,
143            Self::StorageTexelBuffer(binding_count) => binding_count,
144            Self::UniformBuffer(binding_count) => binding_count,
145            Self::UniformTexelBuffer(binding_count) => binding_count,
146        }
147    }
148
149    pub fn descriptor_type(self) -> vk::DescriptorType {
150        match self {
151            Self::AccelerationStructure(_) => vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
152            Self::CombinedImageSampler(..) => vk::DescriptorType::COMBINED_IMAGE_SAMPLER,
153            Self::InputAttachment(..) => vk::DescriptorType::INPUT_ATTACHMENT,
154            Self::SampledImage(_) => vk::DescriptorType::SAMPLED_IMAGE,
155            Self::Sampler(..) => vk::DescriptorType::SAMPLER,
156            Self::StorageBuffer(_) => vk::DescriptorType::STORAGE_BUFFER,
157            Self::StorageImage(_) => vk::DescriptorType::STORAGE_IMAGE,
158            Self::StorageTexelBuffer(_) => vk::DescriptorType::STORAGE_TEXEL_BUFFER,
159            Self::UniformBuffer(_) => vk::DescriptorType::UNIFORM_BUFFER,
160            Self::UniformTexelBuffer(_) => vk::DescriptorType::UNIFORM_TEXEL_BUFFER,
161        }
162    }
163
164    fn sampler_info(self) -> Option<SamplerInfo> {
165        match self {
166            Self::CombinedImageSampler(_, sampler_info, _) | Self::Sampler(_, sampler_info, _) => {
167                Some(sampler_info)
168            }
169            _ => None,
170        }
171    }
172
173    pub fn set_binding_count(&mut self, binding_count: u32) {
174        *match self {
175            Self::AccelerationStructure(binding_count) => binding_count,
176            Self::CombinedImageSampler(binding_count, ..) => binding_count,
177            Self::InputAttachment(binding_count, _) => binding_count,
178            Self::SampledImage(binding_count) => binding_count,
179            Self::Sampler(binding_count, ..) => binding_count,
180            Self::StorageBuffer(binding_count) => binding_count,
181            Self::StorageImage(binding_count) => binding_count,
182            Self::StorageTexelBuffer(binding_count) => binding_count,
183            Self::UniformBuffer(binding_count) => binding_count,
184            Self::UniformTexelBuffer(binding_count) => binding_count,
185        } = binding_count;
186    }
187}
188
189#[derive(Debug)]
190pub(crate) struct PipelineDescriptorInfo {
191    pub layouts: BTreeMap<u32, DescriptorSetLayout>,
192    pub pool_sizes: HashMap<u32, HashMap<vk::DescriptorType, u32>>,
193
194    #[allow(dead_code)]
195    samplers: Box<[Sampler]>,
196}
197
198impl PipelineDescriptorInfo {
199    #[profiling::function]
200    pub fn create(
201        device: &Arc<Device>,
202        descriptor_bindings: &DescriptorBindingMap,
203    ) -> Result<Self, DriverError> {
204        let descriptor_set_count = descriptor_bindings
205            .keys()
206            .map(|descriptor| descriptor.set)
207            .max()
208            .map(|set| set + 1)
209            .unwrap_or_default();
210        let mut layouts = BTreeMap::new();
211        let mut pool_sizes = HashMap::new();
212
213        //trace!("descriptor_bindings: {:#?}", &descriptor_bindings);
214
215        let mut sampler_info_binding_count = HashMap::<_, u32>::with_capacity(
216            descriptor_bindings
217                .values()
218                .filter(|(descriptor_info, _)| descriptor_info.sampler_info().is_some())
219                .count(),
220        );
221
222        for (sampler_info, binding_count) in
223            descriptor_bindings
224                .values()
225                .filter_map(|(descriptor_info, _)| {
226                    descriptor_info
227                        .sampler_info()
228                        .map(|sampler_info| (sampler_info, descriptor_info.binding_count()))
229                })
230        {
231            sampler_info_binding_count
232                .entry(sampler_info)
233                .and_modify(|sampler_info_binding_count| {
234                    *sampler_info_binding_count = binding_count.max(*sampler_info_binding_count);
235                })
236                .or_insert(binding_count);
237        }
238
239        let mut samplers = sampler_info_binding_count
240            .keys()
241            .copied()
242            .map(|sampler_info| {
243                Sampler::create(device, sampler_info).map(|sampler| (sampler_info, sampler))
244            })
245            .collect::<Result<HashMap<_, _>, _>>()?;
246        let immutable_samplers = sampler_info_binding_count
247            .iter()
248            .map(|(sampler_info, &binding_count)| {
249                (
250                    *sampler_info,
251                    repeat(*samplers[sampler_info])
252                        .take(binding_count as _)
253                        .collect::<Box<_>>(),
254                )
255            })
256            .collect::<HashMap<_, _>>();
257
258        for descriptor_set_idx in 0..descriptor_set_count {
259            let mut binding_counts = HashMap::<vk::DescriptorType, u32>::new();
260            let mut bindings = vec![];
261
262            for (descriptor, (descriptor_info, stage_flags)) in descriptor_bindings
263                .iter()
264                .filter(|(descriptor, _)| descriptor.set == descriptor_set_idx)
265            {
266                let descriptor_ty = descriptor_info.descriptor_type();
267                *binding_counts.entry(descriptor_ty).or_default() +=
268                    descriptor_info.binding_count();
269                let mut binding = vk::DescriptorSetLayoutBinding::default()
270                    .binding(descriptor.binding)
271                    .descriptor_count(descriptor_info.binding_count())
272                    .descriptor_type(descriptor_ty)
273                    .stage_flags(*stage_flags);
274
275                if let Some(immutable_samplers) =
276                    descriptor_info.sampler_info().map(|sampler_info| {
277                        &immutable_samplers[&sampler_info]
278                            [0..descriptor_info.binding_count() as usize]
279                    })
280                {
281                    binding = binding.immutable_samplers(immutable_samplers);
282                }
283
284                bindings.push(binding);
285            }
286
287            let pool_size = pool_sizes
288                .entry(descriptor_set_idx)
289                .or_insert_with(HashMap::new);
290
291            for (descriptor_ty, binding_count) in binding_counts.into_iter() {
292                *pool_size.entry(descriptor_ty).or_default() += binding_count;
293            }
294
295            //trace!("bindings: {:#?}", &bindings);
296
297            let mut create_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
298
299            // The bindless flags have to be created for every descriptor set layout binding.
300            // [vulkan spec](https://www.khronos.org/registry/vulkan/specs/1.3-extensions/man/html/VkDescriptorSetLayoutBindingFlagsCreateInfo.html)
301            // Maybe using one vector and updating it would be more efficient.
302            let bindless_flags = vec![vk::DescriptorBindingFlags::PARTIALLY_BOUND; bindings.len()];
303            let mut bindless_flags = if device
304                .physical_device
305                .features_v1_2
306                .descriptor_binding_partially_bound
307            {
308                let bindless_flags = vk::DescriptorSetLayoutBindingFlagsCreateInfo::default()
309                    .binding_flags(&bindless_flags);
310                Some(bindless_flags)
311            } else {
312                None
313            };
314
315            if let Some(bindless_flags) = bindless_flags.as_mut() {
316                create_info = create_info.push_next(bindless_flags);
317            }
318
319            layouts.insert(
320                descriptor_set_idx,
321                DescriptorSetLayout::create(device, &create_info)?,
322            );
323        }
324
325        let samplers = samplers
326            .drain()
327            .map(|(_, sampler)| sampler)
328            .collect::<Box<_>>();
329
330        //trace!("layouts {:#?}", &layouts);
331        // trace!("pool_sizes {:#?}", &pool_sizes);
332
333        Ok(Self {
334            layouts,
335            pool_sizes,
336            samplers,
337        })
338    }
339}
340
341pub(crate) struct Sampler {
342    device: Arc<Device>,
343    sampler: vk::Sampler,
344}
345
346impl Sampler {
347    #[profiling::function]
348    pub fn create(device: &Arc<Device>, info: impl Into<SamplerInfo>) -> Result<Self, DriverError> {
349        let device = Arc::clone(device);
350        let info = info.into();
351
352        let sampler = unsafe {
353            device
354                .create_sampler(
355                    &vk::SamplerCreateInfo::default()
356                        .flags(info.flags)
357                        .mag_filter(info.mag_filter)
358                        .min_filter(info.min_filter)
359                        .mipmap_mode(info.mipmap_mode)
360                        .address_mode_u(info.address_mode_u)
361                        .address_mode_v(info.address_mode_v)
362                        .address_mode_w(info.address_mode_w)
363                        .mip_lod_bias(info.mip_lod_bias.0)
364                        .anisotropy_enable(info.anisotropy_enable)
365                        .max_anisotropy(info.max_anisotropy.0)
366                        .compare_enable(info.compare_enable)
367                        .compare_op(info.compare_op)
368                        .min_lod(info.min_lod.0)
369                        .max_lod(info.max_lod.0)
370                        .border_color(info.border_color)
371                        .unnormalized_coordinates(info.unnormalized_coordinates)
372                        .push_next(
373                            &mut vk::SamplerReductionModeCreateInfo::default()
374                                .reduction_mode(info.reduction_mode),
375                        ),
376                    None,
377                )
378                .map_err(|err| {
379                    warn!("{err}");
380
381                    match err {
382                        vk::Result::ERROR_OUT_OF_HOST_MEMORY
383                        | vk::Result::ERROR_OUT_OF_DEVICE_MEMORY => DriverError::OutOfMemory,
384                        _ => DriverError::Unsupported,
385                    }
386                })?
387        };
388
389        Ok(Self { device, sampler })
390    }
391}
392
393impl Debug for Sampler {
394    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
395        write!(f, "{:?}", self.sampler)
396    }
397}
398
399impl Deref for Sampler {
400    type Target = vk::Sampler;
401
402    fn deref(&self) -> &Self::Target {
403        &self.sampler
404    }
405}
406
407impl Drop for Sampler {
408    #[profiling::function]
409    fn drop(&mut self) {
410        if panicking() {
411            return;
412        }
413
414        unsafe {
415            self.device.destroy_sampler(self.sampler, None);
416        }
417    }
418}
419
420/// Information used to create a [`vk::Sampler`] instance.
421#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
422#[builder(
423    build_fn(private, name = "fallible_build", error = "SamplerInfoBuilderError"),
424    derive(Clone, Copy, Debug),
425    pattern = "owned"
426)]
427#[non_exhaustive]
428pub struct SamplerInfo {
429    /// Bitmask specifying additional parameters of a sampler.
430    #[builder(default)]
431    pub flags: vk::SamplerCreateFlags,
432
433    /// Specify the magnification filter to apply to texture lookups.
434    ///
435    /// The default value is [`vk::Filter::NEAREST`]
436    #[builder(default)]
437    pub mag_filter: vk::Filter,
438
439    /// Specify the minification filter to apply to texture lookups.
440    ///
441    /// The default value is [`vk::Filter::NEAREST`]
442    #[builder(default)]
443    pub min_filter: vk::Filter,
444
445    /// A value specifying the mipmap filter to apply to lookups.
446    ///
447    /// The default value is [`vk::SamplerMipmapMode::NEAREST`]
448    #[builder(default)]
449    pub mipmap_mode: vk::SamplerMipmapMode,
450
451    /// A value specifying the addressing mode for U coordinates outside `[0, 1)`.
452    ///
453    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
454    #[builder(default)]
455    pub address_mode_u: vk::SamplerAddressMode,
456
457    /// A value specifying the addressing mode for V coordinates outside `[0, 1)`.
458    ///
459    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
460    #[builder(default)]
461    pub address_mode_v: vk::SamplerAddressMode,
462
463    /// A value specifying the addressing mode for W coordinates outside `[0, 1)`.
464    ///
465    /// The default value is [`vk::SamplerAddressMode::REPEAT`]
466    #[builder(default)]
467    pub address_mode_w: vk::SamplerAddressMode,
468
469    /// The bias to be added to mipmap LOD calculation and bias provided by image sampling functions
470    /// in SPIR-V, as described in the
471    /// [LOD Operation](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-level-of-detail-operation)
472    /// section.
473    #[builder(default, setter(into))]
474    pub mip_lod_bias: OrderedFloat<f32>,
475
476    /// Enables anisotropic filtering, as described in the
477    /// [Texel Anisotropic Filtering](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-texel-anisotropic-filtering)
478    /// section
479    #[builder(default)]
480    pub anisotropy_enable: bool,
481
482    /// The anisotropy value clamp used by the sampler when `anisotropy_enable` is `true`.
483    ///
484    /// If `anisotropy_enable` is `false`, max_anisotropy is ignored.
485    #[builder(default, setter(into))]
486    pub max_anisotropy: OrderedFloat<f32>,
487
488    /// Enables comparison against a reference value during lookups.
489    #[builder(default)]
490    pub compare_enable: bool,
491
492    /// Specifies the comparison operator to apply to fetched data before filtering as described in
493    /// the
494    /// [Depth Compare Operation](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-depth-compare-operation)
495    /// section.
496    #[builder(default)]
497    pub compare_op: vk::CompareOp,
498
499    /// Used to clamp the
500    /// [minimum of the computed LOD value](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-level-of-detail-operation).
501    #[builder(default, setter(into))]
502    pub min_lod: OrderedFloat<f32>,
503
504    /// Used to clamp the
505    /// [maximum of the computed LOD value](https://registry.khronos.org/vulkan/specs/1.3-extensions/html/vkspec.html#textures-level-of-detail-operation).
506    ///
507    /// To avoid clamping the maximum value, set maxLod to the constant `vk::LOD_CLAMP_NONE`.
508    #[builder(default, setter(into))]
509    pub max_lod: OrderedFloat<f32>,
510
511    /// Secifies the predefined border color to use.
512    ///
513    /// The default value is [`vk::BorderColor::FLOAT_TRANSPARENT_BLACK`]
514    #[builder(default)]
515    pub border_color: vk::BorderColor,
516
517    /// Controls whether to use unnormalized or normalized texel coordinates to address texels of
518    /// the image.
519    ///
520    /// When set to `true`, the range of the image coordinates used to lookup the texel is in the
521    /// range of zero to the image size in each dimension.
522    ///
523    /// When set to `false` the range of image coordinates is zero to one.
524    ///
525    /// See
526    /// [requirements](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkSamplerCreateInfo.html).
527    #[builder(default)]
528    pub unnormalized_coordinates: bool,
529
530    /// Specifies sampler reduction mode.
531    ///
532    /// Setting magnification filter ([`mag_filter`](Self::mag_filter)) to [`vk::Filter::NEAREST`]
533    /// disables sampler reduction mode.
534    ///
535    /// The default value is [`vk::SamplerReductionMode::WEIGHTED_AVERAGE`]
536    ///
537    /// See
538    /// [requirements](https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkSamplerCreateInfo.html).
539    #[builder(default)]
540    pub reduction_mode: vk::SamplerReductionMode,
541}
542
543impl SamplerInfo {
544    /// Default sampler information with `mag_filter`, `min_filter` and `mipmap_mode` set to linear.
545    pub const LINEAR: SamplerInfoBuilder = SamplerInfoBuilder {
546        flags: None,
547        mag_filter: Some(vk::Filter::LINEAR),
548        min_filter: Some(vk::Filter::LINEAR),
549        mipmap_mode: Some(vk::SamplerMipmapMode::LINEAR),
550        address_mode_u: None,
551        address_mode_v: None,
552        address_mode_w: None,
553        mip_lod_bias: None,
554        anisotropy_enable: None,
555        max_anisotropy: None,
556        compare_enable: None,
557        compare_op: None,
558        min_lod: None,
559        max_lod: None,
560        border_color: None,
561        unnormalized_coordinates: None,
562        reduction_mode: None,
563    };
564
565    /// Default sampler information with `mag_filter`, `min_filter` and `mipmap_mode` set to
566    /// nearest.
567    pub const NEAREST: SamplerInfoBuilder = SamplerInfoBuilder {
568        flags: None,
569        mag_filter: Some(vk::Filter::NEAREST),
570        min_filter: Some(vk::Filter::NEAREST),
571        mipmap_mode: Some(vk::SamplerMipmapMode::NEAREST),
572        address_mode_u: None,
573        address_mode_v: None,
574        address_mode_w: None,
575        mip_lod_bias: None,
576        anisotropy_enable: None,
577        max_anisotropy: None,
578        compare_enable: None,
579        compare_op: None,
580        min_lod: None,
581        max_lod: None,
582        border_color: None,
583        unnormalized_coordinates: None,
584        reduction_mode: None,
585    };
586
587    /// Creates a default `SamplerInfoBuilder`.
588    #[allow(clippy::new_ret_no_self)]
589    #[deprecated = "Use SamplerInfo::default()"]
590    #[doc(hidden)]
591    pub fn new() -> SamplerInfoBuilder {
592        Self::default().to_builder()
593    }
594
595    /// Converts a `SamplerInfo` into a `SamplerInfoBuilder`.
596    #[inline(always)]
597    pub fn to_builder(self) -> SamplerInfoBuilder {
598        SamplerInfoBuilder {
599            flags: Some(self.flags),
600            mag_filter: Some(self.mag_filter),
601            min_filter: Some(self.min_filter),
602            mipmap_mode: Some(self.mipmap_mode),
603            address_mode_u: Some(self.address_mode_u),
604            address_mode_v: Some(self.address_mode_v),
605            address_mode_w: Some(self.address_mode_w),
606            mip_lod_bias: Some(self.mip_lod_bias),
607            anisotropy_enable: Some(self.anisotropy_enable),
608            max_anisotropy: Some(self.max_anisotropy),
609            compare_enable: Some(self.compare_enable),
610            compare_op: Some(self.compare_op),
611            min_lod: Some(self.min_lod),
612            max_lod: Some(self.max_lod),
613            border_color: Some(self.border_color),
614            unnormalized_coordinates: Some(self.unnormalized_coordinates),
615            reduction_mode: Some(self.reduction_mode),
616        }
617    }
618}
619
620impl Default for SamplerInfo {
621    fn default() -> Self {
622        Self {
623            flags: vk::SamplerCreateFlags::empty(),
624            mag_filter: vk::Filter::NEAREST,
625            min_filter: vk::Filter::NEAREST,
626            mipmap_mode: vk::SamplerMipmapMode::NEAREST,
627            address_mode_u: vk::SamplerAddressMode::REPEAT,
628            address_mode_v: vk::SamplerAddressMode::REPEAT,
629            address_mode_w: vk::SamplerAddressMode::REPEAT,
630            mip_lod_bias: OrderedFloat(0.0),
631            anisotropy_enable: false,
632            max_anisotropy: OrderedFloat(0.0),
633            compare_enable: false,
634            compare_op: vk::CompareOp::NEVER,
635            min_lod: OrderedFloat(0.0),
636            max_lod: OrderedFloat(0.0),
637            border_color: vk::BorderColor::FLOAT_TRANSPARENT_BLACK,
638            unnormalized_coordinates: false,
639            reduction_mode: vk::SamplerReductionMode::WEIGHTED_AVERAGE,
640        }
641    }
642}
643
644impl SamplerInfoBuilder {
645    /// Builds a new `SamplerInfo`.
646    #[inline(always)]
647    pub fn build(self) -> SamplerInfo {
648        let res = self.fallible_build();
649
650        #[cfg(test)]
651        let res = res.unwrap();
652
653        #[cfg(not(test))]
654        let res = unsafe { res.unwrap_unchecked() };
655
656        res
657    }
658}
659
660impl From<SamplerInfoBuilder> for SamplerInfo {
661    fn from(info: SamplerInfoBuilder) -> Self {
662        info.build()
663    }
664}
665
666#[derive(Debug)]
667struct SamplerInfoBuilderError;
668
669impl From<UninitializedFieldError> for SamplerInfoBuilderError {
670    fn from(_: UninitializedFieldError) -> Self {
671        Self
672    }
673}
674
675/// Describes a shader program which runs on some pipeline stage.
676#[allow(missing_docs)]
677#[derive(Builder, Clone)]
678#[builder(
679    build_fn(private, name = "fallible_build", error = "ShaderBuilderError"),
680    derive(Clone, Debug),
681    pattern = "owned"
682)]
683pub struct Shader {
684    /// The name of the entry point which will be executed by this shader.
685    ///
686    /// The default value is `main`.
687    #[builder(default = "\"main\".to_owned()")]
688    pub entry_name: String,
689
690    /// Data about Vulkan specialization constants.
691    ///
692    /// # Examples
693    ///
694    /// Basic usage (GLSL):
695    ///
696    /// ```
697    /// # inline_spirv::inline_spirv!(r#"
698    /// #version 460 core
699    ///
700    /// // Defaults to 6 if not set using Shader specialization_info!
701    /// layout(constant_id = 0) const uint MY_COUNT = 6;
702    ///
703    /// layout(set = 0, binding = 0) uniform sampler2D my_samplers[MY_COUNT];
704    ///
705    /// void main()
706    /// {
707    ///     // Code uses MY_COUNT number of my_samplers here
708    /// }
709    /// # "#, comp);
710    /// ```
711    ///
712    /// ```no_run
713    /// # use std::sync::Arc;
714    /// # use ash::vk;
715    /// # use screen_13::driver::DriverError;
716    /// # use screen_13::driver::device::{Device, DeviceInfo};
717    /// # use screen_13::driver::shader::{Shader, SpecializationInfo};
718    /// # fn main() -> Result<(), DriverError> {
719    /// # let device = Arc::new(Device::create_headless(DeviceInfo::default())?);
720    /// # let my_shader_code = [0u8; 1];
721    /// // We instead specify 42 for MY_COUNT:
722    /// let shader = Shader::new_fragment(my_shader_code.as_slice())
723    ///     .specialization_info(SpecializationInfo::new(
724    ///         [vk::SpecializationMapEntry {
725    ///             constant_id: 0,
726    ///             offset: 0,
727    ///             size: 4,
728    ///         }],
729    ///         42u32.to_ne_bytes()
730    ///     ));
731    /// # Ok(()) }
732    /// ```
733    #[builder(default, setter(strip_option))]
734    pub specialization_info: Option<SpecializationInfo>,
735
736    /// Shader code.
737    ///
738    /// Although SPIR-V code is specified as `u32` values, this field uses `u8` in order to make
739    /// loading from file simpler. You should always have a SPIR-V code length which is a multiple
740    /// of four bytes, or an error will be returned during pipeline creation.
741    pub spirv: Vec<u8>,
742
743    /// The shader stage this structure applies to.
744    pub stage: vk::ShaderStageFlags,
745
746    #[builder(private)]
747    entry_point: EntryPoint,
748
749    #[builder(default, private)]
750    image_samplers: HashMap<Descriptor, SamplerInfo>,
751
752    #[builder(default, private, setter(strip_option))]
753    vertex_input_state: Option<VertexInputState>,
754}
755
756impl Shader {
757    /// Specifies a shader with the given `stage` and shader code values.
758    #[allow(clippy::new_ret_no_self)]
759    pub fn new(stage: vk::ShaderStageFlags, spirv: impl ShaderCode) -> ShaderBuilder {
760        ShaderBuilder::default()
761            .spirv(spirv.into_vec())
762            .stage(stage)
763    }
764
765    /// Creates a new ray trace shader.
766    ///
767    /// # Panics
768    ///
769    /// If the shader code is invalid or not a multiple of four bytes in length.
770    pub fn new_any_hit(spirv: impl ShaderCode) -> ShaderBuilder {
771        Self::new(vk::ShaderStageFlags::ANY_HIT_KHR, spirv)
772    }
773
774    /// Creates a new ray trace shader.
775    ///
776    /// # Panics
777    ///
778    /// If the shader code is invalid or not a multiple of four bytes in length.
779    pub fn new_callable(spirv: impl ShaderCode) -> ShaderBuilder {
780        Self::new(vk::ShaderStageFlags::CALLABLE_KHR, spirv)
781    }
782
783    /// Creates a new ray trace shader.
784    ///
785    /// # Panics
786    ///
787    /// If the shader code is invalid or not a multiple of four bytes in length.
788    pub fn new_closest_hit(spirv: impl ShaderCode) -> ShaderBuilder {
789        Self::new(vk::ShaderStageFlags::CLOSEST_HIT_KHR, spirv)
790    }
791
792    /// Creates a new compute shader.
793    ///
794    /// # Panics
795    ///
796    /// If the shader code is invalid or not a multiple of four bytes in length.
797    pub fn new_compute(spirv: impl ShaderCode) -> ShaderBuilder {
798        Self::new(vk::ShaderStageFlags::COMPUTE, spirv)
799    }
800
801    /// Creates a new fragment shader.
802    ///
803    /// # Panics
804    ///
805    /// If the shader code is invalid or not a multiple of four bytes in length.
806    pub fn new_fragment(spirv: impl ShaderCode) -> ShaderBuilder {
807        Self::new(vk::ShaderStageFlags::FRAGMENT, spirv)
808    }
809
810    /// Creates a new geometry shader.
811    ///
812    /// # Panics
813    ///
814    /// If the shader code is invalid or not a multiple of four bytes in length.
815    pub fn new_geometry(spirv: impl ShaderCode) -> ShaderBuilder {
816        Self::new(vk::ShaderStageFlags::GEOMETRY, spirv)
817    }
818
819    /// Creates a new ray trace shader.
820    ///
821    /// # Panics
822    ///
823    /// If the shader code is invalid or not a multiple of four bytes in length.
824    pub fn new_intersection(spirv: impl ShaderCode) -> ShaderBuilder {
825        Self::new(vk::ShaderStageFlags::INTERSECTION_KHR, spirv)
826    }
827
828    /// Creates a new mesh shader.
829    ///
830    /// # Panics
831    ///
832    /// If the shader code is invalid.
833    pub fn new_mesh(spirv: impl ShaderCode) -> ShaderBuilder {
834        Self::new(vk::ShaderStageFlags::MESH_EXT, spirv)
835    }
836
837    /// Creates a new ray trace shader.
838    ///
839    /// # Panics
840    ///
841    /// If the shader code is invalid or not a multiple of four bytes in length.
842    pub fn new_miss(spirv: impl ShaderCode) -> ShaderBuilder {
843        Self::new(vk::ShaderStageFlags::MISS_KHR, spirv)
844    }
845
846    /// Creates a new ray trace shader.
847    ///
848    /// # Panics
849    ///
850    /// If the shader code is invalid or not a multiple of four bytes in length.
851    pub fn new_ray_gen(spirv: impl ShaderCode) -> ShaderBuilder {
852        Self::new(vk::ShaderStageFlags::RAYGEN_KHR, spirv)
853    }
854
855    /// Creates a new mesh task shader.
856    ///
857    /// # Panics
858    ///
859    /// If the shader code is invalid.
860    pub fn new_task(spirv: impl ShaderCode) -> ShaderBuilder {
861        Self::new(vk::ShaderStageFlags::TASK_EXT, spirv)
862    }
863
864    /// Creates a new tesselation control shader.
865    ///
866    /// # Panics
867    ///
868    /// If the shader code is invalid or not a multiple of four bytes in length.
869    pub fn new_tesselation_ctrl(spirv: impl ShaderCode) -> ShaderBuilder {
870        Self::new(vk::ShaderStageFlags::TESSELLATION_CONTROL, spirv)
871    }
872
873    /// Creates a new tesselation evaluation shader.
874    ///
875    /// # Panics
876    ///
877    /// If the shader code is invalid or not a multiple of four bytes in length.
878    pub fn new_tesselation_eval(spirv: impl ShaderCode) -> ShaderBuilder {
879        Self::new(vk::ShaderStageFlags::TESSELLATION_EVALUATION, spirv)
880    }
881
882    /// Creates a new vertex shader.
883    ///
884    /// # Panics
885    ///
886    /// If the shader code is invalid or not a multiple of four bytes in length.
887    pub fn new_vertex(spirv: impl ShaderCode) -> ShaderBuilder {
888        Self::new(vk::ShaderStageFlags::VERTEX, spirv)
889    }
890
891    /// Returns the input and write attachments of a shader.
892    #[profiling::function]
893    pub(super) fn attachments(
894        &self,
895    ) -> (
896        impl Iterator<Item = u32> + '_,
897        impl Iterator<Item = u32> + '_,
898    ) {
899        (
900            self.entry_point.vars.iter().filter_map(|var| match var {
901                Variable::Descriptor {
902                    desc_ty: DescriptorType::InputAttachment(attachment),
903                    ..
904                } => Some(*attachment),
905                _ => None,
906            }),
907            self.entry_point.vars.iter().filter_map(|var| match var {
908                Variable::Output { location, .. } => Some(location.loc()),
909                _ => None,
910            }),
911        )
912    }
913
914    #[profiling::function]
915    pub(super) fn descriptor_bindings(&self) -> DescriptorBindingMap {
916        let mut res = DescriptorBindingMap::default();
917
918        for (name, descriptor, desc_ty, binding_count) in
919            self.entry_point.vars.iter().filter_map(|var| match var {
920                Variable::Descriptor {
921                    name,
922                    desc_bind,
923                    desc_ty,
924                    nbind,
925                    ..
926                } => Some((
927                    name,
928                    Descriptor {
929                        set: desc_bind.set(),
930                        binding: desc_bind.bind(),
931                    },
932                    desc_ty,
933                    *nbind,
934                )),
935                _ => None,
936            })
937        {
938            trace!(
939                "descriptor {}: {}.{} = {:?}[{}]",
940                name.as_deref().unwrap_or_default(),
941                descriptor.set,
942                descriptor.binding,
943                *desc_ty,
944                binding_count
945            );
946
947            let descriptor_info = match desc_ty {
948                DescriptorType::AccelStruct() => {
949                    DescriptorInfo::AccelerationStructure(binding_count)
950                }
951                DescriptorType::CombinedImageSampler() => {
952                    let (sampler_info, is_manually_defined) =
953                        self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
954
955                    DescriptorInfo::CombinedImageSampler(
956                        binding_count,
957                        sampler_info,
958                        is_manually_defined,
959                    )
960                }
961                DescriptorType::InputAttachment(attachment) => {
962                    DescriptorInfo::InputAttachment(binding_count, *attachment)
963                }
964                DescriptorType::SampledImage() => DescriptorInfo::SampledImage(binding_count),
965                DescriptorType::Sampler() => {
966                    let (sampler_info, is_manually_defined) =
967                        self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
968
969                    DescriptorInfo::Sampler(binding_count, sampler_info, is_manually_defined)
970                }
971                DescriptorType::StorageBuffer(_access_ty) => {
972                    DescriptorInfo::StorageBuffer(binding_count)
973                }
974                DescriptorType::StorageImage(_access_ty) => {
975                    DescriptorInfo::StorageImage(binding_count)
976                }
977                DescriptorType::StorageTexelBuffer(_access_ty) => {
978                    DescriptorInfo::StorageTexelBuffer(binding_count)
979                }
980                DescriptorType::UniformBuffer() => DescriptorInfo::UniformBuffer(binding_count),
981                DescriptorType::UniformTexelBuffer() => {
982                    DescriptorInfo::UniformTexelBuffer(binding_count)
983                }
984            };
985            res.insert(descriptor, (descriptor_info, self.stage));
986        }
987
988        res
989    }
990
991    fn image_sampler(&self, descriptor: Descriptor, name: &str) -> (SamplerInfo, bool) {
992        self.image_samplers
993            .get(&descriptor)
994            .copied()
995            .map(|sampler_info| (sampler_info, true))
996            .unwrap_or_else(|| (guess_immutable_sampler(name), false))
997    }
998
999    #[profiling::function]
1000    pub(super) fn merge_descriptor_bindings(
1001        descriptor_bindings: impl IntoIterator<Item = DescriptorBindingMap>,
1002    ) -> DescriptorBindingMap {
1003        fn merge_info(lhs: &mut DescriptorInfo, rhs: DescriptorInfo) -> bool {
1004            let (lhs_count, rhs_count) = match lhs {
1005                DescriptorInfo::AccelerationStructure(lhs) => {
1006                    if let DescriptorInfo::AccelerationStructure(rhs) = rhs {
1007                        (lhs, rhs)
1008                    } else {
1009                        return false;
1010                    }
1011                }
1012                DescriptorInfo::CombinedImageSampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1013                    if let DescriptorInfo::CombinedImageSampler(
1014                        rhs,
1015                        rhs_sampler,
1016                        rhs_is_manually_defined,
1017                    ) = rhs
1018                    {
1019                        // Allow one of the samplers to be manually defined (only one!)
1020                        if *lhs_is_manually_defined && rhs_is_manually_defined {
1021                            return false;
1022                        } else if rhs_is_manually_defined {
1023                            *lhs_sampler = rhs_sampler;
1024                        }
1025
1026                        (lhs, rhs)
1027                    } else {
1028                        return false;
1029                    }
1030                }
1031                DescriptorInfo::InputAttachment(lhs, lhs_idx) => {
1032                    if let DescriptorInfo::InputAttachment(rhs, rhs_idx) = rhs {
1033                        if *lhs_idx != rhs_idx {
1034                            return false;
1035                        }
1036
1037                        (lhs, rhs)
1038                    } else {
1039                        return false;
1040                    }
1041                }
1042                DescriptorInfo::SampledImage(lhs) => {
1043                    if let DescriptorInfo::SampledImage(rhs) = rhs {
1044                        (lhs, rhs)
1045                    } else {
1046                        return false;
1047                    }
1048                }
1049                DescriptorInfo::Sampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1050                    if let DescriptorInfo::Sampler(rhs, rhs_sampler, rhs_is_manually_defined) = rhs
1051                    {
1052                        // Allow one of the samplers to be manually defined (only one!)
1053                        if *lhs_is_manually_defined && rhs_is_manually_defined {
1054                            return false;
1055                        } else if rhs_is_manually_defined {
1056                            *lhs_sampler = rhs_sampler;
1057                        }
1058
1059                        (lhs, rhs)
1060                    } else {
1061                        return false;
1062                    }
1063                }
1064                DescriptorInfo::StorageBuffer(lhs) => {
1065                    if let DescriptorInfo::StorageBuffer(rhs) = rhs {
1066                        (lhs, rhs)
1067                    } else {
1068                        return false;
1069                    }
1070                }
1071                DescriptorInfo::StorageImage(lhs) => {
1072                    if let DescriptorInfo::StorageImage(rhs) = rhs {
1073                        (lhs, rhs)
1074                    } else {
1075                        return false;
1076                    }
1077                }
1078                DescriptorInfo::StorageTexelBuffer(lhs) => {
1079                    if let DescriptorInfo::StorageTexelBuffer(rhs) = rhs {
1080                        (lhs, rhs)
1081                    } else {
1082                        return false;
1083                    }
1084                }
1085                DescriptorInfo::UniformBuffer(lhs) => {
1086                    if let DescriptorInfo::UniformBuffer(rhs) = rhs {
1087                        (lhs, rhs)
1088                    } else {
1089                        return false;
1090                    }
1091                }
1092                DescriptorInfo::UniformTexelBuffer(lhs) => {
1093                    if let DescriptorInfo::UniformTexelBuffer(rhs) = rhs {
1094                        (lhs, rhs)
1095                    } else {
1096                        return false;
1097                    }
1098                }
1099            };
1100
1101            *lhs_count = rhs_count.max(*lhs_count);
1102
1103            true
1104        }
1105
1106        #[profiling::function]
1107        fn merge_pair(src: DescriptorBindingMap, dst: &mut DescriptorBindingMap) {
1108            for (descriptor_binding, (descriptor_info, descriptor_flags)) in src.into_iter() {
1109                if let Some((existing_info, existing_flags)) = dst.get_mut(&descriptor_binding) {
1110                    if !merge_info(existing_info, descriptor_info) {
1111                        panic!("Inconsistent shader descriptors ({descriptor_binding:?})");
1112                    }
1113
1114                    *existing_flags |= descriptor_flags;
1115                } else {
1116                    dst.insert(descriptor_binding, (descriptor_info, descriptor_flags));
1117                }
1118            }
1119        }
1120
1121        let mut descriptor_bindings = descriptor_bindings.into_iter();
1122        let mut res = descriptor_bindings.next().unwrap_or_default();
1123        for descriptor_binding in descriptor_bindings {
1124            merge_pair(descriptor_binding, &mut res);
1125        }
1126
1127        res
1128    }
1129
1130    #[profiling::function]
1131    pub(super) fn push_constant_range(&self) -> Option<vk::PushConstantRange> {
1132        self.entry_point
1133            .vars
1134            .iter()
1135            .filter_map(|var| match var {
1136                Variable::PushConstant {
1137                    ty: Type::Struct(ty),
1138                    ..
1139                } => Some(ty.members.clone()),
1140                _ => None,
1141            })
1142            .flatten()
1143            .map(|push_const| {
1144                let offset = push_const.offset.unwrap_or_default();
1145                let size = push_const
1146                    .ty
1147                    .nbyte()
1148                    .unwrap_or_default()
1149                    .next_multiple_of(4);
1150                offset..offset + size
1151            })
1152            .reduce(|a, b| a.start.min(b.start)..a.end.max(b.end))
1153            .map(|push_const| vk::PushConstantRange {
1154                stage_flags: self.stage,
1155                size: (push_const.end - push_const.start) as _,
1156                offset: push_const.start as _,
1157            })
1158    }
1159
1160    #[profiling::function]
1161    fn reflect_entry_point(
1162        entry_name: &str,
1163        spirv: &[u8],
1164        specialization_info: Option<&SpecializationInfo>,
1165    ) -> Result<EntryPoint, DriverError> {
1166        let mut config = ReflectConfig::new();
1167        config.ref_all_rscs(true).spv(spirv);
1168
1169        if let Some(spec_info) = specialization_info {
1170            for spec in &spec_info.map_entries {
1171                config.specialize(
1172                    spec.constant_id,
1173                    spec_info.data[spec.offset as usize..spec.offset as usize + spec.size].into(),
1174                );
1175            }
1176        }
1177
1178        let entry_points = config.reflect().map_err(|err| {
1179            error!("Unable to reflect spirv: {err}");
1180
1181            DriverError::InvalidData
1182        })?;
1183        let entry_point = entry_points
1184            .into_iter()
1185            .find(|entry_point| entry_point.name == entry_name)
1186            .ok_or_else(|| {
1187                error!("Entry point not found");
1188
1189                DriverError::InvalidData
1190            })?;
1191
1192        Ok(entry_point)
1193    }
1194
1195    #[profiling::function]
1196    pub(super) fn vertex_input(&self) -> VertexInputState {
1197        // Check for manually-specified vertex layout descriptions
1198        if let Some(vertex_input) = &self.vertex_input_state {
1199            return vertex_input.clone();
1200        }
1201
1202        fn scalar_format(ty: &ScalarType, byte_len: u32) -> vk::Format {
1203            match ty {
1204                ScalarType::Float { .. } => match byte_len {
1205                    4 => vk::Format::R32_SFLOAT,
1206                    8 => vk::Format::R32G32_SFLOAT,
1207                    12 => vk::Format::R32G32B32_SFLOAT,
1208                    16 => vk::Format::R32G32B32A32_SFLOAT,
1209                    _ => unimplemented!("byte_len {byte_len}"),
1210                },
1211                ScalarType::Integer {
1212                    is_signed: true, ..
1213                } => match byte_len {
1214                    4 => vk::Format::R32_SINT,
1215                    8 => vk::Format::R32G32_SINT,
1216                    12 => vk::Format::R32G32B32_SINT,
1217                    16 => vk::Format::R32G32B32A32_SINT,
1218                    _ => unimplemented!("byte_len {byte_len}"),
1219                },
1220                ScalarType::Integer {
1221                    is_signed: false, ..
1222                } => match byte_len {
1223                    4 => vk::Format::R32_UINT,
1224                    8 => vk::Format::R32G32_UINT,
1225                    12 => vk::Format::R32G32B32_UINT,
1226                    16 => vk::Format::R32G32B32A32_UINT,
1227                    _ => unimplemented!("byte_len {byte_len}"),
1228                },
1229                _ => unimplemented!("{:?}", ty),
1230            }
1231        }
1232
1233        let mut input_rates_strides = HashMap::new();
1234        let mut vertex_attribute_descriptions = vec![];
1235
1236        for (name, location, ty) in self.entry_point.vars.iter().filter_map(|var| match var {
1237            Variable::Input { name, location, ty } => Some((name, location, ty)),
1238            _ => None,
1239        }) {
1240            let (binding, guessed_rate) = name
1241                .as_ref()
1242                .filter(|name| name.contains("_ibind") || name.contains("_vbind"))
1243                .map(|name| {
1244                    let binding = name[name.rfind("bind").unwrap()..]
1245                        .parse()
1246                        .unwrap_or_default();
1247                    let rate = if name.contains("_ibind") {
1248                        vk::VertexInputRate::INSTANCE
1249                    } else {
1250                        vk::VertexInputRate::VERTEX
1251                    };
1252
1253                    (binding, rate)
1254                })
1255                .unwrap_or_default();
1256            let (location, _) = location.into_inner();
1257            if let Some((input_rate, _)) = input_rates_strides.get(&binding) {
1258                assert_eq!(*input_rate, guessed_rate);
1259            }
1260
1261            let byte_stride = ty.nbyte().unwrap_or_default() as u32;
1262            let (input_rate, stride) = input_rates_strides.entry(binding).or_default();
1263            *input_rate = guessed_rate;
1264            *stride += byte_stride;
1265
1266            //trace!("{location} {:?} is {byte_stride} bytes", name);
1267
1268            vertex_attribute_descriptions.push(vk::VertexInputAttributeDescription {
1269                location,
1270                binding,
1271                format: match ty {
1272                    Type::Scalar(ty) => scalar_format(ty, ty.nbyte().unwrap_or_default() as _),
1273                    Type::Vector(ty) => scalar_format(&ty.scalar_ty, byte_stride),
1274                    _ => unimplemented!("{:?}", ty),
1275                },
1276                offset: byte_stride, // Figured out below - this data is iter'd in an unknown order
1277            });
1278        }
1279
1280        vertex_attribute_descriptions.sort_unstable_by(|lhs, rhs| {
1281            let binding = lhs.binding.cmp(&rhs.binding);
1282            if binding.is_lt() {
1283                return binding;
1284            }
1285
1286            lhs.location.cmp(&rhs.location)
1287        });
1288
1289        let mut offset = 0;
1290        let mut offset_binding = 0;
1291
1292        for vertex_attribute_description in &mut vertex_attribute_descriptions {
1293            if vertex_attribute_description.binding != offset_binding {
1294                offset_binding = vertex_attribute_description.binding;
1295                offset = 0;
1296            }
1297
1298            let stride = vertex_attribute_description.offset;
1299            vertex_attribute_description.offset = offset;
1300            offset += stride;
1301
1302            debug!(
1303                "vertex attribute {}.{}: {:?} (offset={})",
1304                vertex_attribute_description.binding,
1305                vertex_attribute_description.location,
1306                vertex_attribute_description.format,
1307                vertex_attribute_description.offset,
1308            );
1309        }
1310
1311        let mut vertex_binding_descriptions = vec![];
1312        for (binding, (input_rate, stride)) in input_rates_strides.into_iter() {
1313            vertex_binding_descriptions.push(vk::VertexInputBindingDescription {
1314                binding,
1315                input_rate,
1316                stride,
1317            });
1318        }
1319
1320        VertexInputState {
1321            vertex_attribute_descriptions,
1322            vertex_binding_descriptions,
1323        }
1324    }
1325}
1326
1327impl Debug for Shader {
1328    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1329        // We don't want the default formatter bc vec u8
1330        // TODO: Better output message
1331        f.write_str("Shader")
1332    }
1333}
1334
1335impl From<ShaderBuilder> for Shader {
1336    fn from(shader: ShaderBuilder) -> Self {
1337        shader.build()
1338    }
1339}
1340
1341// HACK: https://github.com/colin-kiegel/rust-derive-builder/issues/56
1342impl ShaderBuilder {
1343    /// Specifies a shader with the given `stage` and shader code values.
1344    pub fn new(stage: vk::ShaderStageFlags, spirv: Vec<u8>) -> Self {
1345        Self::default().stage(stage).spirv(spirv)
1346    }
1347
1348    /// Builds a new `Shader`.
1349    pub fn build(mut self) -> Shader {
1350        let entry_name = self.entry_name.as_deref().unwrap_or("main");
1351        self.entry_point = Some(
1352            Shader::reflect_entry_point(
1353                entry_name,
1354                self.spirv.as_deref().unwrap(),
1355                self.specialization_info
1356                    .as_ref()
1357                    .map(|opt| opt.as_ref())
1358                    .unwrap_or_default(),
1359            )
1360            .unwrap_or_else(|_| panic!("invalid shader code for entry name \'{entry_name}\'")),
1361        );
1362
1363        self.fallible_build()
1364            .expect("All required fields set at initialization")
1365    }
1366
1367    /// Specifies a manually-defined image sampler.
1368    ///
1369    /// Sampled images, by default, use reflection to automatically assign image samplers. Each
1370    /// sampled image may use a suffix such as `_llr` or `_nne` for common linear/linear repeat or
1371    /// nearest/nearest clamp-to-edge samplers, respectively.
1372    ///
1373    /// See the [main documentation] for more information about automatic image samplers.
1374    ///
1375    /// Descriptor bindings may be specified as `(1, 2)` for descriptor set index `1` and binding
1376    /// index `2`, or if the descriptor set index is `0` simply specify `2` for the same case.
1377    ///
1378    /// _NOTE:_ When defining image samplers which are used in multiple stages of a single pipeline
1379    /// you must only call this function on one of the shader stages, it does not matter which one.
1380    ///
1381    /// # Panics
1382    ///
1383    /// Panics if two shader stages of the same pipeline define individual calls to `image_sampler`.
1384    ///
1385    /// [main documentation]: crate
1386    #[profiling::function]
1387    pub fn image_sampler(
1388        mut self,
1389        descriptor: impl Into<Descriptor>,
1390        info: impl Into<SamplerInfo>,
1391    ) -> Self {
1392        let descriptor = descriptor.into();
1393        let info = info.into();
1394
1395        if self.image_samplers.is_none() {
1396            self.image_samplers = Some(Default::default());
1397        }
1398
1399        self.image_samplers
1400            .as_mut()
1401            .unwrap()
1402            .insert(descriptor, info);
1403
1404        self
1405    }
1406
1407    /// Specifies a manually-defined vertex input layout.
1408    ///
1409    /// The vertex input layout, by default, uses reflection to automatically define vertex binding
1410    /// and attribute descriptions. Each vertex location is inferred to have 32-bit channels and be
1411    /// tightly packed in the vertex buffer. In this mode, a location with `_ibind_0` or `_vbind3`
1412    /// suffixes is inferred to use instance-rate on vertex buffer binding `0` or vertex rate on
1413    /// binding `3`, respectively.
1414    ///
1415    /// See the [main documentation] for more information about automatic vertex input layout.
1416    ///
1417    /// [main documentation]: crate
1418    #[profiling::function]
1419    pub fn vertex_input(
1420        mut self,
1421        bindings: &[vk::VertexInputBindingDescription],
1422        attributes: &[vk::VertexInputAttributeDescription],
1423    ) -> Self {
1424        self.vertex_input_state = Some(Some(VertexInputState {
1425            vertex_binding_descriptions: bindings.to_vec(),
1426            vertex_attribute_descriptions: attributes.to_vec(),
1427        }));
1428        self
1429    }
1430}
1431
1432#[derive(Debug)]
1433struct ShaderBuilderError;
1434
1435impl From<UninitializedFieldError> for ShaderBuilderError {
1436    fn from(_: UninitializedFieldError) -> Self {
1437        Self
1438    }
1439}
1440
1441/// Trait for types which can be converted into shader code.
1442pub trait ShaderCode {
1443    /// Converts the instance into SPIR-V shader code specified as a byte array.
1444    fn into_vec(self) -> Vec<u8>;
1445}
1446
1447impl ShaderCode for &[u8] {
1448    fn into_vec(self) -> Vec<u8> {
1449        debug_assert_eq!(self.len() % 4, 0, "invalid spir-v code");
1450
1451        self.to_vec()
1452    }
1453}
1454
1455impl ShaderCode for &[u32] {
1456    fn into_vec(self) -> Vec<u8> {
1457        pub fn into_u8_slice<T>(t: &[T]) -> &[u8]
1458        where
1459            T: Sized,
1460        {
1461            use std::slice::from_raw_parts;
1462
1463            unsafe { from_raw_parts(t.as_ptr() as *const _, size_of_val(t)) }
1464        }
1465
1466        into_u8_slice(self).into_vec()
1467    }
1468}
1469
1470impl ShaderCode for Vec<u8> {
1471    fn into_vec(self) -> Vec<u8> {
1472        debug_assert_eq!(self.len() % 4, 0, "invalid spir-v code");
1473
1474        self
1475    }
1476}
1477
1478impl ShaderCode for Vec<u32> {
1479    fn into_vec(self) -> Vec<u8> {
1480        self.as_slice().into_vec()
1481    }
1482}
1483
1484/// Describes specialized constant values.
1485#[derive(Clone, Debug)]
1486pub struct SpecializationInfo {
1487    /// A buffer of data which holds the constant values.
1488    pub data: Vec<u8>,
1489
1490    /// Mapping of locations within the constant value data which describe each individual constant.
1491    pub map_entries: Vec<vk::SpecializationMapEntry>,
1492}
1493
1494impl SpecializationInfo {
1495    /// Constructs a new `SpecializationInfo`.
1496    pub fn new(
1497        map_entries: impl Into<Vec<vk::SpecializationMapEntry>>,
1498        data: impl Into<Vec<u8>>,
1499    ) -> Self {
1500        Self {
1501            data: data.into(),
1502            map_entries: map_entries.into(),
1503        }
1504    }
1505}
1506
1507#[cfg(test)]
1508mod tests {
1509    use super::*;
1510
1511    type Info = SamplerInfo;
1512    type Builder = SamplerInfoBuilder;
1513
1514    #[test]
1515    pub fn sampler_info() {
1516        let info = Info::default();
1517        let builder = info.to_builder().build();
1518
1519        assert_eq!(info, builder);
1520    }
1521
1522    #[test]
1523    pub fn sampler_info_builder() {
1524        let info = Info::default();
1525        let builder = Builder::default().build();
1526
1527        assert_eq!(info, builder);
1528    }
1529}