spirv_layout/lib.rs
1#![allow(unknown_lints)]
2#![warn(clippy::all, clippy::pedantic)]
3#![allow(
4 clippy::unreadable_literal,
5 clippy::too_many_lines,
6 clippy::must_use_candidate
7)]
8
9use std::{collections::HashMap, str::Utf8Error};
10
11use ops::{Dim, Id, Op};
12use thiserror::Error;
13
14mod ops;
15
16#[derive(Debug, Clone, Error)]
17pub enum Error {
18 #[error("{0}")]
19 Other(String),
20 #[error("invalid header")]
21 InvalidHeader,
22 #[error("invalid bytes in commmand")]
23 InvalidOp,
24 #[error("invalid id")]
25 InvalidId,
26 #[error("invalid utf-8 in string")]
27 StringFormat(#[from] Utf8Error),
28}
29
30pub type SpirvResult<T> = ::std::result::Result<T, Error>;
31
32/// Stores the reflection info of a single SPIRV module.
33#[derive(Debug)]
34pub struct Module {
35 types: HashMap<u32, Type>,
36 entry_points: Vec<EntryPoint>,
37}
38
39/// Describes a single `EntryPoint` in a SPIR-V module.
40///
41/// A SPIR-V module can have multiple entry points with different names, each defining a single shader.
42#[derive(Debug)]
43pub struct EntryPoint {
44 /// The name of the entry point, used for identification
45 pub name: String,
46 /// The [`ExecutionModel`] of the entry point, selects which type of shader this entry point defines
47 pub execution_model: ExecutionModel,
48 /// All uniform variables used in this shader
49 pub uniforms: Vec<UniformVariable>,
50 /// All push constant variables used in this shader
51 pub push_constants: Vec<PushConstantVariable>,
52 /// All inputs used in this shader
53 pub inputs: Vec<LocationVariable>,
54 /// All outputs used in this shader
55 pub outputs: Vec<LocationVariable>,
56}
57
58impl Module {
59 /// Generates reflection info from a given stream of `words`.
60 ///
61 /// # Errors
62 /// - [`Error::InvalidHeader`] if the SPIRV header is not valid
63 /// - [`Error::InvalidOp`] if the binary representation of any instruction in `words` is not valid
64 /// - [`Error::InvalidId`] if any type declaration in the SPIRV module reference non-existent IDs
65 /// - [`Error::StringFormat`] if any `OpCode` contains a String with invalid UTF-8 characters
66 /// - [`Error::Other`] if any other errors occur
67 pub fn from_words(mut words: &[u32]) -> SpirvResult<Self> {
68 // Check the SPIRV header magic number
69 if words.len() < 6 || words[0] != 0x07230203 {
70 return Err(Error::InvalidHeader);
71 }
72
73 // Skip the rest of the header (Should be parsed in the future)
74 words = &words[5..];
75
76 // decode all opcodes
77 let mut ops = Vec::new();
78 while !words.is_empty() {
79 let op = Op::decode(&mut words)?;
80 ops.push(op);
81 }
82
83 // All OpConstant values are stored in this Map
84 let mut constants = HashMap::new();
85 // All type declarations are stored in this Map
86 let mut types = HashMap::new();
87 // All variable declarations are stored in this Map
88 let mut vars = HashMap::new();
89 // All entry points declarations are stored in this Vec
90 let mut entries = Vec::new();
91
92 Self::collect_types_and_vars(&ops, &mut types, &mut constants, &mut vars, &mut entries)?;
93 Self::collect_decorations_and_names(&ops, &mut types, &mut vars);
94
95 // uniforms are all variables that are a pointer with a storage class of Uniform or UniformConstant
96 let uniforms: HashMap<_, _> = vars
97 .iter()
98 .filter_map(|(id, var)| {
99 if let Some(Type::Pointer {
100 storage_class: StorageClass::Uniform | StorageClass::UniformConstant,
101 pointed_type_id,
102 }) = types.get(&var.type_id)
103 {
104 Some((
105 *id,
106 UniformVariable {
107 set: var.set?,
108 binding: var.binding?,
109 type_id: *pointed_type_id, // for convenience, we store the pointed-to type instead of the pointer, since every uniform is a pointer
110 name: var.name.clone(),
111 },
112 ))
113 } else {
114 None
115 }
116 })
117 .collect();
118
119 let push_constants: HashMap<_, _> = vars
120 .iter()
121 .filter_map(|(id, var)| {
122 if let Some(Type::Pointer {
123 storage_class: StorageClass::PushConstant,
124 pointed_type_id,
125 }) = types.get(&var.type_id)
126 {
127 Some((
128 *id,
129 PushConstantVariable {
130 type_id: *pointed_type_id,
131 name: var.name.clone(),
132 },
133 ))
134 } else {
135 None
136 }
137 })
138 .collect();
139
140 let inputs: HashMap<_, _> = vars
141 .iter()
142 .filter_map(|(id, var)| {
143 if let Some(Type::Pointer {
144 storage_class: StorageClass::Input,
145 pointed_type_id,
146 }) = types.get(&var.type_id)
147 {
148 Some((
149 *id,
150 LocationVariable {
151 location: var.location?,
152 type_id: *pointed_type_id,
153 name: var.name.clone(),
154 },
155 ))
156 } else {
157 None
158 }
159 })
160 .collect();
161
162 let outputs: HashMap<_, _> = vars
163 .iter()
164 .filter_map(|(id, var)| {
165 if let Some(Type::Pointer {
166 storage_class: StorageClass::Output,
167 pointed_type_id,
168 }) = types.get(&var.type_id)
169 {
170 Some((
171 *id,
172 LocationVariable {
173 location: var.location?,
174 type_id: *pointed_type_id,
175 name: var.name.clone(),
176 },
177 ))
178 } else {
179 None
180 }
181 })
182 .collect();
183
184 let entry_points = entries
185 .iter()
186 .map(|e| {
187 let uniforms = e
188 .interface
189 .iter()
190 .filter_map(|id| uniforms.get(&id.0).cloned())
191 .collect();
192 let push_constants = e
193 .interface
194 .iter()
195 .filter_map(|id| push_constants.get(&id.0).cloned())
196 .collect();
197 let inputs = e
198 .interface
199 .iter()
200 .filter_map(|id| inputs.get(&id.0).cloned())
201 .collect();
202 let outputs = e
203 .interface
204 .iter()
205 .filter_map(|id| outputs.get(&id.0).cloned())
206 .collect();
207
208 EntryPoint {
209 name: e.name.clone(),
210 execution_model: e.execution_model,
211 uniforms,
212 push_constants,
213 inputs,
214 outputs,
215 }
216 })
217 .collect();
218
219 Ok(Self {
220 types,
221 entry_points,
222 })
223 }
224
225 /// Returns the [`Type`] definition indicated by `type_id`, or `None` if `type_id` is not a type.
226 pub fn get_type(&self, type_id: u32) -> Option<&Type> {
227 self.types.get(&type_id)
228 }
229
230 /// Returns the [`EntryPoint`] definitions contained in the given SPIR-V module
231 pub fn get_entry_points(&self) -> &[EntryPoint] {
232 &self.entry_points
233 }
234
235 fn get_type_size(&self, type_id: u32, stride: Option<u32>) -> Option<u32> {
236 if let Some(ty) = self.types.get(&type_id) {
237 match ty {
238 Type::Int32 | Type::UInt32 | Type::Float32 => Some(4),
239 Type::Vec2 => Some(8),
240 Type::Vec3 => Some(12),
241 Type::Vec4 => Some(16),
242 Type::Mat3 => stride.map(|stride| stride * 2 + 12), // two rows/columns + sizeof(Vec3)
243 Type::Mat4 => stride.map(|stride| stride * 3 + 16), // three rows/columns + sizeof(Vec4)
244 Type::Struct { elements, .. } => {
245 // Since there is no Size Decoration in SPIRV that tells us the size,
246 // we calculate it from the offset of the last member and its size.
247 let last_element = elements.iter().max_by_key(|e| e.offset.unwrap_or(0))?;
248 let offset = last_element.offset?;
249 let size = self.get_member_size(last_element)?;
250
251 Some(offset + size)
252 }
253 _ => None,
254 }
255 } else {
256 None
257 }
258 }
259
260 /// Returns the size of a given [`StructMember`], if known.
261 pub fn get_member_size(&self, member: &StructMember) -> Option<u32> {
262 self.get_type_size(member.type_id, Some(member.stride))
263 }
264
265 /// Returns the size of a given [`UniformVariable`], [`PushConstantVariable`] or [`LocationVariable`], if known.
266 pub fn get_var_size<T: Variable>(&self, var: &T) -> Option<u32> {
267 self.get_type_size(var.get_type_id(), None)
268 }
269
270 /// Parses all the Op*Decoration and Op*Name instructions
271 fn collect_decorations_and_names(
272 ops: &[Op],
273 types: &mut HashMap<u32, Type>,
274 vars: &mut HashMap<u32, RawVariable>,
275 ) {
276 for op in ops {
277 match op {
278 Op::OpName { target, name } => {
279 if let Some(target) = vars.get_mut(&target.0) {
280 target.name = Some(name.clone());
281 } else if let Some(Type::Struct { name: n, .. }) = types.get_mut(&target.0) {
282 *n = Some(name.clone());
283 }
284 }
285 Op::OpMemberName {
286 target,
287 member_index,
288 name,
289 } => {
290 if let Some(Type::Struct { elements, .. }) = types.get_mut(&target.0) {
291 if elements.len() > *member_index as usize {
292 elements[*member_index as usize].name = Some(name.clone());
293 }
294 }
295 }
296 Op::OpDecorate { target, decoration } => match decoration {
297 ops::Decoration::Binding { binding } => {
298 if let Some(target) = vars.get_mut(&target.0) {
299 target.binding = Some(*binding);
300 }
301 }
302 ops::Decoration::DescriptorSet { set } => {
303 if let Some(target) = vars.get_mut(&target.0) {
304 target.set = Some(*set);
305 }
306 }
307 ops::Decoration::Location { loc } => {
308 if let Some(target) = vars.get_mut(&target.0) {
309 target.location = Some(*loc);
310 }
311 }
312 _ => {}
313 },
314 Op::OpMemberDecorate {
315 target,
316 member_index,
317 decoration,
318 } => {
319 if let Some(Type::Struct { elements, .. }) = types.get_mut(&target.0) {
320 if elements.len() > *member_index as usize {
321 match decoration {
322 ops::Decoration::RowMajor {} => {
323 elements[*member_index as usize].row_major = true;
324 }
325 ops::Decoration::ColMajor {} => {
326 elements[*member_index as usize].row_major = false;
327 }
328 ops::Decoration::MatrixStride { stride } => {
329 elements[*member_index as usize].stride = *stride;
330 }
331 ops::Decoration::Offset { offset } => {
332 elements[*member_index as usize].offset = Some(*offset);
333 }
334 _ => {}
335 }
336 }
337 }
338 }
339 _ => {}
340 }
341 }
342 }
343
344 // Parses all the OpType* and OpVariable instructions
345 fn collect_types_and_vars(
346 ops: &[Op],
347 types: &mut HashMap<u32, Type>,
348 constants: &mut HashMap<u32, u32>,
349 vars: &mut HashMap<u32, RawVariable>,
350 entries: &mut Vec<RawEntryPoint>,
351 ) -> SpirvResult<()> {
352 for op in ops {
353 match op {
354 Op::OpTypeVoid { result } => {
355 types.insert(result.0, Type::Void);
356 }
357 Op::OpTypeBool { result } => {
358 types.insert(result.0, Type::Bool);
359 }
360 Op::OpTypeInt {
361 result,
362 width,
363 signed,
364 } => {
365 if *width != 32 {
366 types.insert(result.0, Type::Unknown);
367 } else if *signed == 0 {
368 types.insert(result.0, Type::UInt32);
369 } else {
370 types.insert(result.0, Type::Int32);
371 }
372 }
373 Op::OpTypeFloat { result, width } => {
374 if *width == 32 {
375 types.insert(result.0, Type::Float32);
376 } else {
377 types.insert(result.0, Type::Unknown);
378 }
379 }
380 Op::OpTypeVector {
381 result,
382 component_type,
383 component_count,
384 } => {
385 if let Some(t) = types.get(&component_type.0) {
386 if let Type::Float32 = t {
387 match component_count {
388 2 => {
389 types.insert(result.0, Type::Vec2);
390 }
391 3 => {
392 types.insert(result.0, Type::Vec3);
393 }
394 4 => {
395 types.insert(result.0, Type::Vec4);
396 }
397 _ => {
398 types.insert(result.0, Type::Unknown);
399 }
400 }
401 } else {
402 types.insert(result.0, Type::Unknown);
403 }
404 } else {
405 return Err(Error::InvalidId);
406 }
407 }
408 Op::OpTypeMatrix {
409 result,
410 column_type,
411 column_count,
412 } => {
413 let t = types
414 .get(&column_type.0)
415 .map(|column_type| match column_type {
416 Type::Vec3 if *column_count == 3 => Type::Mat3,
417 Type::Vec4 if *column_count == 4 => Type::Mat4,
418 _ => Type::Unknown,
419 })
420 .unwrap_or(Type::Unknown);
421 types.insert(result.0, t);
422 }
423 Op::OpTypeImage {
424 result,
425 sampled_type,
426 dim,
427 depth,
428 arrayed: _,
429 ms: _,
430 sampled,
431 format,
432 access: _,
433 } => {
434 let t = if let Some(Type::Float32) = types.get(&sampled_type.0) {
435 if let Dim::D2 {} = dim {
436 Type::Image2D {
437 depth: *depth != 0,
438 sampled: *sampled != 0,
439 format: *format,
440 }
441 } else {
442 Type::Unknown
443 }
444 } else {
445 Type::Unknown
446 };
447 types.insert(result.0, t);
448 }
449 Op::OpTypeSampler { result } => {
450 types.insert(result.0, Type::Sampler);
451 }
452 Op::OpTypeSampledImage { result, image_type } => {
453 let t = if let Some(Type::Image2D { .. }) = types.get(&image_type.0) {
454 Type::SampledImage {
455 image_type_id: image_type.0,
456 }
457 } else {
458 Type::Unknown
459 };
460 types.insert(result.0, t);
461 }
462 Op::OpTypeArray {
463 result,
464 element_type,
465 length,
466 } => {
467 if let Some(length) = constants.get(&length.0) {
468 types.insert(
469 result.0,
470 Type::Array {
471 element_type_id: element_type.0,
472 length: Some(*length),
473 },
474 );
475 } else {
476 return Err(Error::InvalidId);
477 }
478 }
479 Op::OpTypeRuntimeArray {
480 result,
481 element_type,
482 } => {
483 types.insert(
484 result.0,
485 Type::Array {
486 element_type_id: element_type.0,
487 length: None,
488 },
489 );
490 }
491 Op::OpTypeStruct {
492 result,
493 element_types,
494 } => {
495 types.insert(
496 result.0,
497 Type::Struct {
498 name: None,
499 elements: element_types
500 .iter()
501 .map(|e| StructMember {
502 name: None,
503 type_id: e.0,
504 offset: None,
505 row_major: true,
506 stride: 16,
507 })
508 .collect(),
509 },
510 );
511 }
512 Op::OpTypePointer {
513 result,
514 storage_class,
515 pointed_type,
516 } => {
517 types.insert(
518 result.0,
519 Type::Pointer {
520 storage_class: match storage_class {
521 ops::StorageClass::Unknown => StorageClass::Unknown,
522 ops::StorageClass::UniformConstant {}
523 | ops::StorageClass::Uniform {} => StorageClass::Uniform,
524 ops::StorageClass::PushConstant {} => StorageClass::PushConstant,
525 ops::StorageClass::Input {} => StorageClass::Input,
526 ops::StorageClass::Output {} => StorageClass::Output,
527 },
528 pointed_type_id: pointed_type.0,
529 },
530 );
531 }
532 Op::OpConstant {
533 result_type,
534 result,
535 value,
536 } => {
537 if let Some(Type::UInt32) = types.get(&result_type.0) {
538 if value.len() == 1 {
539 constants.insert(result.0, value[0]);
540 }
541 }
542 }
543 Op::OpVariable {
544 result_type,
545 result,
546 storage_class: _,
547 initializer: _,
548 } => {
549 vars.insert(
550 result.0,
551 RawVariable {
552 set: None,
553 binding: None,
554 location: None,
555 type_id: result_type.0,
556 name: None,
557 },
558 );
559 }
560 Op::OpEntryPoint {
561 execution_model,
562 func: _,
563 name,
564 interface,
565 } => {
566 entries.push(RawEntryPoint {
567 name: name.clone(),
568 execution_model: match execution_model {
569 ops::ExecutionModel::Unknown => {
570 return Err(Error::Other(
571 "Unknown execution model in entry point".to_string(),
572 ))
573 }
574 ops::ExecutionModel::Vertex {} => ExecutionModel::Vertex,
575 ops::ExecutionModel::Fragment {} => ExecutionModel::Fragment,
576 },
577 interface: interface.clone(),
578 });
579 }
580 _ => {}
581 }
582 }
583
584 Ok(())
585 }
586}
587
588/// Represents a type declared in a SPIRV module.
589///
590/// Types are declared in a hierarchy, with e.g. pointers relying on previously declared types as pointed-to types.
591#[derive(Debug)]
592#[non_exhaustive]
593pub enum Type {
594 /// An unsupported type
595 Unknown,
596 /// The Void type
597 Void,
598 /// A boolean
599 Bool,
600 /// A signed 32-Bit integer
601 Int32,
602 /// An unsigned 32-Bit integer
603 UInt32,
604 /// A 32-Bit float
605 Float32,
606 /// A 2 component, 32-Bit vector (GLSL: vec2)
607 Vec2,
608 /// A 3 component, 32-Bit vector (GLSL: vec3)
609 Vec3,
610 /// A 4 component, 32-Bit vector (GLSL: vec4)
611 Vec4,
612 /// A 3x3, 32-Bit Matrix (GLSL: mat3)
613 Mat3,
614 /// A 4x4, 32-Bit Matrix (GLSL: mat4)
615 Mat4,
616 /// A 2D image
617 Image2D {
618 /// true if this image is a depth image
619 depth: bool,
620 /// true if this image can be sampled from
621 sampled: bool,
622 /// SPIRV code of the images format (should always be 0 in Vulkan)
623 format: u32,
624 },
625 /// An opaque sampler object
626 Sampler,
627 /// A combined image and sampler (Vulkan: CombinedImageSampler descriptor)
628 SampledImage {
629 /// type id of the image contained in the SampledImage
630 image_type_id: u32,
631 },
632 /// Either a static array with known length (`length` is [`Some`]) or dynamic array with unknown length (`length` is [`None`])
633 Array {
634 /// type id of the contained type
635 element_type_id: u32,
636 /// length of the array (if known)
637 length: Option<u32>,
638 },
639 /// A struct containing other types
640 Struct {
641 name: Option<String>,
642 /// members of the struct, in the order they appear in the SPIRV module (not necessarily ascending offsets)
643 elements: Vec<StructMember>,
644 },
645 /// A pointer pointing to another type
646 Pointer {
647 /// The type of storage this pointer points to
648 storage_class: StorageClass,
649 /// The type id of the pointed-to type
650 pointed_type_id: u32,
651 },
652}
653
654/// Describes a single member of a [`Type::Struct`] type
655#[derive(Debug)]
656pub struct StructMember {
657 /// The name of the member variable (if known)
658 pub name: Option<String>,
659 /// The type id of the member's [`Type`]
660 pub type_id: u32,
661 /// The offset within the struct of this member (if known)
662 pub offset: Option<u32>,
663 /// For matrix members: whether this matrix is stored in row major order
664 pub row_major: bool,
665 /// For matrix members: The stride between rows/columns of the matrix
666 pub stride: u32,
667}
668
669/// Describes what type of storage a pointer points to
670#[derive(Debug)]
671#[non_exhaustive]
672pub enum StorageClass {
673 Unknown,
674 /// The pointer is a uniform variable (Uniform blocks)
675 Uniform,
676 /// The pointer is a uniform variable (Images, etc.)
677 UniformConstant,
678 /// The pointer is a push constant
679 PushConstant,
680 /// The pointer is an input variable
681 Input,
682 /// The pointer is an output variable
683 Output,
684}
685
686/// The execution model of an [`EntryPoint`].
687#[derive(Debug, Clone, Copy)]
688#[non_exhaustive]
689pub enum ExecutionModel {
690 /// A Vertex Shader
691 Vertex,
692 /// A Fragment Shader
693 Fragment,
694}
695
696#[derive(Debug, Clone)]
697struct RawVariable {
698 set: Option<u32>,
699 binding: Option<u32>,
700 location: Option<u32>,
701 type_id: u32,
702 name: Option<String>,
703}
704
705#[derive(Debug)]
706struct RawEntryPoint {
707 name: String,
708 execution_model: ExecutionModel,
709 interface: Vec<Id>,
710}
711
712/// Describes a uniform variable declared in a SPIRV module
713#[derive(Debug, Clone)]
714pub struct UniformVariable {
715 /// Which DescriptorSet the variable is contained in (if known)
716 pub set: u32,
717 /// Which DescriptorSet binding the variable is contained in (if known)
718 pub binding: u32,
719 /// The type id of the variable's [`Type`]
720 pub type_id: u32,
721 /// The variables name (if known)
722 pub name: Option<String>,
723}
724
725/// Describes a push constant variable declared in a SPIRV module
726#[derive(Debug, Clone)]
727pub struct PushConstantVariable {
728 /// The type id of the variable's [`Type`]
729 pub type_id: u32,
730 /// The variables name (if known)
731 pub name: Option<String>,
732}
733
734/// Describes an input or output variable declared in a SPIRV module
735#[derive(Debug, Clone)]
736pub struct LocationVariable {
737 /// The location of the variable (e.g. GLSL `layout(location=XXX)`)
738 pub location: u32,
739 /// The type id of the variable's [`Type`]
740 pub type_id: u32,
741 /// The variable's name (if known)
742 pub name: Option<String>,
743}
744
745mod private {
746 pub trait Variable {
747 fn get_type_id(&self) -> u32;
748 }
749}
750
751pub trait Variable: private::Variable {}
752impl<T: private::Variable> Variable for T {}
753
754impl private::Variable for UniformVariable {
755 fn get_type_id(&self) -> u32 {
756 self.type_id
757 }
758}
759impl private::Variable for PushConstantVariable {
760 fn get_type_id(&self) -> u32 {
761 self.type_id
762 }
763}
764impl private::Variable for LocationVariable {
765 fn get_type_id(&self) -> u32 {
766 self.type_id
767 }
768}