Skip to main content

vm_spirv/
api.rs

1use dynamic::Type;
2use rspirv::{binary::Disassemble, dr::Module};
3use smol_str::SmolStr;
4use std::rc::Rc;
5
6#[derive(Debug, Clone)]
7pub struct SpirvModule {
8    pub(crate) words: Vec<u32>,
9    pub(crate) module: Module,
10}
11
12impl SpirvModule {
13    pub fn words(&self) -> &[u32] {
14        &self.words
15    }
16
17    pub fn into_words(self) -> Vec<u32> {
18        self.words
19    }
20
21    pub fn disassemble(&self) -> String {
22        self.module.disassemble()
23    }
24}
25
26#[derive(Debug, Clone)]
27pub struct Kernel {
28    pub spirv: SpirvModule,
29    pub entry: SmolStr,
30    pub arg_tys: Vec<Type>,
31    pub ret_ty: Type,
32}
33
34#[derive(Debug, Clone)]
35pub struct ExternalFn {
36    pub full_name: SmolStr,
37    pub arg_tys: Vec<Type>,
38    pub ret_ty: Type,
39    pub kind: ExternalFnKind,
40}
41
42#[derive(Debug, Clone)]
43pub enum ExternalFnKind {
44    GlslUnary { float_op: spirv::GlslStd450Op, signed_int_op: Option<spirv::GlslStd450Op> },
45    GlslBinary { float_op: spirv::GlslStd450Op, signed_int_op: spirv::GlslStd450Op, unsigned_int_op: spirv::GlslStd450Op },
46    GlslFloatBinary { op: spirv::GlslStd450Op },
47    GlslFloatTernary { op: spirv::GlslStd450Op },
48    Builtin(BuiltinFn),
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
52pub enum BuiltinFn {
53    GroupId,
54    LocalId,
55    Barrier,
56    AtomicAdd,
57}
58
59impl ExternalFn {
60    pub fn glsl_unary(full_name: impl Into<SmolStr>, arg_ty: Type, ret_ty: Type, float_op: spirv::GlslStd450Op, signed_int_op: Option<spirv::GlslStd450Op>) -> Self {
61        Self { full_name: full_name.into(), arg_tys: vec![arg_ty], ret_ty, kind: ExternalFnKind::GlslUnary { float_op, signed_int_op } }
62    }
63
64    pub fn glsl_binary(full_name: impl Into<SmolStr>, arg_ty: Type, ret_ty: Type, float_op: spirv::GlslStd450Op, signed_int_op: spirv::GlslStd450Op, unsigned_int_op: spirv::GlslStd450Op) -> Self {
65        Self { full_name: full_name.into(), arg_tys: vec![arg_ty.clone(), arg_ty], ret_ty, kind: ExternalFnKind::GlslBinary { float_op, signed_int_op, unsigned_int_op } }
66    }
67
68    pub fn glsl_float_binary(full_name: impl Into<SmolStr>, arg_ty: Type, ret_ty: Type, op: spirv::GlslStd450Op) -> Self {
69        Self { full_name: full_name.into(), arg_tys: vec![arg_ty.clone(), arg_ty], ret_ty, kind: ExternalFnKind::GlslFloatBinary { op } }
70    }
71
72    pub fn glsl_float_ternary(full_name: impl Into<SmolStr>, arg_ty: Type, ret_ty: Type, op: spirv::GlslStd450Op) -> Self {
73        Self { full_name: full_name.into(), arg_tys: vec![arg_ty.clone(), arg_ty.clone(), arg_ty], ret_ty, kind: ExternalFnKind::GlslFloatTernary { op } }
74    }
75
76    pub fn builtin(full_name: impl Into<SmolStr>, arg_tys: Vec<Type>, ret_ty: Type, builtin: BuiltinFn) -> Self {
77        Self { full_name: full_name.into(), arg_tys, ret_ty, kind: ExternalFnKind::Builtin(builtin) }
78    }
79}
80
81pub fn spirv_builtins() -> Vec<ExternalFn> {
82    vec![
83        ExternalFn::builtin("spirv::group_id", vec![], Type::Vec(Rc::new(Type::U32), 3), BuiltinFn::GroupId),
84        ExternalFn::builtin("spirv::local_id", vec![], Type::Vec(Rc::new(Type::U32), 3), BuiltinFn::LocalId),
85        ExternalFn::builtin("spirv::barrier", vec![], Type::Void, BuiltinFn::Barrier),
86        ExternalFn::builtin("spirv::atomic_add", vec![Type::U32, Type::U32], Type::U32, BuiltinFn::AtomicAdd),
87        ExternalFn::glsl_unary("abs", Type::F32, Type::F32, spirv::GlslStd450Op::FAbs, Some(spirv::GlslStd450Op::SAbs)),
88        ExternalFn::glsl_unary("sign", Type::F32, Type::F32, spirv::GlslStd450Op::FSign, Some(spirv::GlslStd450Op::SSign)),
89        ExternalFn::glsl_unary("floor", Type::F32, Type::F32, spirv::GlslStd450Op::Floor, None),
90        ExternalFn::glsl_unary("ceil", Type::F32, Type::F32, spirv::GlslStd450Op::Ceil, None),
91        ExternalFn::glsl_unary("round", Type::F32, Type::F32, spirv::GlslStd450Op::Round, None),
92        ExternalFn::glsl_unary("round_even", Type::F32, Type::F32, spirv::GlslStd450Op::RoundEven, None),
93        ExternalFn::glsl_unary("trunc", Type::F32, Type::F32, spirv::GlslStd450Op::Trunc, None),
94        ExternalFn::glsl_unary("fract", Type::F32, Type::F32, spirv::GlslStd450Op::Fract, None),
95        ExternalFn::glsl_unary("radians", Type::F32, Type::F32, spirv::GlslStd450Op::Radians, None),
96        ExternalFn::glsl_unary("degrees", Type::F32, Type::F32, spirv::GlslStd450Op::Degrees, None),
97        ExternalFn::glsl_unary("sin", Type::F32, Type::F32, spirv::GlslStd450Op::Sin, None),
98        ExternalFn::glsl_unary("cos", Type::F32, Type::F32, spirv::GlslStd450Op::Cos, None),
99        ExternalFn::glsl_unary("tan", Type::F32, Type::F32, spirv::GlslStd450Op::Tan, None),
100        ExternalFn::glsl_unary("asin", Type::F32, Type::F32, spirv::GlslStd450Op::Asin, None),
101        ExternalFn::glsl_unary("acos", Type::F32, Type::F32, spirv::GlslStd450Op::Acos, None),
102        ExternalFn::glsl_unary("atan", Type::F32, Type::F32, spirv::GlslStd450Op::Atan, None),
103        ExternalFn::glsl_unary("sinh", Type::F32, Type::F32, spirv::GlslStd450Op::Sinh, None),
104        ExternalFn::glsl_unary("cosh", Type::F32, Type::F32, spirv::GlslStd450Op::Cosh, None),
105        ExternalFn::glsl_unary("tanh", Type::F32, Type::F32, spirv::GlslStd450Op::Tanh, None),
106        ExternalFn::glsl_unary("asinh", Type::F32, Type::F32, spirv::GlslStd450Op::Asinh, None),
107        ExternalFn::glsl_unary("acosh", Type::F32, Type::F32, spirv::GlslStd450Op::Acosh, None),
108        ExternalFn::glsl_unary("atanh", Type::F32, Type::F32, spirv::GlslStd450Op::Atanh, None),
109        ExternalFn::glsl_unary("exp", Type::F32, Type::F32, spirv::GlslStd450Op::Exp, None),
110        ExternalFn::glsl_unary("log", Type::F32, Type::F32, spirv::GlslStd450Op::Log, None),
111        ExternalFn::glsl_unary("exp2", Type::F32, Type::F32, spirv::GlslStd450Op::Exp2, None),
112        ExternalFn::glsl_unary("log2", Type::F32, Type::F32, spirv::GlslStd450Op::Log2, None),
113        ExternalFn::glsl_unary("sqrt", Type::F32, Type::F32, spirv::GlslStd450Op::Sqrt, None),
114        ExternalFn::glsl_unary("inverse_sqrt", Type::F32, Type::F32, spirv::GlslStd450Op::InverseSqrt, None),
115        ExternalFn::glsl_float_binary("atan2", Type::F32, Type::F32, spirv::GlslStd450Op::Atan2),
116        ExternalFn::glsl_float_binary("pow", Type::F32, Type::F32, spirv::GlslStd450Op::Pow),
117        ExternalFn::glsl_float_binary("step", Type::F32, Type::F32, spirv::GlslStd450Op::Step),
118        ExternalFn::glsl_binary("min", Type::F32, Type::F32, spirv::GlslStd450Op::FMin, spirv::GlslStd450Op::SMin, spirv::GlslStd450Op::UMin),
119        ExternalFn::glsl_binary("max", Type::F32, Type::F32, spirv::GlslStd450Op::FMax, spirv::GlslStd450Op::SMax, spirv::GlslStd450Op::UMax),
120        ExternalFn::glsl_float_ternary("clamp", Type::F32, Type::F32, spirv::GlslStd450Op::FClamp),
121        ExternalFn::glsl_float_ternary("mix", Type::F32, Type::F32, spirv::GlslStd450Op::FMix),
122        ExternalFn::glsl_float_ternary("smoothstep", Type::F32, Type::F32, spirv::GlslStd450Op::SmoothStep),
123        ExternalFn::glsl_float_ternary("fma", Type::F32, Type::F32, spirv::GlslStd450Op::Fma),
124    ]
125}