1use crate::sealed::Sealed;
2use spirv_cross_sys::{spvc_constant, spvc_specialization_constant, TypeId};
3use std::mem::MaybeUninit;
4use std::ops::{Index, IndexMut};
5use std::slice;
6
7use crate::error::{SpirvCrossError, ToContextError};
8use crate::handle::{ConstantId, Handle};
9use crate::iter::impl_iterator;
10use crate::{error, Compiler, PhantomCompiler};
11use spirv_cross_sys as sys;
12
13mod gfx_maths;
14mod half;
15mod glam;
16
17pub trait ConstantScalar: Default + Sealed + Copy {
19 #[doc(hidden)]
20 unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self;
21
22 #[doc(hidden)]
23 unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self);
24}
25
26macro_rules! impl_spvc_constant {
27 ($get:ident $set:ident $prim:ty) => {
28 impl Sealed for $prim {}
29 impl ConstantScalar for $prim {
30 unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
31 unsafe { ::spirv_cross_sys::$get(constant, column, row) as Self }
32 }
33
34 unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
35 unsafe { ::spirv_cross_sys::$set(constant, column, row, value) }
36 }
37 }
38 };
39}
40
41macro_rules! impl_vec_constant {
42 ($vec_ty:ty [$base_ty:ty; $len:literal] for [$($component:ident),*]) => {
43 impl $crate::sealed::Sealed for $vec_ty {}
44 impl $crate::reflect::constants::ConstantValue for $vec_ty {
45 const COLUMNS: usize = 1;
46 const VECSIZE: usize = $len;
47 type BaseArrayType = [$base_ty; $len];
48 type ArrayType = [[$base_ty; $len]; 1];
49 type BaseType = $base_ty;
50
51 fn from_array(value: Self::ArrayType) -> Self {
52 value[0].into()
53 }
54
55 fn to_array(value: Self) -> Self::ArrayType {
56 [[$(value.$component),*]]
57 }
58 }
59 };
60}
61
62impl_spvc_constant!(spvc_constant_get_scalar_i8 spvc_constant_set_scalar_i8 i8);
63impl_spvc_constant!(spvc_constant_get_scalar_i16 spvc_constant_set_scalar_i16 i16);
64impl_spvc_constant!(spvc_constant_get_scalar_i32 spvc_constant_set_scalar_i32 i32);
65impl_spvc_constant!(spvc_constant_get_scalar_i64 spvc_constant_set_scalar_i64 i64);
66
67impl_spvc_constant!(spvc_constant_get_scalar_u8 spvc_constant_set_scalar_u8 u8);
68impl_spvc_constant!(spvc_constant_get_scalar_u16 spvc_constant_set_scalar_u16 u16);
69impl_spvc_constant!(spvc_constant_get_scalar_u32 spvc_constant_set_scalar_u32 u32);
70impl_spvc_constant!(spvc_constant_get_scalar_u64 spvc_constant_set_scalar_u64 u64);
71
72impl_spvc_constant!(spvc_constant_get_scalar_fp32 spvc_constant_set_scalar_fp32 f32);
73impl_spvc_constant!(spvc_constant_get_scalar_fp64 spvc_constant_set_scalar_fp64 f64);
74
75impl Sealed for bool {}
77impl ConstantScalar for bool {
78 unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
79 unsafe {
80 sys::spvc_constant_get_scalar_u8(constant, column, row) != 0
81 }
82 }
83
84 unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
85 sys::spvc_constant_set_scalar_u8(constant, column, row, if value { 1 } else { 0 });
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct SpecializationConstant {
92 pub id: Handle<ConstantId>,
94 pub constant_id: u32,
96}
97
98#[derive(Debug, Clone)]
100pub struct WorkgroupSizeSpecializationConstants {
101 pub x: Option<SpecializationConstant>,
103 pub y: Option<SpecializationConstant>,
105 pub z: Option<SpecializationConstant>,
107 pub builtin_workgroup_size_handle: Option<Handle<ConstantId>>,
109}
110
111pub struct SpecializationConstantIter<'a>(
113 PhantomCompiler,
114 slice::Iter<'a, spvc_specialization_constant>,
115);
116
117impl_iterator!(SpecializationConstantIter<'_>: SpecializationConstant as map |s, o: &spvc_specialization_constant| {
118 SpecializationConstant {
119 id: s.0.create_handle(o.id),
120 constant_id: o.constant_id,
121 }
122} for [1]);
123
124pub struct SpecializationSubConstantIter<'a>(PhantomCompiler, slice::Iter<'a, ConstantId>);
127
128impl_iterator!(SpecializationSubConstantIter<'_>: Handle<ConstantId> as map |s, o: &ConstantId| {
129 s.0.create_handle(*o)
130} for [1]);
131
132impl<T> Compiler<T> {
134 unsafe fn bounds_check_constant(
136 handle: spvc_constant,
137 column: u32,
138 row: u32,
139 ) -> error::Result<()> {
140 if column >= 4 || row >= 4 {
142 return Err(SpirvCrossError::IndexOutOfBounds { row, column });
143 }
144
145 let vecsize = sys::spvc_rs_constant_get_vecsize(handle);
146 let colsize = sys::spvc_rs_constant_get_matrix_colsize(handle);
147
148 if column >= colsize || row >= vecsize {
149 return Err(SpirvCrossError::IndexOutOfBounds { row, column });
150 }
151
152 Ok(())
153 }
154
155 pub fn set_specialization_constant_scalar<S: ConstantScalar>(
167 &mut self,
168 handle: Handle<ConstantId>,
169 column: u32,
170 row: u32,
171 value: S,
172 ) -> error::Result<()> {
173 let constant = self.yield_id(handle)?;
174 unsafe {
175 let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
177 Self::bounds_check_constant(handle, column, row)?;
178 S::set(handle, column, row, value)
179 }
180 Ok(())
181 }
182
183 pub fn specialization_constant_scalar<S: ConstantScalar>(
196 &self,
197 handle: Handle<ConstantId>,
198 column: u32,
199 row: u32,
200 ) -> error::Result<S> {
201 let constant = self.yield_id(handle)?;
202 unsafe {
203 let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
205 Self::bounds_check_constant(handle, column, row)?;
206
207 Ok(S::get(handle, column, row))
208 }
209 }
210
211 pub fn specialization_constants(&self) -> error::Result<SpecializationConstantIter<'static>> {
213 unsafe {
214 let mut constants = std::ptr::null();
215 let mut size = 0;
216 sys::spvc_compiler_get_specialization_constants(
217 self.ptr.as_ptr(),
218 &mut constants,
219 &mut size,
220 )
221 .ok(self)?;
222
223 let slice = slice::from_raw_parts(constants, size);
226 Ok(SpecializationConstantIter(self.phantom(), slice.iter()))
227 }
228 }
229
230 pub fn specialization_sub_constants(
232 &self,
233 constant: Handle<ConstantId>,
234 ) -> error::Result<SpecializationSubConstantIter> {
235 let id = self.yield_id(constant)?;
236 unsafe {
237 let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), id);
238 let mut constants = std::ptr::null();
239 let mut size = 0;
240 sys::spvc_constant_get_subconstants(constant, &mut constants, &mut size);
241
242 Ok(SpecializationSubConstantIter(
243 self.phantom(),
244 slice::from_raw_parts(constants, size).iter(),
245 ))
246 }
247 }
248
249 pub fn work_group_size_specialization_constants(&self) -> WorkgroupSizeSpecializationConstants {
270 unsafe {
271 let mut x = MaybeUninit::zeroed();
272 let mut y = MaybeUninit::zeroed();
273 let mut z = MaybeUninit::zeroed();
274
275 let constant = sys::spvc_compiler_get_work_group_size_specialization_constants(
276 self.ptr.as_ptr(),
277 x.as_mut_ptr(),
278 y.as_mut_ptr(),
279 z.as_mut_ptr(),
280 );
281
282 let constant = self.create_handle_if_not_zero(constant);
283
284 let x = x.assume_init();
285 let y = y.assume_init();
286 let z = z.assume_init();
287
288 let x = self
289 .create_handle_if_not_zero(x.id)
290 .map(|id| SpecializationConstant {
291 id,
292 constant_id: x.constant_id,
293 });
294
295 let y = self
296 .create_handle_if_not_zero(y.id)
297 .map(|id| SpecializationConstant {
298 id,
299 constant_id: y.constant_id,
300 });
301
302 let z = self
303 .create_handle_if_not_zero(z.id)
304 .map(|id| SpecializationConstant {
305 id,
306 constant_id: z.constant_id,
307 });
308
309 WorkgroupSizeSpecializationConstants {
310 x,
311 y,
312 z,
313 builtin_workgroup_size_handle: constant,
314 }
315 }
316 }
317
318 pub fn specialization_constant_type(
320 &self,
321 constant: Handle<ConstantId>,
322 ) -> error::Result<Handle<TypeId>> {
323 let constant = self.yield_id(constant)?;
324 let type_id = unsafe {
325 let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
327 self.create_handle(sys::spvc_constant_get_type(constant))
328 };
329
330 Ok(type_id)
331 }
332}
333
334pub trait ConstantValue: Sealed + Sized {
336 #[doc(hidden)]
340 const COLUMNS: usize;
341 #[doc(hidden)]
342 const VECSIZE: usize;
343 #[doc(hidden)]
344 type BaseArrayType: Default + Index<usize, Output = Self::BaseType> + IndexMut<usize>;
345 #[doc(hidden)]
346 type ArrayType: Default + Index<usize, Output = Self::BaseArrayType> + IndexMut<usize>;
347 #[doc(hidden)]
348 type BaseType: ConstantScalar;
349
350 #[doc(hidden)]
351 fn from_array(value: Self::ArrayType) -> Self;
352
353 #[doc(hidden)]
354 fn to_array(value: Self) -> Self::ArrayType;
355}
356
357impl<T: ConstantScalar> ConstantValue for T {
358 const COLUMNS: usize = 1;
359 const VECSIZE: usize = 1;
360 type BaseArrayType = [T; 1];
361 type ArrayType = [[T; 1]; 1];
362 type BaseType = T;
363
364 fn from_array(value: Self::ArrayType) -> Self {
365 value[0][0]
366 }
367
368 fn to_array(value: Self) -> Self::ArrayType {
369 [[value]]
370 }
371}
372
373impl<T> Compiler<T> {
374 pub fn specialization_constant_value<S: ConstantValue>(
385 &self,
386 handle: Handle<ConstantId>,
387 ) -> error::Result<S> {
388 let constant = self.yield_id(handle)?;
389 unsafe {
390 let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
392 let mut output = S::ArrayType::default();
394
395 Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
397
398 for column in 0..S::COLUMNS {
399 for row in 0..S::VECSIZE {
400 let value = S::BaseType::get(handle, column as u32, row as u32);
401 output[column][row] = value;
402 }
403 }
404 Ok(S::from_array(output))
405 }
406 }
407
408 pub fn set_specialization_constant_value<S: ConstantValue>(
418 &mut self,
419 handle: Handle<ConstantId>,
420 value: S,
421 ) -> error::Result<()> {
422 let constant = self.yield_id(handle)?;
423 unsafe {
424 let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
426
427 Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
429
430 let value = S::to_array(value);
431 for column in 0..S::COLUMNS {
432 for row in 0..S::VECSIZE {
433 S::BaseType::set(handle, column as u32, row as u32, value[column][row]);
434 }
435 }
436 }
437 Ok(())
438 }
439}
440
441pub(self) use impl_vec_constant;