zen_expression/compiler/
compiler.rs

1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{ClosureFunction, FunctionKind, InternalFunction, MethodRegistry};
6use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator};
7use crate::parser::Node;
8use rust_decimal::prelude::ToPrimitive;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11use std::sync::Arc;
12
13#[derive(Debug)]
14pub struct Compiler {
15    bytecode: Vec<Opcode>,
16}
17
18impl Compiler {
19    pub fn new() -> Self {
20        Self {
21            bytecode: Default::default(),
22        }
23    }
24
25    pub fn compile(&mut self, root: &Node) -> CompilerResult<&[Opcode]> {
26        self.bytecode.clear();
27
28        CompilerInner::new(&mut self.bytecode, root).compile()?;
29        Ok(self.bytecode.as_slice())
30    }
31
32    pub fn get_bytecode(&self) -> &[Opcode] {
33        self.bytecode.as_slice()
34    }
35}
36
37#[derive(Debug)]
38struct CompilerInner<'arena, 'bytecode_ref> {
39    root: &'arena Node<'arena>,
40    bytecode: &'bytecode_ref mut Vec<Opcode>,
41}
42
43impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
44    pub fn new(bytecode: &'bytecode_ref mut Vec<Opcode>, root: &'arena Node<'arena>) -> Self {
45        Self { root, bytecode }
46    }
47
48    pub fn compile(&mut self) -> CompilerResult<()> {
49        self.compile_node(self.root)?;
50        Ok(())
51    }
52
53    fn emit(&mut self, op: Opcode) -> usize {
54        self.bytecode.push(op);
55        self.bytecode.len()
56    }
57
58    fn emit_loop<F>(&mut self, body: F) -> CompilerResult<()>
59    where
60        F: FnOnce(&mut Self) -> CompilerResult<()>,
61    {
62        let begin = self.bytecode.len();
63        let end = self.emit(Opcode::Jump(Jump::IfEnd, 0));
64
65        body(self)?;
66
67        self.emit(Opcode::IncrementIt);
68        let e = self.emit(Opcode::Jump(
69            Jump::Backward,
70            self.calc_backward_jump(begin) as u32,
71        ));
72        self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
73        Ok(())
74    }
75
76    fn emit_cond<F>(&mut self, mut body: F)
77    where
78        F: FnMut(&mut Self),
79    {
80        let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0));
81        self.emit(Opcode::Pop);
82
83        body(self);
84
85        let jmp = self.emit(Opcode::Jump(Jump::Forward, 0));
86        self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
87        let e = self.emit(Opcode::Pop);
88        self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
89    }
90
91    fn replace(&mut self, at: usize, op: Opcode) {
92        let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
93    }
94
95    fn calc_backward_jump(&self, to: usize) -> usize {
96        self.bytecode.len() + 1 - to
97    }
98
99    fn compile_argument<T: ToString>(
100        &mut self,
101        function_kind: T,
102        arguments: &[&'arena Node<'arena>],
103        index: usize,
104    ) -> CompilerResult<usize> {
105        let arg = arguments
106            .get(index)
107            .ok_or_else(|| CompilerError::ArgumentNotFound {
108                index,
109                function: function_kind.to_string(),
110            })?;
111
112        self.compile_node(arg)
113    }
114
115    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
116    fn compile_member_fast(&mut self, node: &'arena Node<'arena>) -> Option<Vec<FetchFastTarget>> {
117        match node {
118            Node::Root => Some(vec![FetchFastTarget::Root]),
119            Node::Identifier(v) => Some(vec![
120                FetchFastTarget::Root,
121                FetchFastTarget::String(Arc::from(*v)),
122            ]),
123            Node::Member { node, property } => {
124                let mut path = self.compile_member_fast(node)?;
125                match property {
126                    Node::String(v) => {
127                        path.push(FetchFastTarget::String(Arc::from(*v)));
128                        Some(path)
129                    }
130                    Node::Number(v) => {
131                        if let Some(idx) = v.to_u32() {
132                            path.push(FetchFastTarget::Number(idx));
133                            Some(path)
134                        } else {
135                            None
136                        }
137                    }
138                    _ => None,
139                }
140            }
141            _ => None,
142        }
143    }
144
145    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
146    fn compile_node(&mut self, node: &'arena Node<'arena>) -> CompilerResult<usize> {
147        match node {
148            Node::Null => Ok(self.emit(Opcode::PushNull)),
149            Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
150            Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
151            Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
152            Node::Pointer => Ok(self.emit(Opcode::Pointer)),
153            Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
154            Node::Array(v) => {
155                v.iter()
156                    .try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
157                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
158                Ok(self.emit(Opcode::Array))
159            }
160            Node::Object(v) => {
161                v.iter().try_for_each(|&(key, value)| {
162                    self.compile_node(key).map(|_| ())?;
163                    self.emit(Opcode::CallFunction {
164                        arg_count: 1,
165                        kind: FunctionKind::Internal(InternalFunction::String),
166                    });
167                    self.compile_node(value).map(|_| ())?;
168                    Ok(())
169                })?;
170
171                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
172                Ok(self.emit(Opcode::Object))
173            }
174            Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(Arc::from(*v)))),
175            Node::Closure(v) => self.compile_node(v),
176            Node::Parenthesized(v) => self.compile_node(v),
177            Node::Member {
178                node: n,
179                property: p,
180            } => {
181                if let Some(path) = self.compile_member_fast(node) {
182                    Ok(self.emit(Opcode::FetchFast(path)))
183                } else {
184                    self.compile_node(n)?;
185                    self.compile_node(p)?;
186                    Ok(self.emit(Opcode::Fetch))
187                }
188            }
189            Node::TemplateString(parts) => {
190                parts.iter().try_for_each(|&n| {
191                    self.compile_node(n).map(|_| ())?;
192                    self.emit(Opcode::CallFunction {
193                        arg_count: 1,
194                        kind: FunctionKind::Internal(InternalFunction::String),
195                    });
196                    Ok(())
197                })?;
198
199                self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
200                self.emit(Opcode::Array);
201                self.emit(Opcode::PushString(Arc::from("")));
202                Ok(self.emit(Opcode::Join))
203            }
204            Node::Slice { node, to, from } => {
205                self.compile_node(node)?;
206                if let Some(t) = to {
207                    self.compile_node(t)?;
208                } else {
209                    self.emit(Opcode::Len);
210                    self.emit(Opcode::PushNumber(dec!(1)));
211                    self.emit(Opcode::Subtract);
212                }
213
214                if let Some(f) = from {
215                    self.compile_node(f)?;
216                } else {
217                    self.emit(Opcode::PushNumber(dec!(0)));
218                }
219
220                Ok(self.emit(Opcode::Slice))
221            }
222            Node::Interval {
223                left,
224                right,
225                left_bracket,
226                right_bracket,
227            } => {
228                self.compile_node(left)?;
229                self.compile_node(right)?;
230                Ok(self.emit(Opcode::Interval {
231                    left_bracket: *left_bracket,
232                    right_bracket: *right_bracket,
233                }))
234            }
235            Node::Conditional {
236                condition,
237                on_true,
238                on_false,
239            } => {
240                self.compile_node(condition)?;
241                let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0));
242
243                self.emit(Opcode::Pop);
244                self.compile_node(on_true)?;
245                let end = self.emit(Opcode::Jump(Jump::Forward, 0));
246
247                self.replace(
248                    otherwise,
249                    Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
250                );
251                self.emit(Opcode::Pop);
252                let b = self.compile_node(on_false)?;
253                self.replace(end, Opcode::Jump(Jump::Forward, (b - end) as u32));
254
255                Ok(b)
256            }
257            Node::Unary { node, operator } => {
258                let curr = self.compile_node(node)?;
259                match *operator {
260                    Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr),
261                    Operator::Arithmetic(ArithmeticOperator::Subtract) => {
262                        Ok(self.emit(Opcode::Negate))
263                    }
264                    Operator::Logical(LogicalOperator::Not) => Ok(self.emit(Opcode::Not)),
265                    _ => Err(CompilerError::UnknownUnaryOperator {
266                        operator: operator.to_string(),
267                    }),
268                }
269            }
270            Node::Binary {
271                left,
272                right,
273                operator,
274            } => match *operator {
275                Operator::Comparison(ComparisonOperator::Equal) => {
276                    self.compile_node(left)?;
277                    self.compile_node(right)?;
278
279                    Ok(self.emit(Opcode::Equal))
280                }
281                Operator::Comparison(ComparisonOperator::NotEqual) => {
282                    self.compile_node(left)?;
283                    self.compile_node(right)?;
284
285                    self.emit(Opcode::Equal);
286                    Ok(self.emit(Opcode::Not))
287                }
288                Operator::Logical(LogicalOperator::Or) => {
289                    self.compile_node(left)?;
290                    let end = self.emit(Opcode::Jump(Jump::IfTrue, 0));
291                    self.emit(Opcode::Pop);
292                    let r = self.compile_node(right)?;
293                    self.replace(end, Opcode::Jump(Jump::IfTrue, (r - end) as u32));
294
295                    Ok(r)
296                }
297                Operator::Logical(LogicalOperator::And) => {
298                    self.compile_node(left)?;
299                    let end = self.emit(Opcode::Jump(Jump::IfFalse, 0));
300                    self.emit(Opcode::Pop);
301                    let r = self.compile_node(right)?;
302                    self.replace(end, Opcode::Jump(Jump::IfFalse, (r - end) as u32));
303
304                    Ok(r)
305                }
306                Operator::Logical(LogicalOperator::NullishCoalescing) => {
307                    self.compile_node(left)?;
308                    let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0));
309                    self.emit(Opcode::Pop);
310                    let r = self.compile_node(right)?;
311                    self.replace(end, Opcode::Jump(Jump::IfNotNull, (r - end) as u32));
312
313                    Ok(r)
314                }
315                Operator::Comparison(ComparisonOperator::In) => {
316                    self.compile_node(left)?;
317                    self.compile_node(right)?;
318                    Ok(self.emit(Opcode::In))
319                }
320                Operator::Comparison(ComparisonOperator::NotIn) => {
321                    self.compile_node(left)?;
322                    self.compile_node(right)?;
323                    self.emit(Opcode::In);
324                    Ok(self.emit(Opcode::Not))
325                }
326                Operator::Comparison(ComparisonOperator::LessThan) => {
327                    self.compile_node(left)?;
328                    self.compile_node(right)?;
329                    Ok(self.emit(Opcode::Compare(Compare::Less)))
330                }
331                Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
332                    self.compile_node(left)?;
333                    self.compile_node(right)?;
334                    Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
335                }
336                Operator::Comparison(ComparisonOperator::GreaterThan) => {
337                    self.compile_node(left)?;
338                    self.compile_node(right)?;
339                    Ok(self.emit(Opcode::Compare(Compare::More)))
340                }
341                Operator::Comparison(ComparisonOperator::GreaterThanOrEqual) => {
342                    self.compile_node(left)?;
343                    self.compile_node(right)?;
344                    Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
345                }
346                Operator::Arithmetic(ArithmeticOperator::Add) => {
347                    self.compile_node(left)?;
348                    self.compile_node(right)?;
349                    Ok(self.emit(Opcode::Add))
350                }
351                Operator::Arithmetic(ArithmeticOperator::Subtract) => {
352                    self.compile_node(left)?;
353                    self.compile_node(right)?;
354                    Ok(self.emit(Opcode::Subtract))
355                }
356                Operator::Arithmetic(ArithmeticOperator::Multiply) => {
357                    self.compile_node(left)?;
358                    self.compile_node(right)?;
359                    Ok(self.emit(Opcode::Multiply))
360                }
361                Operator::Arithmetic(ArithmeticOperator::Divide) => {
362                    self.compile_node(left)?;
363                    self.compile_node(right)?;
364                    Ok(self.emit(Opcode::Divide))
365                }
366                Operator::Arithmetic(ArithmeticOperator::Modulus) => {
367                    self.compile_node(left)?;
368                    self.compile_node(right)?;
369                    Ok(self.emit(Opcode::Modulo))
370                }
371                Operator::Arithmetic(ArithmeticOperator::Power) => {
372                    self.compile_node(left)?;
373                    self.compile_node(right)?;
374                    Ok(self.emit(Opcode::Exponent))
375                }
376                _ => Err(CompilerError::UnknownBinaryOperator {
377                    operator: operator.to_string(),
378                }),
379            },
380            Node::FunctionCall { kind, arguments } => match kind {
381                FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => {
382                    let function = FunctionRegistry::get_definition(kind).ok_or_else(|| {
383                        CompilerError::UnknownFunction {
384                            name: kind.to_string(),
385                        }
386                    })?;
387
388                    let min_params = function.required_parameters();
389                    let max_params = min_params + function.optional_parameters();
390                    if arguments.len() < min_params || arguments.len() > max_params {
391                        return Err(CompilerError::InvalidFunctionCall {
392                            name: kind.to_string(),
393                            message: "Invalid number of arguments".to_string(),
394                        });
395                    }
396
397                    for i in 0..arguments.len() {
398                        self.compile_argument(kind, arguments, i)?;
399                    }
400
401                    Ok(self.emit(Opcode::CallFunction {
402                        kind: kind.clone(),
403                        arg_count: arguments.len() as u32,
404                    }))
405                }
406                FunctionKind::Closure(c) => match c {
407                    ClosureFunction::All => {
408                        self.compile_argument(kind, arguments, 0)?;
409                        self.emit(Opcode::Begin);
410                        let mut loop_break: usize = 0;
411                        self.emit_loop(|c| {
412                            c.compile_argument(kind, arguments, 1)?;
413                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
414                            c.emit(Opcode::Pop);
415                            Ok(())
416                        })?;
417                        let e = self.emit(Opcode::PushBool(true));
418                        self.replace(
419                            loop_break,
420                            Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
421                        );
422                        Ok(self.emit(Opcode::End))
423                    }
424                    ClosureFunction::None => {
425                        self.compile_argument(kind, arguments, 0)?;
426                        self.emit(Opcode::Begin);
427                        let mut loop_break: usize = 0;
428                        self.emit_loop(|c| {
429                            c.compile_argument(kind, arguments, 1)?;
430                            c.emit(Opcode::Not);
431                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
432                            c.emit(Opcode::Pop);
433                            Ok(())
434                        })?;
435                        let e = self.emit(Opcode::PushBool(true));
436                        self.replace(
437                            loop_break,
438                            Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
439                        );
440                        Ok(self.emit(Opcode::End))
441                    }
442                    ClosureFunction::Some => {
443                        self.compile_argument(kind, arguments, 0)?;
444                        self.emit(Opcode::Begin);
445                        let mut loop_break: usize = 0;
446                        self.emit_loop(|c| {
447                            c.compile_argument(kind, arguments, 1)?;
448                            loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0));
449                            c.emit(Opcode::Pop);
450                            Ok(())
451                        })?;
452                        let e = self.emit(Opcode::PushBool(false));
453                        self.replace(
454                            loop_break,
455                            Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
456                        );
457                        Ok(self.emit(Opcode::End))
458                    }
459                    ClosureFunction::One => {
460                        self.compile_argument(kind, arguments, 0)?;
461                        self.emit(Opcode::Begin);
462                        self.emit_loop(|c| {
463                            c.compile_argument(kind, arguments, 1)?;
464                            c.emit_cond(|c| {
465                                c.emit(Opcode::IncrementCount);
466                            });
467                            Ok(())
468                        })?;
469                        self.emit(Opcode::GetCount);
470                        self.emit(Opcode::PushNumber(dec!(1)));
471                        self.emit(Opcode::Equal);
472                        Ok(self.emit(Opcode::End))
473                    }
474                    ClosureFunction::Filter => {
475                        self.compile_argument(kind, arguments, 0)?;
476                        self.emit(Opcode::Begin);
477                        self.emit_loop(|c| {
478                            c.compile_argument(kind, arguments, 1)?;
479                            c.emit_cond(|c| {
480                                c.emit(Opcode::IncrementCount);
481                                c.emit(Opcode::Pointer);
482                            });
483                            Ok(())
484                        })?;
485                        self.emit(Opcode::GetCount);
486                        self.emit(Opcode::End);
487                        Ok(self.emit(Opcode::Array))
488                    }
489                    ClosureFunction::Map => {
490                        self.compile_argument(kind, arguments, 0)?;
491                        self.emit(Opcode::Begin);
492                        self.emit_loop(|c| {
493                            c.compile_argument(kind, arguments, 1)?;
494                            Ok(())
495                        })?;
496                        self.emit(Opcode::GetLen);
497                        self.emit(Opcode::End);
498                        Ok(self.emit(Opcode::Array))
499                    }
500                    ClosureFunction::FlatMap => {
501                        self.compile_argument(kind, arguments, 0)?;
502                        self.emit(Opcode::Begin);
503                        self.emit_loop(|c| {
504                            c.compile_argument(kind, arguments, 1)?;
505                            Ok(())
506                        })?;
507                        self.emit(Opcode::GetLen);
508                        self.emit(Opcode::End);
509                        self.emit(Opcode::Array);
510                        Ok(self.emit(Opcode::Flatten))
511                    }
512                    ClosureFunction::Count => {
513                        self.compile_argument(kind, arguments, 0)?;
514                        self.emit(Opcode::Begin);
515                        self.emit_loop(|c| {
516                            c.compile_argument(kind, arguments, 1)?;
517                            c.emit_cond(|c| {
518                                c.emit(Opcode::IncrementCount);
519                            });
520                            Ok(())
521                        })?;
522                        self.emit(Opcode::GetCount);
523                        Ok(self.emit(Opcode::End))
524                    }
525                },
526            },
527            Node::MethodCall {
528                kind,
529                this,
530                arguments,
531            } => {
532                let method = MethodRegistry::get_definition(kind).ok_or_else(|| {
533                    CompilerError::UnknownFunction {
534                        name: kind.to_string(),
535                    }
536                })?;
537
538                self.compile_node(this)?;
539
540                let min_params = method.required_parameters() - 1;
541                let max_params = min_params + method.optional_parameters();
542                if arguments.len() < min_params || arguments.len() > max_params {
543                    return Err(CompilerError::InvalidMethodCall {
544                        name: kind.to_string(),
545                        message: "Invalid number of arguments".to_string(),
546                    });
547                }
548
549                for i in 0..arguments.len() {
550                    self.compile_argument(kind, arguments, i)?;
551                }
552
553                Ok(self.emit(Opcode::CallMethod {
554                    kind: kind.clone(),
555                    arg_count: arguments.len() as u32,
556                }))
557            }
558            Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
559        }
560    }
561}