spirv_layout/
lib.rs

1#![allow(unknown_lints)]
2#![warn(clippy::all, clippy::pedantic)]
3#![allow(
4    clippy::unreadable_literal,
5    clippy::too_many_lines,
6    clippy::must_use_candidate
7)]
8
9use std::{collections::HashMap, str::Utf8Error};
10
11use ops::{Dim, Id, Op};
12use thiserror::Error;
13
14mod ops;
15
16#[derive(Debug, Clone, Error)]
17pub enum Error {
18    #[error("{0}")]
19    Other(String),
20    #[error("invalid header")]
21    InvalidHeader,
22    #[error("invalid bytes in commmand")]
23    InvalidOp,
24    #[error("invalid id")]
25    InvalidId,
26    #[error("invalid utf-8 in string")]
27    StringFormat(#[from] Utf8Error),
28}
29
30pub type SpirvResult<T> = ::std::result::Result<T, Error>;
31
32/// Stores the reflection info of a single SPIRV module.
33#[derive(Debug)]
34pub struct Module {
35    types: HashMap<u32, Type>,
36    entry_points: Vec<EntryPoint>,
37}
38
39/// Describes a single `EntryPoint` in a SPIR-V module.
40///
41/// A SPIR-V module can have multiple entry points with different names, each defining a single shader.
42#[derive(Debug)]
43pub struct EntryPoint {
44    /// The name of the entry point, used for identification
45    pub name: String,
46    /// The [`ExecutionModel`] of the entry point, selects which type of shader this entry point defines
47    pub execution_model: ExecutionModel,
48    /// All uniform variables used in this shader
49    pub uniforms: Vec<UniformVariable>,
50    /// All push constant variables used in this shader
51    pub push_constants: Vec<PushConstantVariable>,
52    /// All inputs used in this shader
53    pub inputs: Vec<LocationVariable>,
54    /// All outputs used in this shader
55    pub outputs: Vec<LocationVariable>,
56}
57
58impl Module {
59    /// Generates reflection info from a given stream of `words`.
60    ///
61    /// # Errors
62    /// - [`Error::InvalidHeader`] if the SPIRV header is not valid
63    /// - [`Error::InvalidOp`] if the binary representation of any instruction in `words` is not valid
64    /// - [`Error::InvalidId`] if any type declaration in the SPIRV module reference non-existent IDs
65    /// - [`Error::StringFormat`] if any `OpCode` contains a String with invalid UTF-8 characters
66    /// - [`Error::Other`] if any other errors occur
67    pub fn from_words(mut words: &[u32]) -> SpirvResult<Self> {
68        // Check the SPIRV header magic number
69        if words.len() < 6 || words[0] != 0x07230203 {
70            return Err(Error::InvalidHeader);
71        }
72
73        // Skip the rest of the header (Should be parsed in the future)
74        words = &words[5..];
75
76        // decode all opcodes
77        let mut ops = Vec::new();
78        while !words.is_empty() {
79            let op = Op::decode(&mut words)?;
80            ops.push(op);
81        }
82
83        // All OpConstant values are stored in this Map
84        let mut constants = HashMap::new();
85        // All type declarations are stored in this Map
86        let mut types = HashMap::new();
87        // All variable declarations are stored in this Map
88        let mut vars = HashMap::new();
89        // All entry points declarations are stored in this Vec
90        let mut entries = Vec::new();
91
92        Self::collect_types_and_vars(&ops, &mut types, &mut constants, &mut vars, &mut entries)?;
93        Self::collect_decorations_and_names(&ops, &mut types, &mut vars);
94
95        // uniforms are all variables that are a pointer with a storage class of Uniform or UniformConstant
96        let uniforms: HashMap<_, _> = vars
97            .iter()
98            .filter_map(|(id, var)| {
99                if let Some(Type::Pointer {
100                    storage_class: StorageClass::Uniform | StorageClass::UniformConstant,
101                    pointed_type_id,
102                }) = types.get(&var.type_id)
103                {
104                    Some((
105                        *id,
106                        UniformVariable {
107                            set: var.set?,
108                            binding: var.binding?,
109                            type_id: *pointed_type_id, // for convenience, we store the pointed-to type instead of the pointer, since every uniform is a pointer
110                            name: var.name.clone(),
111                        },
112                    ))
113                } else {
114                    None
115                }
116            })
117            .collect();
118
119        let push_constants: HashMap<_, _> = vars
120            .iter()
121            .filter_map(|(id, var)| {
122                if let Some(Type::Pointer {
123                    storage_class: StorageClass::PushConstant,
124                    pointed_type_id,
125                }) = types.get(&var.type_id)
126                {
127                    Some((
128                        *id,
129                        PushConstantVariable {
130                            type_id: *pointed_type_id,
131                            name: var.name.clone(),
132                        },
133                    ))
134                } else {
135                    None
136                }
137            })
138            .collect();
139
140        let inputs: HashMap<_, _> = vars
141            .iter()
142            .filter_map(|(id, var)| {
143                if let Some(Type::Pointer {
144                    storage_class: StorageClass::Input,
145                    pointed_type_id,
146                }) = types.get(&var.type_id)
147                {
148                    Some((
149                        *id,
150                        LocationVariable {
151                            location: var.location?,
152                            type_id: *pointed_type_id,
153                            name: var.name.clone(),
154                        },
155                    ))
156                } else {
157                    None
158                }
159            })
160            .collect();
161
162        let outputs: HashMap<_, _> = vars
163            .iter()
164            .filter_map(|(id, var)| {
165                if let Some(Type::Pointer {
166                    storage_class: StorageClass::Output,
167                    pointed_type_id,
168                }) = types.get(&var.type_id)
169                {
170                    Some((
171                        *id,
172                        LocationVariable {
173                            location: var.location?,
174                            type_id: *pointed_type_id,
175                            name: var.name.clone(),
176                        },
177                    ))
178                } else {
179                    None
180                }
181            })
182            .collect();
183
184        let entry_points = entries
185            .iter()
186            .map(|e| {
187                let uniforms = e
188                    .interface
189                    .iter()
190                    .filter_map(|id| uniforms.get(&id.0).cloned())
191                    .collect();
192                let push_constants = e
193                    .interface
194                    .iter()
195                    .filter_map(|id| push_constants.get(&id.0).cloned())
196                    .collect();
197                let inputs = e
198                    .interface
199                    .iter()
200                    .filter_map(|id| inputs.get(&id.0).cloned())
201                    .collect();
202                let outputs = e
203                    .interface
204                    .iter()
205                    .filter_map(|id| outputs.get(&id.0).cloned())
206                    .collect();
207
208                EntryPoint {
209                    name: e.name.clone(),
210                    execution_model: e.execution_model,
211                    uniforms,
212                    push_constants,
213                    inputs,
214                    outputs,
215                }
216            })
217            .collect();
218
219        Ok(Self {
220            types,
221            entry_points,
222        })
223    }
224
225    /// Returns the [`Type`] definition indicated by `type_id`, or `None` if `type_id` is not a type.
226    pub fn get_type(&self, type_id: u32) -> Option<&Type> {
227        self.types.get(&type_id)
228    }
229
230    /// Returns the [`EntryPoint`] definitions contained in the given SPIR-V module
231    pub fn get_entry_points(&self) -> &[EntryPoint] {
232        &self.entry_points
233    }
234
235    fn get_type_size(&self, type_id: u32, stride: Option<u32>) -> Option<u32> {
236        if let Some(ty) = self.types.get(&type_id) {
237            match ty {
238                Type::Int32 | Type::UInt32 | Type::Float32 => Some(4),
239                Type::Vec2 => Some(8),
240                Type::Vec3 => Some(12),
241                Type::Vec4 => Some(16),
242                Type::Mat3 => stride.map(|stride| stride * 2 + 12), // two rows/columns + sizeof(Vec3)
243                Type::Mat4 => stride.map(|stride| stride * 3 + 16), // three rows/columns + sizeof(Vec4)
244                Type::Struct { elements, .. } => {
245                    // Since there is no Size Decoration in SPIRV that tells us the size,
246                    // we calculate it from the offset of the last member and its size.
247                    let last_element = elements.iter().max_by_key(|e| e.offset.unwrap_or(0))?;
248                    let offset = last_element.offset?;
249                    let size = self.get_member_size(last_element)?;
250
251                    Some(offset + size)
252                }
253                _ => None,
254            }
255        } else {
256            None
257        }
258    }
259
260    /// Returns the size of a given [`StructMember`], if known.
261    pub fn get_member_size(&self, member: &StructMember) -> Option<u32> {
262        self.get_type_size(member.type_id, Some(member.stride))
263    }
264
265    /// Returns the size of a given [`UniformVariable`], [`PushConstantVariable`] or [`LocationVariable`], if known.
266    pub fn get_var_size<T: Variable>(&self, var: &T) -> Option<u32> {
267        self.get_type_size(var.get_type_id(), None)
268    }
269
270    /// Parses all the Op*Decoration and Op*Name instructions
271    fn collect_decorations_and_names(
272        ops: &[Op],
273        types: &mut HashMap<u32, Type>,
274        vars: &mut HashMap<u32, RawVariable>,
275    ) {
276        for op in ops {
277            match op {
278                Op::OpName { target, name } => {
279                    if let Some(target) = vars.get_mut(&target.0) {
280                        target.name = Some(name.clone());
281                    } else if let Some(Type::Struct { name: n, .. }) = types.get_mut(&target.0) {
282                        *n = Some(name.clone());
283                    }
284                }
285                Op::OpMemberName {
286                    target,
287                    member_index,
288                    name,
289                } => {
290                    if let Some(Type::Struct { elements, .. }) = types.get_mut(&target.0) {
291                        if elements.len() > *member_index as usize {
292                            elements[*member_index as usize].name = Some(name.clone());
293                        }
294                    }
295                }
296                Op::OpDecorate { target, decoration } => match decoration {
297                    ops::Decoration::Binding { binding } => {
298                        if let Some(target) = vars.get_mut(&target.0) {
299                            target.binding = Some(*binding);
300                        }
301                    }
302                    ops::Decoration::DescriptorSet { set } => {
303                        if let Some(target) = vars.get_mut(&target.0) {
304                            target.set = Some(*set);
305                        }
306                    }
307                    ops::Decoration::Location { loc } => {
308                        if let Some(target) = vars.get_mut(&target.0) {
309                            target.location = Some(*loc);
310                        }
311                    }
312                    _ => {}
313                },
314                Op::OpMemberDecorate {
315                    target,
316                    member_index,
317                    decoration,
318                } => {
319                    if let Some(Type::Struct { elements, .. }) = types.get_mut(&target.0) {
320                        if elements.len() > *member_index as usize {
321                            match decoration {
322                                ops::Decoration::RowMajor {} => {
323                                    elements[*member_index as usize].row_major = true;
324                                }
325                                ops::Decoration::ColMajor {} => {
326                                    elements[*member_index as usize].row_major = false;
327                                }
328                                ops::Decoration::MatrixStride { stride } => {
329                                    elements[*member_index as usize].stride = *stride;
330                                }
331                                ops::Decoration::Offset { offset } => {
332                                    elements[*member_index as usize].offset = Some(*offset);
333                                }
334                                _ => {}
335                            }
336                        }
337                    }
338                }
339                _ => {}
340            }
341        }
342    }
343
344    // Parses all the OpType* and OpVariable instructions
345    fn collect_types_and_vars(
346        ops: &[Op],
347        types: &mut HashMap<u32, Type>,
348        constants: &mut HashMap<u32, u32>,
349        vars: &mut HashMap<u32, RawVariable>,
350        entries: &mut Vec<RawEntryPoint>,
351    ) -> SpirvResult<()> {
352        for op in ops {
353            match op {
354                Op::OpTypeVoid { result } => {
355                    types.insert(result.0, Type::Void);
356                }
357                Op::OpTypeBool { result } => {
358                    types.insert(result.0, Type::Bool);
359                }
360                Op::OpTypeInt {
361                    result,
362                    width,
363                    signed,
364                } => {
365                    if *width != 32 {
366                        types.insert(result.0, Type::Unknown);
367                    } else if *signed == 0 {
368                        types.insert(result.0, Type::UInt32);
369                    } else {
370                        types.insert(result.0, Type::Int32);
371                    }
372                }
373                Op::OpTypeFloat { result, width } => {
374                    if *width == 32 {
375                        types.insert(result.0, Type::Float32);
376                    } else {
377                        types.insert(result.0, Type::Unknown);
378                    }
379                }
380                Op::OpTypeVector {
381                    result,
382                    component_type,
383                    component_count,
384                } => {
385                    if let Some(t) = types.get(&component_type.0) {
386                        if let Type::Float32 = t {
387                            match component_count {
388                                2 => {
389                                    types.insert(result.0, Type::Vec2);
390                                }
391                                3 => {
392                                    types.insert(result.0, Type::Vec3);
393                                }
394                                4 => {
395                                    types.insert(result.0, Type::Vec4);
396                                }
397                                _ => {
398                                    types.insert(result.0, Type::Unknown);
399                                }
400                            }
401                        } else {
402                            types.insert(result.0, Type::Unknown);
403                        }
404                    } else {
405                        return Err(Error::InvalidId);
406                    }
407                }
408                Op::OpTypeMatrix {
409                    result,
410                    column_type,
411                    column_count,
412                } => {
413                    let t = types
414                        .get(&column_type.0)
415                        .map(|column_type| match column_type {
416                            Type::Vec3 if *column_count == 3 => Type::Mat3,
417                            Type::Vec4 if *column_count == 4 => Type::Mat4,
418                            _ => Type::Unknown,
419                        })
420                        .unwrap_or(Type::Unknown);
421                    types.insert(result.0, t);
422                }
423                Op::OpTypeImage {
424                    result,
425                    sampled_type,
426                    dim,
427                    depth,
428                    arrayed: _,
429                    ms: _,
430                    sampled,
431                    format,
432                    access: _,
433                } => {
434                    let t = if let Some(Type::Float32) = types.get(&sampled_type.0) {
435                        if let Dim::D2 {} = dim {
436                            Type::Image2D {
437                                depth: *depth != 0,
438                                sampled: *sampled != 0,
439                                format: *format,
440                            }
441                        } else {
442                            Type::Unknown
443                        }
444                    } else {
445                        Type::Unknown
446                    };
447                    types.insert(result.0, t);
448                }
449                Op::OpTypeSampler { result } => {
450                    types.insert(result.0, Type::Sampler);
451                }
452                Op::OpTypeSampledImage { result, image_type } => {
453                    let t = if let Some(Type::Image2D { .. }) = types.get(&image_type.0) {
454                        Type::SampledImage {
455                            image_type_id: image_type.0,
456                        }
457                    } else {
458                        Type::Unknown
459                    };
460                    types.insert(result.0, t);
461                }
462                Op::OpTypeArray {
463                    result,
464                    element_type,
465                    length,
466                } => {
467                    if let Some(length) = constants.get(&length.0) {
468                        types.insert(
469                            result.0,
470                            Type::Array {
471                                element_type_id: element_type.0,
472                                length: Some(*length),
473                            },
474                        );
475                    } else {
476                        return Err(Error::InvalidId);
477                    }
478                }
479                Op::OpTypeRuntimeArray {
480                    result,
481                    element_type,
482                } => {
483                    types.insert(
484                        result.0,
485                        Type::Array {
486                            element_type_id: element_type.0,
487                            length: None,
488                        },
489                    );
490                }
491                Op::OpTypeStruct {
492                    result,
493                    element_types,
494                } => {
495                    types.insert(
496                        result.0,
497                        Type::Struct {
498                            name: None,
499                            elements: element_types
500                                .iter()
501                                .map(|e| StructMember {
502                                    name: None,
503                                    type_id: e.0,
504                                    offset: None,
505                                    row_major: true,
506                                    stride: 16,
507                                })
508                                .collect(),
509                        },
510                    );
511                }
512                Op::OpTypePointer {
513                    result,
514                    storage_class,
515                    pointed_type,
516                } => {
517                    types.insert(
518                        result.0,
519                        Type::Pointer {
520                            storage_class: match storage_class {
521                                ops::StorageClass::Unknown => StorageClass::Unknown,
522                                ops::StorageClass::UniformConstant {}
523                                | ops::StorageClass::Uniform {} => StorageClass::Uniform,
524                                ops::StorageClass::PushConstant {} => StorageClass::PushConstant,
525                                ops::StorageClass::Input {} => StorageClass::Input,
526                                ops::StorageClass::Output {} => StorageClass::Output,
527                            },
528                            pointed_type_id: pointed_type.0,
529                        },
530                    );
531                }
532                Op::OpConstant {
533                    result_type,
534                    result,
535                    value,
536                } => {
537                    if let Some(Type::UInt32) = types.get(&result_type.0) {
538                        if value.len() == 1 {
539                            constants.insert(result.0, value[0]);
540                        }
541                    }
542                }
543                Op::OpVariable {
544                    result_type,
545                    result,
546                    storage_class: _,
547                    initializer: _,
548                } => {
549                    vars.insert(
550                        result.0,
551                        RawVariable {
552                            set: None,
553                            binding: None,
554                            location: None,
555                            type_id: result_type.0,
556                            name: None,
557                        },
558                    );
559                }
560                Op::OpEntryPoint {
561                    execution_model,
562                    func: _,
563                    name,
564                    interface,
565                } => {
566                    entries.push(RawEntryPoint {
567                        name: name.clone(),
568                        execution_model: match execution_model {
569                            ops::ExecutionModel::Unknown => {
570                                return Err(Error::Other(
571                                    "Unknown execution model in entry point".to_string(),
572                                ))
573                            }
574                            ops::ExecutionModel::Vertex {} => ExecutionModel::Vertex,
575                            ops::ExecutionModel::Fragment {} => ExecutionModel::Fragment,
576                        },
577                        interface: interface.clone(),
578                    });
579                }
580                _ => {}
581            }
582        }
583
584        Ok(())
585    }
586}
587
588/// Represents a type declared in a SPIRV module.
589///
590/// Types are declared in a hierarchy, with e.g. pointers relying on previously declared types as pointed-to types.
591#[derive(Debug)]
592#[non_exhaustive]
593pub enum Type {
594    /// An unsupported type
595    Unknown,
596    /// The Void type
597    Void,
598    /// A boolean
599    Bool,
600    /// A signed 32-Bit integer
601    Int32,
602    /// An unsigned 32-Bit integer
603    UInt32,
604    /// A 32-Bit float
605    Float32,
606    /// A 2 component, 32-Bit vector (GLSL: vec2)
607    Vec2,
608    /// A 3 component, 32-Bit vector (GLSL: vec3)
609    Vec3,
610    /// A 4 component, 32-Bit vector (GLSL: vec4)
611    Vec4,
612    /// A 3x3, 32-Bit Matrix (GLSL: mat3)
613    Mat3,
614    /// A 4x4, 32-Bit Matrix (GLSL: mat4)
615    Mat4,
616    /// A 2D image
617    Image2D {
618        /// true if this image is a depth image
619        depth: bool,
620        /// true if this image can be sampled from
621        sampled: bool,
622        /// SPIRV code of the images format (should always be 0 in Vulkan)
623        format: u32,
624    },
625    /// An opaque sampler object
626    Sampler,
627    /// A combined image and sampler (Vulkan: CombinedImageSampler descriptor)
628    SampledImage {
629        /// type id of the image contained in the SampledImage
630        image_type_id: u32,
631    },
632    /// Either a static array with known length (`length` is [`Some`]) or dynamic array with unknown length (`length` is [`None`])
633    Array {
634        /// type id of the contained type
635        element_type_id: u32,
636        /// length of the array (if known)
637        length: Option<u32>,
638    },
639    /// A struct containing other types
640    Struct {
641        name: Option<String>,
642        /// members of the struct, in the order they appear in the SPIRV module (not necessarily ascending offsets)
643        elements: Vec<StructMember>,
644    },
645    /// A pointer pointing to another type
646    Pointer {
647        /// The type of storage this pointer points to
648        storage_class: StorageClass,
649        /// The type id of the pointed-to type
650        pointed_type_id: u32,
651    },
652}
653
654/// Describes a single member of a [`Type::Struct`] type
655#[derive(Debug)]
656pub struct StructMember {
657    /// The name of the member variable (if known)
658    pub name: Option<String>,
659    /// The type id of the member's [`Type`]
660    pub type_id: u32,
661    /// The offset within the struct of this member (if known)
662    pub offset: Option<u32>,
663    /// For matrix members: whether this matrix is stored in row major order
664    pub row_major: bool,
665    /// For matrix members: The stride between rows/columns of the matrix
666    pub stride: u32,
667}
668
669/// Describes what type of storage a pointer points to
670#[derive(Debug)]
671#[non_exhaustive]
672pub enum StorageClass {
673    Unknown,
674    /// The pointer is a uniform variable (Uniform blocks)
675    Uniform,
676    /// The pointer is a uniform variable (Images, etc.)
677    UniformConstant,
678    /// The pointer is a push constant
679    PushConstant,
680    /// The pointer is an input variable
681    Input,
682    /// The pointer is an output variable
683    Output,
684}
685
686/// The execution model of an [`EntryPoint`].
687#[derive(Debug, Clone, Copy)]
688#[non_exhaustive]
689pub enum ExecutionModel {
690    /// A Vertex Shader
691    Vertex,
692    /// A Fragment Shader
693    Fragment,
694}
695
696#[derive(Debug, Clone)]
697struct RawVariable {
698    set: Option<u32>,
699    binding: Option<u32>,
700    location: Option<u32>,
701    type_id: u32,
702    name: Option<String>,
703}
704
705#[derive(Debug)]
706struct RawEntryPoint {
707    name: String,
708    execution_model: ExecutionModel,
709    interface: Vec<Id>,
710}
711
712/// Describes a uniform variable declared in a SPIRV module
713#[derive(Debug, Clone)]
714pub struct UniformVariable {
715    /// Which DescriptorSet the variable is contained in (if known)
716    pub set: u32,
717    /// Which DescriptorSet binding the variable is contained in (if known)
718    pub binding: u32,
719    /// The type id of the variable's [`Type`]
720    pub type_id: u32,
721    /// The variables name (if known)
722    pub name: Option<String>,
723}
724
725/// Describes a push constant variable declared in a SPIRV module
726#[derive(Debug, Clone)]
727pub struct PushConstantVariable {
728    /// The type id of the variable's [`Type`]
729    pub type_id: u32,
730    /// The variables name (if known)
731    pub name: Option<String>,
732}
733
734/// Describes an input or output variable declared in a SPIRV module
735#[derive(Debug, Clone)]
736pub struct LocationVariable {
737    /// The location of the variable (e.g. GLSL `layout(location=XXX)`)
738    pub location: u32,
739    /// The type id of the variable's [`Type`]
740    pub type_id: u32,
741    /// The variable's name (if known)
742    pub name: Option<String>,
743}
744
745mod private {
746    pub trait Variable {
747        fn get_type_id(&self) -> u32;
748    }
749}
750
751pub trait Variable: private::Variable {}
752impl<T: private::Variable> Variable for T {}
753
754impl private::Variable for UniformVariable {
755    fn get_type_id(&self) -> u32 {
756        self.type_id
757    }
758}
759impl private::Variable for PushConstantVariable {
760    fn get_type_id(&self) -> u32 {
761        self.type_id
762    }
763}
764impl private::Variable for LocationVariable {
765    fn get_type_id(&self) -> u32 {
766        self.type_id
767    }
768}