1use crate::lang::op::{self, ArgId, Ast};
3use crate::Result;
4use std::collections::HashMap;
5
6type Ssa = Vec<crate::lang::ssa::Instr>;
7
8#[derive(Debug, Hash, PartialEq, Eq)]
9pub struct NormalizedKernel {
10 pub(crate) args: Vec<op::Arg>,
11 pub(crate) ops: Vec<op::Store>,
12}
13
14impl NormalizedKernel {
15 pub fn new(k: &op::Kernel) -> Result<Self> {
16 fn walk(ast: &Ast, arg_map: &HashMap<ArgId, ArgId>) -> Result<Ast> {
17 use op::AstInner as A;
18 match ast.inner.as_ref() {
19 A::Id { .. } => crate::bail!("unexpected id node"),
20 A::Load { src, layout } => {
21 let src = match arg_map.get(src) {
22 None => crate::bail!("BUG: missing arg id {src:?}"),
23 Some(id) => *id,
24 };
25 op::load(src, layout.clone(), ast.dtype)
26 }
27 A::Unary { op, arg } => {
28 let arg = walk(arg, arg_map)?;
29 op::unary(*op, arg)
30 }
31 A::Binary { op, lhs, rhs } => {
32 let lhs = walk(lhs, arg_map)?;
33 let rhs = walk(rhs, arg_map)?;
34 op::binary(*op, lhs, rhs)
35 }
36 A::Const(cst) => op::cst(*cst),
37 A::Reduce { op, arg, dim } => {
38 let arg = walk(arg, arg_map)?;
39 op::reduce(*op, arg, *dim)
40 }
41 A::Layout { arg, op } => {
42 let arg = walk(arg, arg_map)?;
43 let inner = A::Layout { arg, op: op.clone() };
44 Ok(Ast {
45 inner: std::sync::Arc::new(inner),
46 dtype: ast.dtype(),
47 shape: ast.shape().clone(),
48 })
49 }
50 }
51 }
52
53 let mut arg_map = HashMap::new();
54 let mut args = Vec::with_capacity(k.args.len());
55 let mut ops = Vec::with_capacity(k.ops.len());
56 for (id, arg) in k.args.iter().enumerate() {
57 let id = ArgId::from_usize(id);
58 arg_map.insert(arg.id(), id);
59 args.push(op::Arg::new(id, arg.type_()));
60 }
61 for op in k.ops.iter() {
62 let op::Store { dst, layout, value } = op;
63 let dst = match arg_map.get(dst) {
64 None => crate::bail!("BUG: missing arg id {dst:?}"),
65 Some(id) => *id,
66 };
67 let value = walk(value, &arg_map)?;
68 ops.push(op::store(dst, layout.clone(), value)?)
69 }
70 Ok(Self { args, ops })
71 }
72}
73
74pub struct CompilationCache<D: crate::Device> {
75 op_cache: HashMap<NormalizedKernel, std::sync::Arc<D::Func>>,
76 ssa_cache: HashMap<Ssa, std::sync::Arc<D::Func>>,
77}
78
79impl<D: crate::Device> Default for CompilationCache<D> {
80 fn default() -> Self {
81 Self { op_cache: Default::default(), ssa_cache: Default::default() }
82 }
83}
84
85impl<D: crate::Device> CompilationCache<D> {
86 pub fn get(&self, kernel: &NormalizedKernel) -> Option<std::sync::Arc<D::Func>> {
87 self.op_cache.get(kernel).cloned()
88 }
89
90 pub fn insert(&mut self, kernel: NormalizedKernel, func: std::sync::Arc<D::Func>) {
91 self.op_cache.insert(kernel, func);
92 }
93
94 pub fn get_ssa(&self, kernel: &Ssa) -> Option<std::sync::Arc<D::Func>> {
95 self.ssa_cache.get(kernel).cloned()
96 }
97
98 pub fn insert_ssa(&mut self, kernel: Ssa, func: std::sync::Arc<D::Func>) {
99 self.ssa_cache.insert(kernel, func);
100 }
101}