use crate::sealed::Sealed;
use spirv_cross_sys::{spvc_constant, spvc_specialization_constant, TypeId};
use std::mem::MaybeUninit;
use std::ops::{Index, IndexMut};
use std::slice;
use crate::error::{SpirvCrossError, ToContextError};
use crate::handle::{ConstantId, Handle};
use crate::{error, Compiler, PhantomCompiler};
use spirv_cross_sys as sys;
pub trait ConstantScalar: Default + Sealed + Copy {
#[doc(hidden)]
unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self;
#[doc(hidden)]
unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self);
}
macro_rules! impl_spvc_constant {
($get:ident $set:ident $prim:ty) => {
impl Sealed for $prim {}
impl ConstantScalar for $prim {
unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
unsafe { ::spirv_cross_sys::$get(constant, column, row) as Self }
}
unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
unsafe { ::spirv_cross_sys::$set(constant, column, row, value) }
}
}
};
}
impl_spvc_constant!(spvc_constant_get_scalar_i8 spvc_constant_set_scalar_i8 i8);
impl_spvc_constant!(spvc_constant_get_scalar_i16 spvc_constant_set_scalar_i16 i16);
impl_spvc_constant!(spvc_constant_get_scalar_i32 spvc_constant_set_scalar_i32 i32);
impl_spvc_constant!(spvc_constant_get_scalar_i64 spvc_constant_set_scalar_i64 i64);
impl_spvc_constant!(spvc_constant_get_scalar_u8 spvc_constant_set_scalar_u8 u8);
impl_spvc_constant!(spvc_constant_get_scalar_u16 spvc_constant_set_scalar_u16 u16);
impl_spvc_constant!(spvc_constant_get_scalar_u32 spvc_constant_set_scalar_u32 u32);
impl_spvc_constant!(spvc_constant_get_scalar_u64 spvc_constant_set_scalar_u64 u64);
impl_spvc_constant!(spvc_constant_get_scalar_fp32 spvc_constant_set_scalar_fp32 f32);
impl_spvc_constant!(spvc_constant_get_scalar_fp64 spvc_constant_set_scalar_fp64 f64);
#[cfg(feature = "f16")]
impl Sealed for half::f16 {}
#[cfg(feature = "f16")]
#[cfg_attr(docsrs, doc(cfg(feature = "f16")))]
impl ConstantScalar for half::f16 {
unsafe fn get(constant: spvc_constant, column: u32, row: u32) -> Self {
let f32 = unsafe { sys::spvc_constant_get_scalar_fp16(constant, column, row) };
half::f16::from_f32(f32)
}
unsafe fn set(constant: spvc_constant, column: u32, row: u32, value: Self) {
unsafe { sys::spvc_constant_set_scalar_fp16(constant, column, row, value.to_bits()) }
}
}
#[derive(Debug, Clone)]
pub struct SpecializationConstant {
pub id: Handle<ConstantId>,
pub constant_id: u32,
}
#[derive(Debug, Clone)]
pub struct WorkgroupSizeSpecializationConstants {
pub x: Option<SpecializationConstant>,
pub y: Option<SpecializationConstant>,
pub z: Option<SpecializationConstant>,
pub builtin_workgroup_size_handle: Option<Handle<ConstantId>>,
}
pub struct SpecializationConstantIter<'a>(
PhantomCompiler<'a>,
slice::Iter<'a, spvc_specialization_constant>,
);
impl ExactSizeIterator for SpecializationConstantIter<'_> {
fn len(&self) -> usize {
self.1.len()
}
}
impl Iterator for SpecializationConstantIter<'_> {
type Item = SpecializationConstant;
fn next(&mut self) -> Option<Self::Item> {
self.1.next().map(|o| SpecializationConstant {
id: self.0.create_handle(o.id),
constant_id: o.constant_id,
})
}
}
impl<'ctx, T> Compiler<'ctx, T> {
unsafe fn bounds_check_constant(
handle: spvc_constant,
column: u32,
row: u32,
) -> error::Result<()> {
if column >= 4 || row >= 4 {
return Err(SpirvCrossError::IndexOutOfBounds { row, column });
}
let vecsize = sys::spvc_rs_constant_get_vecsize(handle);
let colsize = sys::spvc_rs_constant_get_matrix_colsize(handle);
if column >= colsize || row >= vecsize {
return Err(SpirvCrossError::IndexOutOfBounds { row, column });
}
Ok(())
}
pub fn set_specialization_constant_scalar<S: ConstantScalar>(
&mut self,
handle: Handle<ConstantId>,
column: u32,
row: u32,
value: S,
) -> error::Result<()> {
let constant = self.yield_id(handle)?;
unsafe {
let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
Self::bounds_check_constant(handle, column, row)?;
S::set(handle, column, row, value)
}
Ok(())
}
pub fn specialization_constant_scalar<S: ConstantScalar>(
&self,
handle: Handle<ConstantId>,
column: u32,
row: u32,
) -> error::Result<S> {
let constant = self.yield_id(handle)?;
unsafe {
let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
Self::bounds_check_constant(handle, column, row)?;
Ok(S::get(handle, column, row))
}
}
pub fn specialization_constants(&self) -> error::Result<SpecializationConstantIter<'ctx>> {
unsafe {
let mut constants = std::ptr::null();
let mut size = 0;
sys::spvc_compiler_get_specialization_constants(
self.ptr.as_ptr(),
&mut constants,
&mut size,
)
.ok(self)?;
let slice = slice::from_raw_parts(constants, size);
Ok(SpecializationConstantIter(self.phantom(), slice.iter()))
}
}
pub fn specialization_sub_constants(
&self,
constant: Handle<ConstantId>,
) -> error::Result<Vec<Handle<ConstantId>>> {
let id = self.yield_id(constant)?;
unsafe {
let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), id);
let mut constants = std::ptr::null();
let mut size = 0;
sys::spvc_constant_get_subconstants(constant, &mut constants, &mut size);
Ok(slice::from_raw_parts(constants, size)
.iter()
.map(|id| self.create_handle(*id))
.collect())
}
}
pub fn work_group_size_specialization_constants(&self) -> WorkgroupSizeSpecializationConstants {
unsafe {
let mut x = MaybeUninit::zeroed();
let mut y = MaybeUninit::zeroed();
let mut z = MaybeUninit::zeroed();
let constant = sys::spvc_compiler_get_work_group_size_specialization_constants(
self.ptr.as_ptr(),
x.as_mut_ptr(),
y.as_mut_ptr(),
z.as_mut_ptr(),
);
let constant = self.create_handle_if_not_zero(constant);
let x = x.assume_init();
let y = y.assume_init();
let z = z.assume_init();
let x = self
.create_handle_if_not_zero(x.id)
.map(|id| SpecializationConstant {
id,
constant_id: x.constant_id,
});
let y = self
.create_handle_if_not_zero(y.id)
.map(|id| SpecializationConstant {
id,
constant_id: y.constant_id,
});
let z = self
.create_handle_if_not_zero(z.id)
.map(|id| SpecializationConstant {
id,
constant_id: z.constant_id,
});
WorkgroupSizeSpecializationConstants {
x,
y,
z,
builtin_workgroup_size_handle: constant,
}
}
}
pub fn specialization_constant_type(
&self,
constant: Handle<ConstantId>,
) -> error::Result<Handle<TypeId>> {
let constant = self.yield_id(constant)?;
let type_id = unsafe {
let constant = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
self.create_handle(sys::spvc_constant_get_type(constant))
};
Ok(type_id)
}
}
pub trait ConstantValue: Sealed + Sized {
#[doc(hidden)]
const COLUMNS: usize;
#[doc(hidden)]
const VECSIZE: usize;
#[doc(hidden)]
type BaseArrayType: Default + Index<usize, Output = Self::BaseType> + IndexMut<usize>;
#[doc(hidden)]
type ArrayType: Default + Index<usize, Output = Self::BaseArrayType> + IndexMut<usize>;
#[doc(hidden)]
type BaseType: ConstantScalar;
#[doc(hidden)]
fn from_array(value: Self::ArrayType) -> Self;
#[doc(hidden)]
fn to_array(value: Self) -> Self::ArrayType;
}
impl<T: ConstantScalar> ConstantValue for T {
const COLUMNS: usize = 1;
const VECSIZE: usize = 1;
type BaseArrayType = [T; 1];
type ArrayType = [[T; 1]; 1];
type BaseType = T;
fn from_array(value: Self::ArrayType) -> Self {
value[0][0]
}
fn to_array(value: Self) -> Self::ArrayType {
[[value]]
}
}
#[cfg(feature = "gfx-math-types")]
#[cfg_attr(docsrs, doc(cfg(feature = "gfx-math-types")))]
mod gfx_maths_types {
use crate::reflect::ConstantValue;
use crate::sealed::Sealed;
use gfx_maths::{Mat4, Vec2, Vec3, Vec4};
impl Sealed for Vec2 {}
impl ConstantValue for Vec2 {
const COLUMNS: usize = 1;
const VECSIZE: usize = 2;
type BaseArrayType = [f32; 2];
type ArrayType = [[f32; 2]; 1];
type BaseType = f32;
fn from_array(value: Self::ArrayType) -> Self {
value[0].into()
}
fn to_array(value: Self) -> Self::ArrayType {
[[value.x, value.y]]
}
}
impl Sealed for Vec3 {}
impl ConstantValue for Vec3 {
const COLUMNS: usize = 1;
const VECSIZE: usize = 3;
type BaseArrayType = [f32; 3];
type ArrayType = [[f32; 3]; 1];
type BaseType = f32;
fn from_array(value: Self::ArrayType) -> Self {
value[0].into()
}
fn to_array(value: Self) -> Self::ArrayType {
[[value.x, value.y, value.z]]
}
}
impl Sealed for Vec4 {}
impl ConstantValue for Vec4 {
const COLUMNS: usize = 1;
const VECSIZE: usize = 4;
type BaseArrayType = [f32; 4];
type ArrayType = [[f32; 4]; 1];
type BaseType = f32;
fn from_array(value: Self::ArrayType) -> Self {
value[0].into()
}
fn to_array(value: Self) -> Self::ArrayType {
[[value.x, value.y, value.z, value.w]]
}
}
impl Sealed for Mat4 {}
impl ConstantValue for Mat4 {
const COLUMNS: usize = 4;
const VECSIZE: usize = 4;
type BaseArrayType = [f32; 4];
type ArrayType = [[f32; 4]; 4];
type BaseType = f32;
fn from_array(value: Self::ArrayType) -> Self {
value.into()
}
fn to_array(value: Self) -> Self::ArrayType {
let mut array = [[0f32; 4]; 4];
array[0][0] = value[(0, 0)];
array[0][1] = value[(0, 1)];
array[0][2] = value[(0, 2)];
array[0][3] = value[(0, 3)];
array[1][0] = value[(1, 0)];
array[1][1] = value[(1, 1)];
array[1][2] = value[(1, 2)];
array[1][3] = value[(1, 3)];
array[2][0] = value[(2, 0)];
array[2][1] = value[(2, 1)];
array[2][2] = value[(2, 2)];
array[2][3] = value[(2, 3)];
array[3][0] = value[(3, 0)];
array[3][1] = value[(3, 1)];
array[3][2] = value[(3, 2)];
array[3][3] = value[(3, 3)];
array
}
}
}
impl<'a, T> Compiler<'a, T> {
pub fn specialization_constant_value<S: ConstantValue>(
&self,
handle: Handle<ConstantId>,
) -> error::Result<S> {
let constant = self.yield_id(handle)?;
unsafe {
let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
let mut output = S::ArrayType::default();
Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
for column in 0..S::COLUMNS {
for row in 0..S::VECSIZE {
let value = S::BaseType::get(handle, column as u32, row as u32);
output[column][row] = value;
}
}
Ok(S::from_array(output))
}
}
pub fn set_specialization_constant_value<S: ConstantValue>(
&mut self,
handle: Handle<ConstantId>,
value: S,
) -> error::Result<()> {
let constant = self.yield_id(handle)?;
unsafe {
let handle = sys::spvc_compiler_get_constant_handle(self.ptr.as_ptr(), constant);
Self::bounds_check_constant(handle, S::COLUMNS as u32 - 1, S::VECSIZE as u32 - 1)?;
let value = S::to_array(value);
for column in 0..S::COLUMNS {
for row in 0..S::VECSIZE {
S::BaseType::set(handle, column as u32, row as u32, value[column][row]);
}
}
}
Ok(())
}
}