Skip to main content

spirv_cross2/reflect/
types.rs

1use crate::error;
2use crate::Compiler;
3use spirv::StorageClass;
4use spirv_cross_sys::{BaseType, SpvId, VariableId};
5
6use crate::error::{SpirvCrossError, ToContextError};
7use crate::handle::Handle;
8use crate::handle::{ConstantId, TypeId};
9use crate::sealed::Sealed;
10use crate::string::CompilerStr;
11use spirv_cross_sys as sys;
12
13/// The kind of scalar
14#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15#[repr(u8)]
16pub enum ScalarKind {
17    /// Signed integer.
18    Int = 0,
19    /// Unsigned integer.
20    Uint = 1,
21    /// Floating point number.
22    Float = 2,
23    /// Boolean.
24    Bool = 3,
25}
26
27/// The bit width of a scalar.
28#[derive(Debug, Copy, Clone, Eq, PartialEq)]
29#[repr(u8)]
30pub enum BitWidth {
31    /// 1 bit, padded to 1 byte.
32    Bit = 1,
33    /// 8 bits, 1 byte.
34    Byte = 8,
35    /// 16 bits, 2 bytes.
36    HalfWord = 16,
37    /// 32 bits, 4 bytes.
38    Word = 32,
39    /// 64 bits, 8 bytes.
40    DoubleWord = 64,
41}
42
43impl BitWidth {
44    /// Get the size of the bit width in bytes.
45    ///
46    /// Bit-sized types are padded to a whole byte.
47    pub const fn byte_size(&self) -> usize {
48        match self {
49            BitWidth::Bit => 1,
50            BitWidth::Byte => 1,
51            BitWidth::HalfWord => 2,
52            BitWidth::Word => 4,
53            BitWidth::DoubleWord => 8,
54        }
55    }
56}
57
58/// A scalar type.
59#[derive(Debug, Clone, Eq, PartialEq)]
60pub struct Scalar {
61    /// How the value’s bits are to be interpreted.
62    pub kind: ScalarKind,
63    /// The size of the value in bits.
64    pub size: BitWidth,
65}
66
67impl TryFrom<BaseType> for Scalar {
68    type Error = SpirvCrossError;
69
70    fn try_from(value: BaseType) -> Result<Self, Self::Error> {
71        Ok(match value {
72            BaseType::Boolean => Scalar {
73                kind: ScalarKind::Bool,
74                size: BitWidth::Bit,
75            },
76            BaseType::Int8 => Scalar {
77                kind: ScalarKind::Int,
78                size: BitWidth::Byte,
79            },
80            BaseType::Int16 => Scalar {
81                kind: ScalarKind::Int,
82                size: BitWidth::HalfWord,
83            },
84            BaseType::Int32 => Scalar {
85                kind: ScalarKind::Int,
86                size: BitWidth::Word,
87            },
88            BaseType::Int64 => Scalar {
89                kind: ScalarKind::Int,
90                size: BitWidth::DoubleWord,
91            },
92            BaseType::Uint8 => Scalar {
93                kind: ScalarKind::Uint,
94                size: BitWidth::Byte,
95            },
96            BaseType::Uint16 => Scalar {
97                kind: ScalarKind::Uint,
98                size: BitWidth::HalfWord,
99            },
100            BaseType::Uint32 => Scalar {
101                kind: ScalarKind::Uint,
102                size: BitWidth::Word,
103            },
104            BaseType::Uint64 => Scalar {
105                kind: ScalarKind::Uint,
106                size: BitWidth::DoubleWord,
107            },
108            BaseType::Fp16 => Scalar {
109                kind: ScalarKind::Float,
110                size: BitWidth::HalfWord,
111            },
112            BaseType::Fp32 => Scalar {
113                kind: ScalarKind::Float,
114                size: BitWidth::Word,
115            },
116            BaseType::Fp64 => Scalar {
117                kind: ScalarKind::Float,
118                size: BitWidth::DoubleWord,
119            },
120
121            _ => {
122                return Err(SpirvCrossError::InvalidArgument(String::from(
123                    "Invalid base type used to instantiate a scalar",
124                )))
125            }
126        })
127    }
128}
129
130/// A type definition.
131#[derive(Debug, Clone)]
132pub struct Type<'a> {
133    /// The SPIR-V ID of the type.
134    pub id: Handle<TypeId>,
135    /// The name of the type, if any.
136    pub name: Option<CompilerStr<'a>>,
137    /// Inner details about the type.
138    pub inner: TypeInner<'a>,
139    /// A size hint for the type,
140    /// representing the minimum size the type could be.
141    pub size_hint: TypeSizeHint,
142}
143
144/// Type definition for a struct member.
145#[derive(Debug, Clone, Eq, PartialEq)]
146pub struct StructMember<'a> {
147    /// The type ID of the struct member.
148    pub id: Handle<TypeId>,
149    /// The type ID of the parent struct.
150    pub struct_type: Handle<TypeId>,
151    /// The name of the struct member.
152    pub name: Option<CompilerStr<'a>>,
153    /// The index of the member inside the struct.
154    pub index: usize,
155    /// The offset in bytes from the beginning of the struct.
156    pub offset: u32,
157    /// The declared size of the struct member.
158    pub size: usize,
159    /// The matrix stride of the member, if any.
160    ///
161    /// Matrix strides are only decorated on struct members.
162    pub matrix_stride: Option<u32>,
163    /// The array stride of the member, if any.
164    ///
165    /// Array strides are only decorated on struct members.
166    pub array_stride: Option<u32>,
167}
168
169/// Type definition for a struct.
170#[derive(Debug, Clone, Eq, PartialEq)]
171pub struct StructType<'a> {
172    /// The type ID of the struct.
173    pub id: Handle<TypeId>,
174    /// The size of the struct in bytes.
175    pub size: usize,
176    /// The members of the struct.
177    pub members: Vec<StructMember<'a>>,
178}
179
180/// Valid values that specify the dimensions of an array.
181///
182/// Most of the time, these will be [`ArrayDimension::Literal`].
183/// If an array dimension is specified as a specialization constant,
184/// then the dimension will be [`ArrayDimension::Constant`].
185#[derive(Debug, Clone, Eq, PartialEq)]
186pub enum ArrayDimension {
187    /// A literal array dimension, i.e. `array[4]`.
188    Literal(u32),
189    /// An array dimension specified as a specialization constant.
190    ///
191    /// This would show up in something like the following
192    ///
193    /// ```glsl
194    /// layout (constant_id = 0) const int SSAO_KERNEL_SIZE = 2;
195    /// vec4[SSAO_KERNEL_SIZE] kernel;
196    /// ```
197    Constant(Handle<ConstantId>),
198}
199
200/// Class of image or texture handle.
201#[derive(Debug, Clone, Eq, PartialEq)]
202pub enum ImageClass {
203    /// Combined image samplers.
204    Sampled {
205        /// Whether this is a depth sampler (i.e. `samplerNDShadow`.)
206        depth: bool,
207        /// Whether this is a multisampled image.
208        multisampled: bool,
209        /// Whether or not this image is arrayed
210        arrayed: bool,
211    },
212    /// Separate image.
213    Texture {
214        /// Whether this is a multisampled image.
215        multisampled: bool,
216        /// Whether this image is arrayed.
217        arrayed: bool,
218    },
219    /// Storage images.
220    Storage {
221        /// The image format of the storage image.
222        format: spirv::ImageFormat,
223    },
224}
225
226/// Type definition for an image or texture handle.
227#[derive(Debug, Clone, Eq, PartialEq)]
228pub struct ImageType {
229    /// The id of the type.
230    pub id: Handle<TypeId>,
231    /// The id of the type returned when the image is sampled or read from.
232    pub sampled_type: Handle<TypeId>,
233    /// The dimension of the image.
234    pub dimension: spirv::Dim,
235    /// The class of the image.
236    pub class: ImageClass,
237}
238
239/// Enum with additional type information, depending on the kind of type.
240///
241/// The design of this API is inspired heavily by [`naga::TypeInner`](https://docs.rs/naga/latest/naga/enum.TypeInner.html),
242/// with some changes to fit SPIR-V.
243#[derive(Debug, Clone, Eq, PartialEq)]
244pub enum TypeInner<'a> {
245    /// Unknown type.
246    Unknown,
247    /// The void type.
248    Void,
249    /// A pointer to another type.
250    ///
251    /// Atomics are represented as [`TypeInner::Pointer`] with
252    /// the storage class [`StorageClass::AtomicCounter`].
253    Pointer {
254        /// A handle to the base type this points to.
255        base: Handle<TypeId>,
256        /// The storage class of the pointer.
257        ///
258        /// Atomics are represented as [`TypeInner::Pointer`] with
259        /// the storage class [`StorageClass::AtomicCounter`].
260        storage: StorageClass,
261        /// Whether this pointer is a forward pointer (i.e. `base` is another pointer type).
262        forward: bool,
263    },
264    /// A struct type.
265    Struct(StructType<'a>),
266    /// A scalar type.
267    Scalar(Scalar),
268    /// A vector type.
269    ///
270    /// For example, `vec4` would have a width of 4,
271    /// and a scalar type with [`ScalarKind::Float`] and bit-width 32.
272    Vector {
273        /// The width of the vector.
274        width: u32,
275        /// The scalar type of the vector.
276        scalar: Scalar,
277    },
278    /// A matrix type.
279    ///
280    /// For example, `mat4` would have 4 columns, 4 rows,
281    /// and a scalar type with [`ScalarKind::Float`] and bit-width 32.
282    Matrix {
283        /// The number of columns of the matrix type.
284        columns: u32,
285        /// The number of rows of the matrix type.
286        rows: u32,
287        /// The scalar type of the matrix.
288        scalar: Scalar,
289    },
290    /// An array type.
291    Array {
292        /// The base type that the type is an array of.
293        base: Handle<TypeId>,
294        /// The storage class of the array.
295        storage: StorageClass,
296        /// The dimensions of the array.
297        ///
298        /// Most of the time, these will be [`ArrayDimension::Literal`].
299        /// If an array dimension is specified as a specialization constant,
300        /// then the dimension will be [`ArrayDimension::Constant`].
301        ///
302        /// The order of dimensions follow SPIR-V semantics, i.e. backwards compared to C-style
303        /// declarations.
304        ///
305        /// i.e. `int a[4][6]` will return as `[Linear(6), Linear(4)]`.
306        dimensions: Vec<ArrayDimension>,
307        /// The stride, in bytes, of the array’s elements, if this array type
308        /// appears as a struct member.
309        stride: Option<u32>,
310    },
311    /// A texture or image handle.
312    Image(ImageType),
313    /// An opaque acceleration structure.
314    AccelerationStructure,
315    /// An opaque sampler.
316    Sampler,
317}
318
319/// A size hole requiring the stride of a matrix,
320/// and whether the matrix is column or row major.
321///
322/// The hole is a `(usize, bool)` tuple, which
323/// is the stride of the matrix, and whether the
324/// matrix is row major. By default, the matrix is
325/// considered column major.
326#[derive(Debug, Clone)]
327pub struct MatrixStrideHole {
328    columns: usize,
329    rows: usize,
330    declared: usize,
331}
332
333impl Sealed for MatrixStrideHole {}
334impl ResolveSize for MatrixStrideHole {
335    type Hole = (usize, bool);
336
337    fn declared(&self) -> usize {
338        self.declared
339    }
340
341    fn resolve(&self, hole: Self::Hole) -> usize {
342        let (stride, is_row_major) = hole;
343        if is_row_major {
344            stride * self.rows
345        } else {
346            stride * self.columns
347        }
348    }
349}
350
351/// A size hole requiring the number of elements in a runtime array.
352///
353/// This hole must be resolved with the size of the array.
354#[derive(Debug, Clone)]
355pub struct ArraySizeHole {
356    stride: usize,
357    declared: usize,
358}
359
360/// A size hole representing a missing or unknown array stride.
361///
362/// This hole must be resolved with a function that calculates the stride,
363/// given the size hint of the base type of the array.
364///
365/// The declared size of this hole is the number of elements
366/// times the declared size of the base type.
367#[derive(Debug, Clone)]
368pub struct UnknownStrideHole {
369    hint: Box<TypeSizeHint>,
370    count: usize,
371}
372
373impl Sealed for UnknownStrideHole {}
374impl ResolveSize for UnknownStrideHole {
375    type Hole = Box<dyn FnOnce(&TypeSizeHint) -> usize>;
376
377    fn declared(&self) -> usize {
378        self.count * self.hint.declared()
379    }
380
381    fn resolve(&self, hole: Self::Hole) -> usize {
382        self.count * hole(&self.hint)
383    }
384}
385
386impl ResolveSize for usize {
387    type Hole = core::convert::Infallible;
388
389    fn declared(&self) -> usize {
390        *self
391    }
392
393    fn resolve(&self, _hole: Self::Hole) -> usize {
394        self.declared()
395    }
396}
397
398impl ResolveSize for ArraySizeHole {
399    type Hole = usize;
400
401    fn declared(&self) -> usize {
402        self.declared
403    }
404
405    fn resolve(&self, count: Self::Hole) -> usize {
406        count * self.stride
407    }
408}
409
410impl Sealed for ArraySizeHole {}
411impl Sealed for usize {}
412
413/// A size hint for a type. This could be a statically known size,
414/// or need to resolve a hole before getting a more accurate.
415///
416/// Size hints resolve array sizes involving specialization constants.
417///
418/// If an array stride is found, it will calculate a statically known size with
419/// the array stride.
420#[derive(Debug, Clone)]
421pub enum TypeSizeHint {
422    /// A statically known type size hint.
423    Static(usize),
424    /// The size of a runtime array, which is missing an element count.
425    RuntimeArray(ArraySizeHole),
426    /// A matrix type.
427    Matrix(MatrixStrideHole),
428    /// The array stride is missing or unknowable.
429    UnknownArrayStride(UnknownStrideHole),
430}
431
432impl TypeSizeHint {
433    /// Get the statically known, declared size of a type hint,
434    /// ignoring any holes in the calculation.
435    pub fn declared(&self) -> usize {
436        match &self {
437            TypeSizeHint::Static(sz) => *sz,
438            TypeSizeHint::RuntimeArray(hole) => hole.declared(),
439            TypeSizeHint::UnknownArrayStride(hole) => hole.declared(),
440            TypeSizeHint::Matrix(hole) => hole.declared(),
441        }
442    }
443
444    /// Whether the size hint is statically known.
445    pub fn is_static(&self) -> bool {
446        matches!(self, TypeSizeHint::Static(_))
447    }
448}
449
450/// Trait for size hints that need to be resolved against a hole.
451pub trait ResolveSize: Sealed {
452    /// The type of the hole needed to resolve the size.
453    type Hole;
454
455    /// Get the declared size in bytes, regardless of any holes.
456    fn declared(&self) -> usize;
457
458    /// Resolve the size (in bytes) against the hole.
459    fn resolve(&self, hole: Self::Hole) -> usize;
460}
461
462/// Reflection of SPIR-V types.
463impl<T> Compiler<T> {
464    // None of the names here belong to the context, they belong to the compiler.
465    // so 'ctx is unsound to return.
466
467    fn process_struct(&self, struct_ty_id: TypeId) -> error::Result<StructType<'_>> {
468        unsafe {
469            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), struct_ty_id);
470            let base_ty = sys::spvc_type_get_basetype(ty);
471            assert_eq!(base_ty, BaseType::Struct);
472
473            let mut struct_size = 0;
474            sys::spvc_compiler_get_declared_struct_size(self.ptr.as_ptr(), ty, &mut struct_size)
475                .ok(self)?;
476
477            let member_type_len = sys::spvc_type_get_num_member_types(ty);
478            let mut members = Vec::with_capacity(member_type_len as usize);
479            for i in 0..member_type_len {
480                let id = sys::spvc_type_get_member_type(ty, i);
481                let name = CompilerStr::from_ptr(
482                    sys::spvc_compiler_get_member_name(self.ptr.as_ptr(), struct_ty_id, i),
483                    self.ctx.drop_guard(),
484                );
485
486                let name = if name.as_ref().is_empty() {
487                    None
488                } else {
489                    Some(name)
490                };
491
492                let mut size = 0;
493                sys::spvc_compiler_get_declared_struct_member_size(
494                    self.ptr.as_ptr(),
495                    ty,
496                    i,
497                    &mut size,
498                )
499                .ok(self)?;
500
501                let mut offset = 0;
502                sys::spvc_compiler_type_struct_member_offset(self.ptr.as_ptr(), ty, i, &mut offset)
503                    .ok(self)?;
504
505                let member_ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
506
507                let matrix_stride = if sys::spvc_type_get_columns(member_ty) > 1 {
508                    let mut matrix_stride = 0;
509                    let _ = sys::spvc_compiler_type_struct_member_matrix_stride(
510                        self.ptr.as_ptr(),
511                        ty,
512                        i,
513                        &mut matrix_stride,
514                    )
515                    .ok(self);
516                    Some(matrix_stride)
517                } else {
518                   None
519                };
520
521                let array_stride = if sys::spvc_type_get_num_array_dimensions(member_ty) > 0 {
522                    let mut array_stride = 0;
523                    let _ = sys::spvc_compiler_type_struct_member_array_stride(
524                        self.ptr.as_ptr(),
525                        ty,
526                        i,
527                        &mut array_stride,
528                    )
529                    .ok(self);
530                    Some(array_stride)
531                } else {
532                   None
533                };
534
535                members.push(StructMember {
536                    name,
537                    id: self.create_handle(id),
538                    struct_type: self.create_handle(struct_ty_id),
539                    offset,
540                    size,
541                    index: i as usize,
542                    matrix_stride,
543                    array_stride,
544                })
545            }
546
547            Ok(StructType {
548                id: self.create_handle(struct_ty_id),
549                size: struct_size,
550                members,
551            })
552        }
553    }
554
555    fn process_vector(&self, id: TypeId, vec_width: u32) -> error::Result<TypeInner<'_>> {
556        unsafe {
557            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
558            let base_ty = sys::spvc_type_get_basetype(ty);
559            Ok(TypeInner::Vector {
560                width: vec_width,
561                scalar: base_ty.try_into()?,
562            })
563        }
564    }
565
566    fn process_matrix(&self, id: TypeId, rows: u32, columns: u32) -> error::Result<TypeInner<'_>> {
567        unsafe {
568            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
569            let base_ty = sys::spvc_type_get_basetype(ty);
570            Ok(TypeInner::Matrix {
571                rows,
572                columns,
573                scalar: base_ty.try_into()?,
574            })
575        }
576    }
577
578    fn process_array<'a>(
579        &self,
580        id: TypeId,
581        name: Option<CompilerStr<'a>>,
582    ) -> error::Result<Type<'a>> {
583        unsafe {
584            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
585            let base_type_id = sys::spvc_type_get_base_type_id(ty);
586
587            let array_dim_len = sys::spvc_type_get_num_array_dimensions(ty);
588
589            let mut array_dims = Vec::with_capacity(array_dim_len as usize);
590            for i in 0..array_dim_len {
591                array_dims.push(sys::spvc_type_get_array_dimension(ty, i))
592            }
593
594            let mut array_is_literal = Vec::with_capacity(array_dim_len as usize);
595            for i in 0..array_dim_len {
596                array_is_literal.push(sys::spvc_type_array_dimension_is_literal(ty, i))
597            }
598
599            let storage_class = sys::spvc_type_get_storage_class(ty);
600
601            let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32) else {
602                return Err(SpirvCrossError::InvalidSpirv(format!(
603                    "Unknown StorageClass found: {}",
604                    storage_class.0
605                )));
606            };
607
608            let array_dims = array_dims
609                .into_iter()
610                .enumerate()
611                .map(|(index, dim)| {
612                    if array_is_literal[index] {
613                        ArrayDimension::Literal(dim.0)
614                    } else {
615                        ArrayDimension::Constant(self.create_handle(ConstantId(dim)))
616                    }
617                })
618                .collect();
619
620            let id = self.create_handle(id);
621            let stride = self
622                .decoration(id, spirv::Decoration::ArrayStride)?
623                .and_then(|s| s.as_literal());
624
625            let inner = TypeInner::Array {
626                base: self.create_handle(base_type_id),
627                storage: storage_class,
628                dimensions: array_dims,
629                stride,
630            };
631
632            let size_hint = self.type_size_hint(&inner)?;
633
634            Ok(Type {
635                name,
636                id,
637                inner,
638                size_hint,
639            })
640        }
641    }
642
643    fn process_image(&self, id: TypeId) -> error::Result<ImageType> {
644        unsafe {
645            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
646            let base_ty = sys::spvc_type_get_basetype(ty);
647            let sampled_id = sys::spvc_type_get_image_sampled_type(ty);
648            let dimension = sys::spvc_type_get_image_dimension(ty);
649            let depth = sys::spvc_type_get_image_is_depth(ty);
650            let arrayed = sys::spvc_type_get_image_arrayed(ty);
651            let storage = sys::spvc_type_get_image_is_storage(ty);
652            let multisampled = sys::spvc_type_get_image_multisampled(ty);
653            let format = sys::spvc_type_get_image_storage_format(ty);
654
655            let Some(format) = spirv::ImageFormat::from_u32(format.0 as u32) else {
656                return Err(SpirvCrossError::InvalidSpirv(format!(
657                    "Unknown image format found: {}",
658                    format.0
659                )));
660            };
661
662            let Some(dimension) = spirv::Dim::from_u32(dimension.0 as u32) else {
663                return Err(SpirvCrossError::InvalidSpirv(format!(
664                    "Unknown image dimension found: {}",
665                    dimension.0
666                )));
667            };
668
669            let class = if storage {
670                ImageClass::Storage { format }
671            } else if base_ty == BaseType::SampledImage {
672                ImageClass::Sampled {
673                    depth,
674                    multisampled,
675                    arrayed,
676                }
677            } else {
678                ImageClass::Texture {
679                    multisampled,
680                    arrayed,
681                }
682            };
683
684            Ok(ImageType {
685                id: self.create_handle(id),
686                sampled_type: self.create_handle(sampled_id),
687                dimension,
688                class,
689            })
690        }
691    }
692
693    /// Get the type description for the given type ID.
694    ///
695    /// In most cases, a `base_type_id` should be passed in unless
696    /// pointer specifics are desired.
697    ///
698    /// Atomics are represented as `TypeInner::Pointer { storage: StorageClass::AtomicCounter, ... }`,
699    /// usually with a scalar base type.
700    pub fn type_description(&self, id: Handle<TypeId>) -> error::Result<Type<'_>> {
701        let id = self.yield_id(id)?;
702
703        unsafe {
704            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
705            let base_type_id = sys::spvc_type_get_base_type_id(ty);
706
707            let base_ty = sys::spvc_type_get_basetype(ty);
708            let name = CompilerStr::from_ptr(
709                sys::spvc_compiler_get_name(self.ptr.as_ptr(), id.0),
710                self.ctx.drop_guard(),
711            );
712
713            let name = if name.as_ref().is_empty() {
714                None
715            } else {
716                Some(name)
717            };
718
719            let array_dim_len = sys::spvc_type_get_num_array_dimensions(ty);
720            if array_dim_len != 0 {
721                return self.process_array(id, name);
722            }
723
724            // pointer types
725            if sys::spvc_rs_type_is_pointer(ty) {
726                let storage_class = sys::spvc_type_get_storage_class(ty);
727                let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32)
728                else {
729                    return Err(SpirvCrossError::InvalidSpirv(format!(
730                        "Unknown StorageClass found: {}",
731                        storage_class.0
732                    )));
733                };
734
735                let forward = sys::spvc_rs_type_is_forward_pointer(ty);
736
737                let inner = TypeInner::Pointer {
738                    base: self.create_handle(base_type_id),
739                    storage: storage_class,
740                    forward,
741                };
742
743                let size_hint = self.type_size_hint(&inner)?;
744
745                return Ok(Type {
746                    name,
747                    id: self.create_handle(id),
748                    inner,
749                    size_hint,
750                });
751            }
752
753            let vec_size = sys::spvc_type_get_vector_size(ty);
754            let columns = sys::spvc_type_get_columns(ty);
755
756            // Handle non-scalar case
757            let mut maybe_non_scalar = None;
758            if vec_size > 1 && columns == 1 {
759                maybe_non_scalar = Some(self.process_vector(id, vec_size)?);
760            }
761
762            if vec_size > 1 && columns > 1 {
763                maybe_non_scalar = Some(self.process_matrix(id, vec_size, columns)?);
764            }
765
766            let inner = match base_ty {
767                BaseType::Struct => {
768                    let ty = self.process_struct(id)?;
769                    TypeInner::Struct(ty)
770                }
771                BaseType::Image | BaseType::SampledImage => {
772                    TypeInner::Image(self.process_image(id)?)
773                }
774                BaseType::Sampler => TypeInner::Sampler,
775                BaseType::Boolean
776                | BaseType::Int8
777                | BaseType::Uint8
778                | BaseType::Int16
779                | BaseType::Uint16
780                | BaseType::Int32
781                | BaseType::Uint32
782                | BaseType::Int64
783                | BaseType::Uint64
784                | BaseType::Fp16
785                | BaseType::Fp32
786                | BaseType::Fp64 => {
787                    if let Some(prep) = maybe_non_scalar {
788                        prep
789                    } else {
790                        TypeInner::Scalar(base_ty.try_into()?)
791                    }
792                }
793
794                BaseType::Unknown => TypeInner::Unknown,
795                BaseType::Void => TypeInner::Void,
796
797                BaseType::AtomicCounter => {
798                    // This should be covered by the pointer type above.
799                    let storage_class = sys::spvc_type_get_storage_class(ty);
800                    let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32)
801                    else {
802                        return Err(SpirvCrossError::InvalidSpirv(format!(
803                            "Unknown StorageClass found: {}",
804                            storage_class.0
805                        )));
806                    };
807
808                    let forward = sys::spvc_rs_type_is_forward_pointer(ty);
809
810                    TypeInner::Pointer {
811                        base: self.create_handle(base_type_id),
812                        storage: storage_class,
813                        forward,
814                    }
815                }
816
817                BaseType::AccelerationStructure => TypeInner::AccelerationStructure,
818            };
819
820            let size_hint = self.type_size_hint(&inner)?;
821            let ty = Type {
822                name,
823                id: self.create_handle(id),
824                inner,
825                size_hint,
826            };
827            Ok(ty)
828        }
829    }
830
831    /// Get the minimum size of this type in bytes,
832    /// as declared in the shader.
833    ///
834    /// This will resolve array sizes involving specialization constants.
835    fn type_size_hint(&self, ty: &TypeInner) -> error::Result<TypeSizeHint> {
836        Ok(match ty {
837            TypeInner::Pointer { .. } => TypeSizeHint::Static(BitWidth::Word.byte_size()),
838            TypeInner::Struct(s) => {
839                if let Some(stride) = self.struct_has_runtime_array(s)? {
840                    TypeSizeHint::RuntimeArray(ArraySizeHole {
841                        stride: stride as usize,
842                        declared: s.size,
843                    })
844                } else {
845                    TypeSizeHint::Static(s.size)
846                }
847            }
848            TypeInner::Scalar(s) => TypeSizeHint::Static(s.size.byte_size()),
849            TypeInner::Vector { width, scalar } => {
850                TypeSizeHint::Static((*width as usize) * scalar.size.byte_size())
851            }
852
853            TypeInner::Matrix {
854                columns,
855                rows,
856                scalar,
857            } => {
858                // Matrices have alignment 4, so we get the next power of 4.
859                let rows_aligned = ((rows + 3) & !0x3) as usize;
860
861                let scalar_width = scalar.size.byte_size();
862                let columns = *columns as usize;
863                let declared = rows_aligned * scalar_width * columns;
864                TypeSizeHint::Matrix(MatrixStrideHole {
865                    columns,
866                    rows: *rows as usize,
867                    declared,
868                })
869            }
870            TypeInner::Array {
871                dimensions,
872                stride,
873                base,
874                ..
875            } => {
876                let mut count = 1usize;
877                for dim in dimensions.iter() {
878                    match dim {
879                        ArrayDimension::Literal(a) => count *= *a as usize,
880                        ArrayDimension::Constant(c) => {
881                            let value = self.specialization_constant_value::<u32>(*c)?;
882                            count *= value as usize;
883                        } // prod = prod * 1
884                    }
885                }
886
887                if let Some(stride) = stride {
888                    TypeSizeHint::Static(count * (*stride as usize))
889                } else {
890                    // resolve the size of the basetype
891                    let base_stride = self.type_description(*base)?.size_hint;
892                    if base_stride.is_static() {
893                        TypeSizeHint::Static(count * base_stride.declared())
894                    } else {
895                        TypeSizeHint::UnknownArrayStride(UnknownStrideHole {
896                            hint: Box::new(base_stride),
897                            count,
898                        })
899                    }
900                }
901            }
902            TypeInner::Image(_)
903            | TypeInner::AccelerationStructure
904            | TypeInner::Sampler
905            | TypeInner::Unknown
906            | TypeInner::Void => TypeSizeHint::Static(0),
907        })
908    }
909
910    /// Check if the struct has a runtime array. If so, return the stride
911    /// of the array.
912    fn struct_has_runtime_array(&self, struct_type: &StructType) -> error::Result<Option<u32>> {
913        if let Some(last) = struct_type.members.last() {
914            let Some(array_stride) = last.array_stride else {
915                return Ok(None);
916            };
917
918            let inner = self.type_description(last.id)?.inner;
919            if let TypeInner::Array { dimensions, .. } = inner {
920                if let Some(ArrayDimension::Literal(0)) = dimensions.first() {
921                    return Ok(Some(array_stride));
922                }
923            }
924        }
925
926        Ok(None)
927    }
928
929    /// Get the underlying type of the variable.
930    pub fn variable_type(
931        &self,
932        variable: impl Into<Handle<VariableId>>,
933    ) -> error::Result<Handle<TypeId>> {
934        let variable = variable.into();
935        let variable_id = self.yield_id(variable)?;
936
937        unsafe {
938            let mut type_id = TypeId(SpvId(0));
939            sys::spvc_rs_compiler_variable_get_type(self.ptr.as_ptr(), variable_id, &mut type_id)
940                .ok(self)?;
941
942            Ok(self.create_handle(type_id))
943        }
944    }
945}
946
947#[cfg(test)]
948mod test {
949    use crate::error::SpirvCrossError;
950    use crate::Compiler;
951    use crate::{targets, Module};
952
953    static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
954
955    #[test]
956    pub fn get_stage_outputs() -> Result<(), SpirvCrossError> {
957        let vec = Vec::from(BASIC_SPV);
958        let words = Module::from_words(bytemuck::cast_slice(&vec));
959
960        let compiler: Compiler<targets::None> = Compiler::new(words)?;
961        let resources = compiler.shader_resources()?.all_resources()?;
962
963        // println!("{:#?}", resources);
964
965        let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
966        eprintln!("{ty:?}");
967
968        drop(compiler);
969        eprintln!("{resources:?}");
970        eprintln!("{resources:?}");
971        // match ty.inner {
972        //     TypeInner::Struct(ty) => {
973        //         compiler.get_type(ty.members[0].id)?;
974        //     }
975        //     TypeInner::Vector { .. } => {}
976        //     _ => {}
977        // }
978        Ok(())
979    }
980
981    #[test]
982    pub fn set_member_name_validity_test() -> Result<(), SpirvCrossError> {
983        let vec = Vec::from(BASIC_SPV);
984        let words = Module::from_words(bytemuck::cast_slice(&vec));
985
986        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
987        let resources = compiler.shader_resources()?.all_resources()?;
988
989        // println!("{:#?}", resources);
990
991        let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
992        let id = ty.id;
993
994        let name = compiler.member_name(id, 0)?;
995        assert_eq!(Some("MVP"), name.as_deref());
996
997        compiler.set_member_name(ty.id, 0, "NotMVP")?;
998        // assert_eq!(Some("MVP"), name.as_deref());
999
1000        let name = compiler.member_name(id, 0)?;
1001        assert_eq!(Some("NotMVP"), name.as_deref());
1002        let resources = compiler.shader_resources()?.all_resources()?;
1003
1004        let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
1005
1006        Ok(())
1007    }
1008
1009    #[test]
1010    pub fn get_variable_type_test() -> Result<(), SpirvCrossError> {
1011        let vec = Vec::from(BASIC_SPV);
1012        let words = Module::from_words(bytemuck::cast_slice(&vec));
1013
1014        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
1015        let resources = compiler.shader_resources()?.all_resources()?;
1016
1017        let variable = resources.uniform_buffers[0].id;
1018        assert_eq!(
1019            resources.uniform_buffers[0].type_id.id(),
1020            compiler.variable_type(variable)?.id()
1021        );
1022
1023        eprintln!("{:?}", resources);
1024        Ok(())
1025    }
1026}