typed_eval/compiler/
compiler.rs

1use crate::{
2    BoxedFn, CombineResults, CompilerRegistry, DynFn, Error, Errors, EvalType,
3    Expr, MethodCallData, Result, TypeInfo, parse_expr,
4};
5use std::marker::PhantomData;
6
7pub struct Compiler<Ctx> {
8    registry: CompilerRegistry,
9    ctx_ty: PhantomData<Ctx>,
10}
11
12impl<Ctx: EvalType> Compiler<Ctx> {
13    pub fn new() -> Result<Self> {
14        let mut registry = CompilerRegistry::default();
15
16        // Register literal types
17        registry.register_type::<Ctx, i64>()?;
18        registry.register_type::<Ctx, f64>()?;
19        registry.register_type::<Ctx, String>()?;
20
21        // Register context type
22        registry.register_type::<Ctx, Ctx>()?;
23
24        Ok(Self {
25            registry,
26            ctx_ty: PhantomData,
27        })
28    }
29
30    pub fn compile<Ret: EvalType>(
31        &self,
32        input: &str,
33    ) -> Result<BoxedFn<Ctx, Ret>> {
34        let dyn_fn = self.compile_dyn(input)?;
35        let casted_fn = self.cast(dyn_fn, Ret::type_info())?;
36        casted_fn.downcast::<Ctx, Ret>()
37    }
38
39    pub fn compile_dyn(&self, input: &str) -> Result<DynFn> {
40        let (expr, mut errors) = parse_expr(input);
41
42        let Some(expr) = expr else {
43            return Err(errors);
44        };
45
46        let dyn_fn = match self.compile_expr(&expr) {
47            Ok(dyn_fn) => Some(dyn_fn),
48            Err(compile_errors) => {
49                errors.append(compile_errors);
50                None
51            }
52        };
53
54        if dyn_fn.is_none() && errors.is_empty() {
55            errors.append(Error::UnknownError);
56        }
57
58        if !errors.is_empty() {
59            return Err(errors);
60        }
61
62        Ok(dyn_fn.ok_or(Error::UnknownError)?)
63    }
64
65    fn compile_expr(&self, expr: &Expr) -> Result<DynFn> {
66        match expr {
67            Expr::Int(val) => {
68                let val = **val;
69                Ok(DynFn::new::<_, i64>(move |_ctx: &Ctx| val))
70            }
71
72            Expr::Float(val) => {
73                let val = **val;
74                Ok(DynFn::new::<_, f64>(move |_ctx: &Ctx| val))
75            }
76
77            Expr::String(s) => {
78                let s = s.clone();
79                Ok(DynFn::new::<_, String>(move |_ctx: &Ctx| {
80                    (*s).clone().into()
81                }))
82            }
83
84            Expr::Var(var_name) => self.compile_variable(var_name),
85
86            Expr::UnOp(op, rhs) => self.compile_unary_op(**op, rhs),
87
88            Expr::BinOp(op, lhs, rhs) => self.compile_binary_op(**op, lhs, rhs),
89
90            Expr::FieldAccess(obj, field_name) => {
91                let obj_fn = self.compile_expr(obj)?;
92                self.compile_field_access(obj_fn, field_name)
93            }
94
95            Expr::FuncCall(func, args) => {
96                self.compile_function_call(func, args)
97            }
98
99            Expr::InvalidLiteral(err) => {
100                Err(Error::InvalidLiteral((**err).clone()))?
101            }
102
103            Expr::ParseError => Err(Errors::empty())?,
104        }
105    }
106
107    /// try to cast expression to given type
108    /// on success returned DynFn will have ret_type matching ty
109    fn cast(&self, expr: DynFn, ty: TypeInfo) -> Result<DynFn> {
110        if expr.ret_type == ty {
111            return Ok(expr);
112        }
113        let key = (expr.ret_type, ty);
114        let Some(compile_cast) = self.registry.casts.get(&key) else {
115            Err(Error::CantCast {
116                from: expr.ret_type,
117                to: ty,
118            })?
119        };
120        compile_cast(expr)
121    }
122
123    /// try to cast two expressions so that they have the same ret_type
124    fn cast_same_type(&self, a: DynFn, b: DynFn) -> Result<(DynFn, DynFn)> {
125        if a.ret_type == b.ret_type {
126            return Ok((a, b));
127        }
128
129        if let Ok(b_casted) = self.cast(b.clone(), a.ret_type) {
130            return Ok((a, b_casted));
131        }
132
133        if let Ok(a_casted) = self.cast(a.clone(), b.ret_type) {
134            return Ok((a_casted, b));
135        }
136
137        Err(Error::CantCastSameType(a.ret_type, b.ret_type))?
138    }
139
140    fn compile_field_access(
141        &self,
142        object: DynFn,
143        field: &str,
144    ) -> Result<DynFn> {
145        let Some(compile_fn) =
146            self.registry.field_access.get(&(object.ret_type, field))
147        else {
148            Err(Error::FieldNotFound {
149                ty: object.ret_type,
150                field: field.into(),
151            })?
152        };
153        compile_fn(object)
154    }
155
156    fn compile_method_call(
157        &self,
158        object: DynFn,
159        method: &str,
160        arguments: Vec<DynFn>,
161    ) -> Result<DynFn> {
162        let Some(MethodCallData {
163            compile_fn,
164            arg_types,
165        }) = self.registry.method_calls.get(&(object.ret_type, method))
166        else {
167            Err(Error::MethodNotFound {
168                ty: object.ret_type,
169                method: method.into(),
170            })?
171        };
172
173        // cast arguments to arg_types
174        let (_, arguments) = (
175            (arg_types.len() == arguments.len()).then_some(()).ok_or(
176                Error::ArgCountMismatch {
177                    expected: arg_types.len(),
178                    got: arguments.len(),
179                },
180            ),
181            arguments
182                .into_iter()
183                .zip(arg_types.iter().copied())
184                .map(|(arg, ty)| self.cast(arg, ty))
185                .collect::<Vec<_>>()
186                .all_ok(),
187        )
188            .all_ok()?;
189
190        compile_fn(object, arguments)
191    }
192
193    fn compile_unary_op(&self, op: crate::UnOp, rhs: &Expr) -> Result<DynFn> {
194        let rhs_fn = self.compile_expr(rhs)?;
195
196        let ty = rhs_fn.ret_type;
197
198        let Some(compile_fn) = self.registry.unary_operations.get(&(op, ty))
199        else {
200            Err(Error::UnknownUnaryOp { op, ty })?
201        };
202
203        compile_fn(rhs_fn)
204    }
205
206    fn compile_binary_op(
207        &self,
208        op: crate::BinOp,
209        lhs: &Expr,
210        rhs: &Expr,
211    ) -> Result<DynFn> {
212        let (lhs_fn, rhs_fn) =
213            (self.compile_expr(lhs), self.compile_expr(rhs)).all_ok()?;
214
215        let (lhs_fn, rhs_fn) = self.cast_same_type(lhs_fn, rhs_fn)?;
216        let ty = lhs_fn.ret_type;
217
218        let Some(compile_fn) = self.registry.binary_operations.get(&(op, ty))
219        else {
220            Err(Error::UnknownBinaryOp { op, ty })?
221        };
222
223        compile_fn(lhs_fn, rhs_fn)
224    }
225
226    fn compile_variable(&self, var_name: &str) -> Result<DynFn> {
227        let ctx_fn = Ctx::make_dyn_fn(|ctx: &Ctx| Ctx::to_ref_type(ctx));
228        self.compile_field_access(ctx_fn, var_name)
229    }
230
231    fn compile_function_call(
232        &self,
233        function: &Expr,
234        arguments: &[Expr],
235    ) -> Result<DynFn> {
236        let args_fns = arguments
237            .iter()
238            .map(|arg| self.compile_expr(arg))
239            .collect::<Vec<_>>()
240            .all_ok()?;
241
242        let ctx_fn = Ctx::make_dyn_fn(|ctx: &Ctx| Ctx::to_ref_type(ctx));
243
244        match function {
245            Expr::Var(var_name) => {
246                self.compile_method_call(ctx_fn, var_name, args_fns)
247            }
248            Expr::FieldAccess(obj, field_name) => {
249                let obj_fn = self.compile_expr(obj)?;
250                self.compile_method_call(obj_fn, field_name, args_fns)
251            }
252            _ => Err(Error::UnsupportedFunctionCall)?,
253        }
254    }
255}