Skip to main content

vm_spirv/
lib.rs

1mod api;
2mod constants;
3mod context;
4mod expr;
5mod externs;
6mod memory;
7mod ops;
8mod stmt;
9mod symbols;
10mod types;
11
12pub use api::{BuiltinFn, ExternalFn, ExternalFnKind, Kernel, SpirvModule, spirv_builtins};
13
14use anyhow::{Context, Result, bail};
15use compiler::{Compiler, Symbol, substitute_stmt, substitute_type};
16use dynamic::Type;
17use parser::{Span, Stmt, StmtKind};
18use std::{collections::BTreeMap, path::Path};
19
20use crate::{
21    context::SpirvCompiler,
22    externs::register_externs,
23    symbols::{collect_type_defs, collect_user_fns, collect_workgroup_statics},
24};
25
26pub fn compile_source(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str) -> Result<Kernel> {
27    compile_source_with_externs(source, module_name, fn_name, spirv_builtins())
28}
29
30pub fn compile_source_with_workgroup_size(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, workgroup_size: [u32; 3]) -> Result<Kernel> {
31    compile_source_with_externs_and_workgroup_size(source, module_name, fn_name, spirv_builtins(), workgroup_size)
32}
33
34pub fn compile_file_with_workgroup_size(path: impl AsRef<Path>, module_name: &str, fn_name: &str, workgroup_size: [u32; 3]) -> Result<Kernel> {
35    compile_file_with_generic_args_and_workgroup_size(path, module_name, fn_name, &[], workgroup_size)
36}
37
38pub fn compile_file_with_generic_args_and_workgroup_size(path: impl AsRef<Path>, module_name: &str, fn_name: &str, generic_args: &[Type], workgroup_size: [u32; 3]) -> Result<Kernel> {
39    let mut compiler = Compiler::new();
40    let externs = register_externs(&mut compiler, spirv_builtins())?;
41    compiler.import_file(module_name, path)?;
42    compile_function_with_externs_generic_args_and_workgroup_size(&mut compiler, module_name, fn_name, externs, generic_args, workgroup_size)
43}
44
45pub fn compile_source_with_externs(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, externs: impl IntoIterator<Item = ExternalFn>) -> Result<Kernel> {
46    compile_source_with_externs_and_workgroup_size(source, module_name, fn_name, externs, [1, 1, 1])
47}
48
49pub fn compile_source_with_externs_and_workgroup_size(source: impl AsRef<[u8]>, module_name: &str, fn_name: &str, externs: impl IntoIterator<Item = ExternalFn>, workgroup_size: [u32; 3]) -> Result<Kernel> {
50    compile_source_with_externs_generic_args_and_workgroup_size(source, module_name, fn_name, externs, &[], workgroup_size)
51}
52
53pub fn compile_source_with_externs_generic_args_and_workgroup_size(
54    source: impl AsRef<[u8]>,
55    module_name: &str,
56    fn_name: &str,
57    externs: impl IntoIterator<Item = ExternalFn>,
58    generic_args: &[Type],
59    workgroup_size: [u32; 3],
60) -> Result<Kernel> {
61    let mut compiler = Compiler::new();
62    let externs = register_externs(&mut compiler, externs)?;
63    compiler.import_code(module_name, source.as_ref().to_vec())?;
64    compile_function_with_externs_generic_args_and_workgroup_size(&mut compiler, module_name, fn_name, externs, generic_args, workgroup_size)
65}
66
67pub fn compile_function(compiler: &mut Compiler, module_name: &str, fn_name: &str) -> Result<Kernel> {
68    let externs = register_externs(compiler, spirv_builtins())?;
69    compile_function_with_externs(compiler, module_name, fn_name, externs)
70}
71
72pub fn compile_function_with_externs(compiler: &mut Compiler, module_name: &str, fn_name: &str, externs: BTreeMap<u32, ExternalFnKind>) -> Result<Kernel> {
73    compile_function_with_externs_and_workgroup_size(compiler, module_name, fn_name, externs, [1, 1, 1])
74}
75
76pub fn compile_function_with_externs_and_workgroup_size(compiler: &mut Compiler, module_name: &str, fn_name: &str, externs: BTreeMap<u32, ExternalFnKind>, workgroup_size: [u32; 3]) -> Result<Kernel> {
77    compile_function_with_externs_generic_args_and_workgroup_size(compiler, module_name, fn_name, externs, &[], workgroup_size)
78}
79
80pub fn compile_function_with_externs_generic_args_and_workgroup_size(
81    compiler: &mut Compiler,
82    module_name: &str,
83    fn_name: &str,
84    externs: BTreeMap<u32, ExternalFnKind>,
85    generic_args: &[Type],
86    workgroup_size: [u32; 3],
87) -> Result<Kernel> {
88    let full_name = format!("{module_name}::{fn_name}");
89    let id = compiler.symbols.get_id(&full_name).or_else(|_| compiler.symbols.get_id(fn_name)).with_context(|| format!("function {full_name} not found"))?;
90    let symbol = compiler.symbols.get_symbol(id)?.1.clone();
91    let Symbol::Fn { ty, args, generic_params, cap, body, is_pub: _ } = symbol else {
92        bail!("{full_name} is not a zust function");
93    };
94    let Type::Fn { tys: decl_arg_tys, ret: _ } = ty else {
95        bail!("{full_name} has non-function type {ty:?}");
96    };
97    let (arg_tys, body) = specialize_entry_function(compiler, module_name, &args, &generic_params, generic_args, &decl_arg_tys, body.as_ref(), cap)?;
98    let ret_ty = compiler.infer_fn_with_params(id, &arg_tys, generic_args)?;
99    let ret_ty = compiler.symbols.get_type(&ret_ty)?;
100    let type_defs = collect_type_defs(compiler);
101    let user_fns = collect_user_fns(compiler)?;
102    let workgroup_static_tys = collect_workgroup_statics(compiler)?;
103    let builder = SpirvCompiler::new(externs, user_fns, type_defs, workgroup_static_tys, compiler.clone(), workgroup_size);
104    let spirv = builder.compile_kernel(&arg_tys, ret_ty.clone(), &body)?;
105    Ok(Kernel { spirv, entry: "main".into(), arg_tys: arg_tys.clone(), ret_ty })
106}
107
108fn specialize_entry_function(
109    compiler: &mut Compiler,
110    module_name: &str,
111    args: &[smol_str::SmolStr],
112    generic_params: &[Type],
113    generic_args: &[Type],
114    decl_arg_tys: &[Type],
115    body: &Stmt,
116    cap: compiler::Capture,
117) -> Result<(Vec<Type>, Stmt)> {
118    if generic_params.is_empty() {
119        let arg_tys = decl_arg_tys.iter().map(|ty| resolve_entry_type(compiler, module_name, ty)).collect::<Result<Vec<_>>>()?;
120        return Ok((arg_tys, body.clone()));
121    }
122    if generic_params.len() != generic_args.len() {
123        bail!("entry function expects {} generic args, got {}", generic_params.len(), generic_args.len());
124    }
125    let body = substitute_stmt(body, generic_params, generic_args);
126    let mut arg_tys = decl_arg_tys.iter().map(|ty| resolve_entry_type(compiler, module_name, &substitute_type(ty, generic_params, generic_args))).collect::<Result<Vec<_>>>()?;
127    let mut cap = cap;
128    let compiled = compiler.compile_fn(args, &mut arg_tys, body, &mut cap)?;
129    Ok((arg_tys, Stmt::new(StmtKind::Block(compiled), Span::default())))
130}
131
132fn resolve_entry_type(compiler: &Compiler, module_name: &str, ty: &Type) -> Result<Type> {
133    compiler.symbols.get_type(ty).or_else(|_| compiler.symbols.get_type(&qualify_entry_type(module_name, ty)))
134}
135
136fn qualify_entry_type(module_name: &str, ty: &Type) -> Type {
137    match ty {
138        Type::Ident { name, params } if !name.contains("::") && name.as_str() != "Vec" => {
139            Type::Ident { name: format!("{module_name}::{name}").into(), params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() }
140        }
141        Type::Ident { name, params } => Type::Ident { name: name.clone(), params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() },
142        Type::Array(elem, len) => Type::Array(std::rc::Rc::new(qualify_entry_type(module_name, elem)), *len),
143        Type::ArrayParam(elem, len) => Type::ArrayParam(std::rc::Rc::new(qualify_entry_type(module_name, elem)), std::rc::Rc::new(qualify_entry_type(module_name, len))),
144        Type::Vec(elem, len) => Type::Vec(std::rc::Rc::new(qualify_entry_type(module_name, elem)), *len),
145        Type::Tuple(items) => Type::Tuple(items.iter().map(|item| qualify_entry_type(module_name, item)).collect()),
146        Type::Struct { params, fields } => {
147            Type::Struct { params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect(), fields: fields.iter().map(|(name, ty)| (name.clone(), qualify_entry_type(module_name, ty))).collect() }
148        }
149        Type::Fn { tys, ret } => Type::Fn { tys: tys.iter().map(|ty| qualify_entry_type(module_name, ty)).collect(), ret: std::rc::Rc::new(qualify_entry_type(module_name, ret)) },
150        Type::Symbol { id, params } => Type::Symbol { id: *id, params: params.iter().map(|param| qualify_entry_type(module_name, param)).collect() },
151        other => other.clone(),
152    }
153}
154
155#[cfg(test)]
156mod tests;