zen_expression/compiler/
compiler.rs

1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{ClosureFunction, FunctionKind, InternalFunction, MethodRegistry};
6use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator};
7use crate::parser::Node;
8use rust_decimal::prelude::ToPrimitive;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11use std::sync::Arc;
12
13#[derive(Debug)]
14pub struct Compiler {
15    bytecode: Vec<Opcode>,
16}
17
18impl Compiler {
19    pub fn new() -> Self {
20        Self {
21            bytecode: Default::default(),
22        }
23    }
24
25    pub fn compile(&mut self, root: &Node) -> CompilerResult<&[Opcode]> {
26        self.bytecode.clear();
27
28        CompilerInner::new(&mut self.bytecode, root).compile()?;
29        Ok(self.bytecode.as_slice())
30    }
31
32    pub fn get_bytecode(&self) -> &[Opcode] {
33        self.bytecode.as_slice()
34    }
35}
36
37#[derive(Debug)]
38struct CompilerInner<'arena, 'bytecode_ref> {
39    root: &'arena Node<'arena>,
40    bytecode: &'bytecode_ref mut Vec<Opcode>,
41    closure_aliases: Vec<Option<&'arena str>>,
42}
43
44impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
45    pub fn new(bytecode: &'bytecode_ref mut Vec<Opcode>, root: &'arena Node<'arena>) -> Self {
46        Self {
47            root,
48            bytecode,
49            closure_aliases: Vec::new(),
50        }
51    }
52
53    fn lookup_alias(&self, name: &str) -> Option<u32> {
54        self.closure_aliases
55            .iter()
56            .enumerate()
57            .rev()
58            .find_map(|(index, &alias)| {
59                alias.and_then(|a| {
60                    (a == name).then_some((self.closure_aliases.len() - 1 - index) as u32)
61                })
62            })
63    }
64
65    pub fn compile(&mut self) -> CompilerResult<()> {
66        self.compile_node(self.root)?;
67        Ok(())
68    }
69
70    fn emit(&mut self, op: Opcode) -> usize {
71        self.bytecode.push(op);
72        self.bytecode.len()
73    }
74
75    fn emit_loop<F>(&mut self, body: F) -> CompilerResult<()>
76    where
77        F: FnOnce(&mut Self) -> CompilerResult<()>,
78    {
79        let begin = self.bytecode.len();
80        let end = self.emit(Opcode::Jump(Jump::IfEnd, 0));
81
82        body(self)?;
83
84        self.emit(Opcode::IncrementIt);
85        let e = self.emit(Opcode::Jump(
86            Jump::Backward,
87            self.calc_backward_jump(begin) as u32,
88        ));
89        self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
90        Ok(())
91    }
92
93    fn emit_cond<F>(&mut self, mut body: F)
94    where
95        F: FnMut(&mut Self),
96    {
97        let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0));
98        self.emit(Opcode::Pop);
99
100        body(self);
101
102        let jmp = self.emit(Opcode::Jump(Jump::Forward, 0));
103        self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
104        let e = self.emit(Opcode::Pop);
105        self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
106    }
107
108    fn replace(&mut self, at: usize, op: Opcode) {
109        let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
110    }
111
112    fn calc_backward_jump(&self, to: usize) -> usize {
113        self.bytecode.len() + 1 - to
114    }
115
116    fn compile_argument<T: ToString>(
117        &mut self,
118        function_kind: T,
119        arguments: &[&'arena Node<'arena>],
120        index: usize,
121    ) -> CompilerResult<usize> {
122        let arg = arguments
123            .get(index)
124            .ok_or_else(|| CompilerError::ArgumentNotFound {
125                index,
126                function: function_kind.to_string(),
127            })?;
128
129        self.compile_node(arg)
130    }
131
132    #[cfg_attr(not(target_family = "wasm"), recursive::recursive)]
133    fn compile_member_fast(&mut self, node: &'arena Node<'arena>) -> Option<Vec<FetchFastTarget>> {
134        match node {
135            Node::Root => Some(vec![FetchFastTarget::Root]),
136            Node::Identifier(v) => self.lookup_alias(v).is_none().then(|| {
137                vec![
138                    FetchFastTarget::Begin,
139                    FetchFastTarget::String(Arc::from(*v)),
140                ]
141            }),
142            Node::Member { node, property } => {
143                let mut path = self.compile_member_fast(node)?;
144                match property {
145                    Node::String(v) => {
146                        path.push(FetchFastTarget::String(Arc::from(*v)));
147                        Some(path)
148                    }
149                    Node::Number(v) => {
150                        if let Some(idx) = v.to_u32() {
151                            path.push(FetchFastTarget::Number(idx));
152                            Some(path)
153                        } else {
154                            None
155                        }
156                    }
157                    _ => None,
158                }
159            }
160            _ => None,
161        }
162    }
163
164    #[cfg_attr(not(target_family = "wasm"), recursive::recursive)]
165    fn compile_node(&mut self, node: &'arena Node<'arena>) -> CompilerResult<usize> {
166        match node {
167            Node::Null => Ok(self.emit(Opcode::PushNull)),
168            Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
169            Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
170            Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
171            Node::Pointer => Ok(self.emit(Opcode::Pointer(0))),
172            Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
173            Node::Array(v) => {
174                v.iter()
175                    .try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
176                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
177                Ok(self.emit(Opcode::Array))
178            }
179            Node::Object(v) => {
180                v.iter().try_for_each(|&(key, value)| {
181                    self.compile_node(key).map(|_| ())?;
182                    self.emit(Opcode::CallFunction {
183                        arg_count: 1,
184                        kind: FunctionKind::Internal(InternalFunction::String),
185                    });
186                    self.compile_node(value).map(|_| ())?;
187                    Ok(())
188                })?;
189
190                self.emit(Opcode::PushNumber(Decimal::from(v.len())));
191                Ok(self.emit(Opcode::Object))
192            }
193            Node::Assignments { list, output } => {
194                self.emit(Opcode::AssignedObjectBegin);
195                list.iter().try_for_each(|&(key, value)| {
196                    self.compile_node(key).map(|_| ())?;
197                    self.compile_node(value).map(|_| ())?;
198                    self.emit(Opcode::AssignedObjectStep);
199
200                    Ok(())
201                })?;
202
203                if let Some(output) = output {
204                    self.compile_node(output).map(|_| ())?;
205                }
206
207                Ok(self.emit(Opcode::AssignedObjectEnd {
208                    with_return: output.is_some(),
209                }))
210            }
211            Node::Identifier(v) => Ok(self.emit(
212                self.lookup_alias(v)
213                    .map_or_else(|| Opcode::FetchEnv(Arc::from(*v)), Opcode::Pointer),
214            )),
215            Node::Closure { body, alias } => {
216                self.closure_aliases.push(*alias);
217                let result = self.compile_node(body);
218                self.closure_aliases.pop();
219                result
220            }
221            Node::Parenthesized(v) => self.compile_node(v),
222            Node::Member {
223                node: n,
224                property: p,
225            } => {
226                if let Some(path) = self.compile_member_fast(node) {
227                    Ok(self.emit(Opcode::FetchFast(path)))
228                } else {
229                    self.compile_node(n)?;
230                    self.compile_node(p)?;
231                    Ok(self.emit(Opcode::Fetch))
232                }
233            }
234            Node::TemplateString(parts) => {
235                parts.iter().try_for_each(|&n| {
236                    self.compile_node(n).map(|_| ())?;
237                    self.emit(Opcode::CallFunction {
238                        arg_count: 1,
239                        kind: FunctionKind::Internal(InternalFunction::String),
240                    });
241                    Ok(())
242                })?;
243
244                self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
245                self.emit(Opcode::Array);
246                self.emit(Opcode::PushString(Arc::from("")));
247                Ok(self.emit(Opcode::Join))
248            }
249            Node::Slice { node, to, from } => {
250                self.compile_node(node)?;
251                if let Some(t) = to {
252                    self.compile_node(t)?;
253                } else {
254                    self.emit(Opcode::Len);
255                    self.emit(Opcode::PushNumber(dec!(1)));
256                    self.emit(Opcode::Subtract);
257                }
258
259                if let Some(f) = from {
260                    self.compile_node(f)?;
261                } else {
262                    self.emit(Opcode::PushNumber(dec!(0)));
263                }
264
265                Ok(self.emit(Opcode::Slice))
266            }
267            Node::Interval {
268                left,
269                right,
270                left_bracket,
271                right_bracket,
272            } => {
273                self.compile_node(left)?;
274                self.compile_node(right)?;
275                Ok(self.emit(Opcode::Interval {
276                    left_bracket: *left_bracket,
277                    right_bracket: *right_bracket,
278                }))
279            }
280            Node::Conditional {
281                condition,
282                on_true,
283                on_false,
284            } => {
285                self.compile_node(condition)?;
286                let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0));
287
288                self.emit(Opcode::Pop);
289                self.compile_node(on_true)?;
290                let end = self.emit(Opcode::Jump(Jump::Forward, 0));
291
292                self.replace(
293                    otherwise,
294                    Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
295                );
296                self.emit(Opcode::Pop);
297                let b = self.compile_node(on_false)?;
298                self.replace(end, Opcode::Jump(Jump::Forward, (b - end) as u32));
299
300                Ok(b)
301            }
302            Node::Unary { node, operator } => {
303                let curr = self.compile_node(node)?;
304                match *operator {
305                    Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr),
306                    Operator::Arithmetic(ArithmeticOperator::Subtract) => {
307                        Ok(self.emit(Opcode::Negate))
308                    }
309                    Operator::Logical(LogicalOperator::Not) => Ok(self.emit(Opcode::Not)),
310                    _ => Err(CompilerError::UnknownUnaryOperator {
311                        operator: operator.to_string(),
312                    }),
313                }
314            }
315            Node::Binary {
316                left,
317                right,
318                operator,
319            } => match *operator {
320                Operator::Comparison(ComparisonOperator::Equal) => {
321                    self.compile_node(left)?;
322                    self.compile_node(right)?;
323
324                    Ok(self.emit(Opcode::Equal))
325                }
326                Operator::Comparison(ComparisonOperator::NotEqual) => {
327                    self.compile_node(left)?;
328                    self.compile_node(right)?;
329
330                    self.emit(Opcode::Equal);
331                    Ok(self.emit(Opcode::Not))
332                }
333                Operator::Logical(LogicalOperator::Or) => {
334                    self.compile_node(left)?;
335                    let end = self.emit(Opcode::Jump(Jump::IfTrue, 0));
336                    self.emit(Opcode::Pop);
337                    let r = self.compile_node(right)?;
338                    self.replace(end, Opcode::Jump(Jump::IfTrue, (r - end) as u32));
339
340                    Ok(r)
341                }
342                Operator::Logical(LogicalOperator::And) => {
343                    self.compile_node(left)?;
344                    let end = self.emit(Opcode::Jump(Jump::IfFalse, 0));
345                    self.emit(Opcode::Pop);
346                    let r = self.compile_node(right)?;
347                    self.replace(end, Opcode::Jump(Jump::IfFalse, (r - end) as u32));
348
349                    Ok(r)
350                }
351                Operator::Logical(LogicalOperator::NullishCoalescing) => {
352                    self.compile_node(left)?;
353                    let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0));
354                    self.emit(Opcode::Pop);
355                    let r = self.compile_node(right)?;
356                    self.replace(end, Opcode::Jump(Jump::IfNotNull, (r - end) as u32));
357
358                    Ok(r)
359                }
360                Operator::Comparison(ComparisonOperator::In) => {
361                    self.compile_node(left)?;
362                    self.compile_node(right)?;
363                    Ok(self.emit(Opcode::In))
364                }
365                Operator::Comparison(ComparisonOperator::NotIn) => {
366                    self.compile_node(left)?;
367                    self.compile_node(right)?;
368                    self.emit(Opcode::In);
369                    Ok(self.emit(Opcode::Not))
370                }
371                Operator::Comparison(ComparisonOperator::LessThan) => {
372                    self.compile_node(left)?;
373                    self.compile_node(right)?;
374                    Ok(self.emit(Opcode::Compare(Compare::Less)))
375                }
376                Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
377                    self.compile_node(left)?;
378                    self.compile_node(right)?;
379                    Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
380                }
381                Operator::Comparison(ComparisonOperator::GreaterThan) => {
382                    self.compile_node(left)?;
383                    self.compile_node(right)?;
384                    Ok(self.emit(Opcode::Compare(Compare::More)))
385                }
386                Operator::Comparison(ComparisonOperator::GreaterThanOrEqual) => {
387                    self.compile_node(left)?;
388                    self.compile_node(right)?;
389                    Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
390                }
391                Operator::Arithmetic(ArithmeticOperator::Add) => {
392                    self.compile_node(left)?;
393                    self.compile_node(right)?;
394                    Ok(self.emit(Opcode::Add))
395                }
396                Operator::Arithmetic(ArithmeticOperator::Subtract) => {
397                    self.compile_node(left)?;
398                    self.compile_node(right)?;
399                    Ok(self.emit(Opcode::Subtract))
400                }
401                Operator::Arithmetic(ArithmeticOperator::Multiply) => {
402                    self.compile_node(left)?;
403                    self.compile_node(right)?;
404                    Ok(self.emit(Opcode::Multiply))
405                }
406                Operator::Arithmetic(ArithmeticOperator::Divide) => {
407                    self.compile_node(left)?;
408                    self.compile_node(right)?;
409                    Ok(self.emit(Opcode::Divide))
410                }
411                Operator::Arithmetic(ArithmeticOperator::Modulus) => {
412                    self.compile_node(left)?;
413                    self.compile_node(right)?;
414                    Ok(self.emit(Opcode::Modulo))
415                }
416                Operator::Arithmetic(ArithmeticOperator::Power) => {
417                    self.compile_node(left)?;
418                    self.compile_node(right)?;
419                    Ok(self.emit(Opcode::Exponent))
420                }
421                _ => Err(CompilerError::UnknownBinaryOperator {
422                    operator: operator.to_string(),
423                }),
424            },
425            Node::FunctionCall { kind, arguments } => match kind {
426                FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => {
427                    let function = FunctionRegistry::get_definition(kind).ok_or_else(|| {
428                        CompilerError::UnknownFunction {
429                            name: kind.to_string(),
430                        }
431                    })?;
432
433                    let min_params = function.required_parameters();
434                    let max_params = min_params + function.optional_parameters();
435                    if arguments.len() < min_params || arguments.len() > max_params {
436                        return Err(CompilerError::InvalidFunctionCall {
437                            name: kind.to_string(),
438                            message: "Invalid number of arguments".to_string(),
439                        });
440                    }
441
442                    for i in 0..arguments.len() {
443                        self.compile_argument(kind, arguments, i)?;
444                    }
445
446                    Ok(self.emit(Opcode::CallFunction {
447                        kind: kind.clone(),
448                        arg_count: arguments.len() as u32,
449                    }))
450                }
451                FunctionKind::Closure(c) => match c {
452                    ClosureFunction::All => {
453                        self.compile_argument(kind, arguments, 0)?;
454                        self.emit(Opcode::Begin);
455                        let mut loop_break: usize = 0;
456                        self.emit_loop(|c| {
457                            c.compile_argument(kind, arguments, 1)?;
458                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
459                            c.emit(Opcode::Pop);
460                            Ok(())
461                        })?;
462                        let e = self.emit(Opcode::PushBool(true));
463                        self.replace(
464                            loop_break,
465                            Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
466                        );
467                        Ok(self.emit(Opcode::End))
468                    }
469                    ClosureFunction::None => {
470                        self.compile_argument(kind, arguments, 0)?;
471                        self.emit(Opcode::Begin);
472                        let mut loop_break: usize = 0;
473                        self.emit_loop(|c| {
474                            c.compile_argument(kind, arguments, 1)?;
475                            c.emit(Opcode::Not);
476                            loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
477                            c.emit(Opcode::Pop);
478                            Ok(())
479                        })?;
480                        let e = self.emit(Opcode::PushBool(true));
481                        self.replace(
482                            loop_break,
483                            Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
484                        );
485                        Ok(self.emit(Opcode::End))
486                    }
487                    ClosureFunction::Some => {
488                        self.compile_argument(kind, arguments, 0)?;
489                        self.emit(Opcode::Begin);
490                        let mut loop_break: usize = 0;
491                        self.emit_loop(|c| {
492                            c.compile_argument(kind, arguments, 1)?;
493                            loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0));
494                            c.emit(Opcode::Pop);
495                            Ok(())
496                        })?;
497                        let e = self.emit(Opcode::PushBool(false));
498                        self.replace(
499                            loop_break,
500                            Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
501                        );
502                        Ok(self.emit(Opcode::End))
503                    }
504                    ClosureFunction::One => {
505                        self.compile_argument(kind, arguments, 0)?;
506                        self.emit(Opcode::Begin);
507                        self.emit_loop(|c| {
508                            c.compile_argument(kind, arguments, 1)?;
509                            c.emit_cond(|c| {
510                                c.emit(Opcode::IncrementCount);
511                            });
512                            Ok(())
513                        })?;
514                        self.emit(Opcode::GetCount);
515                        self.emit(Opcode::PushNumber(dec!(1)));
516                        self.emit(Opcode::Equal);
517                        Ok(self.emit(Opcode::End))
518                    }
519                    ClosureFunction::Filter => {
520                        self.compile_argument(kind, arguments, 0)?;
521                        self.emit(Opcode::Begin);
522                        self.emit_loop(|c| {
523                            c.compile_argument(kind, arguments, 1)?;
524                            c.emit_cond(|c| {
525                                c.emit(Opcode::IncrementCount);
526                                c.emit(Opcode::Pointer(0));
527                            });
528                            Ok(())
529                        })?;
530                        self.emit(Opcode::GetCount);
531                        self.emit(Opcode::End);
532                        Ok(self.emit(Opcode::Array))
533                    }
534                    ClosureFunction::Map => {
535                        self.compile_argument(kind, arguments, 0)?;
536                        self.emit(Opcode::Begin);
537                        self.emit_loop(|c| {
538                            c.compile_argument(kind, arguments, 1)?;
539                            Ok(())
540                        })?;
541                        self.emit(Opcode::GetLen);
542                        self.emit(Opcode::End);
543                        Ok(self.emit(Opcode::Array))
544                    }
545                    ClosureFunction::FlatMap => {
546                        self.compile_argument(kind, arguments, 0)?;
547                        self.emit(Opcode::Begin);
548                        self.emit_loop(|c| {
549                            c.compile_argument(kind, arguments, 1)?;
550                            Ok(())
551                        })?;
552                        self.emit(Opcode::GetLen);
553                        self.emit(Opcode::End);
554                        self.emit(Opcode::Array);
555                        Ok(self.emit(Opcode::Flatten))
556                    }
557                    ClosureFunction::Count => {
558                        self.compile_argument(kind, arguments, 0)?;
559                        self.emit(Opcode::Begin);
560                        self.emit_loop(|c| {
561                            c.compile_argument(kind, arguments, 1)?;
562                            c.emit_cond(|c| {
563                                c.emit(Opcode::IncrementCount);
564                            });
565                            Ok(())
566                        })?;
567                        self.emit(Opcode::GetCount);
568                        Ok(self.emit(Opcode::End))
569                    }
570                },
571            },
572            Node::MethodCall {
573                kind,
574                this,
575                arguments,
576            } => {
577                let method = MethodRegistry::get_definition(kind).ok_or_else(|| {
578                    CompilerError::UnknownFunction {
579                        name: kind.to_string(),
580                    }
581                })?;
582
583                self.compile_node(this)?;
584
585                let min_params = method.required_parameters() - 1;
586                let max_params = min_params + method.optional_parameters();
587                if arguments.len() < min_params || arguments.len() > max_params {
588                    return Err(CompilerError::InvalidMethodCall {
589                        name: kind.to_string(),
590                        message: "Invalid number of arguments".to_string(),
591                    });
592                }
593
594                for i in 0..arguments.len() {
595                    self.compile_argument(kind, arguments, i)?;
596                }
597
598                Ok(self.emit(Opcode::CallMethod {
599                    kind: kind.clone(),
600                    arg_count: arguments.len() as u32,
601                }))
602            }
603            Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
604        }
605    }
606}