Skip to main content

spirv_cross2/reflect/
constants.rs

1use crate::sealed::Sealed;
2use spirv_cross_sys::{spvc_constant, spvc_specialization_constant, TypeId};
3use std::mem::MaybeUninit;
4use std::ops::{Index, IndexMut};
5use std::slice;
6
7use crate::error::{SpirvCrossError, ToContextError};
8use crate::handle::{ConstantId, Handle};
9use crate::iter::impl_iterator;
10use crate::{error, Compiler, PhantomCompiler};
11use spirv_cross_sys as sys;
12
13mod gfx_maths;
14mod glam;
15mod half;
16
17/// A marker trait for types that can be represented as a scalar SPIR-V constant.
18pub trait ConstantScalar: Default + Sealed + Copy {
19    #[doc(hidden)]
20    unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self;
21
22    #[doc(hidden)]
23    unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self);
24}
25
26macro_rules! impl_spvc_constant {
27    ($get:ident  $set:ident $prim:ty) => {
28        impl Sealed for $prim {}
29        impl ConstantScalar for $prim {
30            unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
31                unsafe { ::spirv_cross_sys::$get(constant, column, row) as Self }
32            }
33
34            unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
35                unsafe { ::spirv_cross_sys::$set(constant, column, row, value) }
36            }
37        }
38    };
39}
40
41#[allow(unused_macros)]
42macro_rules! impl_vec_constant {
43    ($vec_ty:ty [$base_ty:ty; $len:literal] for [$($component:ident),*]) => {
44        impl $crate::sealed::Sealed for $vec_ty {}
45        impl $crate::reflect::constants::ConstantValue for $vec_ty {
46             const COLUMNS: usize = 1;
47             const VECSIZE: usize = $len;
48             type BaseArrayType = [$base_ty; $len];
49             type ArrayType = [[$base_ty; $len]; 1];
50             type BaseType = $base_ty;
51
52             fn from_array(value: Self::ArrayType) -> Self {
53                 value[0].into()
54             }
55
56             fn to_array(value: Self) -> Self::ArrayType {
57                [[$(value.$component),*]]
58             }
59        }
60    };
61}
62
63impl_spvc_constant!(spvc_constant_get_scalar_i8 spvc_constant_set_scalar_i8 i8);
64impl_spvc_constant!(spvc_constant_get_scalar_i16 spvc_constant_set_scalar_i16 i16);
65impl_spvc_constant!(spvc_constant_get_scalar_i32 spvc_constant_set_scalar_i32 i32);
66impl_spvc_constant!(spvc_constant_get_scalar_i64 spvc_constant_set_scalar_i64 i64);
67
68impl_spvc_constant!(spvc_constant_get_scalar_u8 spvc_constant_set_scalar_u8 u8);
69impl_spvc_constant!(spvc_constant_get_scalar_u16 spvc_constant_set_scalar_u16 u16);
70impl_spvc_constant!(spvc_constant_get_scalar_u32 spvc_constant_set_scalar_u32 u32);
71impl_spvc_constant!(spvc_constant_get_scalar_u64 spvc_constant_set_scalar_u64 u64);
72
73impl_spvc_constant!(spvc_constant_get_scalar_fp32 spvc_constant_set_scalar_fp32 f32);
74impl_spvc_constant!(spvc_constant_get_scalar_fp64 spvc_constant_set_scalar_fp64 f64);
75
76// implement manually for bool
77impl Sealed for bool {}
78impl ConstantScalar for bool {
79    unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
80        unsafe { sys::spvc_constant_get_scalar_u8(constant, column, row) != 0 }
81    }
82
83    unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
84        sys::spvc_constant_set_scalar_u8(constant, column, row, if value { 1 } else { 0 });
85    }
86}
87
88/// A SPIR-V specialization constant
89#[derive(Debug, Clone)]
90pub struct SpecializationConstant {
91    /// The handle to the constant.
92    pub id: Handle<ConstantId>,
93    /// The declared `constant_id` of the constant.
94    pub constant_id: u32,
95}
96
97/// Specialization constants for a workgroup size.
98#[derive(Debug, Clone)]
99pub struct WorkgroupSizeSpecializationConstants {
100    /// Workgroup size in _x_.
101    pub x: Option<SpecializationConstant>,
102    /// Workgroup size in _y_.
103    pub y: Option<SpecializationConstant>,
104    /// Workgroup size in _z_.
105    pub z: Option<SpecializationConstant>,
106    /// The constant ID of the builtin `WorkGroupSize`
107    pub builtin_workgroup_size_handle: Option<Handle<ConstantId>>,
108}
109
110/// An iterator over specialization constants, created by [`Compiler::specialization_constants`].
111pub struct SpecializationConstantIter<'a>(
112    PhantomCompiler,
113    slice::Iter<'a, spvc_specialization_constant>,
114);
115
116impl_iterator!(SpecializationConstantIter<'_>: SpecializationConstant as map |s, o: &spvc_specialization_constant| {
117    SpecializationConstant {
118        id: s.0.create_handle(o.id),
119        constant_id: o.constant_id,
120    }
121} for [1]);
122
123/// Iterator for specialization subconstants created by
124/// [`Compiler::specialization_sub_constants`].
125pub struct SpecializationSubConstantIter<'a>(PhantomCompiler, slice::Iter<'a, ConstantId>);
126
127impl_iterator!(SpecializationSubConstantIter<'_>: Handle<ConstantId> as map |s, o: &ConstantId| {
128    s.0.create_handle(*o)
129} for [1]);
130
131/// Reflection of specialization constants.
132impl<T> Compiler<T> {
133    // check bounds of the constant, otherwise you can write to arbitrary memory.
134    unsafe fn bounds_check_constant(
135        handle: spvc_constant,
136        column: u32,
137        row: u32,
138    ) -> error::Result<()> {
139        // SPIRConstant is at most mat4, so anything above that is OOB.
140        if column >= 4 || row >= 4 {
141            return Err(SpirvCrossError::IndexOutOfBounds { row, column });
142        }
143
144        let vecsize = sys::spvc_rs_constant_get_vecsize(handle);
145        let colsize = sys::spvc_rs_constant_get_matrix_colsize(handle);
146
147        if column >= colsize || row >= vecsize {
148            return Err(SpirvCrossError::IndexOutOfBounds { row, column });
149        }
150
151        Ok(())
152    }
153
154    /// Set the value of the specialization value at the given column and row.
155    ///
156    /// The type is inferred from the input, but it is not type checked against the SPIR-V.
157    ///
158    /// Using this function wrong is not unsafe, but could cause the output shader to
159    /// be invalid.
160    ///
161    /// [`Compiler::set_specialization_constant_value`] is more efficient and easier to use in
162    /// most cases, which will handle row and column for vector and matrix scalars. This function
163    /// remains to deal with more esoteric matrix shapes, or for getting only a single
164    /// element of a vector or matrix.
165    pub fn set_specialization_constant_scalar<S: ConstantScalar>(
166        &mut self,
167        handle: Handle<ConstantId>,
168        column: u32,
169        row: u32,
170        value: S,
171    ) -> error::Result<()> {
172        let constant = self.yield_id(handle)?;
173        unsafe {
174            // SAFETY: yield_id ensures safety.
175            let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
176            Self::bounds_check_constant(handle, column, row)?;
177            S::set(handle, column, row, value)
178        }
179        Ok(())
180    }
181
182    /// Get the value of the specialization value at the given column and row.
183    ///
184    /// The type is inferred from the return value, and is not type-checked
185    /// against the input SPIR-V.
186    ///
187    /// If the inferred type differs from what is expected, an indeterminate
188    /// but initialized value will be returned.
189    ///
190    /// [`Compiler::specialization_constant_value`] is more efficient and easier to use in
191    /// most cases, which will handle row and column for vector and matrix scalars. This function
192    /// remains to deal with more esoteric matrix shapes, or for getting only a single
193    /// element of a vector or matrix.
194    pub fn specialization_constant_scalar<S: ConstantScalar>(
195        &self,
196        handle: Handle<ConstantId>,
197        column: u32,
198        row: u32,
199    ) -> error::Result<S> {
200        let constant = self.yield_id(handle)?;
201        unsafe {
202            // SAFETY: yield_id ensures safety.
203            let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
204            Self::bounds_check_constant(handle, column, row)?;
205
206            Ok(S::get(handle, column, row))
207        }
208    }
209
210    /// Query declared specialization constants.
211    pub fn specialization_constants(&self) -> error::Result<SpecializationConstantIter<'static>> {
212        unsafe {
213            let mut constants = std::ptr::null();
214            let mut size = 0;
215            sys::spvc_compiler_get_specialization_constants(
216                self.ptr.as_ptr(),
217                &mut constants,
218                &mut size,
219            )
220            .ok(self)?;
221
222            // SAFETY: 'static is sound here.
223            // https://github.com/KhronosGroup/SPIRV-Cross/blob/main/spirv_cross_c.cpp#L2522
224            let slice = slice::from_raw_parts(constants, size);
225            Ok(SpecializationConstantIter(self.phantom(), slice.iter()))
226        }
227    }
228
229    /// Get subconstants for composite type specialization constants.
230    pub fn specialization_sub_constants(
231        &self,
232        constant: Handle<ConstantId>,
233    ) -> error::Result<SpecializationSubConstantIter<'_>> {
234        let id = self.yield_id(constant)?;
235        unsafe {
236            let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), id);
237            let mut constants = std::ptr::null();
238            let mut size = 0;
239            sys::spvc_constant_get_subconstants(constant, &mut constants, &mut size);
240
241            Ok(SpecializationSubConstantIter(
242                self.phantom(),
243                slice::from_raw_parts(constants, size).iter(),
244            ))
245        }
246    }
247
248    /// In SPIR-V, the compute work group size can be represented by a constant vector, in which case
249    /// the LocalSize execution mode is ignored.
250    ///
251    /// This constant vector can be a constant vector, specialization constant vector, or partly specialized constant vector.
252    /// To modify and query work group dimensions which are specialization constants, constant values must be modified
253    /// directly via [`Compiler::set_specialization_constant_value`] rather than using LocalSize directly.
254    /// This function will return which constants should be modified.
255    ///
256    /// To modify dimensions which are *not* specialization constants, set_execution_mode should be used directly.
257    /// Arguments to set_execution_mode which are specialization constants are effectively ignored during compilation.
258    /// NOTE: This is somewhat different from how SPIR-V works. In SPIR-V, the constant vector will completely replace LocalSize,
259    /// while in this interface, LocalSize is only ignored for specialization constants.
260    ///
261    /// The specialization constant will be written to x, y and z arguments.
262    /// If the component is not a specialization constant, a zeroed out struct will be written.
263    /// The return value is the constant ID of the builtin WorkGroupSize, but this is not expected to be useful
264    /// for most use cases.
265    ///
266    /// If `LocalSizeId` is used, there is no uvec3 value representing the workgroup size, so the return value is 0,
267    /// but _x_, _y_ and _z_ are written as normal if the components are specialization constants.
268    pub fn work_group_size_specialization_constants(&self) -> WorkgroupSizeSpecializationConstants {
269        unsafe {
270            let mut x = MaybeUninit::zeroed();
271            let mut y = MaybeUninit::zeroed();
272            let mut z = MaybeUninit::zeroed();
273
274            let constant = sys::spvc_compiler_get_work_group_size_specialization_constants(
275                self.ptr.as_ptr(),
276                x.as_mut_ptr(),
277                y.as_mut_ptr(),
278                z.as_mut_ptr(),
279            );
280
281            let constant = self.create_handle_if_not_zero(constant);
282
283            let x = x.assume_init();
284            let y = y.assume_init();
285            let z = z.assume_init();
286
287            let x = self
288                .create_handle_if_not_zero(x.id)
289                .map(|id| SpecializationConstant {
290                    id,
291                    constant_id: x.constant_id,
292                });
293
294            let y = self
295                .create_handle_if_not_zero(y.id)
296                .map(|id| SpecializationConstant {
297                    id,
298                    constant_id: y.constant_id,
299                });
300
301            let z = self
302                .create_handle_if_not_zero(z.id)
303                .map(|id| SpecializationConstant {
304                    id,
305                    constant_id: z.constant_id,
306                });
307
308            WorkgroupSizeSpecializationConstants {
309                x,
310                y,
311                z,
312                builtin_workgroup_size_handle: constant,
313            }
314        }
315    }
316
317    /// Get the type of the specialization constant.
318    pub fn specialization_constant_type(
319        &self,
320        constant: Handle<ConstantId>,
321    ) -> error::Result<Handle<TypeId>> {
322        let constant = self.yield_id(constant)?;
323        let type_id = unsafe {
324            // SAFETY: yield_id ensures this is valid for the ID
325            let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
326            self.create_handle(sys::spvc_constant_get_type(constant))
327        };
328
329        Ok(type_id)
330    }
331}
332
333/// A marker trait for types that can be represented as a SPIR-V constant.
334pub trait ConstantValue: Sealed + Sized {
335    // None of anything here is a public API.
336    // As soon as generic_const_expr is stable, we can get rid of
337    // almost all of this silliness.
338    #[doc(hidden)]
339    const COLUMNS: usize;
340    #[doc(hidden)]
341    const VECSIZE: usize;
342    #[doc(hidden)]
343    type BaseArrayType: Default + Index<usize, Output = Self::BaseType> + IndexMut<usize>;
344    #[doc(hidden)]
345    type ArrayType: Default + Index<usize, Output = Self::BaseArrayType> + IndexMut<usize>;
346    #[doc(hidden)]
347    type BaseType: ConstantScalar;
348
349    #[doc(hidden)]
350    fn from_array(value: Self::ArrayType) -> Self;
351
352    #[doc(hidden)]
353    fn to_array(value: Self) -> Self::ArrayType;
354}
355
356impl<T: ConstantScalar> ConstantValue for T {
357    const COLUMNS: usize = 1;
358    const VECSIZE: usize = 1;
359    type BaseArrayType = [T; 1];
360    type ArrayType = [[T; 1]; 1];
361    type BaseType = T;
362
363    fn from_array(value: Self::ArrayType) -> Self {
364        value[0][0]
365    }
366
367    fn to_array(value: Self) -> Self::ArrayType {
368        [[value]]
369    }
370}
371
372impl<T> Compiler<T> {
373    /// Get the value of the specialization value.
374    ///
375    /// The type is inferred from the return value, and is not type-checked
376    /// against the input SPIR-V.
377    ///
378    /// If the output type dimensions are too large for the constant,
379    /// [`SpirvCrossError::IndexOutOfBounds`] will be returned.
380    ///
381    /// If the inferred type differs from what is expected, an indeterminate
382    /// but initialized value will be returned.
383    pub fn specialization_constant_value<S: ConstantValue>(
384        &self,
385        handle: Handle<ConstantId>,
386    ) -> error::Result<S> {
387        let constant = self.yield_id(handle)?;
388        unsafe {
389            // SAFETY: yield_id ensures safety.
390            let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
391            // Self::bounds_check_constant(handle, column, row)?;
392            let mut output = S::ArrayType::default();
393
394            // bounds check the limits of the type.
395            Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
396
397            for column in 0..S::COLUMNS {
398                for row in 0..S::VECSIZE {
399                    let value = S::BaseType::get(handle, column as u32, row as u32);
400                    output[column][row] = value;
401                }
402            }
403            Ok(S::from_array(output))
404        }
405    }
406
407    /// Set the value of the specialization value.
408    ///
409    /// The type is inferred from the input, but it is not type checked against the SPIR-V.
410    ///
411    /// Using this function wrong is not unsafe, but could cause the output shader to
412    /// be invalid.
413    ///
414    /// If the input dimensions are too large for the constant type,
415    /// [`SpirvCrossError::IndexOutOfBounds`] will be returned.
416    pub fn set_specialization_constant_value<S: ConstantValue>(
417        &mut self,
418        handle: Handle<ConstantId>,
419        value: S,
420    ) -> error::Result<()> {
421        let constant = self.yield_id(handle)?;
422        unsafe {
423            // SAFETY: yield_id ensures safety.
424            let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
425
426            // bounds check the limits of the type.
427            Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
428
429            let value = S::to_array(value);
430            for column in 0..S::COLUMNS {
431                for row in 0..S::VECSIZE {
432                    S::BaseType::set(handle, column as u32, row as u32, value[column][row]);
433                }
434            }
435        }
436        Ok(())
437    }
438}
439
440#[allow(unused_imports)]
441#[allow(clippy::needless_pub_self)]
442pub(self) use impl_vec_constant;