spirv_cross2/reflect/
constants.rs

1use crate::sealed::Sealed;
2use spirv_cross_sys::{spvc_constant, spvc_specialization_constant, TypeId};
3use std::mem::MaybeUninit;
4use std::ops::{Index, IndexMut};
5use std::slice;
6
7use crate::error::{SpirvCrossError, ToContextError};
8use crate::handle::{ConstantId, Handle};
9use crate::iter::impl_iterator;
10use crate::{error, Compiler, PhantomCompiler};
11use spirv_cross_sys as sys;
12
13mod gfx_maths;
14mod half;
15mod glam;
16
17/// A marker trait for types that can be represented as a scalar SPIR-V constant.
18pub trait ConstantScalar: Default + Sealed + Copy {
19    #[doc(hidden)]
20    unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self;
21
22    #[doc(hidden)]
23    unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self);
24}
25
26macro_rules! impl_spvc_constant {
27    ($get:ident  $set:ident $prim:ty) => {
28        impl Sealed for $prim {}
29        impl ConstantScalar for $prim {
30            unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
31                unsafe { ::spirv_cross_sys::$get(constant, column, row) as Self }
32            }
33
34            unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
35                unsafe { ::spirv_cross_sys::$set(constant, column, row, value) }
36            }
37        }
38    };
39}
40
41macro_rules! impl_vec_constant {
42    ($vec_ty:ty [$base_ty:ty; $len:literal] for [$($component:ident),*]) => {
43        impl $crate::sealed::Sealed for $vec_ty {}
44        impl $crate::reflect::constants::ConstantValue for $vec_ty {
45             const COLUMNS: usize = 1;
46             const VECSIZE: usize = $len;
47             type BaseArrayType = [$base_ty; $len];
48             type ArrayType = [[$base_ty; $len]; 1];
49             type BaseType = $base_ty;
50
51             fn from_array(value: Self::ArrayType) -> Self {
52                 value[0].into()
53             }
54
55             fn to_array(value: Self) -> Self::ArrayType {
56                [[$(value.$component),*]]
57             }
58        }
59    };
60}
61
62impl_spvc_constant!(spvc_constant_get_scalar_i8 spvc_constant_set_scalar_i8 i8);
63impl_spvc_constant!(spvc_constant_get_scalar_i16 spvc_constant_set_scalar_i16 i16);
64impl_spvc_constant!(spvc_constant_get_scalar_i32 spvc_constant_set_scalar_i32 i32);
65impl_spvc_constant!(spvc_constant_get_scalar_i64 spvc_constant_set_scalar_i64 i64);
66
67impl_spvc_constant!(spvc_constant_get_scalar_u8 spvc_constant_set_scalar_u8 u8);
68impl_spvc_constant!(spvc_constant_get_scalar_u16 spvc_constant_set_scalar_u16 u16);
69impl_spvc_constant!(spvc_constant_get_scalar_u32 spvc_constant_set_scalar_u32 u32);
70impl_spvc_constant!(spvc_constant_get_scalar_u64 spvc_constant_set_scalar_u64 u64);
71
72impl_spvc_constant!(spvc_constant_get_scalar_fp32 spvc_constant_set_scalar_fp32 f32);
73impl_spvc_constant!(spvc_constant_get_scalar_fp64 spvc_constant_set_scalar_fp64 f64);
74
75// implement manually for bool
76impl Sealed for bool {}
77impl ConstantScalar for bool {
78    unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
79        unsafe {
80            sys::spvc_constant_get_scalar_u8(constant, column, row) != 0
81        }
82    }
83
84    unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
85        sys::spvc_constant_set_scalar_u8(constant, column, row, if value { 1 } else { 0 });
86    }
87}
88
89/// A SPIR-V specialization constant
90#[derive(Debug, Clone)]
91pub struct SpecializationConstant {
92    /// The handle to the constant.
93    pub id: Handle<ConstantId>,
94    /// The declared `constant_id` of the constant.
95    pub constant_id: u32,
96}
97
98/// Specialization constants for a workgroup size.
99#[derive(Debug, Clone)]
100pub struct WorkgroupSizeSpecializationConstants {
101    /// Workgroup size in _x_.
102    pub x: Option<SpecializationConstant>,
103    /// Workgroup size in _y_.
104    pub y: Option<SpecializationConstant>,
105    /// Workgroup size in _z_.
106    pub z: Option<SpecializationConstant>,
107    /// The constant ID of the builtin `WorkGroupSize`
108    pub builtin_workgroup_size_handle: Option<Handle<ConstantId>>,
109}
110
111/// An iterator over specialization constants, created by [`Compiler::specialization_constants`].
112pub struct SpecializationConstantIter<'a>(
113    PhantomCompiler,
114    slice::Iter<'a, spvc_specialization_constant>,
115);
116
117impl_iterator!(SpecializationConstantIter<'_>: SpecializationConstant as map |s, o: &spvc_specialization_constant| {
118    SpecializationConstant {
119        id: s.0.create_handle(o.id),
120        constant_id: o.constant_id,
121    }
122} for [1]);
123
124/// Iterator for specialization subconstants created by
125/// [`Compiler::specialization_sub_constants`].
126pub struct SpecializationSubConstantIter<'a>(PhantomCompiler, slice::Iter<'a, ConstantId>);
127
128impl_iterator!(SpecializationSubConstantIter<'_>: Handle<ConstantId> as map |s, o: &ConstantId| {
129    s.0.create_handle(*o)
130} for [1]);
131
132/// Reflection of specialization constants.
133impl<T> Compiler<T> {
134    // check bounds of the constant, otherwise you can write to arbitrary memory.
135    unsafe fn bounds_check_constant(
136        handle: spvc_constant,
137        column: u32,
138        row: u32,
139    ) -> error::Result<()> {
140        // SPIRConstant is at most mat4, so anything above that is OOB.
141        if column >= 4 || row >= 4 {
142            return Err(SpirvCrossError::IndexOutOfBounds { row, column });
143        }
144
145        let vecsize = sys::spvc_rs_constant_get_vecsize(handle);
146        let colsize = sys::spvc_rs_constant_get_matrix_colsize(handle);
147
148        if column >= colsize || row >= vecsize {
149            return Err(SpirvCrossError::IndexOutOfBounds { row, column });
150        }
151
152        Ok(())
153    }
154
155    /// Set the value of the specialization value at the given column and row.
156    ///
157    /// The type is inferred from the input, but it is not type checked against the SPIR-V.
158    ///
159    /// Using this function wrong is not unsafe, but could cause the output shader to
160    /// be invalid.
161    ///
162    /// [`Compiler::set_specialization_constant_value`] is more efficient and easier to use in
163    /// most cases, which will handle row and column for vector and matrix scalars. This function
164    /// remains to deal with more esoteric matrix shapes, or for getting only a single
165    /// element of a vector or matrix.
166    pub fn set_specialization_constant_scalar<S: ConstantScalar>(
167        &mut self,
168        handle: Handle<ConstantId>,
169        column: u32,
170        row: u32,
171        value: S,
172    ) -> error::Result<()> {
173        let constant = self.yield_id(handle)?;
174        unsafe {
175            // SAFETY: yield_id ensures safety.
176            let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
177            Self::bounds_check_constant(handle, column, row)?;
178            S::set(handle, column, row, value)
179        }
180        Ok(())
181    }
182
183    /// Get the value of the specialization value at the given column and row.
184    ///
185    /// The type is inferred from the return value, and is not type-checked
186    /// against the input SPIR-V.
187    ///
188    /// If the inferred type differs from what is expected, an indeterminate
189    /// but initialized value will be returned.
190    ///
191    /// [`Compiler::specialization_constant_value`] is more efficient and easier to use in
192    /// most cases, which will handle row and column for vector and matrix scalars. This function
193    /// remains to deal with more esoteric matrix shapes, or for getting only a single
194    /// element of a vector or matrix.
195    pub fn specialization_constant_scalar<S: ConstantScalar>(
196        &self,
197        handle: Handle<ConstantId>,
198        column: u32,
199        row: u32,
200    ) -> error::Result<S> {
201        let constant = self.yield_id(handle)?;
202        unsafe {
203            // SAFETY: yield_id ensures safety.
204            let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
205            Self::bounds_check_constant(handle, column, row)?;
206
207            Ok(S::get(handle, column, row))
208        }
209    }
210
211    /// Query declared specialization constants.
212    pub fn specialization_constants(&self) -> error::Result<SpecializationConstantIter<'static>> {
213        unsafe {
214            let mut constants = std::ptr::null();
215            let mut size = 0;
216            sys::spvc_compiler_get_specialization_constants(
217                self.ptr.as_ptr(),
218                &mut constants,
219                &mut size,
220            )
221            .ok(self)?;
222
223            // SAFETY: 'static is sound here.
224            // https://github.com/KhronosGroup/SPIRV-Cross/blob/main/spirv_cross_c.cpp#L2522
225            let slice = slice::from_raw_parts(constants, size);
226            Ok(SpecializationConstantIter(self.phantom(), slice.iter()))
227        }
228    }
229
230    /// Get subconstants for composite type specialization constants.
231    pub fn specialization_sub_constants(
232        &self,
233        constant: Handle<ConstantId>,
234    ) -> error::Result<SpecializationSubConstantIter> {
235        let id = self.yield_id(constant)?;
236        unsafe {
237            let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), id);
238            let mut constants = std::ptr::null();
239            let mut size = 0;
240            sys::spvc_constant_get_subconstants(constant, &mut constants, &mut size);
241
242            Ok(SpecializationSubConstantIter(
243                self.phantom(),
244                slice::from_raw_parts(constants, size).iter(),
245            ))
246        }
247    }
248
249    /// In SPIR-V, the compute work group size can be represented by a constant vector, in which case
250    /// the LocalSize execution mode is ignored.
251    ///
252    /// This constant vector can be a constant vector, specialization constant vector, or partly specialized constant vector.
253    /// To modify and query work group dimensions which are specialization constants, constant values must be modified
254    /// directly via [`Compiler::set_specialization_constant_value`] rather than using LocalSize directly.
255    /// This function will return which constants should be modified.
256    ///
257    /// To modify dimensions which are *not* specialization constants, set_execution_mode should be used directly.
258    /// Arguments to set_execution_mode which are specialization constants are effectively ignored during compilation.
259    /// NOTE: This is somewhat different from how SPIR-V works. In SPIR-V, the constant vector will completely replace LocalSize,
260    /// while in this interface, LocalSize is only ignored for specialization constants.
261    ///
262    /// The specialization constant will be written to x, y and z arguments.
263    /// If the component is not a specialization constant, a zeroed out struct will be written.
264    /// The return value is the constant ID of the builtin WorkGroupSize, but this is not expected to be useful
265    /// for most use cases.
266    ///
267    /// If `LocalSizeId` is used, there is no uvec3 value representing the workgroup size, so the return value is 0,
268    /// but _x_, _y_ and _z_ are written as normal if the components are specialization constants.
269    pub fn work_group_size_specialization_constants(&self) -> WorkgroupSizeSpecializationConstants {
270        unsafe {
271            let mut x = MaybeUninit::zeroed();
272            let mut y = MaybeUninit::zeroed();
273            let mut z = MaybeUninit::zeroed();
274
275            let constant = sys::spvc_compiler_get_work_group_size_specialization_constants(
276                self.ptr.as_ptr(),
277                x.as_mut_ptr(),
278                y.as_mut_ptr(),
279                z.as_mut_ptr(),
280            );
281
282            let constant = self.create_handle_if_not_zero(constant);
283
284            let x = x.assume_init();
285            let y = y.assume_init();
286            let z = z.assume_init();
287
288            let x = self
289                .create_handle_if_not_zero(x.id)
290                .map(|id| SpecializationConstant {
291                    id,
292                    constant_id: x.constant_id,
293                });
294
295            let y = self
296                .create_handle_if_not_zero(y.id)
297                .map(|id| SpecializationConstant {
298                    id,
299                    constant_id: y.constant_id,
300                });
301
302            let z = self
303                .create_handle_if_not_zero(z.id)
304                .map(|id| SpecializationConstant {
305                    id,
306                    constant_id: z.constant_id,
307                });
308
309            WorkgroupSizeSpecializationConstants {
310                x,
311                y,
312                z,
313                builtin_workgroup_size_handle: constant,
314            }
315        }
316    }
317
318    /// Get the type of the specialization constant.
319    pub fn specialization_constant_type(
320        &self,
321        constant: Handle<ConstantId>,
322    ) -> error::Result<Handle<TypeId>> {
323        let constant = self.yield_id(constant)?;
324        let type_id = unsafe {
325            // SAFETY: yield_id ensures this is valid for the ID
326            let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
327            self.create_handle(sys::spvc_constant_get_type(constant))
328        };
329
330        Ok(type_id)
331    }
332}
333
334/// A marker trait for types that can be represented as a SPIR-V constant.
335pub trait ConstantValue: Sealed + Sized {
336    // None of anything here is a public API.
337    // As soon as generic_const_expr is stable, we can get rid of
338    // almost all of this silliness.
339    #[doc(hidden)]
340    const COLUMNS: usize;
341    #[doc(hidden)]
342    const VECSIZE: usize;
343    #[doc(hidden)]
344    type BaseArrayType: Default + Index<usize, Output = Self::BaseType> + IndexMut<usize>;
345    #[doc(hidden)]
346    type ArrayType: Default + Index<usize, Output = Self::BaseArrayType> + IndexMut<usize>;
347    #[doc(hidden)]
348    type BaseType: ConstantScalar;
349
350    #[doc(hidden)]
351    fn from_array(value: Self::ArrayType) -> Self;
352
353    #[doc(hidden)]
354    fn to_array(value: Self) -> Self::ArrayType;
355}
356
357impl<T: ConstantScalar> ConstantValue for T {
358    const COLUMNS: usize = 1;
359    const VECSIZE: usize = 1;
360    type BaseArrayType = [T; 1];
361    type ArrayType = [[T; 1]; 1];
362    type BaseType = T;
363
364    fn from_array(value: Self::ArrayType) -> Self {
365        value[0][0]
366    }
367
368    fn to_array(value: Self) -> Self::ArrayType {
369        [[value]]
370    }
371}
372
373impl<T> Compiler<T> {
374    /// Get the value of the specialization value.
375    ///
376    /// The type is inferred from the return value, and is not type-checked
377    /// against the input SPIR-V.
378    ///
379    /// If the output type dimensions are too large for the constant,
380    /// [`SpirvCrossError::IndexOutOfBounds`] will be returned.
381    ///
382    /// If the inferred type differs from what is expected, an indeterminate
383    /// but initialized value will be returned.
384    pub fn specialization_constant_value<S: ConstantValue>(
385        &self,
386        handle: Handle<ConstantId>,
387    ) -> error::Result<S> {
388        let constant = self.yield_id(handle)?;
389        unsafe {
390            // SAFETY: yield_id ensures safety.
391            let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
392            // Self::bounds_check_constant(handle, column, row)?;
393            let mut output = S::ArrayType::default();
394
395            // bounds check the limits of the type.
396            Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
397
398            for column in 0..S::COLUMNS {
399                for row in 0..S::VECSIZE {
400                    let value = S::BaseType::get(handle, column as u32, row as u32);
401                    output[column][row] = value;
402                }
403            }
404            Ok(S::from_array(output))
405        }
406    }
407
408    /// Set the value of the specialization value.
409    ///
410    /// The type is inferred from the input, but it is not type checked against the SPIR-V.
411    ///
412    /// Using this function wrong is not unsafe, but could cause the output shader to
413    /// be invalid.
414    ///
415    /// If the input dimensions are too large for the constant type,
416    /// [`SpirvCrossError::IndexOutOfBounds`] will be returned.
417    pub fn set_specialization_constant_value<S: ConstantValue>(
418        &mut self,
419        handle: Handle<ConstantId>,
420        value: S,
421    ) -> error::Result<()> {
422        let constant = self.yield_id(handle)?;
423        unsafe {
424            // SAFETY: yield_id ensures safety.
425            let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
426
427            // bounds check the limits of the type.
428            Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
429
430            let value = S::to_array(value);
431            for column in 0..S::COLUMNS {
432                for row in 0..S::VECSIZE {
433                    S::BaseType::set(handle, column as u32, row as u32, value[column][row]);
434                }
435            }
436        }
437        Ok(())
438    }
439}
440
441pub(self) use impl_vec_constant;