1use crate::compiler::error::{CompilerError, CompilerResult};
2use crate::compiler::opcode::{FetchFastTarget, Jump};
3use crate::compiler::{Compare, Opcode};
4use crate::functions::registry::FunctionRegistry;
5use crate::functions::{ClosureFunction, FunctionKind, InternalFunction, MethodRegistry};
6use crate::lexer::{ArithmeticOperator, ComparisonOperator, LogicalOperator, Operator};
7use crate::parser::Node;
8use rust_decimal::prelude::ToPrimitive;
9use rust_decimal::Decimal;
10use rust_decimal_macros::dec;
11use std::sync::Arc;
12
13#[derive(Debug)]
14pub struct Compiler {
15 bytecode: Vec<Opcode>,
16}
17
18impl Compiler {
19 pub fn new() -> Self {
20 Self {
21 bytecode: Default::default(),
22 }
23 }
24
25 pub fn compile(&mut self, root: &Node) -> CompilerResult<&[Opcode]> {
26 self.bytecode.clear();
27
28 CompilerInner::new(&mut self.bytecode, root).compile()?;
29 Ok(self.bytecode.as_slice())
30 }
31
32 pub fn get_bytecode(&self) -> &[Opcode] {
33 self.bytecode.as_slice()
34 }
35}
36
37#[derive(Debug)]
38struct CompilerInner<'arena, 'bytecode_ref> {
39 root: &'arena Node<'arena>,
40 bytecode: &'bytecode_ref mut Vec<Opcode>,
41}
42
43impl<'arena, 'bytecode_ref> CompilerInner<'arena, 'bytecode_ref> {
44 pub fn new(bytecode: &'bytecode_ref mut Vec<Opcode>, root: &'arena Node<'arena>) -> Self {
45 Self { root, bytecode }
46 }
47
48 pub fn compile(&mut self) -> CompilerResult<()> {
49 self.compile_node(self.root)?;
50 Ok(())
51 }
52
53 fn emit(&mut self, op: Opcode) -> usize {
54 self.bytecode.push(op);
55 self.bytecode.len()
56 }
57
58 fn emit_loop<F>(&mut self, body: F) -> CompilerResult<()>
59 where
60 F: FnOnce(&mut Self) -> CompilerResult<()>,
61 {
62 let begin = self.bytecode.len();
63 let end = self.emit(Opcode::Jump(Jump::IfEnd, 0));
64
65 body(self)?;
66
67 self.emit(Opcode::IncrementIt);
68 let e = self.emit(Opcode::Jump(
69 Jump::Backward,
70 self.calc_backward_jump(begin) as u32,
71 ));
72 self.replace(end, Opcode::Jump(Jump::IfEnd, (e - end) as u32));
73 Ok(())
74 }
75
76 fn emit_cond<F>(&mut self, mut body: F)
77 where
78 F: FnMut(&mut Self),
79 {
80 let noop = self.emit(Opcode::Jump(Jump::IfFalse, 0));
81 self.emit(Opcode::Pop);
82
83 body(self);
84
85 let jmp = self.emit(Opcode::Jump(Jump::Forward, 0));
86 self.replace(noop, Opcode::Jump(Jump::IfFalse, (jmp - noop) as u32));
87 let e = self.emit(Opcode::Pop);
88 self.replace(jmp, Opcode::Jump(Jump::Forward, (e - jmp) as u32));
89 }
90
91 fn replace(&mut self, at: usize, op: Opcode) {
92 let _ = std::mem::replace(&mut self.bytecode[at - 1], op);
93 }
94
95 fn calc_backward_jump(&self, to: usize) -> usize {
96 self.bytecode.len() + 1 - to
97 }
98
99 fn compile_argument<T: ToString>(
100 &mut self,
101 function_kind: T,
102 arguments: &[&'arena Node<'arena>],
103 index: usize,
104 ) -> CompilerResult<usize> {
105 let arg = arguments
106 .get(index)
107 .ok_or_else(|| CompilerError::ArgumentNotFound {
108 index,
109 function: function_kind.to_string(),
110 })?;
111
112 self.compile_node(arg)
113 }
114
115 #[cfg_attr(feature = "stack-protection", recursive::recursive)]
116 fn compile_member_fast(&mut self, node: &'arena Node<'arena>) -> Option<Vec<FetchFastTarget>> {
117 match node {
118 Node::Root => Some(vec![FetchFastTarget::Root]),
119 Node::Identifier(v) => Some(vec![
120 FetchFastTarget::Root,
121 FetchFastTarget::String(Arc::from(*v)),
122 ]),
123 Node::Member { node, property } => {
124 let mut path = self.compile_member_fast(node)?;
125 match property {
126 Node::String(v) => {
127 path.push(FetchFastTarget::String(Arc::from(*v)));
128 Some(path)
129 }
130 Node::Number(v) => {
131 if let Some(idx) = v.to_u32() {
132 path.push(FetchFastTarget::Number(idx));
133 Some(path)
134 } else {
135 None
136 }
137 }
138 _ => None,
139 }
140 }
141 _ => None,
142 }
143 }
144
145 #[cfg_attr(feature = "stack-protection", recursive::recursive)]
146 fn compile_node(&mut self, node: &'arena Node<'arena>) -> CompilerResult<usize> {
147 match node {
148 Node::Null => Ok(self.emit(Opcode::PushNull)),
149 Node::Bool(v) => Ok(self.emit(Opcode::PushBool(*v))),
150 Node::Number(v) => Ok(self.emit(Opcode::PushNumber(*v))),
151 Node::String(v) => Ok(self.emit(Opcode::PushString(Arc::from(*v)))),
152 Node::Pointer => Ok(self.emit(Opcode::Pointer)),
153 Node::Root => Ok(self.emit(Opcode::FetchRootEnv)),
154 Node::Array(v) => {
155 v.iter()
156 .try_for_each(|&n| self.compile_node(n).map(|_| ()))?;
157 self.emit(Opcode::PushNumber(Decimal::from(v.len())));
158 Ok(self.emit(Opcode::Array))
159 }
160 Node::Object(v) => {
161 v.iter().try_for_each(|&(key, value)| {
162 self.compile_node(key).map(|_| ())?;
163 self.emit(Opcode::CallFunction {
164 arg_count: 1,
165 kind: FunctionKind::Internal(InternalFunction::String),
166 });
167 self.compile_node(value).map(|_| ())?;
168 Ok(())
169 })?;
170
171 self.emit(Opcode::PushNumber(Decimal::from(v.len())));
172 Ok(self.emit(Opcode::Object))
173 }
174 Node::Identifier(v) => Ok(self.emit(Opcode::FetchEnv(Arc::from(*v)))),
175 Node::Closure(v) => self.compile_node(v),
176 Node::Parenthesized(v) => self.compile_node(v),
177 Node::Member {
178 node: n,
179 property: p,
180 } => {
181 if let Some(path) = self.compile_member_fast(node) {
182 Ok(self.emit(Opcode::FetchFast(path)))
183 } else {
184 self.compile_node(n)?;
185 self.compile_node(p)?;
186 Ok(self.emit(Opcode::Fetch))
187 }
188 }
189 Node::TemplateString(parts) => {
190 parts.iter().try_for_each(|&n| {
191 self.compile_node(n).map(|_| ())?;
192 self.emit(Opcode::CallFunction {
193 arg_count: 1,
194 kind: FunctionKind::Internal(InternalFunction::String),
195 });
196 Ok(())
197 })?;
198
199 self.emit(Opcode::PushNumber(Decimal::from(parts.len())));
200 self.emit(Opcode::Array);
201 self.emit(Opcode::PushString(Arc::from("")));
202 Ok(self.emit(Opcode::Join))
203 }
204 Node::Slice { node, to, from } => {
205 self.compile_node(node)?;
206 if let Some(t) = to {
207 self.compile_node(t)?;
208 } else {
209 self.emit(Opcode::Len);
210 self.emit(Opcode::PushNumber(dec!(1)));
211 self.emit(Opcode::Subtract);
212 }
213
214 if let Some(f) = from {
215 self.compile_node(f)?;
216 } else {
217 self.emit(Opcode::PushNumber(dec!(0)));
218 }
219
220 Ok(self.emit(Opcode::Slice))
221 }
222 Node::Interval {
223 left,
224 right,
225 left_bracket,
226 right_bracket,
227 } => {
228 self.compile_node(left)?;
229 self.compile_node(right)?;
230 Ok(self.emit(Opcode::Interval {
231 left_bracket: *left_bracket,
232 right_bracket: *right_bracket,
233 }))
234 }
235 Node::Conditional {
236 condition,
237 on_true,
238 on_false,
239 } => {
240 self.compile_node(condition)?;
241 let otherwise = self.emit(Opcode::Jump(Jump::IfFalse, 0));
242
243 self.emit(Opcode::Pop);
244 self.compile_node(on_true)?;
245 let end = self.emit(Opcode::Jump(Jump::Forward, 0));
246
247 self.replace(
248 otherwise,
249 Opcode::Jump(Jump::IfFalse, (end - otherwise) as u32),
250 );
251 self.emit(Opcode::Pop);
252 let b = self.compile_node(on_false)?;
253 self.replace(end, Opcode::Jump(Jump::Forward, (b - end) as u32));
254
255 Ok(b)
256 }
257 Node::Unary { node, operator } => {
258 let curr = self.compile_node(node)?;
259 match *operator {
260 Operator::Arithmetic(ArithmeticOperator::Add) => Ok(curr),
261 Operator::Arithmetic(ArithmeticOperator::Subtract) => {
262 Ok(self.emit(Opcode::Negate))
263 }
264 Operator::Logical(LogicalOperator::Not) => Ok(self.emit(Opcode::Not)),
265 _ => Err(CompilerError::UnknownUnaryOperator {
266 operator: operator.to_string(),
267 }),
268 }
269 }
270 Node::Binary {
271 left,
272 right,
273 operator,
274 } => match *operator {
275 Operator::Comparison(ComparisonOperator::Equal) => {
276 self.compile_node(left)?;
277 self.compile_node(right)?;
278
279 Ok(self.emit(Opcode::Equal))
280 }
281 Operator::Comparison(ComparisonOperator::NotEqual) => {
282 self.compile_node(left)?;
283 self.compile_node(right)?;
284
285 self.emit(Opcode::Equal);
286 Ok(self.emit(Opcode::Not))
287 }
288 Operator::Logical(LogicalOperator::Or) => {
289 self.compile_node(left)?;
290 let end = self.emit(Opcode::Jump(Jump::IfTrue, 0));
291 self.emit(Opcode::Pop);
292 let r = self.compile_node(right)?;
293 self.replace(end, Opcode::Jump(Jump::IfTrue, (r - end) as u32));
294
295 Ok(r)
296 }
297 Operator::Logical(LogicalOperator::And) => {
298 self.compile_node(left)?;
299 let end = self.emit(Opcode::Jump(Jump::IfFalse, 0));
300 self.emit(Opcode::Pop);
301 let r = self.compile_node(right)?;
302 self.replace(end, Opcode::Jump(Jump::IfFalse, (r - end) as u32));
303
304 Ok(r)
305 }
306 Operator::Logical(LogicalOperator::NullishCoalescing) => {
307 self.compile_node(left)?;
308 let end = self.emit(Opcode::Jump(Jump::IfNotNull, 0));
309 self.emit(Opcode::Pop);
310 let r = self.compile_node(right)?;
311 self.replace(end, Opcode::Jump(Jump::IfNotNull, (r - end) as u32));
312
313 Ok(r)
314 }
315 Operator::Comparison(ComparisonOperator::In) => {
316 self.compile_node(left)?;
317 self.compile_node(right)?;
318 Ok(self.emit(Opcode::In))
319 }
320 Operator::Comparison(ComparisonOperator::NotIn) => {
321 self.compile_node(left)?;
322 self.compile_node(right)?;
323 self.emit(Opcode::In);
324 Ok(self.emit(Opcode::Not))
325 }
326 Operator::Comparison(ComparisonOperator::LessThan) => {
327 self.compile_node(left)?;
328 self.compile_node(right)?;
329 Ok(self.emit(Opcode::Compare(Compare::Less)))
330 }
331 Operator::Comparison(ComparisonOperator::LessThanOrEqual) => {
332 self.compile_node(left)?;
333 self.compile_node(right)?;
334 Ok(self.emit(Opcode::Compare(Compare::LessOrEqual)))
335 }
336 Operator::Comparison(ComparisonOperator::GreaterThan) => {
337 self.compile_node(left)?;
338 self.compile_node(right)?;
339 Ok(self.emit(Opcode::Compare(Compare::More)))
340 }
341 Operator::Comparison(ComparisonOperator::GreaterThanOrEqual) => {
342 self.compile_node(left)?;
343 self.compile_node(right)?;
344 Ok(self.emit(Opcode::Compare(Compare::MoreOrEqual)))
345 }
346 Operator::Arithmetic(ArithmeticOperator::Add) => {
347 self.compile_node(left)?;
348 self.compile_node(right)?;
349 Ok(self.emit(Opcode::Add))
350 }
351 Operator::Arithmetic(ArithmeticOperator::Subtract) => {
352 self.compile_node(left)?;
353 self.compile_node(right)?;
354 Ok(self.emit(Opcode::Subtract))
355 }
356 Operator::Arithmetic(ArithmeticOperator::Multiply) => {
357 self.compile_node(left)?;
358 self.compile_node(right)?;
359 Ok(self.emit(Opcode::Multiply))
360 }
361 Operator::Arithmetic(ArithmeticOperator::Divide) => {
362 self.compile_node(left)?;
363 self.compile_node(right)?;
364 Ok(self.emit(Opcode::Divide))
365 }
366 Operator::Arithmetic(ArithmeticOperator::Modulus) => {
367 self.compile_node(left)?;
368 self.compile_node(right)?;
369 Ok(self.emit(Opcode::Modulo))
370 }
371 Operator::Arithmetic(ArithmeticOperator::Power) => {
372 self.compile_node(left)?;
373 self.compile_node(right)?;
374 Ok(self.emit(Opcode::Exponent))
375 }
376 _ => Err(CompilerError::UnknownBinaryOperator {
377 operator: operator.to_string(),
378 }),
379 },
380 Node::FunctionCall { kind, arguments } => match kind {
381 FunctionKind::Internal(_) | FunctionKind::Deprecated(_) => {
382 let function = FunctionRegistry::get_definition(kind).ok_or_else(|| {
383 CompilerError::UnknownFunction {
384 name: kind.to_string(),
385 }
386 })?;
387
388 let min_params = function.required_parameters();
389 let max_params = min_params + function.optional_parameters();
390 if arguments.len() < min_params || arguments.len() > max_params {
391 return Err(CompilerError::InvalidFunctionCall {
392 name: kind.to_string(),
393 message: "Invalid number of arguments".to_string(),
394 });
395 }
396
397 for i in 0..arguments.len() {
398 self.compile_argument(kind, arguments, i)?;
399 }
400
401 Ok(self.emit(Opcode::CallFunction {
402 kind: kind.clone(),
403 arg_count: arguments.len() as u32,
404 }))
405 }
406 FunctionKind::Closure(c) => match c {
407 ClosureFunction::All => {
408 self.compile_argument(kind, arguments, 0)?;
409 self.emit(Opcode::Begin);
410 let mut loop_break: usize = 0;
411 self.emit_loop(|c| {
412 c.compile_argument(kind, arguments, 1)?;
413 loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
414 c.emit(Opcode::Pop);
415 Ok(())
416 })?;
417 let e = self.emit(Opcode::PushBool(true));
418 self.replace(
419 loop_break,
420 Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
421 );
422 Ok(self.emit(Opcode::End))
423 }
424 ClosureFunction::None => {
425 self.compile_argument(kind, arguments, 0)?;
426 self.emit(Opcode::Begin);
427 let mut loop_break: usize = 0;
428 self.emit_loop(|c| {
429 c.compile_argument(kind, arguments, 1)?;
430 c.emit(Opcode::Not);
431 loop_break = c.emit(Opcode::Jump(Jump::IfFalse, 0));
432 c.emit(Opcode::Pop);
433 Ok(())
434 })?;
435 let e = self.emit(Opcode::PushBool(true));
436 self.replace(
437 loop_break,
438 Opcode::Jump(Jump::IfFalse, (e - loop_break) as u32),
439 );
440 Ok(self.emit(Opcode::End))
441 }
442 ClosureFunction::Some => {
443 self.compile_argument(kind, arguments, 0)?;
444 self.emit(Opcode::Begin);
445 let mut loop_break: usize = 0;
446 self.emit_loop(|c| {
447 c.compile_argument(kind, arguments, 1)?;
448 loop_break = c.emit(Opcode::Jump(Jump::IfTrue, 0));
449 c.emit(Opcode::Pop);
450 Ok(())
451 })?;
452 let e = self.emit(Opcode::PushBool(false));
453 self.replace(
454 loop_break,
455 Opcode::Jump(Jump::IfTrue, (e - loop_break) as u32),
456 );
457 Ok(self.emit(Opcode::End))
458 }
459 ClosureFunction::One => {
460 self.compile_argument(kind, arguments, 0)?;
461 self.emit(Opcode::Begin);
462 self.emit_loop(|c| {
463 c.compile_argument(kind, arguments, 1)?;
464 c.emit_cond(|c| {
465 c.emit(Opcode::IncrementCount);
466 });
467 Ok(())
468 })?;
469 self.emit(Opcode::GetCount);
470 self.emit(Opcode::PushNumber(dec!(1)));
471 self.emit(Opcode::Equal);
472 Ok(self.emit(Opcode::End))
473 }
474 ClosureFunction::Filter => {
475 self.compile_argument(kind, arguments, 0)?;
476 self.emit(Opcode::Begin);
477 self.emit_loop(|c| {
478 c.compile_argument(kind, arguments, 1)?;
479 c.emit_cond(|c| {
480 c.emit(Opcode::IncrementCount);
481 c.emit(Opcode::Pointer);
482 });
483 Ok(())
484 })?;
485 self.emit(Opcode::GetCount);
486 self.emit(Opcode::End);
487 Ok(self.emit(Opcode::Array))
488 }
489 ClosureFunction::Map => {
490 self.compile_argument(kind, arguments, 0)?;
491 self.emit(Opcode::Begin);
492 self.emit_loop(|c| {
493 c.compile_argument(kind, arguments, 1)?;
494 Ok(())
495 })?;
496 self.emit(Opcode::GetLen);
497 self.emit(Opcode::End);
498 Ok(self.emit(Opcode::Array))
499 }
500 ClosureFunction::FlatMap => {
501 self.compile_argument(kind, arguments, 0)?;
502 self.emit(Opcode::Begin);
503 self.emit_loop(|c| {
504 c.compile_argument(kind, arguments, 1)?;
505 Ok(())
506 })?;
507 self.emit(Opcode::GetLen);
508 self.emit(Opcode::End);
509 self.emit(Opcode::Array);
510 Ok(self.emit(Opcode::Flatten))
511 }
512 ClosureFunction::Count => {
513 self.compile_argument(kind, arguments, 0)?;
514 self.emit(Opcode::Begin);
515 self.emit_loop(|c| {
516 c.compile_argument(kind, arguments, 1)?;
517 c.emit_cond(|c| {
518 c.emit(Opcode::IncrementCount);
519 });
520 Ok(())
521 })?;
522 self.emit(Opcode::GetCount);
523 Ok(self.emit(Opcode::End))
524 }
525 },
526 },
527 Node::MethodCall {
528 kind,
529 this,
530 arguments,
531 } => {
532 let method = MethodRegistry::get_definition(kind).ok_or_else(|| {
533 CompilerError::UnknownFunction {
534 name: kind.to_string(),
535 }
536 })?;
537
538 self.compile_node(this)?;
539
540 let min_params = method.required_parameters() - 1;
541 let max_params = min_params + method.optional_parameters();
542 if arguments.len() < min_params || arguments.len() > max_params {
543 return Err(CompilerError::InvalidMethodCall {
544 name: kind.to_string(),
545 message: "Invalid number of arguments".to_string(),
546 });
547 }
548
549 for i in 0..arguments.len() {
550 self.compile_argument(kind, arguments, i)?;
551 }
552
553 Ok(self.emit(Opcode::CallMethod {
554 kind: kind.clone(),
555 arg_count: arguments.len() as u32,
556 }))
557 }
558 Node::Error { .. } => Err(CompilerError::UnexpectedErrorNode),
559 }
560 }
561}