1use {
4 super::{DescriptorSetLayout, DriverError, VertexInputState, device::Device},
5 ash::vk,
6 derive_builder::{Builder, UninitializedFieldError},
7 log::{debug, error, trace, warn},
8 ordered_float::OrderedFloat,
9 spirq::{
10 ReflectConfig,
11 entry_point::EntryPoint,
12 ty::{DescriptorType, ScalarType, SpirvType, Type},
13 var::Variable,
14 },
15 std::{
16 collections::{BTreeMap, HashMap},
17 fmt::{Debug, Formatter},
18 iter::repeat,
19 mem::size_of_val,
20 ops::Deref,
21 sync::Arc,
22 thread::panicking,
23 },
24};
25
26pub(crate) type DescriptorBindingMap = HashMap<Descriptor, (DescriptorInfo, vk::ShaderStageFlags)>;
27
28pub(crate) fn align_spriv(code: &[u8]) -> Result<&[u32], DriverError> {
29 let (prefix, code, suffix) = unsafe { code.align_to() };
30
31 if prefix.len() + suffix.len() == 0 {
32 Ok(code)
33 } else {
34 warn!("Invalid SPIR-V code");
35
36 Err(DriverError::InvalidData)
37 }
38}
39
40#[profiling::function]
41fn guess_immutable_sampler(binding_name: &str) -> SamplerInfo {
42 const INVALID_ERR: &str = "Invalid sampler specification";
43
44 let (texel_filter, mipmap_mode, address_modes) = if binding_name.contains("_sampler_") {
45 let spec = &binding_name[binding_name.len() - 3..];
46 let texel_filter = match &spec[0..1] {
47 "n" => vk::Filter::NEAREST,
48 "l" => vk::Filter::LINEAR,
49 _ => panic!("{INVALID_ERR}: {}", &spec[0..1]),
50 };
51
52 let mipmap_mode = match &spec[1..2] {
53 "n" => vk::SamplerMipmapMode::NEAREST,
54 "l" => vk::SamplerMipmapMode::LINEAR,
55 _ => panic!("{INVALID_ERR}: {}", &spec[1..2]),
56 };
57
58 let address_modes = match &spec[2..3] {
59 "b" => vk::SamplerAddressMode::CLAMP_TO_BORDER,
60 "e" => vk::SamplerAddressMode::CLAMP_TO_EDGE,
61 "m" => vk::SamplerAddressMode::MIRRORED_REPEAT,
62 "r" => vk::SamplerAddressMode::REPEAT,
63 _ => panic!("{INVALID_ERR}: {}", &spec[2..3]),
64 };
65
66 (texel_filter, mipmap_mode, address_modes)
67 } else {
68 debug!("image binding {binding_name} using default sampler");
69
70 (
71 vk::Filter::LINEAR,
72 vk::SamplerMipmapMode::LINEAR,
73 vk::SamplerAddressMode::REPEAT,
74 )
75 };
76 let anisotropy_enable = texel_filter == vk::Filter::LINEAR;
77 let mut info = SamplerInfoBuilder::default()
78 .mag_filter(texel_filter)
79 .min_filter(texel_filter)
80 .mipmap_mode(mipmap_mode)
81 .address_mode_u(address_modes)
82 .address_mode_v(address_modes)
83 .address_mode_w(address_modes)
84 .max_lod(vk::LOD_CLAMP_NONE)
85 .anisotropy_enable(anisotropy_enable);
86
87 if anisotropy_enable {
88 info = info.max_anisotropy(16.0);
89 }
90
91 info.build()
92}
93
94#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
99pub struct Descriptor {
100 pub set: u32,
102
103 pub binding: u32,
105}
106
107impl From<u32> for Descriptor {
108 fn from(binding: u32) -> Self {
109 Self { set: 0, binding }
110 }
111}
112
113impl From<(u32, u32)> for Descriptor {
114 fn from((set, binding): (u32, u32)) -> Self {
115 Self { set, binding }
116 }
117}
118
119#[derive(Clone, Copy, Debug)]
120pub(crate) enum DescriptorInfo {
121 AccelerationStructure(u32),
122 CombinedImageSampler(u32, SamplerInfo, bool), InputAttachment(u32, u32), SampledImage(u32),
125 Sampler(u32, SamplerInfo, bool), StorageBuffer(u32),
127 StorageImage(u32),
128 StorageTexelBuffer(u32),
129 UniformBuffer(u32),
130 UniformTexelBuffer(u32),
131}
132
133impl DescriptorInfo {
134 pub fn binding_count(self) -> u32 {
135 match self {
136 Self::AccelerationStructure(binding_count) => binding_count,
137 Self::CombinedImageSampler(binding_count, ..) => binding_count,
138 Self::InputAttachment(binding_count, _) => binding_count,
139 Self::SampledImage(binding_count) => binding_count,
140 Self::Sampler(binding_count, ..) => binding_count,
141 Self::StorageBuffer(binding_count) => binding_count,
142 Self::StorageImage(binding_count) => binding_count,
143 Self::StorageTexelBuffer(binding_count) => binding_count,
144 Self::UniformBuffer(binding_count) => binding_count,
145 Self::UniformTexelBuffer(binding_count) => binding_count,
146 }
147 }
148
149 pub fn descriptor_type(self) -> vk::DescriptorType {
150 match self {
151 Self::AccelerationStructure(_) => vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
152 Self::CombinedImageSampler(..) => vk::DescriptorType::COMBINED_IMAGE_SAMPLER,
153 Self::InputAttachment(..) => vk::DescriptorType::INPUT_ATTACHMENT,
154 Self::SampledImage(_) => vk::DescriptorType::SAMPLED_IMAGE,
155 Self::Sampler(..) => vk::DescriptorType::SAMPLER,
156 Self::StorageBuffer(_) => vk::DescriptorType::STORAGE_BUFFER,
157 Self::StorageImage(_) => vk::DescriptorType::STORAGE_IMAGE,
158 Self::StorageTexelBuffer(_) => vk::DescriptorType::STORAGE_TEXEL_BUFFER,
159 Self::UniformBuffer(_) => vk::DescriptorType::UNIFORM_BUFFER,
160 Self::UniformTexelBuffer(_) => vk::DescriptorType::UNIFORM_TEXEL_BUFFER,
161 }
162 }
163
164 fn sampler_info(self) -> Option<SamplerInfo> {
165 match self {
166 Self::CombinedImageSampler(_, sampler_info, _) | Self::Sampler(_, sampler_info, _) => {
167 Some(sampler_info)
168 }
169 _ => None,
170 }
171 }
172
173 pub fn set_binding_count(&mut self, binding_count: u32) {
174 *match self {
175 Self::AccelerationStructure(binding_count) => binding_count,
176 Self::CombinedImageSampler(binding_count, ..) => binding_count,
177 Self::InputAttachment(binding_count, _) => binding_count,
178 Self::SampledImage(binding_count) => binding_count,
179 Self::Sampler(binding_count, ..) => binding_count,
180 Self::StorageBuffer(binding_count) => binding_count,
181 Self::StorageImage(binding_count) => binding_count,
182 Self::StorageTexelBuffer(binding_count) => binding_count,
183 Self::UniformBuffer(binding_count) => binding_count,
184 Self::UniformTexelBuffer(binding_count) => binding_count,
185 } = binding_count;
186 }
187}
188
189#[derive(Debug)]
190pub(crate) struct PipelineDescriptorInfo {
191 pub layouts: BTreeMap<u32, DescriptorSetLayout>,
192 pub pool_sizes: HashMap<u32, HashMap<vk::DescriptorType, u32>>,
193
194 #[allow(dead_code)]
195 samplers: Box<[Sampler]>,
196}
197
198impl PipelineDescriptorInfo {
199 #[profiling::function]
200 pub fn create(
201 device: &Arc<Device>,
202 descriptor_bindings: &DescriptorBindingMap,
203 ) -> Result<Self, DriverError> {
204 let descriptor_set_count = descriptor_bindings
205 .keys()
206 .map(|descriptor| descriptor.set)
207 .max()
208 .map(|set| set + 1)
209 .unwrap_or_default();
210 let mut layouts = BTreeMap::new();
211 let mut pool_sizes = HashMap::new();
212
213 let mut sampler_info_binding_count = HashMap::<_, u32>::with_capacity(
216 descriptor_bindings
217 .values()
218 .filter(|(descriptor_info, _)| descriptor_info.sampler_info().is_some())
219 .count(),
220 );
221
222 for (sampler_info, binding_count) in
223 descriptor_bindings
224 .values()
225 .filter_map(|(descriptor_info, _)| {
226 descriptor_info
227 .sampler_info()
228 .map(|sampler_info| (sampler_info, descriptor_info.binding_count()))
229 })
230 {
231 sampler_info_binding_count
232 .entry(sampler_info)
233 .and_modify(|sampler_info_binding_count| {
234 *sampler_info_binding_count = binding_count.max(*sampler_info_binding_count);
235 })
236 .or_insert(binding_count);
237 }
238
239 let mut samplers = sampler_info_binding_count
240 .keys()
241 .copied()
242 .map(|sampler_info| {
243 Sampler::create(device, sampler_info).map(|sampler| (sampler_info, sampler))
244 })
245 .collect::<Result<HashMap<_, _>, _>>()?;
246 let immutable_samplers = sampler_info_binding_count
247 .iter()
248 .map(|(sampler_info, &binding_count)| {
249 (
250 *sampler_info,
251 repeat(*samplers[sampler_info])
252 .take(binding_count as _)
253 .collect::<Box<_>>(),
254 )
255 })
256 .collect::<HashMap<_, _>>();
257
258 for descriptor_set_idx in 0..descriptor_set_count {
259 let mut binding_counts = HashMap::<vk::DescriptorType, u32>::new();
260 let mut bindings = vec![];
261
262 for (descriptor, (descriptor_info, stage_flags)) in descriptor_bindings
263 .iter()
264 .filter(|(descriptor, _)| descriptor.set == descriptor_set_idx)
265 {
266 let descriptor_ty = descriptor_info.descriptor_type();
267 *binding_counts.entry(descriptor_ty).or_default() +=
268 descriptor_info.binding_count();
269 let mut binding = vk::DescriptorSetLayoutBinding::default()
270 .binding(descriptor.binding)
271 .descriptor_count(descriptor_info.binding_count())
272 .descriptor_type(descriptor_ty)
273 .stage_flags(*stage_flags);
274
275 if let Some(immutable_samplers) =
276 descriptor_info.sampler_info().map(|sampler_info| {
277 &immutable_samplers[&sampler_info]
278 [0..descriptor_info.binding_count() as usize]
279 })
280 {
281 binding = binding.immutable_samplers(immutable_samplers);
282 }
283
284 bindings.push(binding);
285 }
286
287 let pool_size = pool_sizes
288 .entry(descriptor_set_idx)
289 .or_insert_with(HashMap::new);
290
291 for (descriptor_ty, binding_count) in binding_counts.into_iter() {
292 *pool_size.entry(descriptor_ty).or_default() += binding_count;
293 }
294
295 let mut create_info = vk::DescriptorSetLayoutCreateInfo::default().bindings(&bindings);
298
299 let bindless_flags = vec![vk::DescriptorBindingFlags::PARTIALLY_BOUND; bindings.len()];
303 let mut bindless_flags = if device
304 .physical_device
305 .features_v1_2
306 .descriptor_binding_partially_bound
307 {
308 let bindless_flags = vk::DescriptorSetLayoutBindingFlagsCreateInfo::default()
309 .binding_flags(&bindless_flags);
310 Some(bindless_flags)
311 } else {
312 None
313 };
314
315 if let Some(bindless_flags) = bindless_flags.as_mut() {
316 create_info = create_info.push_next(bindless_flags);
317 }
318
319 layouts.insert(
320 descriptor_set_idx,
321 DescriptorSetLayout::create(device, &create_info)?,
322 );
323 }
324
325 let samplers = samplers
326 .drain()
327 .map(|(_, sampler)| sampler)
328 .collect::<Box<_>>();
329
330 Ok(Self {
334 layouts,
335 pool_sizes,
336 samplers,
337 })
338 }
339}
340
341pub(crate) struct Sampler {
342 device: Arc<Device>,
343 sampler: vk::Sampler,
344}
345
346impl Sampler {
347 #[profiling::function]
348 pub fn create(device: &Arc<Device>, info: impl Into<SamplerInfo>) -> Result<Self, DriverError> {
349 let device = Arc::clone(device);
350 let info = info.into();
351
352 let sampler = unsafe {
353 device
354 .create_sampler(
355 &vk::SamplerCreateInfo::default()
356 .flags(info.flags)
357 .mag_filter(info.mag_filter)
358 .min_filter(info.min_filter)
359 .mipmap_mode(info.mipmap_mode)
360 .address_mode_u(info.address_mode_u)
361 .address_mode_v(info.address_mode_v)
362 .address_mode_w(info.address_mode_w)
363 .mip_lod_bias(info.mip_lod_bias.0)
364 .anisotropy_enable(info.anisotropy_enable)
365 .max_anisotropy(info.max_anisotropy.0)
366 .compare_enable(info.compare_enable)
367 .compare_op(info.compare_op)
368 .min_lod(info.min_lod.0)
369 .max_lod(info.max_lod.0)
370 .border_color(info.border_color)
371 .unnormalized_coordinates(info.unnormalized_coordinates)
372 .push_next(
373 &mut vk::SamplerReductionModeCreateInfo::default()
374 .reduction_mode(info.reduction_mode),
375 ),
376 None,
377 )
378 .map_err(|err| {
379 warn!("{err}");
380
381 match err {
382 vk::Result::ERROR_OUT_OF_HOST_MEMORY
383 | vk::Result::ERROR_OUT_OF_DEVICE_MEMORY => DriverError::OutOfMemory,
384 _ => DriverError::Unsupported,
385 }
386 })?
387 };
388
389 Ok(Self { device, sampler })
390 }
391}
392
393impl Debug for Sampler {
394 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
395 write!(f, "{:?}", self.sampler)
396 }
397}
398
399impl Deref for Sampler {
400 type Target = vk::Sampler;
401
402 fn deref(&self) -> &Self::Target {
403 &self.sampler
404 }
405}
406
407impl Drop for Sampler {
408 #[profiling::function]
409 fn drop(&mut self) {
410 if panicking() {
411 return;
412 }
413
414 unsafe {
415 self.device.destroy_sampler(self.sampler, None);
416 }
417 }
418}
419
420#[derive(Builder, Clone, Copy, Debug, Eq, Hash, PartialEq)]
422#[builder(
423 build_fn(private, name = "fallible_build", error = "SamplerInfoBuilderError"),
424 derive(Clone, Copy, Debug),
425 pattern = "owned"
426)]
427#[non_exhaustive]
428pub struct SamplerInfo {
429 #[builder(default)]
431 pub flags: vk::SamplerCreateFlags,
432
433 #[builder(default)]
437 pub mag_filter: vk::Filter,
438
439 #[builder(default)]
443 pub min_filter: vk::Filter,
444
445 #[builder(default)]
449 pub mipmap_mode: vk::SamplerMipmapMode,
450
451 #[builder(default)]
455 pub address_mode_u: vk::SamplerAddressMode,
456
457 #[builder(default)]
461 pub address_mode_v: vk::SamplerAddressMode,
462
463 #[builder(default)]
467 pub address_mode_w: vk::SamplerAddressMode,
468
469 #[builder(default, setter(into))]
474 pub mip_lod_bias: OrderedFloat<f32>,
475
476 #[builder(default)]
480 pub anisotropy_enable: bool,
481
482 #[builder(default, setter(into))]
486 pub max_anisotropy: OrderedFloat<f32>,
487
488 #[builder(default)]
490 pub compare_enable: bool,
491
492 #[builder(default)]
497 pub compare_op: vk::CompareOp,
498
499 #[builder(default, setter(into))]
502 pub min_lod: OrderedFloat<f32>,
503
504 #[builder(default, setter(into))]
509 pub max_lod: OrderedFloat<f32>,
510
511 #[builder(default)]
515 pub border_color: vk::BorderColor,
516
517 #[builder(default)]
528 pub unnormalized_coordinates: bool,
529
530 #[builder(default)]
540 pub reduction_mode: vk::SamplerReductionMode,
541}
542
543impl SamplerInfo {
544 pub const LINEAR: SamplerInfoBuilder = SamplerInfoBuilder {
546 flags: None,
547 mag_filter: Some(vk::Filter::LINEAR),
548 min_filter: Some(vk::Filter::LINEAR),
549 mipmap_mode: Some(vk::SamplerMipmapMode::LINEAR),
550 address_mode_u: None,
551 address_mode_v: None,
552 address_mode_w: None,
553 mip_lod_bias: None,
554 anisotropy_enable: None,
555 max_anisotropy: None,
556 compare_enable: None,
557 compare_op: None,
558 min_lod: None,
559 max_lod: None,
560 border_color: None,
561 unnormalized_coordinates: None,
562 reduction_mode: None,
563 };
564
565 pub const NEAREST: SamplerInfoBuilder = SamplerInfoBuilder {
568 flags: None,
569 mag_filter: Some(vk::Filter::NEAREST),
570 min_filter: Some(vk::Filter::NEAREST),
571 mipmap_mode: Some(vk::SamplerMipmapMode::NEAREST),
572 address_mode_u: None,
573 address_mode_v: None,
574 address_mode_w: None,
575 mip_lod_bias: None,
576 anisotropy_enable: None,
577 max_anisotropy: None,
578 compare_enable: None,
579 compare_op: None,
580 min_lod: None,
581 max_lod: None,
582 border_color: None,
583 unnormalized_coordinates: None,
584 reduction_mode: None,
585 };
586
587 #[allow(clippy::new_ret_no_self)]
589 #[deprecated = "Use SamplerInfo::default()"]
590 #[doc(hidden)]
591 pub fn new() -> SamplerInfoBuilder {
592 Self::default().to_builder()
593 }
594
595 #[inline(always)]
597 pub fn to_builder(self) -> SamplerInfoBuilder {
598 SamplerInfoBuilder {
599 flags: Some(self.flags),
600 mag_filter: Some(self.mag_filter),
601 min_filter: Some(self.min_filter),
602 mipmap_mode: Some(self.mipmap_mode),
603 address_mode_u: Some(self.address_mode_u),
604 address_mode_v: Some(self.address_mode_v),
605 address_mode_w: Some(self.address_mode_w),
606 mip_lod_bias: Some(self.mip_lod_bias),
607 anisotropy_enable: Some(self.anisotropy_enable),
608 max_anisotropy: Some(self.max_anisotropy),
609 compare_enable: Some(self.compare_enable),
610 compare_op: Some(self.compare_op),
611 min_lod: Some(self.min_lod),
612 max_lod: Some(self.max_lod),
613 border_color: Some(self.border_color),
614 unnormalized_coordinates: Some(self.unnormalized_coordinates),
615 reduction_mode: Some(self.reduction_mode),
616 }
617 }
618}
619
620impl Default for SamplerInfo {
621 fn default() -> Self {
622 Self {
623 flags: vk::SamplerCreateFlags::empty(),
624 mag_filter: vk::Filter::NEAREST,
625 min_filter: vk::Filter::NEAREST,
626 mipmap_mode: vk::SamplerMipmapMode::NEAREST,
627 address_mode_u: vk::SamplerAddressMode::REPEAT,
628 address_mode_v: vk::SamplerAddressMode::REPEAT,
629 address_mode_w: vk::SamplerAddressMode::REPEAT,
630 mip_lod_bias: OrderedFloat(0.0),
631 anisotropy_enable: false,
632 max_anisotropy: OrderedFloat(0.0),
633 compare_enable: false,
634 compare_op: vk::CompareOp::NEVER,
635 min_lod: OrderedFloat(0.0),
636 max_lod: OrderedFloat(0.0),
637 border_color: vk::BorderColor::FLOAT_TRANSPARENT_BLACK,
638 unnormalized_coordinates: false,
639 reduction_mode: vk::SamplerReductionMode::WEIGHTED_AVERAGE,
640 }
641 }
642}
643
644impl SamplerInfoBuilder {
645 #[inline(always)]
647 pub fn build(self) -> SamplerInfo {
648 let res = self.fallible_build();
649
650 #[cfg(test)]
651 let res = res.unwrap();
652
653 #[cfg(not(test))]
654 let res = unsafe { res.unwrap_unchecked() };
655
656 res
657 }
658}
659
660impl From<SamplerInfoBuilder> for SamplerInfo {
661 fn from(info: SamplerInfoBuilder) -> Self {
662 info.build()
663 }
664}
665
666#[derive(Debug)]
667struct SamplerInfoBuilderError;
668
669impl From<UninitializedFieldError> for SamplerInfoBuilderError {
670 fn from(_: UninitializedFieldError) -> Self {
671 Self
672 }
673}
674
675#[allow(missing_docs)]
677#[derive(Builder, Clone)]
678#[builder(
679 build_fn(private, name = "fallible_build", error = "ShaderBuilderError"),
680 derive(Clone, Debug),
681 pattern = "owned"
682)]
683pub struct Shader {
684 #[builder(default = "\"main\".to_owned()")]
688 pub entry_name: String,
689
690 #[builder(default, setter(strip_option))]
734 pub specialization_info: Option<SpecializationInfo>,
735
736 pub spirv: Vec<u8>,
742
743 pub stage: vk::ShaderStageFlags,
745
746 #[builder(private)]
747 entry_point: EntryPoint,
748
749 #[builder(default, private)]
750 image_samplers: HashMap<Descriptor, SamplerInfo>,
751
752 #[builder(default, private, setter(strip_option))]
753 vertex_input_state: Option<VertexInputState>,
754}
755
756impl Shader {
757 #[allow(clippy::new_ret_no_self)]
759 pub fn new(stage: vk::ShaderStageFlags, spirv: impl ShaderCode) -> ShaderBuilder {
760 ShaderBuilder::default()
761 .spirv(spirv.into_vec())
762 .stage(stage)
763 }
764
765 pub fn new_any_hit(spirv: impl ShaderCode) -> ShaderBuilder {
771 Self::new(vk::ShaderStageFlags::ANY_HIT_KHR, spirv)
772 }
773
774 pub fn new_callable(spirv: impl ShaderCode) -> ShaderBuilder {
780 Self::new(vk::ShaderStageFlags::CALLABLE_KHR, spirv)
781 }
782
783 pub fn new_closest_hit(spirv: impl ShaderCode) -> ShaderBuilder {
789 Self::new(vk::ShaderStageFlags::CLOSEST_HIT_KHR, spirv)
790 }
791
792 pub fn new_compute(spirv: impl ShaderCode) -> ShaderBuilder {
798 Self::new(vk::ShaderStageFlags::COMPUTE, spirv)
799 }
800
801 pub fn new_fragment(spirv: impl ShaderCode) -> ShaderBuilder {
807 Self::new(vk::ShaderStageFlags::FRAGMENT, spirv)
808 }
809
810 pub fn new_geometry(spirv: impl ShaderCode) -> ShaderBuilder {
816 Self::new(vk::ShaderStageFlags::GEOMETRY, spirv)
817 }
818
819 pub fn new_intersection(spirv: impl ShaderCode) -> ShaderBuilder {
825 Self::new(vk::ShaderStageFlags::INTERSECTION_KHR, spirv)
826 }
827
828 pub fn new_mesh(spirv: impl ShaderCode) -> ShaderBuilder {
834 Self::new(vk::ShaderStageFlags::MESH_EXT, spirv)
835 }
836
837 pub fn new_miss(spirv: impl ShaderCode) -> ShaderBuilder {
843 Self::new(vk::ShaderStageFlags::MISS_KHR, spirv)
844 }
845
846 pub fn new_ray_gen(spirv: impl ShaderCode) -> ShaderBuilder {
852 Self::new(vk::ShaderStageFlags::RAYGEN_KHR, spirv)
853 }
854
855 pub fn new_task(spirv: impl ShaderCode) -> ShaderBuilder {
861 Self::new(vk::ShaderStageFlags::TASK_EXT, spirv)
862 }
863
864 pub fn new_tesselation_ctrl(spirv: impl ShaderCode) -> ShaderBuilder {
870 Self::new(vk::ShaderStageFlags::TESSELLATION_CONTROL, spirv)
871 }
872
873 pub fn new_tesselation_eval(spirv: impl ShaderCode) -> ShaderBuilder {
879 Self::new(vk::ShaderStageFlags::TESSELLATION_EVALUATION, spirv)
880 }
881
882 pub fn new_vertex(spirv: impl ShaderCode) -> ShaderBuilder {
888 Self::new(vk::ShaderStageFlags::VERTEX, spirv)
889 }
890
891 #[profiling::function]
893 pub(super) fn attachments(
894 &self,
895 ) -> (
896 impl Iterator<Item = u32> + '_,
897 impl Iterator<Item = u32> + '_,
898 ) {
899 (
900 self.entry_point.vars.iter().filter_map(|var| match var {
901 Variable::Descriptor {
902 desc_ty: DescriptorType::InputAttachment(attachment),
903 ..
904 } => Some(*attachment),
905 _ => None,
906 }),
907 self.entry_point.vars.iter().filter_map(|var| match var {
908 Variable::Output { location, .. } => Some(location.loc()),
909 _ => None,
910 }),
911 )
912 }
913
914 #[profiling::function]
915 pub(super) fn descriptor_bindings(&self) -> DescriptorBindingMap {
916 let mut res = DescriptorBindingMap::default();
917
918 for (name, descriptor, desc_ty, binding_count) in
919 self.entry_point.vars.iter().filter_map(|var| match var {
920 Variable::Descriptor {
921 name,
922 desc_bind,
923 desc_ty,
924 nbind,
925 ..
926 } => Some((
927 name,
928 Descriptor {
929 set: desc_bind.set(),
930 binding: desc_bind.bind(),
931 },
932 desc_ty,
933 *nbind,
934 )),
935 _ => None,
936 })
937 {
938 trace!(
939 "descriptor {}: {}.{} = {:?}[{}]",
940 name.as_deref().unwrap_or_default(),
941 descriptor.set,
942 descriptor.binding,
943 *desc_ty,
944 binding_count
945 );
946
947 let descriptor_info = match desc_ty {
948 DescriptorType::AccelStruct() => {
949 DescriptorInfo::AccelerationStructure(binding_count)
950 }
951 DescriptorType::CombinedImageSampler() => {
952 let (sampler_info, is_manually_defined) =
953 self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
954
955 DescriptorInfo::CombinedImageSampler(
956 binding_count,
957 sampler_info,
958 is_manually_defined,
959 )
960 }
961 DescriptorType::InputAttachment(attachment) => {
962 DescriptorInfo::InputAttachment(binding_count, *attachment)
963 }
964 DescriptorType::SampledImage() => DescriptorInfo::SampledImage(binding_count),
965 DescriptorType::Sampler() => {
966 let (sampler_info, is_manually_defined) =
967 self.image_sampler(descriptor, name.as_deref().unwrap_or_default());
968
969 DescriptorInfo::Sampler(binding_count, sampler_info, is_manually_defined)
970 }
971 DescriptorType::StorageBuffer(_access_ty) => {
972 DescriptorInfo::StorageBuffer(binding_count)
973 }
974 DescriptorType::StorageImage(_access_ty) => {
975 DescriptorInfo::StorageImage(binding_count)
976 }
977 DescriptorType::StorageTexelBuffer(_access_ty) => {
978 DescriptorInfo::StorageTexelBuffer(binding_count)
979 }
980 DescriptorType::UniformBuffer() => DescriptorInfo::UniformBuffer(binding_count),
981 DescriptorType::UniformTexelBuffer() => {
982 DescriptorInfo::UniformTexelBuffer(binding_count)
983 }
984 };
985 res.insert(descriptor, (descriptor_info, self.stage));
986 }
987
988 res
989 }
990
991 fn image_sampler(&self, descriptor: Descriptor, name: &str) -> (SamplerInfo, bool) {
992 self.image_samplers
993 .get(&descriptor)
994 .copied()
995 .map(|sampler_info| (sampler_info, true))
996 .unwrap_or_else(|| (guess_immutable_sampler(name), false))
997 }
998
999 #[profiling::function]
1000 pub(super) fn merge_descriptor_bindings(
1001 descriptor_bindings: impl IntoIterator<Item = DescriptorBindingMap>,
1002 ) -> DescriptorBindingMap {
1003 fn merge_info(lhs: &mut DescriptorInfo, rhs: DescriptorInfo) -> bool {
1004 let (lhs_count, rhs_count) = match lhs {
1005 DescriptorInfo::AccelerationStructure(lhs) => {
1006 if let DescriptorInfo::AccelerationStructure(rhs) = rhs {
1007 (lhs, rhs)
1008 } else {
1009 return false;
1010 }
1011 }
1012 DescriptorInfo::CombinedImageSampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1013 if let DescriptorInfo::CombinedImageSampler(
1014 rhs,
1015 rhs_sampler,
1016 rhs_is_manually_defined,
1017 ) = rhs
1018 {
1019 if *lhs_is_manually_defined && rhs_is_manually_defined {
1021 return false;
1022 } else if rhs_is_manually_defined {
1023 *lhs_sampler = rhs_sampler;
1024 }
1025
1026 (lhs, rhs)
1027 } else {
1028 return false;
1029 }
1030 }
1031 DescriptorInfo::InputAttachment(lhs, lhs_idx) => {
1032 if let DescriptorInfo::InputAttachment(rhs, rhs_idx) = rhs {
1033 if *lhs_idx != rhs_idx {
1034 return false;
1035 }
1036
1037 (lhs, rhs)
1038 } else {
1039 return false;
1040 }
1041 }
1042 DescriptorInfo::SampledImage(lhs) => {
1043 if let DescriptorInfo::SampledImage(rhs) = rhs {
1044 (lhs, rhs)
1045 } else {
1046 return false;
1047 }
1048 }
1049 DescriptorInfo::Sampler(lhs, lhs_sampler, lhs_is_manually_defined) => {
1050 if let DescriptorInfo::Sampler(rhs, rhs_sampler, rhs_is_manually_defined) = rhs
1051 {
1052 if *lhs_is_manually_defined && rhs_is_manually_defined {
1054 return false;
1055 } else if rhs_is_manually_defined {
1056 *lhs_sampler = rhs_sampler;
1057 }
1058
1059 (lhs, rhs)
1060 } else {
1061 return false;
1062 }
1063 }
1064 DescriptorInfo::StorageBuffer(lhs) => {
1065 if let DescriptorInfo::StorageBuffer(rhs) = rhs {
1066 (lhs, rhs)
1067 } else {
1068 return false;
1069 }
1070 }
1071 DescriptorInfo::StorageImage(lhs) => {
1072 if let DescriptorInfo::StorageImage(rhs) = rhs {
1073 (lhs, rhs)
1074 } else {
1075 return false;
1076 }
1077 }
1078 DescriptorInfo::StorageTexelBuffer(lhs) => {
1079 if let DescriptorInfo::StorageTexelBuffer(rhs) = rhs {
1080 (lhs, rhs)
1081 } else {
1082 return false;
1083 }
1084 }
1085 DescriptorInfo::UniformBuffer(lhs) => {
1086 if let DescriptorInfo::UniformBuffer(rhs) = rhs {
1087 (lhs, rhs)
1088 } else {
1089 return false;
1090 }
1091 }
1092 DescriptorInfo::UniformTexelBuffer(lhs) => {
1093 if let DescriptorInfo::UniformTexelBuffer(rhs) = rhs {
1094 (lhs, rhs)
1095 } else {
1096 return false;
1097 }
1098 }
1099 };
1100
1101 *lhs_count = rhs_count.max(*lhs_count);
1102
1103 true
1104 }
1105
1106 #[profiling::function]
1107 fn merge_pair(src: DescriptorBindingMap, dst: &mut DescriptorBindingMap) {
1108 for (descriptor_binding, (descriptor_info, descriptor_flags)) in src.into_iter() {
1109 if let Some((existing_info, existing_flags)) = dst.get_mut(&descriptor_binding) {
1110 if !merge_info(existing_info, descriptor_info) {
1111 panic!("Inconsistent shader descriptors ({descriptor_binding:?})");
1112 }
1113
1114 *existing_flags |= descriptor_flags;
1115 } else {
1116 dst.insert(descriptor_binding, (descriptor_info, descriptor_flags));
1117 }
1118 }
1119 }
1120
1121 let mut descriptor_bindings = descriptor_bindings.into_iter();
1122 let mut res = descriptor_bindings.next().unwrap_or_default();
1123 for descriptor_binding in descriptor_bindings {
1124 merge_pair(descriptor_binding, &mut res);
1125 }
1126
1127 res
1128 }
1129
1130 #[profiling::function]
1131 pub(super) fn push_constant_range(&self) -> Option<vk::PushConstantRange> {
1132 self.entry_point
1133 .vars
1134 .iter()
1135 .filter_map(|var| match var {
1136 Variable::PushConstant {
1137 ty: Type::Struct(ty),
1138 ..
1139 } => Some(ty.members.clone()),
1140 _ => None,
1141 })
1142 .flatten()
1143 .map(|push_const| {
1144 let offset = push_const.offset.unwrap_or_default();
1145 let size = push_const
1146 .ty
1147 .nbyte()
1148 .unwrap_or_default()
1149 .next_multiple_of(4);
1150 offset..offset + size
1151 })
1152 .reduce(|a, b| a.start.min(b.start)..a.end.max(b.end))
1153 .map(|push_const| vk::PushConstantRange {
1154 stage_flags: self.stage,
1155 size: (push_const.end - push_const.start) as _,
1156 offset: push_const.start as _,
1157 })
1158 }
1159
1160 #[profiling::function]
1161 fn reflect_entry_point(
1162 entry_name: &str,
1163 spirv: &[u8],
1164 specialization_info: Option<&SpecializationInfo>,
1165 ) -> Result<EntryPoint, DriverError> {
1166 let mut config = ReflectConfig::new();
1167 config.ref_all_rscs(true).spv(spirv);
1168
1169 if let Some(spec_info) = specialization_info {
1170 for spec in &spec_info.map_entries {
1171 config.specialize(
1172 spec.constant_id,
1173 spec_info.data[spec.offset as usize..spec.offset as usize + spec.size].into(),
1174 );
1175 }
1176 }
1177
1178 let entry_points = config.reflect().map_err(|err| {
1179 error!("Unable to reflect spirv: {err}");
1180
1181 DriverError::InvalidData
1182 })?;
1183 let entry_point = entry_points
1184 .into_iter()
1185 .find(|entry_point| entry_point.name == entry_name)
1186 .ok_or_else(|| {
1187 error!("Entry point not found");
1188
1189 DriverError::InvalidData
1190 })?;
1191
1192 Ok(entry_point)
1193 }
1194
1195 #[profiling::function]
1196 pub(super) fn vertex_input(&self) -> VertexInputState {
1197 if let Some(vertex_input) = &self.vertex_input_state {
1199 return vertex_input.clone();
1200 }
1201
1202 fn scalar_format(ty: &ScalarType, byte_len: u32) -> vk::Format {
1203 match ty {
1204 ScalarType::Float { .. } => match byte_len {
1205 4 => vk::Format::R32_SFLOAT,
1206 8 => vk::Format::R32G32_SFLOAT,
1207 12 => vk::Format::R32G32B32_SFLOAT,
1208 16 => vk::Format::R32G32B32A32_SFLOAT,
1209 _ => unimplemented!("byte_len {byte_len}"),
1210 },
1211 ScalarType::Integer {
1212 is_signed: true, ..
1213 } => match byte_len {
1214 4 => vk::Format::R32_SINT,
1215 8 => vk::Format::R32G32_SINT,
1216 12 => vk::Format::R32G32B32_SINT,
1217 16 => vk::Format::R32G32B32A32_SINT,
1218 _ => unimplemented!("byte_len {byte_len}"),
1219 },
1220 ScalarType::Integer {
1221 is_signed: false, ..
1222 } => match byte_len {
1223 4 => vk::Format::R32_UINT,
1224 8 => vk::Format::R32G32_UINT,
1225 12 => vk::Format::R32G32B32_UINT,
1226 16 => vk::Format::R32G32B32A32_UINT,
1227 _ => unimplemented!("byte_len {byte_len}"),
1228 },
1229 _ => unimplemented!("{:?}", ty),
1230 }
1231 }
1232
1233 let mut input_rates_strides = HashMap::new();
1234 let mut vertex_attribute_descriptions = vec![];
1235
1236 for (name, location, ty) in self.entry_point.vars.iter().filter_map(|var| match var {
1237 Variable::Input { name, location, ty } => Some((name, location, ty)),
1238 _ => None,
1239 }) {
1240 let (binding, guessed_rate) = name
1241 .as_ref()
1242 .filter(|name| name.contains("_ibind") || name.contains("_vbind"))
1243 .map(|name| {
1244 let binding = name[name.rfind("bind").unwrap()..]
1245 .parse()
1246 .unwrap_or_default();
1247 let rate = if name.contains("_ibind") {
1248 vk::VertexInputRate::INSTANCE
1249 } else {
1250 vk::VertexInputRate::VERTEX
1251 };
1252
1253 (binding, rate)
1254 })
1255 .unwrap_or_default();
1256 let (location, _) = location.into_inner();
1257 if let Some((input_rate, _)) = input_rates_strides.get(&binding) {
1258 assert_eq!(*input_rate, guessed_rate);
1259 }
1260
1261 let byte_stride = ty.nbyte().unwrap_or_default() as u32;
1262 let (input_rate, stride) = input_rates_strides.entry(binding).or_default();
1263 *input_rate = guessed_rate;
1264 *stride += byte_stride;
1265
1266 vertex_attribute_descriptions.push(vk::VertexInputAttributeDescription {
1269 location,
1270 binding,
1271 format: match ty {
1272 Type::Scalar(ty) => scalar_format(ty, ty.nbyte().unwrap_or_default() as _),
1273 Type::Vector(ty) => scalar_format(&ty.scalar_ty, byte_stride),
1274 _ => unimplemented!("{:?}", ty),
1275 },
1276 offset: byte_stride, });
1278 }
1279
1280 vertex_attribute_descriptions.sort_unstable_by(|lhs, rhs| {
1281 let binding = lhs.binding.cmp(&rhs.binding);
1282 if binding.is_lt() {
1283 return binding;
1284 }
1285
1286 lhs.location.cmp(&rhs.location)
1287 });
1288
1289 let mut offset = 0;
1290 let mut offset_binding = 0;
1291
1292 for vertex_attribute_description in &mut vertex_attribute_descriptions {
1293 if vertex_attribute_description.binding != offset_binding {
1294 offset_binding = vertex_attribute_description.binding;
1295 offset = 0;
1296 }
1297
1298 let stride = vertex_attribute_description.offset;
1299 vertex_attribute_description.offset = offset;
1300 offset += stride;
1301
1302 debug!(
1303 "vertex attribute {}.{}: {:?} (offset={})",
1304 vertex_attribute_description.binding,
1305 vertex_attribute_description.location,
1306 vertex_attribute_description.format,
1307 vertex_attribute_description.offset,
1308 );
1309 }
1310
1311 let mut vertex_binding_descriptions = vec![];
1312 for (binding, (input_rate, stride)) in input_rates_strides.into_iter() {
1313 vertex_binding_descriptions.push(vk::VertexInputBindingDescription {
1314 binding,
1315 input_rate,
1316 stride,
1317 });
1318 }
1319
1320 VertexInputState {
1321 vertex_attribute_descriptions,
1322 vertex_binding_descriptions,
1323 }
1324 }
1325}
1326
1327impl Debug for Shader {
1328 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
1329 f.write_str("Shader")
1332 }
1333}
1334
1335impl From<ShaderBuilder> for Shader {
1336 fn from(shader: ShaderBuilder) -> Self {
1337 shader.build()
1338 }
1339}
1340
1341impl ShaderBuilder {
1343 pub fn new(stage: vk::ShaderStageFlags, spirv: Vec<u8>) -> Self {
1345 Self::default().stage(stage).spirv(spirv)
1346 }
1347
1348 pub fn build(mut self) -> Shader {
1350 let entry_name = self.entry_name.as_deref().unwrap_or("main");
1351 self.entry_point = Some(
1352 Shader::reflect_entry_point(
1353 entry_name,
1354 self.spirv.as_deref().unwrap(),
1355 self.specialization_info
1356 .as_ref()
1357 .map(|opt| opt.as_ref())
1358 .unwrap_or_default(),
1359 )
1360 .unwrap_or_else(|_| panic!("invalid shader code for entry name \'{entry_name}\'")),
1361 );
1362
1363 self.fallible_build()
1364 .expect("All required fields set at initialization")
1365 }
1366
1367 #[profiling::function]
1387 pub fn image_sampler(
1388 mut self,
1389 descriptor: impl Into<Descriptor>,
1390 info: impl Into<SamplerInfo>,
1391 ) -> Self {
1392 let descriptor = descriptor.into();
1393 let info = info.into();
1394
1395 if self.image_samplers.is_none() {
1396 self.image_samplers = Some(Default::default());
1397 }
1398
1399 self.image_samplers
1400 .as_mut()
1401 .unwrap()
1402 .insert(descriptor, info);
1403
1404 self
1405 }
1406
1407 #[profiling::function]
1419 pub fn vertex_input(
1420 mut self,
1421 bindings: &[vk::VertexInputBindingDescription],
1422 attributes: &[vk::VertexInputAttributeDescription],
1423 ) -> Self {
1424 self.vertex_input_state = Some(Some(VertexInputState {
1425 vertex_binding_descriptions: bindings.to_vec(),
1426 vertex_attribute_descriptions: attributes.to_vec(),
1427 }));
1428 self
1429 }
1430}
1431
1432#[derive(Debug)]
1433struct ShaderBuilderError;
1434
1435impl From<UninitializedFieldError> for ShaderBuilderError {
1436 fn from(_: UninitializedFieldError) -> Self {
1437 Self
1438 }
1439}
1440
1441pub trait ShaderCode {
1443 fn into_vec(self) -> Vec<u8>;
1445}
1446
1447impl ShaderCode for &[u8] {
1448 fn into_vec(self) -> Vec<u8> {
1449 debug_assert_eq!(self.len() % 4, 0, "invalid spir-v code");
1450
1451 self.to_vec()
1452 }
1453}
1454
1455impl ShaderCode for &[u32] {
1456 fn into_vec(self) -> Vec<u8> {
1457 pub fn into_u8_slice<T>(t: &[T]) -> &[u8]
1458 where
1459 T: Sized,
1460 {
1461 use std::slice::from_raw_parts;
1462
1463 unsafe { from_raw_parts(t.as_ptr() as *const _, size_of_val(t)) }
1464 }
1465
1466 into_u8_slice(self).into_vec()
1467 }
1468}
1469
1470impl ShaderCode for Vec<u8> {
1471 fn into_vec(self) -> Vec<u8> {
1472 debug_assert_eq!(self.len() % 4, 0, "invalid spir-v code");
1473
1474 self
1475 }
1476}
1477
1478impl ShaderCode for Vec<u32> {
1479 fn into_vec(self) -> Vec<u8> {
1480 self.as_slice().into_vec()
1481 }
1482}
1483
1484#[derive(Clone, Debug)]
1486pub struct SpecializationInfo {
1487 pub data: Vec<u8>,
1489
1490 pub map_entries: Vec<vk::SpecializationMapEntry>,
1492}
1493
1494impl SpecializationInfo {
1495 pub fn new(
1497 map_entries: impl Into<Vec<vk::SpecializationMapEntry>>,
1498 data: impl Into<Vec<u8>>,
1499 ) -> Self {
1500 Self {
1501 data: data.into(),
1502 map_entries: map_entries.into(),
1503 }
1504 }
1505}
1506
1507#[cfg(test)]
1508mod tests {
1509 use super::*;
1510
1511 type Info = SamplerInfo;
1512 type Builder = SamplerInfoBuilder;
1513
1514 #[test]
1515 pub fn sampler_info() {
1516 let info = Info::default();
1517 let builder = info.to_builder().build();
1518
1519 assert_eq!(info, builder);
1520 }
1521
1522 #[test]
1523 pub fn sampler_info_builder() {
1524 let info = Info::default();
1525 let builder = Builder::default().build();
1526
1527 assert_eq!(info, builder);
1528 }
1529}