spirq_core/
func.rs

1use fnv::{FnvHashMap as HashMap, FnvHashSet as HashSet};
2
3use crate::{
4    constant::Constant,
5    error::{anyhow, Result},
6};
7
8type VariableId = u32;
9type FunctionId = u32;
10
11/// SPIR-V execution mode.
12#[derive(PartialEq, Eq, Hash, Clone, Debug)]
13pub struct ExecutionMode {
14    pub exec_mode: spirv::ExecutionMode,
15    pub operands: Vec<Constant>,
16}
17
18/// Function reflection intermediate.
19#[derive(Default, Debug, Clone)]
20pub struct Function {
21    pub name: Option<String>,
22    pub accessed_vars: HashSet<VariableId>,
23    pub callees: HashSet<FunctionId>,
24    pub exec_modes: Vec<ExecutionMode>,
25}
26
27#[derive(Default)]
28pub struct FunctionRegistry {
29    func_map: HashMap<FunctionId, Function>,
30}
31impl FunctionRegistry {
32    pub fn set(&mut self, id: FunctionId, func: Function) -> Result<()> {
33        use std::collections::hash_map::Entry;
34        match self.func_map.entry(id) {
35            Entry::Vacant(entry) => {
36                entry.insert(func);
37                Ok(())
38            }
39            _ => Err(anyhow!("function id {} already existed", id)),
40        }
41    }
42
43    pub fn get(&self, id: FunctionId) -> Result<&Function> {
44        self.func_map
45            .get(&id)
46            .ok_or(anyhow!("function id {} is not found", id))
47    }
48    pub fn get_mut(&mut self, id: FunctionId) -> Result<&mut Function> {
49        self.func_map
50            .get_mut(&id)
51            .ok_or(anyhow!("function id {} is not found", id))
52    }
53    pub fn get_by_name(&self, name: &str) -> Result<&Function> {
54        self.func_map
55            .values()
56            .find(|x| {
57                if let Some(nm) = x.name.as_ref() {
58                    nm == name
59                } else {
60                    false
61                }
62            })
63            .ok_or(anyhow!("function {} is not found", name))
64    }
65
66    pub fn collect_ordered(&self) -> Vec<Function> {
67        let mut out: Vec<_> = self.func_map.iter().collect();
68        out.sort_by_key(|x| x.0);
69        out.into_iter().map(|x| x.1.clone()).collect()
70    }
71
72    fn collect_fn_vars_impl(&self, func: FunctionId, vars: &mut Vec<VariableId>) {
73        if let Ok(func) = self.get(func) {
74            vars.extend(func.accessed_vars.iter());
75            for call in func.callees.iter() {
76                self.collect_fn_vars_impl(*call, vars);
77            }
78        }
79    }
80    pub fn collect_fn_vars(&self, func: FunctionId) -> Vec<VariableId> {
81        let mut accessed_vars = Vec::new();
82        self.collect_fn_vars_impl(func, &mut accessed_vars);
83        accessed_vars
84    }
85}