1use fnv::{FnvHashMap as HashMap, FnvHashSet as HashSet};
2
3use crate::{
4 constant::Constant,
5 error::{anyhow, Result},
6};
7
8type VariableId = u32;
9type FunctionId = u32;
10
11#[derive(PartialEq, Eq, Hash, Clone, Debug)]
13pub struct ExecutionMode {
14 pub exec_mode: spirv::ExecutionMode,
15 pub operands: Vec<Constant>,
16}
17
18#[derive(Default, Debug, Clone)]
20pub struct Function {
21 pub name: Option<String>,
22 pub accessed_vars: HashSet<VariableId>,
23 pub callees: HashSet<FunctionId>,
24 pub exec_modes: Vec<ExecutionMode>,
25}
26
27#[derive(Default)]
28pub struct FunctionRegistry {
29 func_map: HashMap<FunctionId, Function>,
30}
31impl FunctionRegistry {
32 pub fn set(&mut self, id: FunctionId, func: Function) -> Result<()> {
33 use std::collections::hash_map::Entry;
34 match self.func_map.entry(id) {
35 Entry::Vacant(entry) => {
36 entry.insert(func);
37 Ok(())
38 }
39 _ => Err(anyhow!("function id {} already existed", id)),
40 }
41 }
42
43 pub fn get(&self, id: FunctionId) -> Result<&Function> {
44 self.func_map
45 .get(&id)
46 .ok_or(anyhow!("function id {} is not found", id))
47 }
48 pub fn get_mut(&mut self, id: FunctionId) -> Result<&mut Function> {
49 self.func_map
50 .get_mut(&id)
51 .ok_or(anyhow!("function id {} is not found", id))
52 }
53 pub fn get_by_name(&self, name: &str) -> Result<&Function> {
54 self.func_map
55 .values()
56 .find(|x| {
57 if let Some(nm) = x.name.as_ref() {
58 nm == name
59 } else {
60 false
61 }
62 })
63 .ok_or(anyhow!("function {} is not found", name))
64 }
65
66 pub fn collect_ordered(&self) -> Vec<Function> {
67 let mut out: Vec<_> = self.func_map.iter().collect();
68 out.sort_by_key(|x| x.0);
69 out.into_iter().map(|x| x.1.clone()).collect()
70 }
71
72 fn collect_fn_vars_impl(&self, func: FunctionId, vars: &mut Vec<VariableId>) {
73 if let Ok(func) = self.get(func) {
74 vars.extend(func.accessed_vars.iter());
75 for call in func.callees.iter() {
76 self.collect_fn_vars_impl(*call, vars);
77 }
78 }
79 }
80 pub fn collect_fn_vars(&self, func: FunctionId) -> Vec<VariableId> {
81 let mut accessed_vars = Vec::new();
82 self.collect_fn_vars_impl(func, &mut accessed_vars);
83 accessed_vars
84 }
85}