typed_eval/compiler/
compiler_registry.rs

1use crate::{BinOp, DynFn, Error, EvalType, Result, TypeInfo, UnOp};
2use std::{
3    collections::{HashMap, HashSet, hash_map::Entry},
4    hash::Hash,
5    marker::PhantomData,
6};
7
8type CastKey = (TypeInfo, TypeInfo);
9type CompileCastFunc = Box<dyn Fn(DynFn) -> Result<DynFn>>;
10
11type UnOpKey = (UnOp, TypeInfo);
12type CompileUnOpFunc = Box<dyn Fn(DynFn) -> Result<DynFn>>;
13
14type BinOpKey = (BinOp, TypeInfo);
15type CompileBinOpFunc = Box<dyn Fn(DynFn, DynFn) -> Result<DynFn>>;
16
17type FieldAccessKey = (TypeInfo, &'static str);
18type FieldAccessFunc = Box<dyn Fn(DynFn) -> Result<DynFn>>;
19
20type MethodCallKey = (TypeInfo, &'static str);
21
22pub(crate) struct MethodCallData {
23    pub compile_fn: Box<dyn Fn(DynFn, Vec<DynFn>) -> Result<DynFn>>,
24    pub arg_types: Vec<TypeInfo>,
25}
26
27pub struct RegistryAccess<'r, Ctx, T> {
28    pub(super) registry: &'r mut CompilerRegistry,
29    ty: PhantomData<(Ctx, T)>,
30}
31
32impl<'r, Ctx, T> RegistryAccess<'r, Ctx, T> {
33    pub(crate) fn new(registry: &'r mut CompilerRegistry) -> Self {
34        Self {
35            registry,
36            ty: PhantomData,
37        }
38    }
39}
40
41#[derive(Default)]
42pub(crate) struct CompilerRegistry {
43    registered_types: HashSet<TypeInfo>,
44    pub(crate) casts: HashMap<CastKey, CompileCastFunc>,
45    pub(crate) unary_operations: HashMap<UnOpKey, CompileUnOpFunc>,
46    pub(crate) binary_operations: HashMap<BinOpKey, CompileBinOpFunc>,
47    pub(crate) field_access: HashMap<FieldAccessKey, FieldAccessFunc>,
48    pub(crate) method_calls: HashMap<MethodCallKey, MethodCallData>,
49}
50
51impl CompilerRegistry {
52    pub fn register_type<Ctx: EvalType, T: EvalType>(&mut self) -> Result<()> {
53        let type_id = T::type_info();
54
55        if self.registered_types.insert(type_id) {
56            T::register_methods(RegistryAccess::<Ctx, T>::new(self))?;
57            T::register(RegistryAccess::<Ctx, T>::new(self))?;
58        }
59
60        Ok(())
61    }
62}
63
64impl<'r, Ctx: EvalType, T: EvalType> RegistryAccess<'r, Ctx, T> {
65    pub fn register_type<T2: EvalType>(&mut self) -> Result<()> {
66        self.registry.register_type::<Ctx, T2>()
67    }
68
69    // cast from T to To
70    pub fn register_cast<To>(
71        &mut self,
72        cast_fn: for<'a> fn(&'a Ctx, T::RefType<'a>) -> To::RefType<'a>,
73    ) -> Result<()>
74    where
75        To: EvalType,
76    {
77        let (from, to) = (T::type_info(), To::type_info());
78
79        let compile_func = Box::new(move |from: DynFn| -> Result<DynFn> {
80            let from = from.downcast::<Ctx, T>()?;
81            Ok(To::make_dyn_fn(move |ctx| cast_fn(ctx, from(ctx))))
82        });
83
84        try_insert(&mut self.registry.casts, (from, to), compile_func, || {
85            Error::DuplicateCast { from, to }
86        })
87    }
88
89    pub fn register_un_op(
90        &mut self,
91        op: UnOp,
92        un_op_fn: for<'a> fn(&'a Ctx, T::RefType<'a>) -> T::RefType<'a>,
93    ) -> Result<()> {
94        let ty = T::type_info();
95
96        let compile_func = Box::new(move |rhs: DynFn| -> Result<DynFn> {
97            let rhs = rhs.downcast::<Ctx, T>()?;
98            Ok(T::make_dyn_fn(move |ctx| un_op_fn(ctx, rhs(ctx))))
99        });
100
101        try_insert(
102            &mut self.registry.unary_operations,
103            (op, ty),
104            compile_func,
105            || Error::DuplicateUnOp { op, ty },
106        )
107    }
108
109    pub fn register_bin_op(
110        &mut self,
111        op: BinOp,
112        bin_op_fn: for<'a> fn(
113            &'a Ctx,
114            T::RefType<'a>,
115            T::RefType<'a>,
116        ) -> T::RefType<'a>,
117    ) -> Result<()> {
118        let ty = T::type_info();
119
120        let compile_func =
121            Box::new(move |lhs: DynFn, rhs: DynFn| -> Result<DynFn> {
122                let lhs = lhs.downcast::<Ctx, T>()?;
123                let rhs = rhs.downcast::<Ctx, T>()?;
124                Ok(T::make_dyn_fn(move |ctx| {
125                    bin_op_fn(ctx, lhs(ctx), rhs(ctx))
126                }))
127            });
128
129        try_insert(
130            &mut self.registry.binary_operations,
131            (op, ty),
132            compile_func,
133            || Error::DuplicateBinOp { op, ty },
134        )
135    }
136
137    // access field on type T
138    pub fn register_field_access<Field>(
139        &mut self,
140        field: &'static str,
141        getter: for<'a> fn(&'a T) -> Field::RefType<'a>,
142    ) -> Result<()>
143    where
144        T: for<'a> EvalType<RefType<'a> = &'a T>,
145        Field: EvalType,
146    {
147        self.register_type::<Field>()?;
148
149        let ty = T::type_info();
150
151        let compile_func = Box::new(move |obj: DynFn| -> Result<DynFn> {
152            let obj = obj.downcast::<Ctx, T>()?;
153            Ok(Field::make_dyn_fn(move |ctx| getter(obj(ctx))))
154        });
155
156        try_insert(
157            &mut self.registry.field_access,
158            (ty, field),
159            compile_func,
160            || Error::DuplicateField { ty, field },
161        )
162    }
163}
164
165pub(super) fn try_insert<K, V>(
166    map: &mut HashMap<K, V>,
167    key: K,
168    value: V,
169    make_err: impl FnOnce() -> Error,
170) -> Result<()>
171where
172    K: Hash + Eq,
173{
174    match map.entry(key) {
175        Entry::Occupied(_) => Err(make_err())?,
176        Entry::Vacant(vacant) => {
177            vacant.insert(value);
178        }
179    }
180    Ok(())
181}