use std::fmt;
use std::hash::Hash;
use std::rc::Rc;
pub mod reg;
pub mod walk;
pub use self::{reg::TypeRegistry, walk::Walk};
pub use crate::spirv::{Dim, ImageFormat, StorageClass};
pub trait SpirvType {
fn min_nbyte(&self) -> Option<usize>;
fn nbyte(&self) -> Option<usize> {
self.min_nbyte()
}
fn is_sized(&self) -> bool {
self.nbyte().is_some()
}
fn member_offset(&self, _member_index: usize) -> Option<usize> {
None
}
fn access_ty(&self) -> Option<AccessType> {
None
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub enum ScalarType {
Void,
Boolean,
Integer {
bits: u32,
is_signed: bool,
},
Float {
bits: u32,
},
}
impl ScalarType {
pub fn int(bits: u32) -> Self {
Self::Integer {
bits,
is_signed: true,
}
}
pub fn uint(bits: u32) -> Self {
Self::Integer {
bits,
is_signed: false,
}
}
pub fn float(bits: u32) -> Self {
Self::Float { bits }
}
pub fn i32() -> Self {
Self::int(32)
}
pub fn u32() -> Self {
Self::uint(32)
}
pub fn f32() -> Self {
Self::float(32)
}
}
impl SpirvType for ScalarType {
fn min_nbyte(&self) -> Option<usize> {
match self {
Self::Void => None,
Self::Boolean => None,
Self::Integer { bits, .. } => Some((*bits / 8) as usize),
Self::Float { bits } => Some((*bits / 8) as usize),
}
}
}
impl fmt::Display for ScalarType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Void => f.write_str("void"),
Self::Boolean => f.write_str("bool"),
Self::Integer { bits, is_signed } => match is_signed {
true => write!(f, "i{}", bits),
false => write!(f, "u{}", bits),
},
Self::Float { bits } => write!(f, "f{}", bits),
}
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct VectorType {
pub scalar_ty: ScalarType,
pub nscalar: u32,
}
impl SpirvType for VectorType {
fn min_nbyte(&self) -> Option<usize> {
Some(self.scalar_ty.min_nbyte()? * self.nscalar as usize)
}
}
impl fmt::Display for VectorType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "vec{}<{}>", self.nscalar, self.scalar_ty)
}
}
#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)]
pub enum MatrixAxisOrder {
ColumnMajor,
RowMajor,
}
impl Default for MatrixAxisOrder {
fn default() -> MatrixAxisOrder {
MatrixAxisOrder::ColumnMajor
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct MatrixType {
pub vector_ty: VectorType,
pub nvector: u32,
pub axis_order: Option<MatrixAxisOrder>,
pub stride: Option<usize>,
}
impl SpirvType for MatrixType {
fn min_nbyte(&self) -> Option<usize> {
Some(self.stride? * self.nvector as usize)
}
}
impl fmt::Display for MatrixType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let major = match self.axis_order {
Some(MatrixAxisOrder::ColumnMajor) => "ColumnMajor",
Some(MatrixAxisOrder::RowMajor) => "RowMajor",
None => "AxisOrder?",
};
let nrow = self.vector_ty.nscalar;
let ncol = self.nvector;
let scalar_ty = &self.vector_ty.scalar_ty;
let stride = match self.stride {
Some(x) => x.to_string(),
None => "?".to_owned(),
};
write!(f, "mat{nrow}x{ncol}<{scalar_ty},{major},{stride}>")
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct ImageType {
pub scalar_ty: ScalarType,
pub dim: Dim,
pub is_depth: Option<bool>,
pub is_array: bool,
pub is_multisampled: bool,
pub is_sampled: Option<bool>,
pub fmt: ImageFormat,
}
impl SpirvType for ImageType {
fn min_nbyte(&self) -> Option<usize> {
None
}
}
impl fmt::Display for ImageType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let scalar_ty = &self.scalar_ty;
let is_sampled = match self.is_sampled {
Some(true) => "Sampled",
Some(false) => "Storage",
None => "Sampled?",
};
let depth = match self.is_depth {
Some(true) => "Depth",
Some(false) => "Color",
None => "Depth?",
};
let dim = format!("{:?}", self.dim)[3..].to_owned();
let is_array = match self.is_array {
true => "Array",
false => "",
};
let is_multisampled = match self.is_multisampled {
true => "MS",
false => "",
};
write!(
f,
"Image{dim}{is_array}{is_multisampled}<{scalar_ty},{is_sampled},{depth},{:?}>",
self.fmt
)
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct SamplerType {}
impl SpirvType for SamplerType {
fn min_nbyte(&self) -> Option<usize> {
None
}
}
impl fmt::Display for SamplerType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("Sampler")
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct CombinedImageSamplerType {
pub sampled_image_ty: SampledImageType,
}
impl SpirvType for CombinedImageSamplerType {
fn min_nbyte(&self) -> Option<usize> {
None
}
}
impl fmt::Display for CombinedImageSamplerType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "CombinedImageSampler<{}>", self.sampled_image_ty)
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct SampledImageType {
pub scalar_ty: ScalarType,
pub dim: Dim,
pub is_depth: Option<bool>,
pub is_array: bool,
pub is_multisampled: bool,
}
impl SpirvType for SampledImageType {
fn min_nbyte(&self) -> Option<usize> {
None
}
}
impl fmt::Display for SampledImageType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let scalar_ty = &self.scalar_ty;
let dim = format!("{:?}", self.dim)[3..].to_owned();
let depth = match self.is_depth {
Some(true) => "Depth",
Some(false) => "Color",
None => "Depth?",
};
let is_array = match self.is_array {
true => "Array",
false => "",
};
let is_multisampled = match self.is_multisampled {
true => "MS",
false => "",
};
write!(
f,
"SampledImage{dim}{is_array}{is_multisampled}<{scalar_ty},{depth}>"
)
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct StorageImageType {
pub dim: Dim,
pub is_array: bool,
pub is_multisampled: bool,
pub fmt: ImageFormat,
}
impl SpirvType for StorageImageType {
fn min_nbyte(&self) -> Option<usize> {
None
}
}
impl fmt::Display for StorageImageType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let dim = format!("{:?}", self.dim)[3..].to_owned();
let is_array = match self.is_array {
true => "Array",
false => "",
};
let is_multisampled = match self.is_multisampled {
true => "MS",
false => "",
};
write!(
f,
"StorageImage{dim}{is_array}{is_multisampled}<{:?}>",
self.fmt
)
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct SubpassDataType {
pub scalar_ty: ScalarType,
pub is_multisampled: bool,
}
impl SpirvType for SubpassDataType {
fn min_nbyte(&self) -> Option<usize> {
None
}
}
impl fmt::Display for SubpassDataType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let scalar_ty = &self.scalar_ty;
let is_multisampled = match self.is_multisampled {
true => "MS",
false => "",
};
write!(f, "SubpassData{is_multisampled}<{scalar_ty}>")
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct ArrayType {
pub element_ty: Box<Type>,
pub nelement: Option<u32>,
pub stride: Option<usize>,
}
impl SpirvType for ArrayType {
fn min_nbyte(&self) -> Option<usize> {
Some(self.stride? * self.nelement.unwrap_or(0).max(1) as usize)
}
fn nbyte(&self) -> Option<usize> {
Some(self.stride? * self.nelement.unwrap_or(0) as usize)
}
fn member_offset(&self, member_index: usize) -> Option<usize> {
Some(self.stride? * member_index)
}
fn access_ty(&self) -> Option<AccessType> {
self.element_ty.access_ty()
}
}
impl fmt::Display for ArrayType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(nrepeat) = self.nelement {
write!(f, "[{}; {}]", self.element_ty, nrepeat)
} else {
write!(f, "[{}]", self.element_ty)
}
}
}
#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AccessType {
ReadOnly = 1,
WriteOnly = 2,
ReadWrite = 3,
}
impl std::ops::BitOr<AccessType> for AccessType {
type Output = AccessType;
fn bitor(self, rhs: AccessType) -> AccessType {
return match (self, rhs) {
(Self::ReadOnly, Self::ReadOnly) => Self::ReadOnly,
(Self::WriteOnly, Self::WriteOnly) => Self::WriteOnly,
_ => Self::ReadWrite,
};
}
}
impl std::ops::BitAnd<AccessType> for AccessType {
type Output = Option<AccessType>;
fn bitand(self, rhs: AccessType) -> Option<AccessType> {
return match (self, rhs) {
(Self::ReadOnly, Self::ReadWrite)
| (Self::ReadWrite, Self::ReadOnly)
| (Self::ReadOnly, Self::ReadOnly) => Some(Self::ReadOnly),
(Self::WriteOnly, Self::ReadWrite)
| (Self::ReadWrite, Self::WriteOnly)
| (Self::WriteOnly, Self::WriteOnly) => Some(Self::WriteOnly),
(Self::ReadWrite, Self::ReadWrite) => Some(Self::ReadWrite),
(_, _) => None,
};
}
}
#[derive(PartialEq, Eq, Clone, Hash, Debug)]
pub struct StructMember {
pub name: Option<String>,
pub offset: Option<usize>,
pub ty: Type,
pub access_ty: AccessType,
}
#[derive(PartialEq, Eq, Default, Clone, Hash, Debug)]
pub struct StructType {
pub name: Option<String>,
pub members: Vec<StructMember>, }
impl StructType {
pub fn name(&self) -> Option<&str> {
self.name.as_ref().map(AsRef::as_ref)
}
}
impl SpirvType for StructType {
fn min_nbyte(&self) -> Option<usize> {
let last_member = &self.members.last()?;
Some(last_member.offset? + last_member.ty.min_nbyte()?)
}
fn nbyte(&self) -> Option<usize> {
let last_member = &self.members.last()?;
Some(last_member.offset? + last_member.ty.nbyte()?)
}
fn member_offset(&self, member_index: usize) -> Option<usize> {
self.members.get(member_index).and_then(|x| x.offset)
}
fn access_ty(&self) -> Option<AccessType> {
self.members.iter().fold(None, |seed, x| match seed {
None => Some(x.access_ty),
Some(AccessType::ReadOnly) => match x.access_ty {
AccessType::ReadOnly => Some(AccessType::ReadOnly),
_ => Some(AccessType::ReadWrite),
},
Some(AccessType::WriteOnly) => match x.access_ty {
AccessType::WriteOnly => Some(AccessType::WriteOnly),
_ => Some(AccessType::ReadWrite),
},
Some(AccessType::ReadWrite) => Some(AccessType::ReadWrite),
})
}
}
impl fmt::Display for StructType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(name) = &self.name {
write!(f, "{} {{ ", name)?;
} else {
f.write_str("{ ")?;
}
for (i, member) in self.members.iter().enumerate() {
if i != 0 {
f.write_str(", ")?;
}
if let Some(name) = &member.name {
write!(f, "{}: {}", name, member.ty)?;
} else {
write!(f, "{}: {}", i, member.ty)?;
}
}
f.write_str(" }")
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct AccelStructType {}
impl SpirvType for AccelStructType {
fn min_nbyte(&self) -> Option<usize> {
None
}
}
impl fmt::Display for AccelStructType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("AccelStruct")
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct DeviceAddressType {}
impl SpirvType for DeviceAddressType {
fn min_nbyte(&self) -> Option<usize> {
Some(std::mem::size_of::<u64>())
}
}
impl fmt::Display for DeviceAddressType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("Address")
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct PointerType {
pub pointee_ty: Box<Type>,
pub store_cls: StorageClass,
}
impl SpirvType for PointerType {
fn min_nbyte(&self) -> Option<usize> {
Some(std::mem::size_of::<u64>())
}
fn access_ty(&self) -> Option<AccessType> {
self.pointee_ty.access_ty()
}
}
impl fmt::Display for PointerType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("Pointer { ")?;
write!(f, "{}", *self.pointee_ty)?;
f.write_str(" }")
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
pub struct RayQueryType {}
impl SpirvType for RayQueryType {
fn min_nbyte(&self) -> Option<usize> {
None
}
}
impl fmt::Display for RayQueryType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("RayQuery")
}
}
macro_rules! declr_ty_accessor {
([$e:ident] $($name:ident -> $ty:ident,)+) => {
$(
pub fn $name(&self) -> bool {
match self {
$e::$ty(..) => true,
_ => false
}
}
)+
}
}
macro_rules! declr_ty_downcast {
([$e:ident] $($name:ident -> $ty:ident($inner_ty:ident),)+) => {
$(
pub fn $name(&self) -> Option<&$inner_ty> {
match self {
$e::$ty(x) => Some(x),
_ => None
}
}
)+
}
}
#[derive(PartialEq, Eq, Hash, Clone, Debug)]
#[non_exhaustive]
pub enum Type {
Scalar(ScalarType),
Vector(VectorType),
Matrix(MatrixType),
Image(ImageType),
CombinedImageSampler(CombinedImageSamplerType),
SampledImage(SampledImageType),
StorageImage(StorageImageType),
Sampler(SamplerType),
SubpassData(SubpassDataType),
Array(ArrayType),
Struct(StructType),
AccelStruct(AccelStructType),
DeviceAddress(DeviceAddressType),
DevicePointer(PointerType),
RayQuery(RayQueryType),
}
impl Type {
pub fn min_nbyte(&self) -> Option<usize> {
match self {
Type::Scalar(x) => x.min_nbyte(),
Type::Vector(x) => x.min_nbyte(),
Type::Matrix(x) => x.min_nbyte(),
Type::Image(x) => x.min_nbyte(),
Type::CombinedImageSampler(x) => x.min_nbyte(),
Type::SampledImage(x) => x.min_nbyte(),
Type::StorageImage(x) => x.min_nbyte(),
Type::Sampler(x) => x.min_nbyte(),
Type::SubpassData(x) => x.min_nbyte(),
Type::Array(x) => x.min_nbyte(),
Type::Struct(x) => x.min_nbyte(),
Type::AccelStruct(x) => x.min_nbyte(),
Type::DeviceAddress(x) => x.min_nbyte(),
Type::DevicePointer(x) => x.min_nbyte(),
Type::RayQuery(x) => x.min_nbyte(),
}
}
pub fn nbyte(&self) -> Option<usize> {
match self {
Type::Scalar(x) => x.nbyte(),
Type::Vector(x) => x.nbyte(),
Type::Matrix(x) => x.nbyte(),
Type::Image(x) => x.nbyte(),
Type::CombinedImageSampler(x) => x.nbyte(),
Type::SampledImage(x) => x.nbyte(),
Type::StorageImage(x) => x.nbyte(),
Type::Sampler(x) => x.nbyte(),
Type::SubpassData(x) => x.nbyte(),
Type::Array(x) => x.nbyte(),
Type::Struct(x) => x.nbyte(),
Type::AccelStruct(x) => x.nbyte(),
Type::DeviceAddress(x) => x.nbyte(),
Type::DevicePointer(x) => x.nbyte(),
Type::RayQuery(x) => x.nbyte(),
}
}
pub fn member_offset(&self, member_index: usize) -> Option<usize> {
match self {
Type::Scalar(x) => x.member_offset(member_index),
Type::Vector(x) => x.member_offset(member_index),
Type::Matrix(x) => x.member_offset(member_index),
Type::Image(x) => x.member_offset(member_index),
Type::CombinedImageSampler(x) => x.member_offset(member_index),
Type::SampledImage(x) => x.member_offset(member_index),
Type::StorageImage(x) => x.member_offset(member_index),
Type::Sampler(x) => x.member_offset(member_index),
Type::SubpassData(x) => x.member_offset(member_index),
Type::Array(x) => x.member_offset(member_index),
Type::Struct(x) => x.member_offset(member_index),
Type::AccelStruct(x) => x.member_offset(member_index),
Type::DeviceAddress(x) => x.member_offset(member_index),
Type::DevicePointer(x) => x.member_offset(member_index),
Type::RayQuery(x) => x.member_offset(member_index),
}
}
pub fn access_ty(&self) -> Option<AccessType> {
match self {
Type::Scalar(x) => x.access_ty(),
Type::Vector(x) => x.access_ty(),
Type::Matrix(x) => x.access_ty(),
Type::Image(x) => x.access_ty(),
Type::CombinedImageSampler(x) => x.access_ty(),
Type::SampledImage(x) => x.access_ty(),
Type::StorageImage(x) => x.access_ty(),
Type::Sampler(x) => x.access_ty(),
Type::SubpassData(x) => x.access_ty(),
Type::Array(x) => x.access_ty(),
Type::Struct(x) => x.access_ty(),
Type::AccelStruct(x) => x.access_ty(),
Type::DeviceAddress(x) => x.access_ty(),
Type::DevicePointer(x) => x.access_ty(),
Type::RayQuery(x) => x.access_ty(),
}
}
pub fn walk<'a>(&'a self) -> Walk<'a> {
Walk::new(self)
}
declr_ty_accessor! {
[Type]
is_scalar -> Scalar,
is_vector -> Vector,
is_matrix -> Matrix,
is_image -> Image,
is_sampler -> Sampler,
is_combined_image_sampler -> CombinedImageSampler,
is_sampled_image -> SampledImage,
is_storage_image -> StorageImage,
is_subpass_data -> SubpassData,
is_array -> Array,
is_struct -> Struct,
is_accel_struct -> AccelStruct,
is_device_address -> DeviceAddress,
is_device_pointer -> DevicePointer,
}
declr_ty_downcast! {
[Type]
as_scalar -> Scalar(ScalarType),
as_vector -> Vector(VectorType),
as_matrix -> Matrix(MatrixType),
as_image -> Image(ImageType),
as_sampler -> Sampler(SamplerType),
as_combined_image_sampler -> CombinedImageSampler(CombinedImageSamplerType),
as_sampled_image -> SampledImage(SampledImageType),
as_storage_image -> StorageImage(StorageImageType),
as_subpass_data -> SubpassData(SubpassDataType),
as_array -> Array(ArrayType),
as_struct -> Struct(StructType),
as_accel_struct -> AccelStruct(AccelStructType),
as_device_address -> DeviceAddress(DeviceAddressType),
as_device_pointer -> DevicePointer(PointerType),
}
fn mutate_impl<F: Fn(Type) -> Type>(self, f: Rc<F>) -> Type {
use Type::*;
let out = match self {
Array(src) => {
let dst = ArrayType {
element_ty: Box::new(src.element_ty.mutate_impl(f.clone())),
nelement: src.nelement,
stride: src.stride,
};
Type::Array(dst)
}
Struct(src) => {
let dst = StructType {
name: src.name,
members: src
.members
.into_iter()
.map(|x| StructMember {
name: x.name,
offset: x.offset,
ty: x.ty.mutate_impl(f.clone()),
access_ty: x.access_ty,
})
.collect(),
};
Type::Struct(dst)
}
_ => self,
};
(*f)(out)
}
pub fn mutate<F: Fn(Type) -> Type>(self, f: F) -> Type {
self.mutate_impl(Rc::new(f))
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Type::Scalar(x) => x.fmt(f),
Type::Vector(x) => x.fmt(f),
Type::Matrix(x) => x.fmt(f),
Type::Image(x) => x.fmt(f),
Type::Sampler(x) => x.fmt(f),
Type::CombinedImageSampler(x) => x.fmt(f),
Type::SampledImage(x) => x.fmt(f),
Type::StorageImage(x) => x.fmt(f),
Type::SubpassData(x) => x.fmt(f),
Type::Array(x) => x.fmt(f),
Type::Struct(x) => x.fmt(f),
Type::AccelStruct(x) => x.fmt(f),
Type::DeviceAddress(x) => x.fmt(f),
Type::DevicePointer(x) => x.fmt(f),
Type::RayQuery(x) => x.fmt(f),
}
}
}
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
pub enum DescriptorType {
Sampler(),
CombinedImageSampler(),
SampledImage(),
StorageImage(AccessType),
UniformTexelBuffer(),
StorageTexelBuffer(AccessType),
UniformBuffer(),
StorageBuffer(AccessType),
InputAttachment(u32),
AccelStruct(),
}