zen_expression/compiler/
compiler.rs

1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{ClosureFunction, FunctionKind, InternalFunction, MethodRegistry};
6use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator};
7use crate::parser::Node;
8use rust_decimal::prelude::ToPrimitive;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11use std::sync::Arc;
12
13#[derive(Debug)]
14pub struct Compiler {
15    bytecode: Vec<Opcode>,
16}
17
18impl Compiler {
19    pub fn new() -> Self {
20        Self {
21            bytecode: Default::default(),
22        }
23    }
24
25    pub fn compile(&mut self, root: &Node) -> CompilerResult<&[Opcode]> {
26        self.bytecode.clear();
27
28        CompilerInner::new(&mut self.bytecode, root).compile()?;
29        Ok(self.bytecode.as_slice())
30    }
31
32    pub fn get_bytecode(&self) -> &[Opcode] {
33        self.bytecode.as_slice()
34    }
35}
36
37#[derive(Debug)]
38struct CompilerInner<'arena, 'bytecode_ref> {
39    root: &'arena Node<'arena>,
40    bytecode: &'bytecode_ref mut Vec<Opcode>,
41}
42
43impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
44    pub fn new(bytecode: &'bytecode_ref mut Vec<Opcode>, root: &'arena Node<'arena>) -> Self {
45        Self { root, bytecode }
46    }
47
48    pub fn compile(&mut self) -> CompilerResult<()> {
49        self.compile_node(self.root)?;
50        Ok(())
51    }
52
53    fn emit(&mut self, op: Opcode) -> usize {
54        self.bytecode.push(op);
55        self.bytecode.len()
56    }
57
58    fn emit_loop<F>(&mut self, body: F) -> CompilerResult<()>
59    where
60        F: FnOnce(&mut Self) -> CompilerResult<()>,
61    {
62        let begin = self.bytecode.len();
63        let end = self.emit(Opcode::Jump(Jump::IfEnd, 0));
64
65        body(self)?;
66
67        self.emit(Opcode::IncrementIt);
68        let e = self.emit(Opcode::Jump(
69            Jump::Backward,
70            self.calc_backward_jump(begin) as u32,
71        ));
72        self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
73        Ok(())
74    }
75
76    fn emit_cond<F>(&mut self, mut body: F)
77    where
78        F: FnMut(&mut Self),
79    {
80        let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0));
81        self.emit(Opcode::Pop);
82
83        body(self);
84
85        let jmp = self.emit(Opcode::Jump(Jump::Forward, 0));
86        self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
87        let e = self.emit(Opcode::Pop);
88        self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
89    }
90
91    fn replace(&mut self, at: usize, op: Opcode) {
92        let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
93    }
94
95    fn calc_backward_jump(&self, to: usize) -> usize {
96        self.bytecode.len() + 1 - to
97    }
98
99    fn compile_argument<T: ToString>(
100        &mut self,
101        function_kind: T,
102        arguments: &[&'arena Node<'arena>],
103        index: usize,
104    ) -> CompilerResult<usize> {
105        let arg = arguments
106            .get(index)
107            .ok_or_else(|| CompilerError::ArgumentNotFound {
108                index,
109                function: function_kind.to_string(),
110            })?;
111
112        self.compile_node(arg)
113    }
114
115    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
116    fn compile_member_fast(&mut self, node: &'arena Node<'arena>) -> Option<Vec<FetchFastTarget>> {
117        match node {
118            Node::Root => Some(vec![FetchFastTarget::Root]),
119            Node::Identifier(v) => Some(vec![
120                FetchFastTarget::Begin,
121                FetchFastTarget::String(Arc::from(*v)),
122            ]),
123            Node::Member { node, property } => {
124                let mut path = self.compile_member_fast(node)?;
125                match property {
126                    Node::String(v) => {
127                        path.push(FetchFastTarget::String(Arc::from(*v)));
128                        Some(path)
129                    }
130                    Node::Number(v) => {
131                        if let Some(idx) = v.to_u32() {
132                            path.push(FetchFastTarget::Number(idx));
133                            Some(path)
134                        } else {
135                            None
136                        }
137                    }
138                    _ => None,
139                }
140            }
141            _ => None,
142        }
143    }
144
145    #[cfg_attr(feature = "stack-protection", recursive::recursive)]
146    fn compile_node(&mut self, node: &'arena Node<'arena>) -> CompilerResult<usize> {
147        match node {
148            Node::Null => Ok(self.emit(Opcode::PushNull)),
149            Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
150            Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
151            Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
152            Node::Pointer => Ok(self.emit(Opcode::Pointer)),
153            Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
154            Node::Array(v) => {
155                v.iter()
156                    .try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
157                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
158                Ok(self.emit(Opcode::Array))
159            }
160            Node::Object(v) => {
161                v.iter().try_for_each(|&(key, value)| {
162                    self.compile_node(key).map(|_| ())?;
163                    self.emit(Opcode::CallFunction {
164                        arg_count: 1,
165                        kind: FunctionKind::Internal(InternalFunction::String),
166                    });
167                    self.compile_node(value).map(|_| ())?;
168                    Ok(())
169                })?;
170
171                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
172                Ok(self.emit(Opcode::Object))
173            }
174            Node::Assignments { list, output } => {
175                self.emit(Opcode::AssignedObjectBegin);
176                list.iter().try_for_each(|&(key, value)| {
177                    self.compile_node(key).map(|_| ())?;
178                    self.compile_node(value).map(|_| ())?;
179                    self.emit(Opcode::AssignedObjectStep);
180
181                    Ok(())
182                })?;
183
184                if let Some(output) = output {
185                    self.compile_node(output).map(|_| ())?;
186                }
187
188                Ok(self.emit(Opcode::AssignedObjectEnd {
189                    with_return: output.is_some(),
190                }))
191            }
192            Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(Arc::from(*v)))),
193            Node::Closure(v) => self.compile_node(v),
194            Node::Parenthesized(v) => self.compile_node(v),
195            Node::Member {
196                node: n,
197                property: p,
198            } => {
199                if let Some(path) = self.compile_member_fast(node) {
200                    Ok(self.emit(Opcode::FetchFast(path)))
201                } else {
202                    self.compile_node(n)?;
203                    self.compile_node(p)?;
204                    Ok(self.emit(Opcode::Fetch))
205                }
206            }
207            Node::TemplateString(parts) => {
208                parts.iter().try_for_each(|&n| {
209                    self.compile_node(n).map(|_| ())?;
210                    self.emit(Opcode::CallFunction {
211                        arg_count: 1,
212                        kind: FunctionKind::Internal(InternalFunction::String),
213                    });
214                    Ok(())
215                })?;
216
217                self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
218                self.emit(Opcode::Array);
219                self.emit(Opcode::PushString(Arc::from("")));
220                Ok(self.emit(Opcode::Join))
221            }
222            Node::Slice { node, to, from } => {
223                self.compile_node(node)?;
224                if let Some(t) = to {
225                    self.compile_node(t)?;
226                } else {
227                    self.emit(Opcode::Len);
228                    self.emit(Opcode::PushNumber(dec!(1)));
229                    self.emit(Opcode::Subtract);
230                }
231
232                if let Some(f) = from {
233                    self.compile_node(f)?;
234                } else {
235                    self.emit(Opcode::PushNumber(dec!(0)));
236                }
237
238                Ok(self.emit(Opcode::Slice))
239            }
240            Node::Interval {
241                left,
242                right,
243                left_bracket,
244                right_bracket,
245            } => {
246                self.compile_node(left)?;
247                self.compile_node(right)?;
248                Ok(self.emit(Opcode::Interval {
249                    left_bracket: *left_bracket,
250                    right_bracket: *right_bracket,
251                }))
252            }
253            Node::Conditional {
254                condition,
255                on_true,
256                on_false,
257            } => {
258                self.compile_node(condition)?;
259                let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0));
260
261                self.emit(Opcode::Pop);
262                self.compile_node(on_true)?;
263                let end = self.emit(Opcode::Jump(Jump::Forward, 0));
264
265                self.replace(
266                    otherwise,
267                    Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
268                );
269                self.emit(Opcode::Pop);
270                let b = self.compile_node(on_false)?;
271                self.replace(end, Opcode::Jump(Jump::Forward, (b - end) as u32));
272
273                Ok(b)
274            }
275            Node::Unary { node, operator } => {
276                let curr = self.compile_node(node)?;
277                match *operator {
278                    Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr),
279                    Operator::Arithmetic(ArithmeticOperator::Subtract) => {
280                        Ok(self.emit(Opcode::Negate))
281                    }
282                    Operator::Logical(LogicalOperator::Not) => Ok(self.emit(Opcode::Not)),
283                    _ => Err(CompilerError::UnknownUnaryOperator {
284                        operator: operator.to_string(),
285                    }),
286                }
287            }
288            Node::Binary {
289                left,
290                right,
291                operator,
292            } => match *operator {
293                Operator::Comparison(ComparisonOperator::Equal) => {
294                    self.compile_node(left)?;
295                    self.compile_node(right)?;
296
297                    Ok(self.emit(Opcode::Equal))
298                }
299                Operator::Comparison(ComparisonOperator::NotEqual) => {
300                    self.compile_node(left)?;
301                    self.compile_node(right)?;
302
303                    self.emit(Opcode::Equal);
304                    Ok(self.emit(Opcode::Not))
305                }
306                Operator::Logical(LogicalOperator::Or) => {
307                    self.compile_node(left)?;
308                    let end = self.emit(Opcode::Jump(Jump::IfTrue, 0));
309                    self.emit(Opcode::Pop);
310                    let r = self.compile_node(right)?;
311                    self.replace(end, Opcode::Jump(Jump::IfTrue, (r - end) as u32));
312
313                    Ok(r)
314                }
315                Operator::Logical(LogicalOperator::And) => {
316                    self.compile_node(left)?;
317                    let end = self.emit(Opcode::Jump(Jump::IfFalse, 0));
318                    self.emit(Opcode::Pop);
319                    let r = self.compile_node(right)?;
320                    self.replace(end, Opcode::Jump(Jump::IfFalse, (r - end) as u32));
321
322                    Ok(r)
323                }
324                Operator::Logical(LogicalOperator::NullishCoalescing) => {
325                    self.compile_node(left)?;
326                    let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0));
327                    self.emit(Opcode::Pop);
328                    let r = self.compile_node(right)?;
329                    self.replace(end, Opcode::Jump(Jump::IfNotNull, (r - end) as u32));
330
331                    Ok(r)
332                }
333                Operator::Comparison(ComparisonOperator::In) => {
334                    self.compile_node(left)?;
335                    self.compile_node(right)?;
336                    Ok(self.emit(Opcode::In))
337                }
338                Operator::Comparison(ComparisonOperator::NotIn) => {
339                    self.compile_node(left)?;
340                    self.compile_node(right)?;
341                    self.emit(Opcode::In);
342                    Ok(self.emit(Opcode::Not))
343                }
344                Operator::Comparison(ComparisonOperator::LessThan) => {
345                    self.compile_node(left)?;
346                    self.compile_node(right)?;
347                    Ok(self.emit(Opcode::Compare(Compare::Less)))
348                }
349                Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
350                    self.compile_node(left)?;
351                    self.compile_node(right)?;
352                    Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
353                }
354                Operator::Comparison(ComparisonOperator::GreaterThan) => {
355                    self.compile_node(left)?;
356                    self.compile_node(right)?;
357                    Ok(self.emit(Opcode::Compare(Compare::More)))
358                }
359                Operator::Comparison(ComparisonOperator::GreaterThanOrEqual) => {
360                    self.compile_node(left)?;
361                    self.compile_node(right)?;
362                    Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
363                }
364                Operator::Arithmetic(ArithmeticOperator::Add) => {
365                    self.compile_node(left)?;
366                    self.compile_node(right)?;
367                    Ok(self.emit(Opcode::Add))
368                }
369                Operator::Arithmetic(ArithmeticOperator::Subtract) => {
370                    self.compile_node(left)?;
371                    self.compile_node(right)?;
372                    Ok(self.emit(Opcode::Subtract))
373                }
374                Operator::Arithmetic(ArithmeticOperator::Multiply) => {
375                    self.compile_node(left)?;
376                    self.compile_node(right)?;
377                    Ok(self.emit(Opcode::Multiply))
378                }
379                Operator::Arithmetic(ArithmeticOperator::Divide) => {
380                    self.compile_node(left)?;
381                    self.compile_node(right)?;
382                    Ok(self.emit(Opcode::Divide))
383                }
384                Operator::Arithmetic(ArithmeticOperator::Modulus) => {
385                    self.compile_node(left)?;
386                    self.compile_node(right)?;
387                    Ok(self.emit(Opcode::Modulo))
388                }
389                Operator::Arithmetic(ArithmeticOperator::Power) => {
390                    self.compile_node(left)?;
391                    self.compile_node(right)?;
392                    Ok(self.emit(Opcode::Exponent))
393                }
394                _ => Err(CompilerError::UnknownBinaryOperator {
395                    operator: operator.to_string(),
396                }),
397            },
398            Node::FunctionCall { kind, arguments } => match kind {
399                FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => {
400                    let function = FunctionRegistry::get_definition(kind).ok_or_else(|| {
401                        CompilerError::UnknownFunction {
402                            name: kind.to_string(),
403                        }
404                    })?;
405
406                    let min_params = function.required_parameters();
407                    let max_params = min_params + function.optional_parameters();
408                    if arguments.len() < min_params || arguments.len() > max_params {
409                        return Err(CompilerError::InvalidFunctionCall {
410                            name: kind.to_string(),
411                            message: "Invalid number of arguments".to_string(),
412                        });
413                    }
414
415                    for i in 0..arguments.len() {
416                        self.compile_argument(kind, arguments, i)?;
417                    }
418
419                    Ok(self.emit(Opcode::CallFunction {
420                        kind: kind.clone(),
421                        arg_count: arguments.len() as u32,
422                    }))
423                }
424                FunctionKind::Closure(c) => match c {
425                    ClosureFunction::All => {
426                        self.compile_argument(kind, arguments, 0)?;
427                        self.emit(Opcode::Begin);
428                        let mut loop_break: usize = 0;
429                        self.emit_loop(|c| {
430                            c.compile_argument(kind, arguments, 1)?;
431                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
432                            c.emit(Opcode::Pop);
433                            Ok(())
434                        })?;
435                        let e = self.emit(Opcode::PushBool(true));
436                        self.replace(
437                            loop_break,
438                            Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
439                        );
440                        Ok(self.emit(Opcode::End))
441                    }
442                    ClosureFunction::None => {
443                        self.compile_argument(kind, arguments, 0)?;
444                        self.emit(Opcode::Begin);
445                        let mut loop_break: usize = 0;
446                        self.emit_loop(|c| {
447                            c.compile_argument(kind, arguments, 1)?;
448                            c.emit(Opcode::Not);
449                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
450                            c.emit(Opcode::Pop);
451                            Ok(())
452                        })?;
453                        let e = self.emit(Opcode::PushBool(true));
454                        self.replace(
455                            loop_break,
456                            Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
457                        );
458                        Ok(self.emit(Opcode::End))
459                    }
460                    ClosureFunction::Some => {
461                        self.compile_argument(kind, arguments, 0)?;
462                        self.emit(Opcode::Begin);
463                        let mut loop_break: usize = 0;
464                        self.emit_loop(|c| {
465                            c.compile_argument(kind, arguments, 1)?;
466                            loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0));
467                            c.emit(Opcode::Pop);
468                            Ok(())
469                        })?;
470                        let e = self.emit(Opcode::PushBool(false));
471                        self.replace(
472                            loop_break,
473                            Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
474                        );
475                        Ok(self.emit(Opcode::End))
476                    }
477                    ClosureFunction::One => {
478                        self.compile_argument(kind, arguments, 0)?;
479                        self.emit(Opcode::Begin);
480                        self.emit_loop(|c| {
481                            c.compile_argument(kind, arguments, 1)?;
482                            c.emit_cond(|c| {
483                                c.emit(Opcode::IncrementCount);
484                            });
485                            Ok(())
486                        })?;
487                        self.emit(Opcode::GetCount);
488                        self.emit(Opcode::PushNumber(dec!(1)));
489                        self.emit(Opcode::Equal);
490                        Ok(self.emit(Opcode::End))
491                    }
492                    ClosureFunction::Filter => {
493                        self.compile_argument(kind, arguments, 0)?;
494                        self.emit(Opcode::Begin);
495                        self.emit_loop(|c| {
496                            c.compile_argument(kind, arguments, 1)?;
497                            c.emit_cond(|c| {
498                                c.emit(Opcode::IncrementCount);
499                                c.emit(Opcode::Pointer);
500                            });
501                            Ok(())
502                        })?;
503                        self.emit(Opcode::GetCount);
504                        self.emit(Opcode::End);
505                        Ok(self.emit(Opcode::Array))
506                    }
507                    ClosureFunction::Map => {
508                        self.compile_argument(kind, arguments, 0)?;
509                        self.emit(Opcode::Begin);
510                        self.emit_loop(|c| {
511                            c.compile_argument(kind, arguments, 1)?;
512                            Ok(())
513                        })?;
514                        self.emit(Opcode::GetLen);
515                        self.emit(Opcode::End);
516                        Ok(self.emit(Opcode::Array))
517                    }
518                    ClosureFunction::FlatMap => {
519                        self.compile_argument(kind, arguments, 0)?;
520                        self.emit(Opcode::Begin);
521                        self.emit_loop(|c| {
522                            c.compile_argument(kind, arguments, 1)?;
523                            Ok(())
524                        })?;
525                        self.emit(Opcode::GetLen);
526                        self.emit(Opcode::End);
527                        self.emit(Opcode::Array);
528                        Ok(self.emit(Opcode::Flatten))
529                    }
530                    ClosureFunction::Count => {
531                        self.compile_argument(kind, arguments, 0)?;
532                        self.emit(Opcode::Begin);
533                        self.emit_loop(|c| {
534                            c.compile_argument(kind, arguments, 1)?;
535                            c.emit_cond(|c| {
536                                c.emit(Opcode::IncrementCount);
537                            });
538                            Ok(())
539                        })?;
540                        self.emit(Opcode::GetCount);
541                        Ok(self.emit(Opcode::End))
542                    }
543                },
544            },
545            Node::MethodCall {
546                kind,
547                this,
548                arguments,
549            } => {
550                let method = MethodRegistry::get_definition(kind).ok_or_else(|| {
551                    CompilerError::UnknownFunction {
552                        name: kind.to_string(),
553                    }
554                })?;
555
556                self.compile_node(this)?;
557
558                let min_params = method.required_parameters() - 1;
559                let max_params = min_params + method.optional_parameters();
560                if arguments.len() < min_params || arguments.len() > max_params {
561                    return Err(CompilerError::InvalidMethodCall {
562                        name: kind.to_string(),
563                        message: "Invalid number of arguments".to_string(),
564                    });
565                }
566
567                for i in 0..arguments.len() {
568                    self.compile_argument(kind, arguments, i)?;
569                }
570
571                Ok(self.emit(Opcode::CallMethod {
572                    kind: kind.clone(),
573                    arg_count: arguments.len() as u32,
574                }))
575            }
576            Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
577        }
578    }
579}