1use {
4 super::{DescriptorSetLayout, DriverError, VertexInputState, device::Device},
5 ash::vk,
6 derive_builder::{Builder, UninitializedFieldError},
7 log::{debug, error, trace, warn},
8 ordered_float::OrderedFloat,
9 spirq::{
10 ReflectConfig,
11 entry_point::EntryPoint,
12 parse::SpirvBinary,
13 spirv::ExecutionModel,
14 ty::{DescriptorType, ScalarType, Type, VectorType},
15 var::Variable,
16 },
17 std::{
18 collections::{BTreeMap, HashMap},
19 fmt::{Debug, Formatter},
20 iter::repeat_n,
21 ops::Deref,
22 thread::panicking,
23 },
24};
25
26#[allow(deprecated)]
27#[deprecated = "use SpecializationMap struct"]
28#[doc(hidden)]
29pub type SpecializationInfo = self::deprecated::SpecializationInfo;
30
31pub(crate) type DescriptorBindingMap = HashMap<Descriptor, (DescriptorInfo, vk::ShaderStageFlags)>;
32
33#[profiling::function]
34fn guess_immutable_sampler(binding_name: &str) -> SamplerInfo {
35 const INVALID_ERR: &str = "Invalid sampler specification";
36
37 let (texel_filter, mipmap_mode, address_modes) = if binding_name.contains("_sampler_") {
38 let spec = &binding_name[binding_name.len() - 3..];
39 let texel_filter = match &spec[0..1] {
40 "n" => vk::Filter::NEAREST,
41 "l" => vk::Filter::LINEAR,
42 _ => panic!("{INVALID_ERR}: {}", &spec[0..1]),
43 };
44
45 let mipmap_mode = match &spec[1..2] {
46 "n" => vk::SamplerMipmapMode::NEAREST,
47 "l" => vk::SamplerMipmapMode::LINEAR,
48 _ => panic!("{INVALID_ERR}: {}", &spec[1..2]),
49 };
50
51 let address_modes = match &spec[2..3] {
52 "b" => vk::SamplerAddressMode::CLAMP_TO_BORDER,
53 "e" => vk::SamplerAddressMode::CLAMP_TO_EDGE,
54 "m" => vk::SamplerAddressMode::MIRRORED_REPEAT,
55 "r" => vk::SamplerAddressMode::REPEAT,
56 _ => panic!("{INVALID_ERR}: {}", &spec[2..3]),
57 };
58
59 (texel_filter, mipmap_mode, address_modes)
60 } else {
61 debug!("image binding {binding_name} using default sampler");
62
63 (
64 vk::Filter::LINEAR,
65 vk::SamplerMipmapMode::LINEAR,
66 vk::SamplerAddressMode::REPEAT,
67 )
68 };
69 let anisotropy_enable = texel_filter == vk::Filter::LINEAR;
70 let mut info = SamplerInfoBuilder::default()
71 .mag_filter(texel_filter)
72 .min_filter(texel_filter)
73 .mipmap_mode(mipmap_mode)
74 .address_mode_u(address_modes)
75 .address_mode_v(address_modes)
76 .address_mode_w(address_modes)
77 .max_lod(vk::LOD_CLAMP_NONE)
78 .anisotropy_enable(anisotropy_enable);
79
80 if anisotropy_enable {
81 info = info.max_anisotropy(16.0);
82 }
83
84 info.build()
85}
86
87#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
92pub struct Descriptor {
93 pub set: u32,
95
96 pub binding: u32,
98}
99
100impl From<u32> for Descriptor {
101 fn from(binding: u32) -> Self {
102 Self { set: 0, binding }
103 }
104}
105
106impl From<(u32, u32)> for Descriptor {
107 fn from((set, binding): (u32, u32)) -> Self {
108 Self { set, binding }
109 }
110}
111
112#[derive(Clone, Copy, Debug)]
113pub(crate) enum DescriptorInfo {
114 AccelerationStructure(u32),
115 CombinedImageSampler(u32, SamplerInfo, bool), InputAttachment(u32, u32), SampledImage(u32),
118 Sampler(u32, SamplerInfo, bool), StorageBuffer(u32),
120 StorageImage(u32),
121 StorageTexelBuffer(u32),
122 UniformBuffer(u32),
123 UniformTexelBuffer(u32),
124}
125
126impl DescriptorInfo {
127 pub fn binding_count(self) -> u32 {
128 match self {
129 Self::AccelerationStructure(binding_count) => binding_count,
130 Self::CombinedImageSampler(binding_count, ..) => binding_count,
131 Self::InputAttachment(binding_count, _) => binding_count,
132 Self::SampledImage(binding_count) => binding_count,
133 Self::Sampler(binding_count, ..) => binding_count,
134 Self::StorageBuffer(binding_count) => binding_count,
135 Self::StorageImage(binding_count) => binding_count,
136 Self::StorageTexelBuffer(binding_count) => binding_count,
137 Self::UniformBuffer(binding_count) => binding_count,
138 Self::UniformTexelBuffer(binding_count) => binding_count,
139 }
140 }
141
142 pub fn descriptor_type(self) -> vk::DescriptorType {
143 match self {
144 Self::AccelerationStructure(_) => vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
145 Self::CombinedImageSampler(..) => vk::DescriptorType::COMBINED_IMAGE_SAMPLER,
146 Self::InputAttachment(..) => vk::DescriptorType::INPUT_ATTACHMENT,
147 Self::SampledImage(_) => vk::DescriptorType::SAMPLED_IMAGE,
148 Self::Sampler(..) => vk::DescriptorType::SAMPLER,
149 Self::StorageBuffer(_) => vk::DescriptorType::STORAGE_BUFFER,
150 Self::StorageImage(_) => vk::DescriptorType::STORAGE_IMAGE,
151 Self::StorageTexelBuffer(_) => vk::DescriptorType::STORAGE_TEXEL_BUFFER,
152 Self::UniformBuffer(_) => vk::DescriptorType::UNIFORM_BUFFER,
153 Self::UniformTexelBuffer(_) => vk::DescriptorType::UNIFORM_TEXEL_BUFFER,
154 }
155 }
156
157 fn sampler_info(self) -> Option<SamplerInfo> {
158 match self {
159 Self::CombinedImageSampler(_, sampler_info, _) | Self::Sampler(_, sampler_info, _) => {
160 Some(sampler_info)
161 }
162 _ => None,
163 }
164 }
165
166 pub fn set_binding_count(&mut self, binding_count: u32) {
167 *match self {
168 Self::AccelerationStructure(binding_count) => binding_count,
169 Self::CombinedImageSampler(binding_count, ..) => binding_count,
170 Self::InputAttachment(binding_count, _) => binding_count,
171 Self::SampledImage(binding_count) => binding_count,
172 Self::Sampler(binding_count, ..) => binding_count,
173 Self::StorageBuffer(binding_count) => binding_count,
174 Self::StorageImage(binding_count) => binding_count,
175 Self::StorageTexelBuffer(binding_count) => binding_count,
176 Self::UniformBuffer(binding_count) => binding_count,
177 Self::UniformTexelBuffer(binding_count) => binding_count,
178 } = binding_count;
179 }
180}
181
182#[derive(Debug)]
183pub(crate) struct PipelineDescriptorInfo {
184 pub layouts: BTreeMap<u32, DescriptorSetLayout>,
185 pub pool_sizes: HashMap<u32, HashMap<vk::DescriptorType, u32>>,
186
187 #[allow(dead_code)]
188 samplers: Box<[Sampler]>,
189}
190
191impl PipelineDescriptorInfo {
192 #[profiling::function]
193 pub fn create(
194 device: &Device,
195 descriptor_bindings: &DescriptorBindingMap,
196 ) -> Result<Self, DriverError> {
197 let descriptor_set_count = descriptor_bindings
198 .keys()
199 .map(|descriptor| descriptor.set)
200 .max()
201 .map(|set| set + 1)
202 .unwrap_or_default();
203 let mut layouts = BTreeMap::new();
204 let mut pool_sizes = HashMap::new();
205
206 let mut sampler_info_binding_count = HashMap::<_, u32>::with_capacity(
209 descriptor_bindings
210 .values()
211 .filter(|(descriptor_info, _)| descriptor_info.sampler_info().is_some())
212 .count(),
213 );
214
215 for (sampler_info, binding_count) in
216 descriptor_bindings
217 .values()
218 .filter_map(|(descriptor_info, _)| {
219 descriptor_info
220 .sampler_info()
221 .map(|sampler_info| (sampler_info, descriptor_info.binding_count()))
222 })
223 {
224 sampler_info_binding_count
225 .entry(sampler_info)
226 .and_modify(|sampler_info_binding_count| {
227 *sampler_info_binding_count = binding_count.max(*sampler_info_binding_count);
228 })
229 .or_insert(binding_count);
230 }
231
232 let mut samplers = sampler_info_binding_count
233 .keys()
234 .copied()
235 .map(|sampler_info| {
236 Sampler::create(device, sampler_info).map(|sampler| (sampler_info, sampler))
237 })
238 .collect::<Result<HashMap<_, _>, _>>()?;
239 let immutable_samplers = sampler_info_binding_count
240 .iter()
241 .map(|(sampler_info, &binding_count)| {
242 (
243 *sampler_info,
244 repeat_n(*samplers[sampler_info], binding_count as _).collect::<Box<_>>(),
245 )
246 })
247 .collect::<HashMap<_, _>>();
248
249 for descriptor_set_idx in 0..descriptor_set_count {
250 let mut binding_counts = HashMap::<vk::DescriptorType, u32>::new();
251 let mut bindings = vec![];
252
253 for (descriptor, (descriptor_info, stage_flags)) in descriptor_bindings
254 .iter()
255 .filter(|(descriptor, _)| descriptor.set == descriptor_set_idx)
256 {
257 let descriptor_ty = descriptor_info.descriptor_type();
258 *binding_counts.entry(descriptor_ty).or_default() +=
259 descriptor_info.binding_count();
260 let mut binding = vk::DescriptorSetLayoutBinding::default()
261 .binding(descriptor.binding)
262 .descriptor_count(descriptor_info.binding_count())
263 .descriptor_type(descriptor_ty)
264 .stage_flags(*stage_flags);
265
266 if let Some(immutable_samplers) =
267 descriptor_info.sampler_info().map(|sampler_info| {
268 &immutable_samplers[&sampler_info]
269 [0..descriptor_info.binding_count() as usize]
270 })
271 {
272 binding = binding.immutable_samplers(immutable_samplers);
273 }
274
275 bindings.push(binding);
276 }
277
278 let pool_size = pool_sizes
279 .entry(descriptor_set_idx)
280 .or_insert_with(HashMap::new);
281
282 for (descriptor_ty, binding_count) in binding_counts.into_iter() {
283 *pool_size.entry(descriptor_ty).or_default() += binding_count;
284 }
285
286 let mut create_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
289
290 let bindless_flags = vec![vk::DescriptorBindingFlags::PARTIALLY_BOUND; bindings.len()];
294 let mut bindless_flags = if device
295 .physical_device
296 .features_v1_2
297 .descriptor_binding_partially_bound
298 {
299 let bindless_flags = vk::DescriptorSetLayoutBindingFlagsCreateInfo::default()
300 .binding_flags(&bindless_flags);
301 Some(bindless_flags)
302 } else {
303 None
304 };
305
306 if let Some(bindless_flags) = bindless_flags.as_mut() {
307 create_info = create_info.push_next(bindless_flags);
308 }
309
310 layouts.insert(
311 descriptor_set_idx,
312 DescriptorSetLayout::create(device, &create_info)?,
313 );
314 }
315
316 let samplers = samplers
317 .drain()
318 .map(|(_, sampler)| sampler)
319 .collect::<Box<_>>();
320
321 Ok(Self {
325 layouts,
326 pool_sizes,
327 samplers,
328 })
329 }
330}
331
332pub(crate) struct Sampler {
333 device: Device,
334 sampler: vk::Sampler,
335}
336
337impl Sampler {
338 #[profiling::function]
339 pub fn create(device: &Device, info: impl Into<SamplerInfo>) -> Result<Self, DriverError> {
340 let device = device.clone();
341 let info = info.into();
342
343 let sampler = unsafe {
344 device
345 .create_sampler(
346 &vk::SamplerCreateInfo::default()
347 .flags(info.flags)
348 .mag_filter(info.mag_filter)
349 .min_filter(info.min_filter)
350 .mipmap_mode(info.mipmap_mode)
351 .address_mode_u(info.address_mode_u)
352 .address_mode_v(info.address_mode_v)
353 .address_mode_w(info.address_mode_w)
354 .mip_lod_bias(info.mip_lod_bias.0)
355 .anisotropy_enable(info.anisotropy_enable)
356 .max_anisotropy(info.max_anisotropy.0)
357 .compare_enable(info.compare_enable)
358 .compare_op(info.compare_op)
359 .min_lod(info.min_lod.0)
360 .max_lod(info.max_lod.0)
361 .border_color(info.border_color)
362 .unnormalized_coordinates(info.unnormalized_coordinates)
363 .push_next(
364 &mut vk::SamplerReductionModeCreateInfo::default()
365 .reduction_mode(info.reduction_mode),
366 ),
367 None,
368 )
369 .map_err(|err| match err {
370 vk::Result::ERROR_OUT_OF_HOST_MEMORY
371 | vk::Result::ERROR_OUT_OF_DEVICE_MEMORY => {
372 warn!("unable to create sampler: {err}");
373 DriverError::OutOfMemory
374 }
375 _ => {
376 warn!("unsupported sampler creation: {err}");
377 DriverError::Unsupported
378 }
379 })?
380 };
381
382 Ok(Self { device, sampler })
383 }
384}
385
386impl Debug for Sampler {
387 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
388 write!(f, "{:?}", self.sampler)
389 }
390}
391
392impl Deref for Sampler {
393 type Target = vk::Sampler;
394
395 fn deref(&self) -> &Self::Target {
396 &self.sampler
397 }
398}
399
400impl Drop for Sampler {
401 #[profiling::function]
402 fn drop(&mut self) {
403 if panicking() {
404 return;
405 }
406
407 unsafe {
408 self.device.destroy_sampler(self.sampler, None);
409 }
410 }
411}
412
413#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
415#[builder(
416 build_fn(private, name = "fallible_build", error = "SamplerInfoBuilderError"),
417 derive(Clone, Copy, Debug),
418 pattern = "owned"
419)]
420pub struct SamplerInfo {
421 #[builder(default)]
423 pub flags: vk::SamplerCreateFlags,
424
425 #[builder(default)]
429 pub mag_filter: vk::Filter,
430
431 #[builder(default)]
435 pub min_filter: vk::Filter,
436
437 #[builder(default)]
441 pub mipmap_mode: vk::SamplerMipmapMode,
442
443 #[builder(default)]
447 pub address_mode_u: vk::SamplerAddressMode,
448
449 #[builder(default)]
453 pub address_mode_v: vk::SamplerAddressMode,
454
455 #[builder(default)]
459 pub address_mode_w: vk::SamplerAddressMode,
460
461 #[builder(default, setter(into))]
466 pub mip_lod_bias: OrderedFloat<f32>,
467
468 #[builder(default)]
472 pub anisotropy_enable: bool,
473
474 #[builder(default, setter(into))]
478 pub max_anisotropy: OrderedFloat<f32>,
479
480 #[builder(default)]
482 pub compare_enable: bool,
483
484 #[builder(default)]
489 pub compare_op: vk::CompareOp,
490
491 #[builder(default, setter(into))]
494 pub min_lod: OrderedFloat<f32>,
495
496 #[builder(default, setter(into))]
501 pub max_lod: OrderedFloat<f32>,
502
503 #[builder(default)]
507 pub border_color: vk::BorderColor,
508
509 #[builder(default)]
520 pub unnormalized_coordinates: bool,
521
522 #[builder(default)]
532 pub reduction_mode: vk::SamplerReductionMode,
533}
534
535impl SamplerInfo {
536 pub const LINEAR: SamplerInfoBuilder = SamplerInfoBuilder {
538 flags: None,
539 mag_filter: Some(vk::Filter::LINEAR),
540 min_filter: Some(vk::Filter::LINEAR),
541 mipmap_mode: Some(vk::SamplerMipmapMode::LINEAR),
542 address_mode_u: None,
543 address_mode_v: None,
544 address_mode_w: None,
545 mip_lod_bias: None,
546 anisotropy_enable: None,
547 max_anisotropy: None,
548 compare_enable: None,
549 compare_op: None,
550 min_lod: None,
551 max_lod: None,
552 border_color: None,
553 unnormalized_coordinates: None,
554 reduction_mode: None,
555 };
556
557 pub const NEAREST: SamplerInfoBuilder = SamplerInfoBuilder {
560 flags: None,
561 mag_filter: Some(vk::Filter::NEAREST),
562 min_filter: Some(vk::Filter::NEAREST),
563 mipmap_mode: Some(vk::SamplerMipmapMode::NEAREST),
564 address_mode_u: None,
565 address_mode_v: None,
566 address_mode_w: None,
567 mip_lod_bias: None,
568 anisotropy_enable: None,
569 max_anisotropy: None,
570 compare_enable: None,
571 compare_op: None,
572 min_lod: None,
573 max_lod: None,
574 border_color: None,
575 unnormalized_coordinates: None,
576 reduction_mode: None,
577 };
578
579 #[allow(clippy::new_ret_no_self)]
581 #[deprecated = "Use SamplerInfo::default()"]
582 #[doc(hidden)]
583 pub fn new() -> SamplerInfoBuilder {
584 Self::default().into_builder()
585 }
586
587 pub fn builder() -> SamplerInfoBuilder {
589 Default::default()
590 }
591
592 pub fn into_builder(self) -> SamplerInfoBuilder {
594 SamplerInfoBuilder {
595 flags: Some(self.flags),
596 mag_filter: Some(self.mag_filter),
597 min_filter: Some(self.min_filter),
598 mipmap_mode: Some(self.mipmap_mode),
599 address_mode_u: Some(self.address_mode_u),
600 address_mode_v: Some(self.address_mode_v),
601 address_mode_w: Some(self.address_mode_w),
602 mip_lod_bias: Some(self.mip_lod_bias),
603 anisotropy_enable: Some(self.anisotropy_enable),
604 max_anisotropy: Some(self.max_anisotropy),
605 compare_enable: Some(self.compare_enable),
606 compare_op: Some(self.compare_op),
607 min_lod: Some(self.min_lod),
608 max_lod: Some(self.max_lod),
609 border_color: Some(self.border_color),
610 unnormalized_coordinates: Some(self.unnormalized_coordinates),
611 reduction_mode: Some(self.reduction_mode),
612 }
613 }
614
615 #[deprecated = "use into_builder function"]
616 #[doc(hidden)]
617 pub fn to_builder(self) -> SamplerInfoBuilder {
618 self.into_builder()
619 }
620}
621
622impl Default for SamplerInfo {
623 fn default() -> Self {
624 Self {
625 flags: vk::SamplerCreateFlags::empty(),
626 mag_filter: vk::Filter::NEAREST,
627 min_filter: vk::Filter::NEAREST,
628 mipmap_mode: vk::SamplerMipmapMode::NEAREST,
629 address_mode_u: vk::SamplerAddressMode::REPEAT,
630 address_mode_v: vk::SamplerAddressMode::REPEAT,
631 address_mode_w: vk::SamplerAddressMode::REPEAT,
632 mip_lod_bias: OrderedFloat(0.0),
633 anisotropy_enable: false,
634 max_anisotropy: OrderedFloat(0.0),
635 compare_enable: false,
636 compare_op: vk::CompareOp::NEVER,
637 min_lod: OrderedFloat(0.0),
638 max_lod: OrderedFloat(0.0),
639 border_color: vk::BorderColor::FLOAT_TRANSPARENT_BLACK,
640 unnormalized_coordinates: false,
641 reduction_mode: vk::SamplerReductionMode::WEIGHTED_AVERAGE,
642 }
643 }
644}
645
646impl SamplerInfoBuilder {
647 #[inline(always)]
649 pub fn build(self) -> SamplerInfo {
650 self.fallible_build().expect("invalid sampler info")
651 }
652}
653
654impl From<SamplerInfoBuilder> for SamplerInfo {
655 fn from(info: SamplerInfoBuilder) -> Self {
656 info.build()
657 }
658}
659
660#[derive(Debug)]
661struct SamplerInfoBuilderError;
662
663impl From<UninitializedFieldError> for SamplerInfoBuilderError {
664 fn from(_: UninitializedFieldError) -> Self {
665 Self
666 }
667}
668
669#[allow(missing_docs)]
671#[derive(Builder, Clone)]
672#[builder(
673 build_fn(private, name = "fallible_build", error = "ShaderBuilderError"),
674 derive(Clone, Debug),
675 pattern = "owned"
676)]
677pub struct Shader {
678 #[builder(default = "\"main\".to_owned()", setter(into))]
682 pub entry_name: String,
683
684 #[builder(default, setter(strip_option))]
725 pub specialization: Option<SpecializationMap>,
726
727 #[builder(setter(into))]
733 pub spirv: SpirvBinary,
734
735 pub stage: vk::ShaderStageFlags,
737
738 #[builder(private)]
739 entry_point: EntryPoint,
740
741 #[builder(default, private)]
742 image_samplers: HashMap<Descriptor, SamplerInfo>,
743
744 #[builder(default, private, setter(strip_option))]
745 vertex_input_state: Option<VertexInputState>,
746}
747
748impl Shader {
749 #[allow(clippy::new_ret_no_self)]
751 pub fn new(stage: vk::ShaderStageFlags, spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
752 ShaderBuilder::default().spirv(spirv).stage(stage)
753 }
754
755 pub fn new_any_hit(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
761 Self::new(vk::ShaderStageFlags::ANY_HIT_KHR, spirv)
762 }
763
764 pub fn new_callable(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
770 Self::new(vk::ShaderStageFlags::CALLABLE_KHR, spirv)
771 }
772
773 pub fn new_closest_hit(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
779 Self::new(vk::ShaderStageFlags::CLOSEST_HIT_KHR, spirv)
780 }
781
782 pub fn new_compute(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
788 Self::new(vk::ShaderStageFlags::COMPUTE, spirv)
789 }
790
791 pub fn new_fragment(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
797 Self::new(vk::ShaderStageFlags::FRAGMENT, spirv)
798 }
799
800 pub fn new_geometry(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
806 Self::new(vk::ShaderStageFlags::GEOMETRY, spirv)
807 }
808
809 pub fn new_intersection(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
815 Self::new(vk::ShaderStageFlags::INTERSECTION_KHR, spirv)
816 }
817
818 pub fn new_mesh(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
824 Self::new(vk::ShaderStageFlags::MESH_EXT, spirv)
825 }
826
827 pub fn new_miss(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
833 Self::new(vk::ShaderStageFlags::MISS_KHR, spirv)
834 }
835
836 pub fn new_ray_gen(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
842 Self::new(vk::ShaderStageFlags::RAYGEN_KHR, spirv)
843 }
844
845 pub fn new_task(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
851 Self::new(vk::ShaderStageFlags::TASK_EXT, spirv)
852 }
853
854 pub fn new_tessellation_ctrl(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
860 Self::new(vk::ShaderStageFlags::TESSELLATION_CONTROL, spirv)
861 }
862
863 #[deprecated = "use new_tessellation_ctrl function"]
864 #[doc(hidden)]
865 pub fn new_tesselation_ctrl(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
866 Self::new_tessellation_ctrl(spirv)
867 }
868
869 pub fn new_tessellation_eval(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
875 Self::new(vk::ShaderStageFlags::TESSELLATION_EVALUATION, spirv)
876 }
877
878 #[deprecated = "use new_tessellation_eval function"]
879 #[doc(hidden)]
880 pub fn new_tesselation_eval(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
881 Self::new_tessellation_eval(spirv)
882 }
883
884 pub fn new_vertex(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
890 Self::new(vk::ShaderStageFlags::VERTEX, spirv)
891 }
892
893 #[profiling::function]
895 pub(super) fn attachments(
896 &self,
897 ) -> (
898 impl Iterator<Item = u32> + '_,
899 impl Iterator<Item = u32> + '_,
900 ) {
901 (
902 self.entry_point.vars.iter().filter_map(|var| match var {
903 Variable::Descriptor {
904 desc_ty: DescriptorType::InputAttachment(attachment),
905 ..
906 } => Some(*attachment),
907 _ => None,
908 }),
909 self.entry_point.vars.iter().filter_map(|var| match var {
910 Variable::Output { location, .. } => Some(location.loc()),
911 _ => None,
912 }),
913 )
914 }
915
916 pub fn builder() -> ShaderBuilder {
918 Default::default()
919 }
920
921 #[profiling::function]
922 pub(super) fn descriptor_bindings(&self) -> DescriptorBindingMap {
923 let mut res = DescriptorBindingMap::default();
924
925 for (name, descriptor, desc_ty, binding_count) in
926 self.entry_point.vars.iter().filter_map(|var| match var {
927 Variable::Descriptor {
928 name,
929 desc_bind,
930 desc_ty,
931 nbind,
932 ..
933 } => Some((
934 name,
935 Descriptor {
936 set: desc_bind.set(),
937 binding: desc_bind.bind(),
938 },
939 desc_ty,
940 *nbind,
941 )),
942 _ => None,
943 })
944 {
945 trace!(
946 "descriptor {}: {}.{} = {:?}[{}]",
947 name.as_deref().unwrap_or_default(),
948 descriptor.set,
949 descriptor.binding,
950 *desc_ty,
951 binding_count
952 );
953
954 let descriptor_info = match desc_ty {
955 DescriptorType::AccelStruct() => {
956 DescriptorInfo::AccelerationStructure(binding_count)
957 }
958 DescriptorType::CombinedImageSampler() => {
959 let (sampler_info, is_manually_defined) =
960 self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
961
962 DescriptorInfo::CombinedImageSampler(
963 binding_count,
964 sampler_info,
965 is_manually_defined,
966 )
967 }
968 DescriptorType::InputAttachment(attachment) => {
969 DescriptorInfo::InputAttachment(binding_count, *attachment)
970 }
971 DescriptorType::SampledImage() => DescriptorInfo::SampledImage(binding_count),
972 DescriptorType::Sampler() => {
973 let (sampler_info, is_manually_defined) =
974 self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
975
976 DescriptorInfo::Sampler(binding_count, sampler_info, is_manually_defined)
977 }
978 DescriptorType::StorageBuffer(_access_ty) => {
979 DescriptorInfo::StorageBuffer(binding_count)
980 }
981 DescriptorType::StorageImage(_access_ty) => {
982 DescriptorInfo::StorageImage(binding_count)
983 }
984 DescriptorType::StorageTexelBuffer(_access_ty) => {
985 DescriptorInfo::StorageTexelBuffer(binding_count)
986 }
987 DescriptorType::UniformBuffer() => DescriptorInfo::UniformBuffer(binding_count),
988 DescriptorType::UniformTexelBuffer() => {
989 DescriptorInfo::UniformTexelBuffer(binding_count)
990 }
991 };
992 res.insert(descriptor, (descriptor_info, self.stage));
993 }
994
995 res
996 }
997
998 pub fn from_spirv(spirv: impl Into<SpirvBinary>) -> ShaderBuilder {
1000 ShaderBuilder::default().spirv(spirv)
1001 }
1002
1003 fn image_sampler(&self, descriptor: Descriptor, name: &str) -> (SamplerInfo, bool) {
1004 self.image_samplers
1005 .get(&descriptor)
1006 .copied()
1007 .map(|sampler_info| (sampler_info, true))
1008 .unwrap_or_else(|| (guess_immutable_sampler(name), false))
1009 }
1010
1011 #[profiling::function]
1012 pub(super) fn merge_descriptor_bindings(
1013 descriptor_bindings: impl IntoIterator<Item = DescriptorBindingMap>,
1014 ) -> DescriptorBindingMap {
1015 fn merge_info(lhs: &mut DescriptorInfo, rhs: DescriptorInfo) -> bool {
1016 let (lhs_count, rhs_count) = match lhs {
1017 DescriptorInfo::AccelerationStructure(lhs) => {
1018 if let DescriptorInfo::AccelerationStructure(rhs) = rhs {
1019 (lhs, rhs)
1020 } else {
1021 return false;
1022 }
1023 }
1024 DescriptorInfo::CombinedImageSampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1025 if let DescriptorInfo::CombinedImageSampler(
1026 rhs,
1027 rhs_sampler,
1028 rhs_is_manually_defined,
1029 ) = rhs
1030 {
1031 if *lhs_is_manually_defined && rhs_is_manually_defined {
1033 return false;
1034 } else if rhs_is_manually_defined {
1035 *lhs_sampler = rhs_sampler;
1036 }
1037
1038 (lhs, rhs)
1039 } else {
1040 return false;
1041 }
1042 }
1043 DescriptorInfo::InputAttachment(lhs, lhs_idx) => {
1044 if let DescriptorInfo::InputAttachment(rhs, rhs_idx) = rhs {
1045 if *lhs_idx != rhs_idx {
1046 return false;
1047 }
1048
1049 (lhs, rhs)
1050 } else {
1051 return false;
1052 }
1053 }
1054 DescriptorInfo::SampledImage(lhs) => {
1055 if let DescriptorInfo::SampledImage(rhs) = rhs {
1056 (lhs, rhs)
1057 } else {
1058 return false;
1059 }
1060 }
1061 DescriptorInfo::Sampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1062 if let DescriptorInfo::Sampler(rhs, rhs_sampler, rhs_is_manually_defined) = rhs
1063 {
1064 if *lhs_is_manually_defined && rhs_is_manually_defined {
1066 return false;
1067 } else if rhs_is_manually_defined {
1068 *lhs_sampler = rhs_sampler;
1069 }
1070
1071 (lhs, rhs)
1072 } else {
1073 return false;
1074 }
1075 }
1076 DescriptorInfo::StorageBuffer(lhs) => {
1077 if let DescriptorInfo::StorageBuffer(rhs) = rhs {
1078 (lhs, rhs)
1079 } else {
1080 return false;
1081 }
1082 }
1083 DescriptorInfo::StorageImage(lhs) => {
1084 if let DescriptorInfo::StorageImage(rhs) = rhs {
1085 (lhs, rhs)
1086 } else {
1087 return false;
1088 }
1089 }
1090 DescriptorInfo::StorageTexelBuffer(lhs) => {
1091 if let DescriptorInfo::StorageTexelBuffer(rhs) = rhs {
1092 (lhs, rhs)
1093 } else {
1094 return false;
1095 }
1096 }
1097 DescriptorInfo::UniformBuffer(lhs) => {
1098 if let DescriptorInfo::UniformBuffer(rhs) = rhs {
1099 (lhs, rhs)
1100 } else {
1101 return false;
1102 }
1103 }
1104 DescriptorInfo::UniformTexelBuffer(lhs) => {
1105 if let DescriptorInfo::UniformTexelBuffer(rhs) = rhs {
1106 (lhs, rhs)
1107 } else {
1108 return false;
1109 }
1110 }
1111 };
1112
1113 *lhs_count = rhs_count.max(*lhs_count);
1114
1115 true
1116 }
1117
1118 #[profiling::function]
1119 fn merge_pair(src: DescriptorBindingMap, dst: &mut DescriptorBindingMap) {
1120 for (descriptor_binding, (descriptor_info, descriptor_flags)) in src.into_iter() {
1121 if let Some((existing_info, existing_flags)) = dst.get_mut(&descriptor_binding) {
1122 if !merge_info(existing_info, descriptor_info) {
1123 panic!("Inconsistent shader descriptors ({descriptor_binding:?})");
1124 }
1125
1126 *existing_flags |= descriptor_flags;
1127 } else {
1128 dst.insert(descriptor_binding, (descriptor_info, descriptor_flags));
1129 }
1130 }
1131 }
1132
1133 let mut descriptor_bindings = descriptor_bindings.into_iter();
1134 let mut res = descriptor_bindings.next().unwrap_or_default();
1135 for descriptor_binding in descriptor_bindings {
1136 merge_pair(descriptor_binding, &mut res);
1137 }
1138
1139 res
1140 }
1141
1142 #[profiling::function]
1143 pub(super) fn push_constant_range(&self) -> Option<vk::PushConstantRange> {
1144 self.entry_point
1145 .vars
1146 .iter()
1147 .filter_map(|var| match var {
1148 Variable::PushConstant {
1149 ty: Type::Struct(ty),
1150 ..
1151 } => Some(ty.members.clone()),
1152 _ => None,
1153 })
1154 .flatten()
1155 .map(|push_const| {
1156 let offset = push_const.offset.unwrap_or_default();
1157 let size = push_const
1158 .ty
1159 .nbyte()
1160 .unwrap_or_default()
1161 .next_multiple_of(4);
1162 offset..offset + size
1163 })
1164 .reduce(|a, b| a.start.min(b.start)..a.end.max(b.end))
1165 .map(|push_const| vk::PushConstantRange {
1166 stage_flags: self.stage,
1167 size: (push_const.end - push_const.start) as _,
1168 offset: push_const.start as _,
1169 })
1170 }
1171
1172 #[profiling::function]
1173 fn reflect_entry_point(
1174 entry_name: &str,
1175 spirv: impl Into<SpirvBinary>,
1176 specialization: Option<&SpecializationMap>,
1177 ) -> Result<EntryPoint, DriverError> {
1178 let mut config = ReflectConfig::new();
1179 config.ref_all_rscs(true).spv(spirv);
1180
1181 if let Some(specialization) = specialization {
1182 for &vk::SpecializationMapEntry {
1183 constant_id,
1184 offset,
1185 size,
1186 } in &specialization.entries
1187 {
1188 config.specialize(
1189 constant_id,
1190 specialization.data[offset as usize..offset as usize + size].into(),
1191 );
1192 }
1193 }
1194
1195 let entry_points = config.reflect().map_err(|err| {
1196 error!("invalid spirv reflection data: {err}");
1197
1198 DriverError::InvalidData
1199 })?;
1200 let entry_point = entry_points
1201 .into_iter()
1202 .find(|entry_point| entry_point.name == entry_name)
1203 .ok_or_else(|| {
1204 error!("invalid shader entry point: not found");
1205
1206 DriverError::InvalidData
1207 })?;
1208
1209 Ok(entry_point)
1210 }
1211
1212 #[profiling::function]
1213 pub(super) fn try_vertex_input(&self) -> Result<VertexInputState, DriverError> {
1214 if let Some(vertex_input) = &self.vertex_input_state {
1216 return Ok(vertex_input.clone());
1217 }
1218
1219 fn scalar_format(ty: &ScalarType) -> Option<vk::Format> {
1220 match *ty {
1221 ScalarType::Float { bits } => match bits {
1222 u8::BITS => Some(vk::Format::R8_SNORM),
1223 u16::BITS => Some(vk::Format::R16_SFLOAT),
1224 u32::BITS => Some(vk::Format::R32_SFLOAT),
1225 u64::BITS => Some(vk::Format::R64_SFLOAT),
1226 _ => None,
1227 },
1228 ScalarType::Integer {
1229 bits,
1230 is_signed: false,
1231 } => match bits {
1232 u8::BITS => Some(vk::Format::R8_UINT),
1233 u16::BITS => Some(vk::Format::R16_UINT),
1234 u32::BITS => Some(vk::Format::R32_UINT),
1235 u64::BITS => Some(vk::Format::R64_UINT),
1236 _ => None,
1237 },
1238 ScalarType::Integer {
1239 bits,
1240 is_signed: true,
1241 } => match bits {
1242 u8::BITS => Some(vk::Format::R8_SINT),
1243 u16::BITS => Some(vk::Format::R16_SINT),
1244 u32::BITS => Some(vk::Format::R32_SINT),
1245 u64::BITS => Some(vk::Format::R64_SINT),
1246 _ => None,
1247 },
1248 _ => None,
1249 }
1250 }
1251
1252 fn vector_format(ty: &VectorType) -> Option<vk::Format> {
1253 match *ty {
1254 VectorType {
1255 scalar_ty: ScalarType::Float { bits },
1256 nscalar,
1257 } => match (bits, nscalar) {
1258 (u8::BITS, 2) => Some(vk::Format::R8G8_SNORM),
1259 (u8::BITS, 3) => Some(vk::Format::R8G8B8_SNORM),
1260 (u8::BITS, 4) => Some(vk::Format::R8G8B8A8_SNORM),
1261 (u16::BITS, 2) => Some(vk::Format::R16G16_SFLOAT),
1262 (u16::BITS, 3) => Some(vk::Format::R16G16B16_SFLOAT),
1263 (u16::BITS, 4) => Some(vk::Format::R16G16B16A16_SFLOAT),
1264 (u32::BITS, 2) => Some(vk::Format::R32G32_SFLOAT),
1265 (u32::BITS, 3) => Some(vk::Format::R32G32B32_SFLOAT),
1266 (u32::BITS, 4) => Some(vk::Format::R32G32B32A32_SFLOAT),
1267 (u64::BITS, 2) => Some(vk::Format::R64G64_SFLOAT),
1268 (u64::BITS, 3) => Some(vk::Format::R64G64B64_SFLOAT),
1269 (u64::BITS, 4) => Some(vk::Format::R64G64B64A64_SFLOAT),
1270 _ => None,
1271 },
1272 VectorType {
1273 scalar_ty:
1274 ScalarType::Integer {
1275 bits,
1276 is_signed: false,
1277 },
1278 nscalar,
1279 } => match (bits, nscalar) {
1280 (u8::BITS, 2) => Some(vk::Format::R8G8_UINT),
1281 (u8::BITS, 3) => Some(vk::Format::R8G8B8_UINT),
1282 (u8::BITS, 4) => Some(vk::Format::R8G8B8A8_UINT),
1283 (u16::BITS, 2) => Some(vk::Format::R16G16_UINT),
1284 (u16::BITS, 3) => Some(vk::Format::R16G16B16_UINT),
1285 (u16::BITS, 4) => Some(vk::Format::R16G16B16A16_UINT),
1286 (u32::BITS, 2) => Some(vk::Format::R32G32_UINT),
1287 (u32::BITS, 3) => Some(vk::Format::R32G32B32_UINT),
1288 (u32::BITS, 4) => Some(vk::Format::R32G32B32A32_UINT),
1289 (u64::BITS, 2) => Some(vk::Format::R64G64_UINT),
1290 (u64::BITS, 3) => Some(vk::Format::R64G64B64_UINT),
1291 (u64::BITS, 4) => Some(vk::Format::R64G64B64A64_UINT),
1292 _ => None,
1293 },
1294 VectorType {
1295 scalar_ty:
1296 ScalarType::Integer {
1297 bits,
1298 is_signed: true,
1299 },
1300 nscalar,
1301 } => match (bits, nscalar) {
1302 (u8::BITS, 2) => Some(vk::Format::R8G8_SINT),
1303 (u8::BITS, 3) => Some(vk::Format::R8G8B8_SINT),
1304 (u8::BITS, 4) => Some(vk::Format::R8G8B8A8_SINT),
1305 (u16::BITS, 2) => Some(vk::Format::R16G16_SINT),
1306 (u16::BITS, 3) => Some(vk::Format::R16G16B16_SINT),
1307 (u16::BITS, 4) => Some(vk::Format::R16G16B16A16_SINT),
1308 (u32::BITS, 2) => Some(vk::Format::R32G32_SINT),
1309 (u32::BITS, 3) => Some(vk::Format::R32G32B32_SINT),
1310 (u32::BITS, 4) => Some(vk::Format::R32G32B32A32_SINT),
1311 (u64::BITS, 2) => Some(vk::Format::R64G64_SINT),
1312 (u64::BITS, 3) => Some(vk::Format::R64G64B64_SINT),
1313 (u64::BITS, 4) => Some(vk::Format::R64G64B64A64_SINT),
1314 _ => None,
1315 },
1316 _ => None,
1317 }
1318 }
1319
1320 let mut input_rates_strides = HashMap::new();
1321 let mut vertex_attribute_descriptions = vec![];
1322
1323 for (name, location, ty) in self.entry_point.vars.iter().filter_map(|var| match var {
1324 Variable::Input { name, location, ty } => Some((name, location, ty)),
1325 _ => None,
1326 }) {
1327 let (binding, guessed_rate) = name
1328 .as_ref()
1329 .filter(|name| name.contains("_ibind") || name.contains("_vbind"))
1330 .map(|name| {
1331 let binding = name[name.rfind("bind").expect("missing bind suffix")..]
1332 .parse()
1333 .unwrap_or_default();
1334 let rate = if name.contains("_ibind") {
1335 vk::VertexInputRate::INSTANCE
1336 } else {
1337 vk::VertexInputRate::VERTEX
1338 };
1339
1340 (binding, rate)
1341 })
1342 .unwrap_or_default();
1343 let (location, _) = location.into_inner();
1344 if let Some((input_rate, _)) = input_rates_strides.get(&binding) {
1345 assert_eq!(*input_rate, guessed_rate);
1346 }
1347
1348 let byte_stride = ty.nbyte().unwrap_or_default() as u32;
1349 let (input_rate, stride) = input_rates_strides.entry(binding).or_default();
1350 *input_rate = guessed_rate;
1351 *stride += byte_stride;
1352
1353 let format = match ty {
1356 Type::Scalar(ty) => scalar_format(ty),
1357 Type::Vector(ty) => vector_format(ty),
1358 _ => None,
1359 }
1360 .ok_or_else(|| {
1361 warn!("unsupported reflected vertex input type: {ty:?}");
1362
1363 DriverError::Unsupported
1364 })?;
1365
1366 vertex_attribute_descriptions.push(vk::VertexInputAttributeDescription {
1367 location,
1368 binding,
1369 format,
1370 offset: byte_stride, });
1372 }
1373
1374 vertex_attribute_descriptions.sort_unstable_by(|lhs, rhs| {
1375 let binding = lhs.binding.cmp(&rhs.binding);
1376 if binding.is_lt() {
1377 return binding;
1378 }
1379
1380 lhs.location.cmp(&rhs.location)
1381 });
1382
1383 let mut offset = 0;
1384 let mut offset_binding = 0;
1385
1386 for vertex_attribute_description in &mut vertex_attribute_descriptions {
1387 if vertex_attribute_description.binding != offset_binding {
1388 offset_binding = vertex_attribute_description.binding;
1389 offset = 0;
1390 }
1391
1392 let stride = vertex_attribute_description.offset;
1393 vertex_attribute_description.offset = offset;
1394 offset += stride;
1395
1396 debug!(
1397 "vertex attribute {}.{}: {:?} (offset={})",
1398 vertex_attribute_description.binding,
1399 vertex_attribute_description.location,
1400 vertex_attribute_description.format,
1401 vertex_attribute_description.offset,
1402 );
1403 }
1404
1405 let mut vertex_binding_descriptions = vec![];
1406 for (binding, (input_rate, stride)) in input_rates_strides.into_iter() {
1407 vertex_binding_descriptions.push(vk::VertexInputBindingDescription {
1408 binding,
1409 input_rate,
1410 stride,
1411 });
1412 }
1413
1414 Ok(VertexInputState {
1415 vertex_attribute_descriptions,
1416 vertex_binding_descriptions,
1417 })
1418 }
1419
1420 #[profiling::function]
1421 pub(super) fn vertex_input(&self) -> VertexInputState {
1422 self.try_vertex_input()
1423 .expect("unsupported reflected vertex input layout")
1424 }
1425}
1426
1427impl Debug for Shader {
1428 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1429 f.write_str("Shader")
1432 }
1433}
1434
1435impl From<ShaderBuilder> for Shader {
1436 fn from(shader: ShaderBuilder) -> Self {
1437 shader.build()
1438 }
1439}
1440
1441impl<T> From<T> for Shader
1442where
1443 T: Into<SpirvBinary>,
1444{
1445 fn from(spirv: T) -> Self {
1446 Shader::from_spirv(spirv).build()
1447 }
1448}
1449
1450impl ShaderBuilder {
1452 pub fn new(stage: vk::ShaderStageFlags, spirv: Vec<u8>) -> Self {
1454 Self::default().stage(stage).spirv(spirv)
1455 }
1456
1457 pub fn build(self) -> Shader {
1459 let entry_name = self.entry_name.clone().unwrap_or_else(|| "main".to_owned());
1460
1461 self.try_build().unwrap_or_else(|_| {
1462 panic!("invalid or unsupported shader code for entry name '{entry_name}'")
1463 })
1464 }
1465
1466 #[profiling::function]
1486 pub fn image_sampler(
1487 mut self,
1488 descriptor: impl Into<Descriptor>,
1489 info: impl Into<SamplerInfo>,
1490 ) -> Self {
1491 let descriptor = descriptor.into();
1492 let info = info.into();
1493
1494 if self.image_samplers.is_none() {
1495 self.image_samplers = Some(Default::default());
1496 }
1497
1498 self.image_samplers
1499 .as_mut()
1500 .expect("missing image samplers")
1501 .insert(descriptor, info);
1502
1503 self
1504 }
1505
1506 pub fn try_build(mut self) -> Result<Shader, DriverError> {
1508 let entry_name = self.entry_name.as_deref().unwrap_or("main");
1509 let entry_point = Shader::reflect_entry_point(
1510 entry_name,
1511 self.spirv
1512 .as_ref()
1513 .map(|spirv| spirv.words())
1514 .expect("missing spirv code"),
1515 self.specialization
1516 .as_ref()
1517 .map(|opt| opt.as_ref())
1518 .unwrap_or_default(),
1519 )
1520 .map_err(|err| {
1521 warn!("invalid shader reflection entry point: {err}");
1522
1523 DriverError::InvalidData
1524 })?;
1525
1526 if self.stage.unwrap_or_default().is_empty() {
1527 self.stage = Some(match entry_point.exec_model {
1528 ExecutionModel::Vertex => vk::ShaderStageFlags::VERTEX,
1529 ExecutionModel::TessellationControl => vk::ShaderStageFlags::TESSELLATION_CONTROL,
1530 ExecutionModel::TessellationEvaluation => {
1531 vk::ShaderStageFlags::TESSELLATION_EVALUATION
1532 }
1533 ExecutionModel::Geometry => vk::ShaderStageFlags::GEOMETRY,
1534 ExecutionModel::Fragment => vk::ShaderStageFlags::FRAGMENT,
1535 ExecutionModel::GLCompute => vk::ShaderStageFlags::COMPUTE,
1536 ExecutionModel::Kernel => {
1537 warn!("unsupported shader execution model: kernel");
1538
1539 return Err(DriverError::Unsupported);
1540 }
1541 ExecutionModel::TaskNV => vk::ShaderStageFlags::TASK_EXT,
1542 ExecutionModel::MeshNV => vk::ShaderStageFlags::MESH_EXT,
1543 ExecutionModel::RayGenerationNV => vk::ShaderStageFlags::RAYGEN_KHR,
1544 ExecutionModel::IntersectionNV => vk::ShaderStageFlags::INTERSECTION_KHR,
1545 ExecutionModel::AnyHitNV => vk::ShaderStageFlags::ANY_HIT_KHR,
1546 ExecutionModel::ClosestHitNV => vk::ShaderStageFlags::CLOSEST_HIT_KHR,
1547 ExecutionModel::MissNV => vk::ShaderStageFlags::MISS_KHR,
1548 ExecutionModel::CallableNV => vk::ShaderStageFlags::CALLABLE_KHR,
1549 ExecutionModel::TaskEXT => vk::ShaderStageFlags::TASK_EXT,
1550 ExecutionModel::MeshEXT => vk::ShaderStageFlags::MESH_EXT,
1551 })
1552 }
1553
1554 self.entry_point = Some(entry_point);
1555
1556 self.fallible_build().map_err(|err| {
1557 warn!("invalid shader builder state: {err:?}");
1558
1559 DriverError::InvalidData
1560 })
1561 }
1562
1563 #[profiling::function]
1575 pub fn vertex_input(
1576 mut self,
1577 bindings: impl Into<Vec<vk::VertexInputBindingDescription>>,
1578 attributes: impl Into<Vec<vk::VertexInputAttributeDescription>>,
1579 ) -> Self {
1580 self.vertex_input_state = Some(Some(VertexInputState {
1581 vertex_binding_descriptions: bindings.into(),
1582 vertex_attribute_descriptions: attributes.into(),
1583 }));
1584 self
1585 }
1586}
1587
1588#[derive(Debug)]
1589struct ShaderBuilderError;
1590
1591impl From<UninitializedFieldError> for ShaderBuilderError {
1592 fn from(_: UninitializedFieldError) -> Self {
1593 Self
1594 }
1595}
1596
1597#[derive(Clone, Debug, Default)]
1599pub struct SpecializationMap {
1600 pub data: Vec<u8>,
1602
1603 pub entries: Vec<vk::SpecializationMapEntry>,
1606}
1607
1608impl SpecializationMap {
1609 pub fn new(data: impl Into<Vec<u8>>) -> Self {
1611 Self {
1612 data: data.into(),
1613 entries: Default::default(),
1614 }
1615 }
1616
1617 pub fn constant(mut self, constant_id: u32, offset: u32, size: usize) -> Self {
1619 self.set_constant(constant_id, offset, size);
1620 self
1621 }
1622
1623 pub fn set_constant(&mut self, constant_id: u32, offset: u32, size: usize) {
1625 self.entries.push(vk::SpecializationMapEntry {
1626 constant_id,
1627 offset,
1628 size,
1629 });
1630 }
1631}
1632
1633impl<'a> From<&'a SpecializationMap> for vk::SpecializationInfo<'a> {
1634 fn from(value: &'a SpecializationMap) -> Self {
1635 vk::SpecializationInfo::default()
1636 .map_entries(&value.entries)
1637 .data(&value.data)
1638 }
1639}
1640
1641mod deprecated {
1642 use {
1643 crate::driver::shader::{ShaderBuilder, SpecializationMap},
1644 ash::vk,
1645 };
1646
1647 #[derive(Clone, Debug)]
1648 pub struct SpecializationInfo {
1649 pub data: Vec<u8>,
1650 pub map_entries: Vec<vk::SpecializationMapEntry>,
1651 }
1652
1653 impl SpecializationInfo {
1654 pub fn new(
1655 map_entries: impl Into<Vec<vk::SpecializationMapEntry>>,
1656 data: impl Into<Vec<u8>>,
1657 ) -> Self {
1658 Self {
1659 data: data.into(),
1660 map_entries: map_entries.into(),
1661 }
1662 }
1663 }
1664
1665 impl ShaderBuilder {
1666 #[deprecated = "use specialization function"]
1667 #[doc(hidden)]
1668 pub fn specialization_info(self, info: SpecializationInfo) -> Self {
1669 let mut specialization = SpecializationMap::new(info.data);
1670
1671 for entry in &info.map_entries {
1672 specialization.set_constant(entry.constant_id, entry.offset, entry.size);
1673 }
1674
1675 self.specialization(specialization)
1676 }
1677 }
1678}
1679
1680#[cfg(test)]
1681mod test {
1682 use super::*;
1683
1684 type Info = SamplerInfo;
1685 type Builder = SamplerInfoBuilder;
1686
1687 #[test]
1688 pub fn sampler_info() {
1689 let info = Info::default();
1690 let builder = info.into_builder().build();
1691
1692 assert_eq!(info, builder);
1693 }
1694
1695 #[test]
1696 pub fn sampler_info_builder() {
1697 let info = Info::default();
1698 let builder = Builder::default().build();
1699
1700 assert_eq!(info, builder);
1701 }
1702}