spirv_cross2/reflect/
types.rs

1use crate::error;
2use crate::Compiler;
3use spirv::StorageClass;
4use spirv_cross_sys::{BaseType, SpvId, VariableId};
5
6use crate::error::{SpirvCrossError, ToContextError};
7use crate::handle::Handle;
8use crate::handle::{ConstantId, TypeId};
9use crate::sealed::Sealed;
10use crate::string::CompilerStr;
11use spirv_cross_sys as sys;
12
13/// The kind of scalar
14#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15#[repr(u8)]
16pub enum ScalarKind {
17    /// Signed integer.
18    Int = 0,
19    /// Unsigned integer.
20    Uint = 1,
21    /// Floating point number.
22    Float = 2,
23    /// Boolean.
24    Bool = 3,
25}
26
27/// The bit width of a scalar.
28#[derive(Debug, Copy, Clone, Eq, PartialEq)]
29#[repr(u8)]
30pub enum BitWidth {
31    /// 1 bit, padded to 1 byte.
32    Bit = 1,
33    /// 8 bits, 1 byte.
34    Byte = 8,
35    /// 16 bits, 2 bytes.
36    HalfWord = 16,
37    /// 32 bits, 4 bytes.
38    Word = 32,
39    /// 64 bits, 8 bytes.
40    DoubleWord = 64,
41}
42
43impl BitWidth {
44    /// Get the size of the bit width in bytes.
45    ///
46    /// Bit-sized types are padded to a whole byte.
47    pub const fn byte_size(&self) -> usize {
48        match self {
49            BitWidth::Bit => 1,
50            BitWidth::Byte => 1,
51            BitWidth::HalfWord => 2,
52            BitWidth::Word => 4,
53            BitWidth::DoubleWord => 8,
54        }
55    }
56}
57
58/// A scalar type.
59#[derive(Debug, Clone, Eq, PartialEq)]
60pub struct Scalar {
61    /// How the value’s bits are to be interpreted.
62    pub kind: ScalarKind,
63    /// The size of the value in bits.
64    pub size: BitWidth,
65}
66
67impl TryFrom<BaseType> for Scalar {
68    type Error = SpirvCrossError;
69
70    fn try_from(value: BaseType) -> Result<Self, Self::Error> {
71        Ok(match value {
72            BaseType::Boolean => Scalar {
73                kind: ScalarKind::Bool,
74                size: BitWidth::Bit,
75            },
76            BaseType::Int8 => Scalar {
77                kind: ScalarKind::Int,
78                size: BitWidth::Byte,
79            },
80            BaseType::Int16 => Scalar {
81                kind: ScalarKind::Int,
82                size: BitWidth::HalfWord,
83            },
84            BaseType::Int32 => Scalar {
85                kind: ScalarKind::Int,
86                size: BitWidth::Word,
87            },
88            BaseType::Int64 => Scalar {
89                kind: ScalarKind::Int,
90                size: BitWidth::DoubleWord,
91            },
92            BaseType::Uint8 => Scalar {
93                kind: ScalarKind::Uint,
94                size: BitWidth::Byte,
95            },
96            BaseType::Uint16 => Scalar {
97                kind: ScalarKind::Uint,
98                size: BitWidth::HalfWord,
99            },
100            BaseType::Uint32 => Scalar {
101                kind: ScalarKind::Uint,
102                size: BitWidth::Word,
103            },
104            BaseType::Uint64 => Scalar {
105                kind: ScalarKind::Uint,
106                size: BitWidth::DoubleWord,
107            },
108            BaseType::Fp16 => Scalar {
109                kind: ScalarKind::Float,
110                size: BitWidth::HalfWord,
111            },
112            BaseType::Fp32 => Scalar {
113                kind: ScalarKind::Float,
114                size: BitWidth::Word,
115            },
116            BaseType::Fp64 => Scalar {
117                kind: ScalarKind::Float,
118                size: BitWidth::DoubleWord,
119            },
120
121            _ => {
122                return Err(SpirvCrossError::InvalidArgument(String::from(
123                    "Invalid base type used to instantiate a scalar",
124                )))
125            }
126        })
127    }
128}
129
130/// A type definition.
131#[derive(Debug, Clone)]
132pub struct Type<'a> {
133    /// The SPIR-V ID of the type.
134    pub id: Handle<TypeId>,
135    /// The name of the type, if any.
136    pub name: Option<CompilerStr<'a>>,
137    /// Inner details about the type.
138    pub inner: TypeInner<'a>,
139    /// A size hint for the type,
140    /// representing the minimum size the type could be.
141    pub size_hint: TypeSizeHint,
142}
143
144/// Type definition for a struct member.
145#[derive(Debug, Clone, Eq, PartialEq)]
146pub struct StructMember<'a> {
147    /// The type ID of the struct member.
148    pub id: Handle<TypeId>,
149    /// The type ID of the parent struct.
150    pub struct_type: Handle<TypeId>,
151    /// The name of the struct member.
152    pub name: Option<CompilerStr<'a>>,
153    /// The index of the member inside the struct.
154    pub index: usize,
155    /// The offset in bytes from the beginning of the struct.
156    pub offset: u32,
157    /// The declared size of the struct member.
158    pub size: usize,
159    /// The matrix stride of the member, if any.
160    ///
161    /// Matrix strides are only decorated on struct members.
162    pub matrix_stride: Option<u32>,
163    /// The array stride of the member, if any.
164    ///
165    /// Array strides are only decorated on struct members.
166    pub array_stride: Option<u32>,
167}
168
169/// Type definition for a struct.
170#[derive(Debug, Clone, Eq, PartialEq)]
171pub struct StructType<'a> {
172    /// The type ID of the struct.
173    pub id: Handle<TypeId>,
174    /// The size of the struct in bytes.
175    pub size: usize,
176    /// The members of the struct.
177    pub members: Vec<StructMember<'a>>,
178}
179
180/// Valid values that specify the dimensions of an array.
181///
182/// Most of the time, these will be [`ArrayDimension::Literal`].
183/// If an array dimension is specified as a specialization constant,
184/// then the dimension will be [`ArrayDimension::Constant`].
185#[derive(Debug, Clone, Eq, PartialEq)]
186pub enum ArrayDimension {
187    /// A literal array dimension, i.e. `array[4]`.
188    Literal(u32),
189    /// An array dimension specified as a specialization constant.
190    ///
191    /// This would show up in something like the following
192    ///
193    /// ```glsl
194    /// layout (constant_id = 0) const int SSAO_KERNEL_SIZE = 2;
195    /// vec4[SSAO_KERNEL_SIZE] kernel;
196    /// ```
197    Constant(Handle<ConstantId>),
198}
199
200/// Class of image or texture handle.
201#[derive(Debug, Clone, Eq, PartialEq)]
202pub enum ImageClass {
203    /// Combined image samplers.
204    Sampled {
205        /// Whether this is a depth sampler (i.e. `samplerNDShadow`.)
206        depth: bool,
207        /// Whether this is a multisampled image.
208        multisampled: bool,
209        /// Whether or not this image is arrayed
210        arrayed: bool,
211    },
212    /// Separate image.
213    Texture {
214        /// Whether this is a multisampled image.
215        multisampled: bool,
216        /// Whether this image is arrayed.
217        arrayed: bool,
218    },
219    /// Storage images.
220    Storage {
221        /// The image format of the storage image.
222        format: spirv::ImageFormat,
223    },
224}
225
226/// Type definition for an image or texture handle.
227#[derive(Debug, Clone, Eq, PartialEq)]
228pub struct ImageType {
229    /// The id of the type.
230    pub id: Handle<TypeId>,
231    /// The id of the type returned when the image is sampled or read from.
232    pub sampled_type: Handle<TypeId>,
233    /// The dimension of the image.
234    pub dimension: spirv::Dim,
235    /// The class of the image.
236    pub class: ImageClass,
237}
238
239/// Enum with additional type information, depending on the kind of type.
240///
241/// The design of this API is inspired heavily by [`naga::TypeInner`](https://docs.rs/naga/latest/naga/enum.TypeInner.html),
242/// with some changes to fit SPIR-V.
243#[derive(Debug, Clone, Eq, PartialEq)]
244pub enum TypeInner<'a> {
245    /// Unknown type.
246    Unknown,
247    /// The void type.
248    Void,
249    /// A pointer to another type.
250    ///
251    /// Atomics are represented as [`TypeInner::Pointer`] with
252    /// the storage class [`StorageClass::AtomicCounter`].
253    Pointer {
254        /// A handle to the base type this points to.
255        base: Handle<TypeId>,
256        /// The storage class of the pointer.
257        ///
258        /// Atomics are represented as [`TypeInner::Pointer`] with
259        /// the storage class [`StorageClass::AtomicCounter`].
260        storage: StorageClass,
261        /// Whether this pointer is a forward pointer (i.e. `base` is another pointer type).
262        forward: bool,
263    },
264    /// A struct type.
265    Struct(StructType<'a>),
266    /// A scalar type.
267    Scalar(Scalar),
268    /// A vector type.
269    ///
270    /// For example, `vec4` would have a width of 4,
271    /// and a scalar type with [`ScalarKind::Float`] and bit-width 32.
272    Vector {
273        /// The width of the vector.
274        width: u32,
275        /// The scalar type of the vector.
276        scalar: Scalar,
277    },
278    /// A matrix type.
279    ///
280    /// For example, `mat4` would have 4 columns, 4 rows,
281    /// and a scalar type with [`ScalarKind::Float`] and bit-width 32.
282    Matrix {
283        /// The number of columns of the matrix type.
284        columns: u32,
285        /// The number of rows of the matrix type.
286        rows: u32,
287        /// The scalar type of the matrix.
288        scalar: Scalar,
289    },
290    /// An array type.
291    Array {
292        /// The base type that the type is an array of.
293        base: Handle<TypeId>,
294        /// The storage class of the array.
295        storage: StorageClass,
296        /// The dimensions of the array.
297        ///
298        /// Most of the time, these will be [`ArrayDimension::Literal`].
299        /// If an array dimension is specified as a specialization constant,
300        /// then the dimension will be [`ArrayDimension::Constant`].
301        ///
302        /// The order of dimensions follow SPIR-V semantics, i.e. backwards compared to C-style
303        /// declarations.
304        ///
305        /// i.e. `int a[4][6]` will return as `[Linear(6), Linear(4)]`.
306        dimensions: Vec<ArrayDimension>,
307        /// The stride, in bytes, of the array’s elements, if this array type
308        /// appears as a struct member.
309        stride: Option<u32>,
310    },
311    /// A texture or image handle.
312    Image(ImageType),
313    /// An opaque acceleration structure.
314    AccelerationStructure,
315    /// An opaque sampler.
316    Sampler,
317}
318
319/// A size hole requiring the stride of a matrix,
320/// and whether the matrix is column or row major.
321///
322/// The hole is a `(usize, bool)` tuple, which
323/// is the stride of the matrix, and whether the
324/// matrix is row major. By default, the matrix is
325/// considered column major.
326#[derive(Debug, Clone)]
327pub struct MatrixStrideHole {
328    columns: usize,
329    rows: usize,
330    declared: usize,
331}
332
333impl Sealed for MatrixStrideHole {}
334impl ResolveSize for MatrixStrideHole {
335    type Hole = (usize, bool);
336
337    fn declared(&self) -> usize {
338        self.declared
339    }
340
341    fn resolve(&self, hole: Self::Hole) -> usize {
342        let (stride, is_row_major) = hole;
343        if is_row_major {
344            stride * self.rows
345        } else {
346            stride * self.columns
347        }
348    }
349}
350
351/// A size hole requiring the number of elements in a runtime array.
352///
353/// This hole must be resolved with the size of the array.
354#[derive(Debug, Clone)]
355pub struct ArraySizeHole {
356    stride: usize,
357    declared: usize,
358}
359
360/// A size hole representing a missing or unknown array stride.
361///
362/// This hole must be resolved with a function that calculates the stride,
363/// given the size hint of the base type of the array.
364///
365/// The declared size of this hole is the number of elements
366/// times the declared size of the base type.
367#[derive(Debug, Clone)]
368pub struct UnknownStrideHole {
369    hint: Box<TypeSizeHint>,
370    count: usize,
371}
372
373impl Sealed for UnknownStrideHole {}
374impl ResolveSize for UnknownStrideHole {
375    type Hole = Box<dyn FnOnce(&TypeSizeHint) -> usize>;
376
377    fn declared(&self) -> usize {
378        self.count * self.hint.declared()
379    }
380
381    fn resolve(&self, hole: Self::Hole) -> usize {
382        self.count * hole(&self.hint)
383    }
384}
385
386impl ResolveSize for usize {
387    type Hole = core::convert::Infallible;
388
389    fn declared(&self) -> usize {
390        *self
391    }
392
393    fn resolve(&self, _hole: Self::Hole) -> usize {
394        self.declared()
395    }
396}
397
398impl ResolveSize for ArraySizeHole {
399    type Hole = usize;
400
401    fn declared(&self) -> usize {
402        self.declared
403    }
404
405    fn resolve(&self, count: Self::Hole) -> usize {
406        count * self.stride
407    }
408}
409
410impl Sealed for ArraySizeHole {}
411impl Sealed for usize {}
412
413/// A size hint for a type. This could be a statically known size,
414/// or need to resolve a hole before getting a more accurate.
415///
416/// Size hints resolve array sizes involving specialization constants.
417///
418/// If an array stride is found, it will calculate a statically known size with
419/// the array stride.
420#[derive(Debug, Clone)]
421pub enum TypeSizeHint {
422    /// A statically known type size hint.
423    Static(usize),
424    /// The size of a runtime array, which is missing an element count.
425    RuntimeArray(ArraySizeHole),
426    /// A matrix type.
427    Matrix(MatrixStrideHole),
428    /// The array stride is missing or unknowable.
429    UnknownArrayStride(UnknownStrideHole),
430}
431
432impl TypeSizeHint {
433    /// Get the statically known, declared size of a type hint,
434    /// ignoring any holes in the calculation.
435    pub fn declared(&self) -> usize {
436        match &self {
437            TypeSizeHint::Static(sz) => *sz,
438            TypeSizeHint::RuntimeArray(hole) => hole.declared(),
439            TypeSizeHint::UnknownArrayStride(hole) => hole.declared(),
440            TypeSizeHint::Matrix(hole) => hole.declared(),
441        }
442    }
443
444    /// Whether the size hint is statically known.
445    pub fn is_static(&self) -> bool {
446        matches!(self, TypeSizeHint::Static(_))
447    }
448}
449
450/// Trait for size hints that need to be resolved against a hole.
451pub trait ResolveSize: Sealed {
452    /// The type of the hole needed to resolve the size.
453    type Hole;
454
455    /// Get the declared size in bytes, regardless of any holes.
456    fn declared(&self) -> usize;
457
458    /// Resolve the size (in bytes) against the hole.
459    fn resolve(&self, hole: Self::Hole) -> usize;
460}
461
462/// Reflection of SPIR-V types.
463impl<T> Compiler<T> {
464    // None of the names here belong to the context, they belong to the compiler.
465    // so 'ctx is unsound to return.
466
467    fn process_struct(&self, struct_ty_id: TypeId) -> error::Result<StructType> {
468        unsafe {
469            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), struct_ty_id);
470            let base_ty = sys::spvc_type_get_basetype(ty);
471            assert_eq!(base_ty, BaseType::Struct);
472
473            let mut struct_size = 0;
474            sys::spvc_compiler_get_declared_struct_size(self.ptr.as_ptr(), ty, &mut struct_size)
475                .ok(self)?;
476
477            let member_type_len = sys::spvc_type_get_num_member_types(ty);
478            let mut members = Vec::with_capacity(member_type_len as usize);
479            for i in 0..member_type_len {
480                let id = sys::spvc_type_get_member_type(ty, i);
481                let name = CompilerStr::from_ptr(
482                    sys::spvc_compiler_get_member_name(self.ptr.as_ptr(), struct_ty_id, i),
483                    self.ctx.drop_guard(),
484                );
485
486                let name = if name.as_ref().is_empty() {
487                    None
488                } else {
489                    Some(name)
490                };
491
492                let mut size = 0;
493                sys::spvc_compiler_get_declared_struct_member_size(
494                    self.ptr.as_ptr(),
495                    ty,
496                    i,
497                    &mut size,
498                )
499                .ok(self)?;
500
501                let mut offset = 0;
502                sys::spvc_compiler_type_struct_member_offset(self.ptr.as_ptr(), ty, i, &mut offset)
503                    .ok(self)?;
504
505                let mut matrix_stride = 0;
506                let matrix_stride = sys::spvc_compiler_type_struct_member_matrix_stride(
507                    self.ptr.as_ptr(),
508                    ty,
509                    i,
510                    &mut matrix_stride,
511                )
512                .ok(self)
513                .ok()
514                .map(|_| matrix_stride);
515
516                let mut array_stride = 0;
517                let array_stride = sys::spvc_compiler_type_struct_member_array_stride(
518                    self.ptr.as_ptr(),
519                    ty,
520                    i,
521                    &mut array_stride,
522                )
523                .ok(self)
524                .ok()
525                .map(|_| array_stride);
526
527                members.push(StructMember {
528                    name,
529                    id: self.create_handle(id),
530                    struct_type: self.create_handle(struct_ty_id),
531                    offset,
532                    size,
533                    index: i as usize,
534                    matrix_stride,
535                    array_stride,
536                })
537            }
538
539            Ok(StructType {
540                id: self.create_handle(struct_ty_id),
541                size: struct_size,
542                members,
543            })
544        }
545    }
546
547    fn process_vector(&self, id: TypeId, vec_width: u32) -> error::Result<TypeInner> {
548        unsafe {
549            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
550            let base_ty = sys::spvc_type_get_basetype(ty);
551            Ok(TypeInner::Vector {
552                width: vec_width,
553                scalar: base_ty.try_into()?,
554            })
555        }
556    }
557
558    fn process_matrix(&self, id: TypeId, rows: u32, columns: u32) -> error::Result<TypeInner> {
559        unsafe {
560            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
561            let base_ty = sys::spvc_type_get_basetype(ty);
562            Ok(TypeInner::Matrix {
563                rows,
564                columns,
565                scalar: base_ty.try_into()?,
566            })
567        }
568    }
569
570    fn process_array<'a>(
571        &self,
572        id: TypeId,
573        name: Option<CompilerStr<'a>>,
574    ) -> error::Result<Type<'a>> {
575        unsafe {
576            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
577            let base_type_id = sys::spvc_type_get_base_type_id(ty);
578
579            let array_dim_len = sys::spvc_type_get_num_array_dimensions(ty);
580
581            let mut array_dims = Vec::with_capacity(array_dim_len as usize);
582            for i in 0..array_dim_len {
583                array_dims.push(sys::spvc_type_get_array_dimension(ty, i))
584            }
585
586            let mut array_is_literal = Vec::with_capacity(array_dim_len as usize);
587            for i in 0..array_dim_len {
588                array_is_literal.push(sys::spvc_type_array_dimension_is_literal(ty, i))
589            }
590
591            let storage_class = sys::spvc_type_get_storage_class(ty);
592
593            let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32) else {
594                return Err(SpirvCrossError::InvalidSpirv(format!(
595                    "Unknown StorageClass found: {}",
596                    storage_class.0
597                )));
598            };
599
600            let array_dims = array_dims
601                .into_iter()
602                .enumerate()
603                .map(|(index, dim)| {
604                    if array_is_literal[index] {
605                        ArrayDimension::Literal(dim.0)
606                    } else {
607                        ArrayDimension::Constant(self.create_handle(ConstantId(dim)))
608                    }
609                })
610                .collect();
611
612            let id = self.create_handle(id);
613            let stride = self
614                .decoration(id, spirv::Decoration::ArrayStride)?
615                .and_then(|s| s.as_literal());
616
617            let inner = TypeInner::Array {
618                base: self.create_handle(base_type_id),
619                storage: storage_class,
620                dimensions: array_dims,
621                stride,
622            };
623
624            let size_hint = self.type_size_hint(&inner)?;
625
626            Ok(Type {
627                name,
628                id,
629                inner,
630                size_hint,
631            })
632        }
633    }
634
635    fn process_image(&self, id: TypeId) -> error::Result<ImageType> {
636        unsafe {
637            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
638            let base_ty = sys::spvc_type_get_basetype(ty);
639            let sampled_id = sys::spvc_type_get_image_sampled_type(ty);
640            let dimension = sys::spvc_type_get_image_dimension(ty);
641            let depth = sys::spvc_type_get_image_is_depth(ty);
642            let arrayed = sys::spvc_type_get_image_arrayed(ty);
643            let storage = sys::spvc_type_get_image_is_storage(ty);
644            let multisampled = sys::spvc_type_get_image_multisampled(ty);
645            let format = sys::spvc_type_get_image_storage_format(ty);
646
647            let Some(format) = spirv::ImageFormat::from_u32(format.0 as u32) else {
648                return Err(SpirvCrossError::InvalidSpirv(format!(
649                    "Unknown image format found: {}",
650                    format.0
651                )));
652            };
653
654            let Some(dimension) = spirv::Dim::from_u32(dimension.0 as u32) else {
655                return Err(SpirvCrossError::InvalidSpirv(format!(
656                    "Unknown image dimension found: {}",
657                    dimension.0
658                )));
659            };
660
661            let class = if storage {
662                ImageClass::Storage { format }
663            } else if base_ty == BaseType::SampledImage {
664                ImageClass::Sampled {
665                    depth,
666                    multisampled,
667                    arrayed,
668                }
669            } else {
670                ImageClass::Texture {
671                    multisampled,
672                    arrayed,
673                }
674            };
675
676            Ok(ImageType {
677                id: self.create_handle(id),
678                sampled_type: self.create_handle(sampled_id),
679                dimension,
680                class,
681            })
682        }
683    }
684
685    /// Get the type description for the given type ID.
686    ///
687    /// In most cases, a `base_type_id` should be passed in unless
688    /// pointer specifics are desired.
689    ///
690    /// Atomics are represented as `TypeInner::Pointer { storage: StorageClass::AtomicCounter, ... }`,
691    /// usually with a scalar base type.
692    pub fn type_description(&self, id: Handle<TypeId>) -> error::Result<Type> {
693        let id = self.yield_id(id)?;
694
695        unsafe {
696            let ty = sys::spvc_compiler_get_type_handle(self.ptr.as_ptr(), id);
697            let base_type_id = sys::spvc_type_get_base_type_id(ty);
698
699            let base_ty = sys::spvc_type_get_basetype(ty);
700            let name = CompilerStr::from_ptr(
701                sys::spvc_compiler_get_name(self.ptr.as_ptr(), id.0),
702                self.ctx.drop_guard(),
703            );
704
705            let name = if name.as_ref().is_empty() {
706                None
707            } else {
708                Some(name)
709            };
710
711            let array_dim_len = sys::spvc_type_get_num_array_dimensions(ty);
712            if array_dim_len != 0 {
713                return self.process_array(id, name);
714            }
715
716            // pointer types
717            if sys::spvc_rs_type_is_pointer(ty) {
718                let storage_class = sys::spvc_type_get_storage_class(ty);
719                let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32)
720                else {
721                    return Err(SpirvCrossError::InvalidSpirv(format!(
722                        "Unknown StorageClass found: {}",
723                        storage_class.0
724                    )));
725                };
726
727                let forward = sys::spvc_rs_type_is_forward_pointer(ty);
728
729                let inner = TypeInner::Pointer {
730                    base: self.create_handle(base_type_id),
731                    storage: storage_class,
732                    forward,
733                };
734
735                let size_hint = self.type_size_hint(&inner)?;
736
737                return Ok(Type {
738                    name,
739                    id: self.create_handle(id),
740                    inner,
741                    size_hint,
742                });
743            }
744
745            let vec_size = sys::spvc_type_get_vector_size(ty);
746            let columns = sys::spvc_type_get_columns(ty);
747
748            // Handle non-scalar case
749            let mut maybe_non_scalar = None;
750            if vec_size > 1 && columns == 1 {
751                maybe_non_scalar = Some(self.process_vector(id, vec_size)?);
752            }
753
754            if vec_size > 1 && columns > 1 {
755                maybe_non_scalar = Some(self.process_matrix(id, vec_size, columns)?);
756            }
757
758            let inner = match base_ty {
759                BaseType::Struct => {
760                    let ty = self.process_struct(id)?;
761                    TypeInner::Struct(ty)
762                }
763                BaseType::Image | BaseType::SampledImage => {
764                    TypeInner::Image(self.process_image(id)?)
765                }
766                BaseType::Sampler => TypeInner::Sampler,
767                BaseType::Boolean
768                | BaseType::Int8
769                | BaseType::Uint8
770                | BaseType::Int16
771                | BaseType::Uint16
772                | BaseType::Int32
773                | BaseType::Uint32
774                | BaseType::Int64
775                | BaseType::Uint64
776                | BaseType::Fp16
777                | BaseType::Fp32
778                | BaseType::Fp64 => {
779                    if let Some(prep) = maybe_non_scalar {
780                        prep
781                    } else {
782                        TypeInner::Scalar(base_ty.try_into()?)
783                    }
784                }
785
786                BaseType::Unknown => TypeInner::Unknown,
787                BaseType::Void => TypeInner::Void,
788
789                BaseType::AtomicCounter => {
790                    // This should be covered by the pointer type above.
791                    let storage_class = sys::spvc_type_get_storage_class(ty);
792                    let Some(storage_class) = spirv::StorageClass::from_u32(storage_class.0 as u32)
793                    else {
794                        return Err(SpirvCrossError::InvalidSpirv(format!(
795                            "Unknown StorageClass found: {}",
796                            storage_class.0
797                        )));
798                    };
799
800                    let forward = sys::spvc_rs_type_is_forward_pointer(ty);
801
802                    TypeInner::Pointer {
803                        base: self.create_handle(base_type_id),
804                        storage: storage_class,
805                        forward,
806                    }
807                }
808
809                BaseType::AccelerationStructure => TypeInner::AccelerationStructure,
810            };
811
812            let size_hint = self.type_size_hint(&inner)?;
813            let ty = Type {
814                name,
815                id: self.create_handle(id),
816                inner,
817                size_hint,
818            };
819            Ok(ty)
820        }
821    }
822
823    /// Get the minimum size of this type in bytes,
824    /// as declared in the shader.
825    ///
826    /// This will resolve array sizes involving specialization constants.
827    fn type_size_hint(&self, ty: &TypeInner) -> error::Result<TypeSizeHint> {
828        Ok(match ty {
829            TypeInner::Pointer { .. } => TypeSizeHint::Static(BitWidth::Word.byte_size()),
830            TypeInner::Struct(s) => {
831                if let Some(stride) = self.struct_has_runtime_array(s)? {
832                    TypeSizeHint::RuntimeArray(ArraySizeHole {
833                        stride: stride as usize,
834                        declared: s.size,
835                    })
836                } else {
837                    TypeSizeHint::Static(s.size)
838                }
839            }
840            TypeInner::Scalar(s) => TypeSizeHint::Static(s.size.byte_size()),
841            TypeInner::Vector { width, scalar } => {
842                TypeSizeHint::Static((*width as usize) * scalar.size.byte_size())
843            }
844
845            TypeInner::Matrix {
846                columns,
847                rows,
848                scalar,
849            } => {
850                // Matrices have alignment 4, so we get the next power of 4.
851                let rows_aligned = (rows + 3 & !0x3) as usize;
852
853                let scalar_width = scalar.size.byte_size();
854                let columns = *columns as usize;
855                let declared = rows_aligned * scalar_width * columns;
856                TypeSizeHint::Matrix(MatrixStrideHole {
857                    columns,
858                    rows: *rows as usize,
859                    declared,
860                })
861            }
862            TypeInner::Array {
863                dimensions,
864                stride,
865                base,
866                ..
867            } => {
868                let mut count = 1usize;
869                for dim in dimensions.iter() {
870                    match dim {
871                        ArrayDimension::Literal(a) => count = count * (*a as usize),
872                        ArrayDimension::Constant(c) => {
873                            let value = self.specialization_constant_value::<u32>(*c)?;
874                            count = count * value as usize;
875                        } // prod = prod * 1
876                    }
877                }
878
879                if let Some(stride) = stride {
880                    TypeSizeHint::Static(count * (*stride as usize))
881                } else {
882                    // resolve the size of the basetype
883                    let base_stride = self.type_description(*base)?.size_hint;
884                    if base_stride.is_static() {
885                        TypeSizeHint::Static(count * base_stride.declared())
886                    } else {
887                        TypeSizeHint::UnknownArrayStride(UnknownStrideHole {
888                            hint: Box::new(base_stride),
889                            count,
890                        })
891                    }
892                }
893            }
894            TypeInner::Image(_)
895            | TypeInner::AccelerationStructure
896            | TypeInner::Sampler
897            | TypeInner::Unknown
898            | TypeInner::Void => TypeSizeHint::Static(0),
899        })
900    }
901
902    /// Check if the struct has a runtime array. If so, return the stride
903    /// of the array.
904    fn struct_has_runtime_array(&self, struct_type: &StructType) -> error::Result<Option<u32>> {
905        if let Some(last) = struct_type.members.last() {
906            let Some(array_stride) = last.array_stride else {
907                return Ok(None);
908            };
909
910            let inner = self.type_description(last.id)?.inner;
911            if let TypeInner::Array { dimensions, .. } = inner {
912                if let Some(ArrayDimension::Literal(0)) = dimensions.first() {
913                    return Ok(Some(array_stride));
914                }
915            }
916        }
917
918        Ok(None)
919    }
920
921    /// Get the underlying type of the variable.
922    pub fn variable_type(
923        &self,
924        variable: impl Into<Handle<VariableId>>,
925    ) -> error::Result<Handle<TypeId>> {
926        let variable = variable.into();
927        let variable_id = self.yield_id(variable)?;
928
929        unsafe {
930            let mut type_id = TypeId(SpvId(0));
931            sys::spvc_rs_compiler_variable_get_type(self.ptr.as_ptr(), variable_id, &mut type_id)
932                .ok(self)?;
933
934            Ok(self.create_handle(type_id))
935        }
936    }
937}
938
939#[cfg(test)]
940mod test {
941    use crate::error::SpirvCrossError;
942    use crate::Compiler;
943    use crate::{targets, Module};
944
945    static BASIC_SPV: &[u8] = include_bytes!("../../basic.spv");
946
947    #[test]
948    pub fn get_stage_outputs() -> Result<(), SpirvCrossError> {
949        let vec = Vec::from(BASIC_SPV);
950        let words = Module::from_words(bytemuck::cast_slice(&vec));
951
952        let compiler: Compiler<targets::None> = Compiler::new(words)?;
953        let resources = compiler.shader_resources()?.all_resources()?;
954
955        // println!("{:#?}", resources);
956
957        let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
958        eprintln!("{ty:?}");
959
960        drop(compiler);
961        eprintln!("{resources:?}");
962        eprintln!("{resources:?}");
963        // match ty.inner {
964        //     TypeInner::Struct(ty) => {
965        //         compiler.get_type(ty.members[0].id)?;
966        //     }
967        //     TypeInner::Vector { .. } => {}
968        //     _ => {}
969        // }
970        Ok(())
971    }
972
973    #[test]
974    pub fn set_member_name_validity_test() -> Result<(), SpirvCrossError> {
975        let vec = Vec::from(BASIC_SPV);
976        let words = Module::from_words(bytemuck::cast_slice(&vec));
977
978        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
979        let resources = compiler.shader_resources()?.all_resources()?;
980
981        // println!("{:#?}", resources);
982
983        let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
984        let id = ty.id;
985
986        let name = compiler.member_name(id, 0)?;
987        assert_eq!(Some("MVP"), name.as_deref());
988
989        compiler.set_member_name(ty.id, 0, "NotMVP")?;
990        // assert_eq!(Some("MVP"), name.as_deref());
991
992        let name = compiler.member_name(id, 0)?;
993        assert_eq!(Some("NotMVP"), name.as_deref());
994        let resources = compiler.shader_resources()?.all_resources()?;
995
996        let ty = compiler.type_description(resources.uniform_buffers[0].base_type_id)?;
997
998        Ok(())
999    }
1000
1001    #[test]
1002    pub fn get_variable_type_test() -> Result<(), SpirvCrossError> {
1003        let vec = Vec::from(BASIC_SPV);
1004        let words = Module::from_words(bytemuck::cast_slice(&vec));
1005
1006        let mut compiler: Compiler<targets::None> = Compiler::new(words)?;
1007        let resources = compiler.shader_resources()?.all_resources()?;
1008
1009        let variable = resources.uniform_buffers[0].id;
1010        assert_eq!(
1011            resources.uniform_buffers[0].type_id.id(),
1012            compiler.variable_type(variable)?.id()
1013        );
1014
1015        eprintln!("{:?}", resources);
1016        Ok(())
1017    }
1018}