1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{ClosureFunction, FunctionKind, InternalFunction, MethodRegistry};
6use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator};
7use crate::parser::Node;
8use rust_decimal::prelude::ToPrimitive;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11use std::sync::Arc;
12
13#[derive(Debug)]
14pub struct Compiler {
15 bytecode: Vec<Opcode>,
16}
17
18impl Compiler {
19 pub fn new() -> Self {
20 Self {
21 bytecode: Default::default(),
22 }
23 }
24
25 pub fn compile(&mut self, root: &Node) -> CompilerResult<&[Opcode]> {
26 self.bytecode.clear();
27
28 CompilerInner::new(&mut self.bytecode, root).compile()?;
29 Ok(self.bytecode.as_slice())
30 }
31
32 pub fn get_bytecode(&self) -> &[Opcode] {
33 self.bytecode.as_slice()
34 }
35}
36
37#[derive(Debug)]
38struct CompilerInner<'arena, 'bytecode_ref> {
39 root: &'arena Node<'arena>,
40 bytecode: &'bytecode_ref mut Vec<Opcode>,
41}
42
43impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
44 pub fn new(bytecode: &'bytecode_ref mut Vec<Opcode>, root: &'arena Node<'arena>) -> Self {
45 Self { root, bytecode }
46 }
47
48 pub fn compile(&mut self) -> CompilerResult<()> {
49 self.compile_node(self.root)?;
50 Ok(())
51 }
52
53 fn emit(&mut self, op: Opcode) -> usize {
54 self.bytecode.push(op);
55 self.bytecode.len()
56 }
57
58 fn emit_loop<F>(&mut self, body: F) -> CompilerResult<()>
59 where
60 F: FnOnce(&mut Self) -> CompilerResult<()>,
61 {
62 let begin = self.bytecode.len();
63 let end = self.emit(Opcode::Jump(Jump::IfEnd, 0));
64
65 body(self)?;
66
67 self.emit(Opcode::IncrementIt);
68 let e = self.emit(Opcode::Jump(
69 Jump::Backward,
70 self.calc_backward_jump(begin) as u32,
71 ));
72 self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
73 Ok(())
74 }
75
76 fn emit_cond<F>(&mut self, mut body: F)
77 where
78 F: FnMut(&mut Self),
79 {
80 let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0));
81 self.emit(Opcode::Pop);
82
83 body(self);
84
85 let jmp = self.emit(Opcode::Jump(Jump::Forward, 0));
86 self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
87 let e = self.emit(Opcode::Pop);
88 self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
89 }
90
91 fn replace(&mut self, at: usize, op: Opcode) {
92 let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
93 }
94
95 fn calc_backward_jump(&self, to: usize) -> usize {
96 self.bytecode.len() + 1 - to
97 }
98
99 fn compile_argument<T: ToString>(
100 &mut self,
101 function_kind: T,
102 arguments: &[&'arena Node<'arena>],
103 index: usize,
104 ) -> CompilerResult<usize> {
105 let arg = arguments
106 .get(index)
107 .ok_or_else(|| CompilerError::ArgumentNotFound {
108 index,
109 function: function_kind.to_string(),
110 })?;
111
112 self.compile_node(arg)
113 }
114
115 #[cfg_attr(feature = "stack-protection", recursive::recursive)]
116 fn compile_member_fast(&mut self, node: &'arena Node<'arena>) -> Option<Vec<FetchFastTarget>> {
117 match node {
118 Node::Root => Some(vec![FetchFastTarget::Root]),
119 Node::Identifier(v) => Some(vec![
120 FetchFastTarget::Begin,
121 FetchFastTarget::String(Arc::from(*v)),
122 ]),
123 Node::Member { node, property } => {
124 let mut path = self.compile_member_fast(node)?;
125 match property {
126 Node::String(v) => {
127 path.push(FetchFastTarget::String(Arc::from(*v)));
128 Some(path)
129 }
130 Node::Number(v) => {
131 if let Some(idx) = v.to_u32() {
132 path.push(FetchFastTarget::Number(idx));
133 Some(path)
134 } else {
135 None
136 }
137 }
138 _ => None,
139 }
140 }
141 _ => None,
142 }
143 }
144
145 #[cfg_attr(feature = "stack-protection", recursive::recursive)]
146 fn compile_node(&mut self, node: &'arena Node<'arena>) -> CompilerResult<usize> {
147 match node {
148 Node::Null => Ok(self.emit(Opcode::PushNull)),
149 Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
150 Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
151 Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
152 Node::Pointer => Ok(self.emit(Opcode::Pointer)),
153 Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
154 Node::Array(v) => {
155 v.iter()
156 .try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
157 self.emit(Opcode::PushNumber(Decimal::from(v.len())));
158 Ok(self.emit(Opcode::Array))
159 }
160 Node::Object(v) => {
161 v.iter().try_for_each(|&(key, value)| {
162 self.compile_node(key).map(|_| ())?;
163 self.emit(Opcode::CallFunction {
164 arg_count: 1,
165 kind: FunctionKind::Internal(InternalFunction::String),
166 });
167 self.compile_node(value).map(|_| ())?;
168 Ok(())
169 })?;
170
171 self.emit(Opcode::PushNumber(Decimal::from(v.len())));
172 Ok(self.emit(Opcode::Object))
173 }
174 Node::Assignments { list, output } => {
175 self.emit(Opcode::AssignedObjectBegin);
176 list.iter().try_for_each(|&(key, value)| {
177 self.compile_node(key).map(|_| ())?;
178 self.compile_node(value).map(|_| ())?;
179 self.emit(Opcode::AssignedObjectStep);
180
181 Ok(())
182 })?;
183
184 if let Some(output) = output {
185 self.compile_node(output).map(|_| ())?;
186 }
187
188 Ok(self.emit(Opcode::AssignedObjectEnd {
189 with_return: output.is_some(),
190 }))
191 }
192 Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(Arc::from(*v)))),
193 Node::Closure(v) => self.compile_node(v),
194 Node::Parenthesized(v) => self.compile_node(v),
195 Node::Member {
196 node: n,
197 property: p,
198 } => {
199 if let Some(path) = self.compile_member_fast(node) {
200 Ok(self.emit(Opcode::FetchFast(path)))
201 } else {
202 self.compile_node(n)?;
203 self.compile_node(p)?;
204 Ok(self.emit(Opcode::Fetch))
205 }
206 }
207 Node::TemplateString(parts) => {
208 parts.iter().try_for_each(|&n| {
209 self.compile_node(n).map(|_| ())?;
210 self.emit(Opcode::CallFunction {
211 arg_count: 1,
212 kind: FunctionKind::Internal(InternalFunction::String),
213 });
214 Ok(())
215 })?;
216
217 self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
218 self.emit(Opcode::Array);
219 self.emit(Opcode::PushString(Arc::from("")));
220 Ok(self.emit(Opcode::Join))
221 }
222 Node::Slice { node, to, from } => {
223 self.compile_node(node)?;
224 if let Some(t) = to {
225 self.compile_node(t)?;
226 } else {
227 self.emit(Opcode::Len);
228 self.emit(Opcode::PushNumber(dec!(1)));
229 self.emit(Opcode::Subtract);
230 }
231
232 if let Some(f) = from {
233 self.compile_node(f)?;
234 } else {
235 self.emit(Opcode::PushNumber(dec!(0)));
236 }
237
238 Ok(self.emit(Opcode::Slice))
239 }
240 Node::Interval {
241 left,
242 right,
243 left_bracket,
244 right_bracket,
245 } => {
246 self.compile_node(left)?;
247 self.compile_node(right)?;
248 Ok(self.emit(Opcode::Interval {
249 left_bracket: *left_bracket,
250 right_bracket: *right_bracket,
251 }))
252 }
253 Node::Conditional {
254 condition,
255 on_true,
256 on_false,
257 } => {
258 self.compile_node(condition)?;
259 let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0));
260
261 self.emit(Opcode::Pop);
262 self.compile_node(on_true)?;
263 let end = self.emit(Opcode::Jump(Jump::Forward, 0));
264
265 self.replace(
266 otherwise,
267 Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
268 );
269 self.emit(Opcode::Pop);
270 let b = self.compile_node(on_false)?;
271 self.replace(end, Opcode::Jump(Jump::Forward, (b - end) as u32));
272
273 Ok(b)
274 }
275 Node::Unary { node, operator } => {
276 let curr = self.compile_node(node)?;
277 match *operator {
278 Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr),
279 Operator::Arithmetic(ArithmeticOperator::Subtract) => {
280 Ok(self.emit(Opcode::Negate))
281 }
282 Operator::Logical(LogicalOperator::Not) => Ok(self.emit(Opcode::Not)),
283 _ => Err(CompilerError::UnknownUnaryOperator {
284 operator: operator.to_string(),
285 }),
286 }
287 }
288 Node::Binary {
289 left,
290 right,
291 operator,
292 } => match *operator {
293 Operator::Comparison(ComparisonOperator::Equal) => {
294 self.compile_node(left)?;
295 self.compile_node(right)?;
296
297 Ok(self.emit(Opcode::Equal))
298 }
299 Operator::Comparison(ComparisonOperator::NotEqual) => {
300 self.compile_node(left)?;
301 self.compile_node(right)?;
302
303 self.emit(Opcode::Equal);
304 Ok(self.emit(Opcode::Not))
305 }
306 Operator::Logical(LogicalOperator::Or) => {
307 self.compile_node(left)?;
308 let end = self.emit(Opcode::Jump(Jump::IfTrue, 0));
309 self.emit(Opcode::Pop);
310 let r = self.compile_node(right)?;
311 self.replace(end, Opcode::Jump(Jump::IfTrue, (r - end) as u32));
312
313 Ok(r)
314 }
315 Operator::Logical(LogicalOperator::And) => {
316 self.compile_node(left)?;
317 let end = self.emit(Opcode::Jump(Jump::IfFalse, 0));
318 self.emit(Opcode::Pop);
319 let r = self.compile_node(right)?;
320 self.replace(end, Opcode::Jump(Jump::IfFalse, (r - end) as u32));
321
322 Ok(r)
323 }
324 Operator::Logical(LogicalOperator::NullishCoalescing) => {
325 self.compile_node(left)?;
326 let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0));
327 self.emit(Opcode::Pop);
328 let r = self.compile_node(right)?;
329 self.replace(end, Opcode::Jump(Jump::IfNotNull, (r - end) as u32));
330
331 Ok(r)
332 }
333 Operator::Comparison(ComparisonOperator::In) => {
334 self.compile_node(left)?;
335 self.compile_node(right)?;
336 Ok(self.emit(Opcode::In))
337 }
338 Operator::Comparison(ComparisonOperator::NotIn) => {
339 self.compile_node(left)?;
340 self.compile_node(right)?;
341 self.emit(Opcode::In);
342 Ok(self.emit(Opcode::Not))
343 }
344 Operator::Comparison(ComparisonOperator::LessThan) => {
345 self.compile_node(left)?;
346 self.compile_node(right)?;
347 Ok(self.emit(Opcode::Compare(Compare::Less)))
348 }
349 Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
350 self.compile_node(left)?;
351 self.compile_node(right)?;
352 Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
353 }
354 Operator::Comparison(ComparisonOperator::GreaterThan) => {
355 self.compile_node(left)?;
356 self.compile_node(right)?;
357 Ok(self.emit(Opcode::Compare(Compare::More)))
358 }
359 Operator::Comparison(ComparisonOperator::GreaterThanOrEqual) => {
360 self.compile_node(left)?;
361 self.compile_node(right)?;
362 Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
363 }
364 Operator::Arithmetic(ArithmeticOperator::Add) => {
365 self.compile_node(left)?;
366 self.compile_node(right)?;
367 Ok(self.emit(Opcode::Add))
368 }
369 Operator::Arithmetic(ArithmeticOperator::Subtract) => {
370 self.compile_node(left)?;
371 self.compile_node(right)?;
372 Ok(self.emit(Opcode::Subtract))
373 }
374 Operator::Arithmetic(ArithmeticOperator::Multiply) => {
375 self.compile_node(left)?;
376 self.compile_node(right)?;
377 Ok(self.emit(Opcode::Multiply))
378 }
379 Operator::Arithmetic(ArithmeticOperator::Divide) => {
380 self.compile_node(left)?;
381 self.compile_node(right)?;
382 Ok(self.emit(Opcode::Divide))
383 }
384 Operator::Arithmetic(ArithmeticOperator::Modulus) => {
385 self.compile_node(left)?;
386 self.compile_node(right)?;
387 Ok(self.emit(Opcode::Modulo))
388 }
389 Operator::Arithmetic(ArithmeticOperator::Power) => {
390 self.compile_node(left)?;
391 self.compile_node(right)?;
392 Ok(self.emit(Opcode::Exponent))
393 }
394 _ => Err(CompilerError::UnknownBinaryOperator {
395 operator: operator.to_string(),
396 }),
397 },
398 Node::FunctionCall { kind, arguments } => match kind {
399 FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => {
400 let function = FunctionRegistry::get_definition(kind).ok_or_else(|| {
401 CompilerError::UnknownFunction {
402 name: kind.to_string(),
403 }
404 })?;
405
406 let min_params = function.required_parameters();
407 let max_params = min_params + function.optional_parameters();
408 if arguments.len() < min_params || arguments.len() > max_params {
409 return Err(CompilerError::InvalidFunctionCall {
410 name: kind.to_string(),
411 message: "Invalid number of arguments".to_string(),
412 });
413 }
414
415 for i in 0..arguments.len() {
416 self.compile_argument(kind, arguments, i)?;
417 }
418
419 Ok(self.emit(Opcode::CallFunction {
420 kind: kind.clone(),
421 arg_count: arguments.len() as u32,
422 }))
423 }
424 FunctionKind::Closure(c) => match c {
425 ClosureFunction::All => {
426 self.compile_argument(kind, arguments, 0)?;
427 self.emit(Opcode::Begin);
428 let mut loop_break: usize = 0;
429 self.emit_loop(|c| {
430 c.compile_argument(kind, arguments, 1)?;
431 loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
432 c.emit(Opcode::Pop);
433 Ok(())
434 })?;
435 let e = self.emit(Opcode::PushBool(true));
436 self.replace(
437 loop_break,
438 Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
439 );
440 Ok(self.emit(Opcode::End))
441 }
442 ClosureFunction::None => {
443 self.compile_argument(kind, arguments, 0)?;
444 self.emit(Opcode::Begin);
445 let mut loop_break: usize = 0;
446 self.emit_loop(|c| {
447 c.compile_argument(kind, arguments, 1)?;
448 c.emit(Opcode::Not);
449 loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
450 c.emit(Opcode::Pop);
451 Ok(())
452 })?;
453 let e = self.emit(Opcode::PushBool(true));
454 self.replace(
455 loop_break,
456 Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
457 );
458 Ok(self.emit(Opcode::End))
459 }
460 ClosureFunction::Some => {
461 self.compile_argument(kind, arguments, 0)?;
462 self.emit(Opcode::Begin);
463 let mut loop_break: usize = 0;
464 self.emit_loop(|c| {
465 c.compile_argument(kind, arguments, 1)?;
466 loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0));
467 c.emit(Opcode::Pop);
468 Ok(())
469 })?;
470 let e = self.emit(Opcode::PushBool(false));
471 self.replace(
472 loop_break,
473 Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
474 );
475 Ok(self.emit(Opcode::End))
476 }
477 ClosureFunction::One => {
478 self.compile_argument(kind, arguments, 0)?;
479 self.emit(Opcode::Begin);
480 self.emit_loop(|c| {
481 c.compile_argument(kind, arguments, 1)?;
482 c.emit_cond(|c| {
483 c.emit(Opcode::IncrementCount);
484 });
485 Ok(())
486 })?;
487 self.emit(Opcode::GetCount);
488 self.emit(Opcode::PushNumber(dec!(1)));
489 self.emit(Opcode::Equal);
490 Ok(self.emit(Opcode::End))
491 }
492 ClosureFunction::Filter => {
493 self.compile_argument(kind, arguments, 0)?;
494 self.emit(Opcode::Begin);
495 self.emit_loop(|c| {
496 c.compile_argument(kind, arguments, 1)?;
497 c.emit_cond(|c| {
498 c.emit(Opcode::IncrementCount);
499 c.emit(Opcode::Pointer);
500 });
501 Ok(())
502 })?;
503 self.emit(Opcode::GetCount);
504 self.emit(Opcode::End);
505 Ok(self.emit(Opcode::Array))
506 }
507 ClosureFunction::Map => {
508 self.compile_argument(kind, arguments, 0)?;
509 self.emit(Opcode::Begin);
510 self.emit_loop(|c| {
511 c.compile_argument(kind, arguments, 1)?;
512 Ok(())
513 })?;
514 self.emit(Opcode::GetLen);
515 self.emit(Opcode::End);
516 Ok(self.emit(Opcode::Array))
517 }
518 ClosureFunction::FlatMap => {
519 self.compile_argument(kind, arguments, 0)?;
520 self.emit(Opcode::Begin);
521 self.emit_loop(|c| {
522 c.compile_argument(kind, arguments, 1)?;
523 Ok(())
524 })?;
525 self.emit(Opcode::GetLen);
526 self.emit(Opcode::End);
527 self.emit(Opcode::Array);
528 Ok(self.emit(Opcode::Flatten))
529 }
530 ClosureFunction::Count => {
531 self.compile_argument(kind, arguments, 0)?;
532 self.emit(Opcode::Begin);
533 self.emit_loop(|c| {
534 c.compile_argument(kind, arguments, 1)?;
535 c.emit_cond(|c| {
536 c.emit(Opcode::IncrementCount);
537 });
538 Ok(())
539 })?;
540 self.emit(Opcode::GetCount);
541 Ok(self.emit(Opcode::End))
542 }
543 },
544 },
545 Node::MethodCall {
546 kind,
547 this,
548 arguments,
549 } => {
550 let method = MethodRegistry::get_definition(kind).ok_or_else(|| {
551 CompilerError::UnknownFunction {
552 name: kind.to_string(),
553 }
554 })?;
555
556 self.compile_node(this)?;
557
558 let min_params = method.required_parameters() - 1;
559 let max_params = min_params + method.optional_parameters();
560 if arguments.len() < min_params || arguments.len() > max_params {
561 return Err(CompilerError::InvalidMethodCall {
562 name: kind.to_string(),
563 message: "Invalid number of arguments".to_string(),
564 });
565 }
566
567 for i in 0..arguments.len() {
568 self.compile_argument(kind, arguments, i)?;
569 }
570
571 Ok(self.emit(Opcode::CallMethod {
572 kind: kind.clone(),
573 arg_count: arguments.len() as u32,
574 }))
575 }
576 Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
577 }
578 }
579}