ug/
cache.rs

1//! Compilation cache utilities.
2use crate::lang::op::{self, ArgId, Ast};
3use crate::Result;
4use std::collections::HashMap;
5
6type Ssa = Vec<crate::lang::ssa::Instr>;
7
8#[derive(Debug, Hash, PartialEq, Eq)]
9pub struct NormalizedKernel {
10    pub(crate) args: Vec<op::Arg>,
11    pub(crate) ops: Vec<op::Store>,
12}
13
14impl NormalizedKernel {
15    pub fn new(k: &op::Kernel) -> Result<Self> {
16        fn walk(ast: &Ast, arg_map: &HashMap<ArgId, ArgId>) -> Result<Ast> {
17            use op::AstInner as A;
18            match ast.inner.as_ref() {
19                A::Id { .. } => crate::bail!("unexpected id node"),
20                A::Load { src, layout } => {
21                    let src = match arg_map.get(src) {
22                        None => crate::bail!("BUG: missing arg id {src:?}"),
23                        Some(id) => *id,
24                    };
25                    op::load(src, layout.clone(), ast.dtype)
26                }
27                A::Unary { op, arg } => {
28                    let arg = walk(arg, arg_map)?;
29                    op::unary(*op, arg)
30                }
31                A::Binary { op, lhs, rhs } => {
32                    let lhs = walk(lhs, arg_map)?;
33                    let rhs = walk(rhs, arg_map)?;
34                    op::binary(*op, lhs, rhs)
35                }
36                A::Const(cst) => op::cst(*cst),
37                A::Reduce { op, arg, dim } => {
38                    let arg = walk(arg, arg_map)?;
39                    op::reduce(*op, arg, *dim)
40                }
41                A::Layout { arg, op } => {
42                    let arg = walk(arg, arg_map)?;
43                    let inner = A::Layout { arg, op: op.clone() };
44                    Ok(Ast {
45                        inner: std::sync::Arc::new(inner),
46                        dtype: ast.dtype(),
47                        shape: ast.shape().clone(),
48                    })
49                }
50            }
51        }
52
53        let mut arg_map = HashMap::new();
54        let mut args = Vec::with_capacity(k.args.len());
55        let mut ops = Vec::with_capacity(k.ops.len());
56        for (id, arg) in k.args.iter().enumerate() {
57            let id = ArgId::from_usize(id);
58            arg_map.insert(arg.id(), id);
59            args.push(op::Arg::new(id, arg.type_()));
60        }
61        for op in k.ops.iter() {
62            let op::Store { dst, layout, value } = op;
63            let dst = match arg_map.get(dst) {
64                None => crate::bail!("BUG: missing arg id {dst:?}"),
65                Some(id) => *id,
66            };
67            let value = walk(value, &arg_map)?;
68            ops.push(op::store(dst, layout.clone(), value)?)
69        }
70        Ok(Self { args, ops })
71    }
72}
73
74pub struct CompilationCache<D: crate::Device> {
75    op_cache: HashMap<NormalizedKernel, std::sync::Arc<D::Func>>,
76    ssa_cache: HashMap<Ssa, std::sync::Arc<D::Func>>,
77}
78
79impl<D: crate::Device> Default for CompilationCache<D> {
80    fn default() -> Self {
81        Self { op_cache: Default::default(), ssa_cache: Default::default() }
82    }
83}
84
85impl<D: crate::Device> CompilationCache<D> {
86    pub fn get(&self, kernel: &NormalizedKernel) -> Option<std::sync::Arc<D::Func>> {
87        self.op_cache.get(kernel).cloned()
88    }
89
90    pub fn insert(&mut self, kernel: NormalizedKernel, func: std::sync::Arc<D::Func>) {
91        self.op_cache.insert(kernel, func);
92    }
93
94    pub fn get_ssa(&self, kernel: &Ssa) -> Option<std::sync::Arc<D::Func>> {
95        self.ssa_cache.get(kernel).cloned()
96    }
97
98    pub fn insert_ssa(&mut self, kernel: Ssa, func: std::sync::Arc<D::Func>) {
99        self.ssa_cache.insert(kernel, func);
100    }
101}