1use crate::{
2 BoxedFn, CombineResults, CompilerRegistry, DynFn, Error, Errors, EvalType,
3 Expr, MethodCallData, Result, TypeInfo, parse_expr,
4};
5use std::marker::PhantomData;
6
7pub struct Compiler<Ctx> {
8 registry: CompilerRegistry,
9 ctx_ty: PhantomData<Ctx>,
10}
11
12impl<Ctx: EvalType> Compiler<Ctx> {
13 pub fn new() -> Result<Self> {
14 let mut registry = CompilerRegistry::default();
15
16 registry.register_type::<Ctx, i64>()?;
18 registry.register_type::<Ctx, f64>()?;
19 registry.register_type::<Ctx, String>()?;
20
21 registry.register_type::<Ctx, Ctx>()?;
23
24 Ok(Self {
25 registry,
26 ctx_ty: PhantomData,
27 })
28 }
29
30 pub fn compile<Ret: EvalType>(
31 &self,
32 input: &str,
33 ) -> Result<BoxedFn<Ctx, Ret>> {
34 let dyn_fn = self.compile_dyn(input)?;
35 let casted_fn = self.cast(dyn_fn, Ret::type_info())?;
36 casted_fn.downcast::<Ctx, Ret>()
37 }
38
39 pub fn compile_dyn(&self, input: &str) -> Result<DynFn> {
40 let (expr, mut errors) = parse_expr(input);
41
42 let Some(expr) = expr else {
43 return Err(errors);
44 };
45
46 let dyn_fn = match self.compile_expr(&expr) {
47 Ok(dyn_fn) => Some(dyn_fn),
48 Err(compile_errors) => {
49 errors.append(compile_errors);
50 None
51 }
52 };
53
54 if dyn_fn.is_none() && errors.is_empty() {
55 errors.append(Error::UnknownError);
56 }
57
58 if !errors.is_empty() {
59 return Err(errors);
60 }
61
62 Ok(dyn_fn.ok_or(Error::UnknownError)?)
63 }
64
65 fn compile_expr(&self, expr: &Expr) -> Result<DynFn> {
66 match expr {
67 Expr::Int(val) => {
68 let val = **val;
69 Ok(DynFn::new::<_, i64>(move |_ctx: &Ctx| val))
70 }
71
72 Expr::Float(val) => {
73 let val = **val;
74 Ok(DynFn::new::<_, f64>(move |_ctx: &Ctx| val))
75 }
76
77 Expr::String(s) => {
78 let s = s.clone();
79 Ok(DynFn::new::<_, String>(move |_ctx: &Ctx| {
80 (*s).clone().into()
81 }))
82 }
83
84 Expr::Var(var_name) => self.compile_variable(var_name),
85
86 Expr::UnOp(op, rhs) => self.compile_unary_op(**op, rhs),
87
88 Expr::BinOp(op, lhs, rhs) => self.compile_binary_op(**op, lhs, rhs),
89
90 Expr::FieldAccess(obj, field_name) => {
91 let obj_fn = self.compile_expr(obj)?;
92 self.compile_field_access(obj_fn, field_name)
93 }
94
95 Expr::FuncCall(func, args) => {
96 self.compile_function_call(func, args)
97 }
98
99 Expr::InvalidLiteral(err) => {
100 Err(Error::InvalidLiteral((**err).clone()))?
101 }
102
103 Expr::ParseError => Err(Errors::empty())?,
104 }
105 }
106
107 fn cast(&self, expr: DynFn, ty: TypeInfo) -> Result<DynFn> {
110 if expr.ret_type == ty {
111 return Ok(expr);
112 }
113 let key = (expr.ret_type, ty);
114 let Some(compile_cast) = self.registry.casts.get(&key) else {
115 Err(Error::CantCast {
116 from: expr.ret_type,
117 to: ty,
118 })?
119 };
120 compile_cast(expr)
121 }
122
123 fn cast_same_type(&self, a: DynFn, b: DynFn) -> Result<(DynFn, DynFn)> {
125 if a.ret_type == b.ret_type {
126 return Ok((a, b));
127 }
128
129 if let Ok(b_casted) = self.cast(b.clone(), a.ret_type) {
130 return Ok((a, b_casted));
131 }
132
133 if let Ok(a_casted) = self.cast(a.clone(), b.ret_type) {
134 return Ok((a_casted, b));
135 }
136
137 Err(Error::CantCastSameType(a.ret_type, b.ret_type))?
138 }
139
140 fn compile_field_access(
141 &self,
142 object: DynFn,
143 field: &str,
144 ) -> Result<DynFn> {
145 let Some(compile_fn) =
146 self.registry.field_access.get(&(object.ret_type, field))
147 else {
148 Err(Error::FieldNotFound {
149 ty: object.ret_type,
150 field: field.into(),
151 })?
152 };
153 compile_fn(object)
154 }
155
156 fn compile_method_call(
157 &self,
158 object: DynFn,
159 method: &str,
160 arguments: Vec<DynFn>,
161 ) -> Result<DynFn> {
162 let Some(MethodCallData {
163 compile_fn,
164 arg_types,
165 }) = self.registry.method_calls.get(&(object.ret_type, method))
166 else {
167 Err(Error::MethodNotFound {
168 ty: object.ret_type,
169 method: method.into(),
170 })?
171 };
172
173 let (_, arguments) = (
175 (arg_types.len() == arguments.len()).then_some(()).ok_or(
176 Error::ArgCountMismatch {
177 expected: arg_types.len(),
178 got: arguments.len(),
179 },
180 ),
181 arguments
182 .into_iter()
183 .zip(arg_types.iter().copied())
184 .map(|(arg, ty)| self.cast(arg, ty))
185 .collect::<Vec<_>>()
186 .all_ok(),
187 )
188 .all_ok()?;
189
190 compile_fn(object, arguments)
191 }
192
193 fn compile_unary_op(&self, op: crate::UnOp, rhs: &Expr) -> Result<DynFn> {
194 let rhs_fn = self.compile_expr(rhs)?;
195
196 let ty = rhs_fn.ret_type;
197
198 let Some(compile_fn) = self.registry.unary_operations.get(&(op, ty))
199 else {
200 Err(Error::UnknownUnaryOp { op, ty })?
201 };
202
203 compile_fn(rhs_fn)
204 }
205
206 fn compile_binary_op(
207 &self,
208 op: crate::BinOp,
209 lhs: &Expr,
210 rhs: &Expr,
211 ) -> Result<DynFn> {
212 let (lhs_fn, rhs_fn) =
213 (self.compile_expr(lhs), self.compile_expr(rhs)).all_ok()?;
214
215 let (lhs_fn, rhs_fn) = self.cast_same_type(lhs_fn, rhs_fn)?;
216 let ty = lhs_fn.ret_type;
217
218 let Some(compile_fn) = self.registry.binary_operations.get(&(op, ty))
219 else {
220 Err(Error::UnknownBinaryOp { op, ty })?
221 };
222
223 compile_fn(lhs_fn, rhs_fn)
224 }
225
226 fn compile_variable(&self, var_name: &str) -> Result<DynFn> {
227 let ctx_fn = Ctx::make_dyn_fn(|ctx: &Ctx| Ctx::to_ref_type(ctx));
228 self.compile_field_access(ctx_fn, var_name)
229 }
230
231 fn compile_function_call(
232 &self,
233 function: &Expr,
234 arguments: &[Expr],
235 ) -> Result<DynFn> {
236 let args_fns = arguments
237 .iter()
238 .map(|arg| self.compile_expr(arg))
239 .collect::<Vec<_>>()
240 .all_ok()?;
241
242 let ctx_fn = Ctx::make_dyn_fn(|ctx: &Ctx| Ctx::to_ref_type(ctx));
243
244 match function {
245 Expr::Var(var_name) => {
246 self.compile_method_call(ctx_fn, var_name, args_fns)
247 }
248 Expr::FieldAccess(obj, field_name) => {
249 let obj_fn = self.compile_expr(obj)?;
250 self.compile_method_call(obj_fn, field_name, args_fns)
251 }
252 _ => Err(Error::UnsupportedFunctionCall)?,
253 }
254 }
255}