simple_expressions/
parser.rs

1use crate::types::error::{Error, Result};
2use crate::types::expression::{BinaryOp, Expr, UnaryOp};
3use crate::types::value::Primitive;
4use chumsky::prelude::*;
5
6// Postfix operators: call, index, member
7#[derive(Debug, Clone)]
8enum Postfix {
9    Call(Vec<Expr>),
10    Index(Expr),
11    Member(String),
12}
13
14fn expr_and_spacer<'src>() -> (impl Parser<'src, &'src str, ()> + Clone, impl Parser<'src, &'src str, Expr> + Clone) {
15    // Whitespace and comments
16    let line_comment = just("//").ignore_then(any().filter(|c: &char| *c != '\n').repeated()).ignored();
17    let ws = one_of(" \t\r\n").repeated().at_least(1).ignored();
18    let spacer = choice((ws, line_comment)).repeated().ignored();
19
20    // Identifiers
21    let ident = text::ident().map(|s: &str| s.to_string());
22
23    // Strings: support single or double quotes with escapes \n, \\, \", \'
24    let escape = just('\\').ignore_then(choice((
25        just('n').to('\n'),
26        just('r').to('\r'),
27        just('t').to('\t'),
28        just('\\').to('\\'),
29        just('"').to('"'),
30        just('\'').to('\''),
31        // Allow escaping newline directly
32        just('\n').to('\n'),
33    )));
34
35    let string_sq = just('\'')
36        .ignore_then(choice((escape, any().filter(|c: &char| *c != '\\' && *c != '\'' && *c != '\n'))).repeated().collect::<String>())
37        .then_ignore(just('\''))
38        .map(Primitive::Str);
39
40    let string_dq = just('"')
41        .ignore_then(choice((escape, any().filter(|c: &char| *c != '\\' && *c != '"' && *c != '\n'))).repeated().collect::<String>())
42        .then_ignore(just('"'))
43        .map(Primitive::Str);
44
45    // Numbers: optional leading '-', digits, optional fractional part
46    let digits = text::digits(10);
47    let number = just('-')
48        .or_not()
49        .then(digits)
50        .then(just('.').then(digits).or_not())
51        .to_slice()
52        .map(|s: &str| if s.contains('.') { Primitive::Float(s.parse::<f64>().unwrap()) } else { Primitive::Int(s.parse::<i64>().unwrap()) });
53
54    let boolean = choice((text::keyword("true").to(Primitive::Bool(true)), text::keyword("false").to(Primitive::Bool(false))));
55
56    // Parentheses grouping will be handled via recursive expression parser
57    let expr = recursive(|expr| {
58        // Arguments list for calls
59        let args = expr
60            .clone()
61            .separated_by(just(',').padded_by(spacer))
62            .allow_trailing()
63            .collect::<Vec<_>>()
64            .delimited_by(just('(').padded_by(spacer), just(')').padded_by(spacer));
65
66        // List literal: [expr, expr, ...]
67        let list_lit = expr
68            .clone()
69            .separated_by(just(',').padded_by(spacer))
70            .allow_trailing()
71            .collect::<Vec<_>>()
72            .delimited_by(just('[').padded_by(spacer), just(']').padded_by(spacer))
73            .map(Expr::ListLiteral);
74
75        // Dict literal: {key_expr: value_expr, ...} where key_expr can be any expression; runtime enforces string keys
76        let dict_pair = expr.clone().then_ignore(just(':').padded_by(spacer)).then(expr.clone());
77        let dict_lit = dict_pair
78            .separated_by(just(',').padded_by(spacer))
79            .allow_trailing()
80            .collect::<Vec<_>>()
81            .delimited_by(just('{').padded_by(spacer), just('}').padded_by(spacer))
82            .map(Expr::DictLiteral);
83
84        // Primary: literals, identifiers as Var, parenthesized, list/dict literals
85        let primary = choice((
86            dict_lit,
87            list_lit,
88            choice((string_sq, string_dq, number, boolean.clone())).map(Expr::Literal),
89            ident.map(Expr::Var),
90            expr.clone().delimited_by(just('(').padded_by(spacer), just(')').padded_by(spacer)),
91        ))
92        .padded_by(spacer);
93
94        let index = expr.clone().delimited_by(just('[').padded_by(spacer), just(']').padded_by(spacer)).map(Postfix::Index);
95
96        let member = just('.').ignore_then(text::ident().map(|s: &str| s.to_string())).map(Postfix::Member);
97
98        let call = args.clone().map(Postfix::Call);
99
100        let postfix_chain = choice((call, index, member)).repeated().collect::<Vec<_>>();
101
102        let postfix = primary.then(postfix_chain).map(|(base, posts)| {
103            let mut acc = base;
104            for p in posts {
105                acc = match p {
106                    Postfix::Call(a) => Expr::Call { callee: Box::new(acc), args: a },
107                    Postfix::Index(i) => Expr::Index { object: Box::new(acc), index: Box::new(i) },
108                    Postfix::Member(f) => Expr::Member { object: Box::new(acc), field: f },
109                };
110            }
111            acc
112        });
113
114        // Unary '!'
115        let unary = just('!').repeated().foldr(postfix, |_bang, rhs| Expr::Unary { op: UnaryOp::Not, expr: Box::new(rhs) });
116
117        // Exponentiation '^' (right-assoc) using recursion
118        let pow = recursive(|pow| {
119            unary.clone().then(just('^').padded_by(spacer).ignore_then(pow).or_not()).map(|(lhs, rhs)| match rhs {
120                Some(r) => Expr::Binary {
121                    op: BinaryOp::Pow,
122                    left: Box::new(lhs),
123                    right: Box::new(r),
124                },
125                None => lhs,
126            })
127        });
128
129        let mul_div_mod = pow.clone().foldl(
130            choice((just('*').to(BinaryOp::Mul), just('/').to(BinaryOp::Div), just('%').to(BinaryOp::Mod))).padded_by(spacer).then(pow).repeated(),
131            |lhs, (op, rhs)| Expr::Binary {
132                op,
133                left: Box::new(lhs),
134                right: Box::new(rhs),
135            },
136        );
137
138        let add_sub = mul_div_mod
139            .clone()
140            .foldl(choice((just('+').to(BinaryOp::Add), just('-').to(BinaryOp::Sub))).padded_by(spacer).then(mul_div_mod).repeated(), |lhs, (op, rhs)| {
141                Expr::Binary {
142                    op,
143                    left: Box::new(lhs),
144                    right: Box::new(rhs),
145                }
146            });
147
148        let cmp = add_sub.clone().foldl(
149            choice((just("<=").to(BinaryOp::Le), just(">=").to(BinaryOp::Ge), just('<').to(BinaryOp::Lt), just('>').to(BinaryOp::Gt)))
150                .padded_by(spacer)
151                .then(add_sub)
152                .repeated(),
153            |lhs, (op, rhs)| Expr::Binary {
154                op,
155                left: Box::new(lhs),
156                right: Box::new(rhs),
157            },
158        );
159
160        let eq = cmp
161            .clone()
162            .foldl(choice((just("==").to(BinaryOp::Eq), just("!=").to(BinaryOp::Ne))).padded_by(spacer).then(cmp).repeated(), |lhs, (op, rhs)| Expr::Binary {
163                op,
164                left: Box::new(lhs),
165                right: Box::new(rhs),
166            });
167
168        let and = eq.clone().foldl(just("&&").to(BinaryOp::And).padded_by(spacer).then(eq).repeated(), |lhs, (op, rhs)| Expr::Binary {
169            op,
170            left: Box::new(lhs),
171            right: Box::new(rhs),
172        });
173
174        let or = and.clone().foldl(just("||").to(BinaryOp::Or).padded_by(spacer).then(and).repeated(), |lhs, (op, rhs)| Expr::Binary {
175            op,
176            left: Box::new(lhs),
177            right: Box::new(rhs),
178        });
179
180        or.padded_by(spacer)
181    });
182
183    (spacer, expr)
184}
185
186pub fn parser<'src>() -> impl Parser<'src, &'src str, Expr> {
187    let (spacer, expr) = expr_and_spacer();
188
189    // Allow multiple expressions separated by whitespace/comments and take the last one
190    let program = spacer
191        .clone()
192        .ignore_then(expr.clone())
193        .then_ignore(spacer.clone())
194        .repeated()
195        .at_least(1)
196        .collect::<Vec<_>>()
197        .map(|mut v: Vec<Expr>| v.pop().unwrap());
198
199    program.then_ignore(end())
200}
201
202// Main entry point for parsing an expression, returns the AST (Expr) or Error
203pub fn parse(input: &str) -> Result<Expr> {
204    match parser().parse(input).into_result() {
205        Ok(ast) => Ok(ast),
206        Err(errs) => {
207            let joined = errs.into_iter().map(|e| e.to_string()).collect::<Vec<_>>().join("\n");
208            let snippet: String = input.chars().take(80).collect();
209            let msg = if joined.trim().is_empty() { "parse error".to_string() } else { joined };
210            Err(Error::ParseFailed(msg, snippet))
211        }
212    }
213}
214
215// Parse an expression that must be terminated by a closing '}' and return
216// the parsed Expr along with the number of bytes consumed (including the '}').
217pub(crate) fn parse_in_braces(input: &str) -> Result<(Expr, usize)> {
218    // Use the existing expression parser and require a trailing '}' using parser combinators.
219    // This leverages the parser's own handling of strings, escapes, and nesting instead of manual scanning.
220    let (_spacer, expr) = expr_and_spacer();
221    let p = expr
222        .clone()
223        .then_ignore(just('}'))
224        .map_with(|e, extra| {
225            let span = extra.span();
226            (e, span.end)
227        })
228        .then_ignore(any().repeated()); // allow trailing content after '}' so callers can continue scanning
229
230    match p.parse(input).into_result() {
231        Ok((ast, consumed)) => Ok((ast, consumed)),
232        Err(errs) => {
233            let joined = errs.into_iter().map(|e| e.to_string()).collect::<Vec<_>>().join("\n");
234            let snippet: String = input.chars().take(80).collect();
235            let msg = if joined.trim().is_empty() { "parse error".to_string() } else { joined };
236            Err(Error::ParseFailed(msg, snippet))
237        }
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn parse_literals() {
247        assert_eq!(parse("123").unwrap(), Expr::Literal(Primitive::Int(123)));
248        assert_eq!(parse("-42").unwrap(), Expr::Literal(Primitive::Int(-42)));
249        assert_eq!(parse("3.14").unwrap(), Expr::Literal(Primitive::Float(3.14)));
250        assert_eq!(parse("true").unwrap(), Expr::Literal(Primitive::Bool(true)));
251        assert_eq!(parse("false").unwrap(), Expr::Literal(Primitive::Bool(false)));
252        assert_eq!(parse("'hi'").unwrap(), Expr::Literal(Primitive::Str("hi".into())));
253        assert_eq!(parse("'hi'\n// comment\n\"ok\"").unwrap(), Expr::Literal(Primitive::Str("ok".into())));
254    }
255
256    #[test]
257    fn parse_binary() {
258        let ast = parse("1 + 2 * 3").unwrap();
259        // ensure it builds
260        if let Expr::Binary { op: BinaryOp::Add, .. } = ast {
261        } else {
262            panic!("bad ast");
263        }
264    }
265
266    #[test]
267    fn parse_boolean() {
268        let ast = parse("1 + 2 * 3").unwrap();
269        // ensure it builds
270        if let Expr::Binary { op: BinaryOp::Add, .. } = ast {
271        } else {
272            panic!("bad ast");
273        }
274    }
275    #[test]
276    fn parse_calls_and_paths() {
277        let ast = parse("foo.bar(baz, 1+2).qux").unwrap();
278        // Just check it parses
279        let _ = ast;
280    }
281
282    #[test]
283    fn list_literals() {
284        use Expr::*;
285        assert_eq!(parse("[1, 2, 3]").unwrap(), ListLiteral(vec![Literal(Primitive::Int(1)), Literal(Primitive::Int(2)), Literal(Primitive::Int(3))]));
286        assert_eq!(parse("[1,2,]").unwrap(), ListLiteral(vec![Literal(Primitive::Int(1)), Literal(Primitive::Int(2))]));
287        assert_eq!(
288            parse("[1, [2], 3]").unwrap(),
289            ListLiteral(vec![Literal(Primitive::Int(1)), ListLiteral(vec![Literal(Primitive::Int(2))]), Literal(Primitive::Int(3)),])
290        );
291    }
292
293    #[test]
294    fn dict_literals() {
295        use Expr::*;
296        let d = parse("{\"a\": 1, \"b\": 2}").unwrap();
297        assert_eq!(
298            d,
299            DictLiteral(vec![(Literal(Primitive::Str("a".into())), Literal(Primitive::Int(1))), (Literal(Primitive::Str("b".into())), Literal(Primitive::Int(2))),])
300        );
301        let d2 = parse("{\"a\": 1,}").unwrap();
302        assert_eq!(d2, DictLiteral(vec![(Literal(Primitive::Str("a".into())), Literal(Primitive::Int(1)))]));
303
304        // allow non-string keys at parse-time; runtime will enforce type
305        assert_eq!(parse("{a: 1}").unwrap(), DictLiteral(vec![(Var("a".into()), Literal(Primitive::Int(1)))]));
306        assert_eq!(parse("{1: 2}").unwrap(), DictLiteral(vec![(Literal(Primitive::Int(1)), Literal(Primitive::Int(2)))]));
307
308        // still reject malformed missing colon
309        match parse("{\"a\" 1}") {
310            Err(Error::ParseFailed(_, _)) => (),
311            other => panic!("expected parse failed, got {:?}", other),
312        }
313    }
314
315    #[test]
316    fn postfix_combinations() {
317        use Expr::*;
318        let ast = parse("{\"xs\": [10, 20]}[\"xs\"][1]").unwrap();
319        // Expect Index(Index(DictLiteral, "xs"), 1)
320        match ast {
321            Index { object, index } => {
322                assert_eq!(*index, Literal(Primitive::Int(1)));
323                match *object {
324                    Index { object: inner_obj, index: inner_idx } => {
325                        match *inner_idx {
326                            Literal(Primitive::Str(ref s)) => assert_eq!(s, "xs"),
327                            _ => panic!("inner index should be string literal"),
328                        }
329                        match *inner_obj {
330                            DictLiteral(_) => (),
331                            _ => panic!("inner object should be dict literal"),
332                        }
333                    }
334                    _ => panic!("outer object should be index"),
335                }
336            }
337            _ => panic!("bad ast shape: {:?}", ast),
338        }
339
340        let ast2 = parse("{\"a\": 1}.a").unwrap();
341        match ast2 {
342            Member { ref object, ref field } => {
343                assert_eq!(field, "a");
344                match **object {
345                    DictLiteral(_) => (),
346                    _ => panic!("member base should be dict literal"),
347                }
348            }
349            _ => panic!("bad ast shape"),
350        }
351
352        let ast3 = parse("{\"a\": [1,2,3]}.a[0]").unwrap();
353        match ast3 {
354            Index { object, index } => {
355                assert_eq!(*index, Literal(Primitive::Int(0)));
356                match *object {
357                    Member { object: base, field } => {
358                        assert_eq!(field, "a");
359                        match *base {
360                            DictLiteral(_) => (),
361                            _ => panic!("member base should be dict literal"),
362                        }
363                    }
364                    _ => panic!("expected member then index"),
365                }
366            }
367            _ => panic!("bad ast shape"),
368        }
369    }
370
371    #[test]
372    fn postfix_member_chain_var() {
373        let ast = parse("a.b.c").unwrap();
374        let expected = Expr::Member {
375            object: Box::new(Expr::Member {
376                object: Box::new(Expr::Var("a".into())),
377                field: "b".into(),
378            }),
379            field: "c".into(),
380        };
381        assert_eq!(ast, expected);
382    }
383
384    #[test]
385    fn postfix_mixed_chain_shapes() {
386        let ast = parse("a.b(1, 2).c[0].d(e)").unwrap();
387        // a.b(1,2).c[0].d(e)
388        // Verify outermost is Call(...)
389        match ast {
390            Expr::Call { ref callee, ref args } => {
391                assert_eq!(args.len(), 1);
392                // callee should be Member(..., field: "d")
393                match **callee {
394                    Expr::Member { ref object, ref field } => {
395                        assert_eq!(field, "d");
396                        // object should be Index(Member(Call(Member(Var("a"),"b"), [1,2]), field:"c"), index:0)
397                        match **object {
398                            Expr::Index { ref object, ref index } => {
399                                // index == 0
400                                assert_eq!(**index, Expr::Literal(Primitive::Int(0)));
401                                // object is Member(..., "c")
402                                match **object {
403                                    Expr::Member { ref object, ref field } => {
404                                        assert_eq!(field, "c");
405                                        // object is Call(Member(Var("a"),"b"), [1,2])
406                                        match **object {
407                                            Expr::Call { ref callee, ref args } => {
408                                                assert_eq!(args.len(), 2);
409                                                assert_eq!(args[0], Expr::Literal(Primitive::Int(1)));
410                                                // second arg is 2
411                                                assert_eq!(args[1], Expr::Literal(Primitive::Int(2)));
412                                                // callee: Member(Var("a"),"b")
413                                                let expected_callee = Expr::Member {
414                                                    object: Box::new(Expr::Var("a".into())),
415                                                    field: "b".into(),
416                                                };
417                                                assert_eq!(**callee, expected_callee);
418                                            }
419                                            _ => panic!("expected call before .c"),
420                                        }
421                                    }
422                                    _ => panic!("expected member .c"),
423                                }
424                            }
425                            _ => panic!("expected index [0]"),
426                        }
427                    }
428                    _ => panic!("expected outer member .d"),
429                }
430            }
431            _ => panic!("expected outer call (e)"),
432        }
433    }
434
435    #[test]
436    fn postfix_parenthesized_base() {
437        // (a + b).c(d)
438        let ast = parse("(a + b).c(d)").unwrap();
439        match ast {
440            Expr::Call { callee, args } => {
441                assert_eq!(args.len(), 1);
442                match *callee {
443                    Expr::Member { object, field } => {
444                        assert_eq!(field, "c");
445                        match *object {
446                            Expr::Binary { op: BinaryOp::Add, .. } => (),
447                            _ => panic!("member should be applied to parenthesized binary expr"),
448                        }
449                    }
450                    _ => panic!("expected member callee"),
451                }
452            }
453            _ => panic!("expected call"),
454        }
455    }
456
457    #[test]
458    fn postfix_call_chains() {
459        let ast = parse("foo(1)(2)(3)").unwrap();
460        // foo(1)(2)(3) => Call(Call(Call(Var("foo"),1),2),3)
461        fn is_int(e: &Expr, v: i64) -> bool {
462            *e == Expr::Literal(Primitive::Int(v))
463        }
464        match ast {
465            Expr::Call { callee: c3, args: a3 } => {
466                assert!(a3.len() == 1 && is_int(&a3[0], 3));
467                match *c3 {
468                    Expr::Call { callee: c2, args: a2 } => {
469                        assert!(a2.len() == 1 && is_int(&a2[0], 2));
470                        match *c2 {
471                            Expr::Call { callee: c1, args: a1 } => {
472                                assert!(a1.len() == 1 && is_int(&a1[0], 1));
473                                assert_eq!(*c1, Expr::Var("foo".into()));
474                            }
475                            _ => panic!("expected second call"),
476                        }
477                    }
478                    _ => panic!("expected first call"),
479                }
480            }
481            _ => panic!("expected outer call"),
482        }
483    }
484
485    #[test]
486    fn postfix_index_chains() {
487        let ast = parse("arr[1+2][0]").unwrap();
488        match ast {
489            Expr::Index { object, index } => {
490                assert_eq!(*index, Expr::Literal(Primitive::Int(0)));
491                match *object {
492                    Expr::Index { object, index } => {
493                        match *index {
494                            Expr::Binary { op: BinaryOp::Add, .. } => (),
495                            _ => panic!("expected 1+2 as index expr"),
496                        }
497                        assert_eq!(*object, Expr::Var("arr".into()));
498                    }
499                    _ => panic!("expected inner index"),
500                }
501            }
502            _ => panic!("expected outer index"),
503        }
504    }
505
506    #[test]
507    fn precedence_with_postfix_vs_add() {
508        let ast = parse("a.b + c.d").unwrap();
509        match ast {
510            Expr::Binary { op: BinaryOp::Add, left, right } => {
511                // both sides should be postfix chains
512                match *left {
513                    Expr::Member { object, field } => {
514                        assert_eq!(field, "b");
515                        assert_eq!(*object, Expr::Var("a".into()));
516                    }
517                    _ => panic!("left not a member"),
518                }
519                match *right {
520                    Expr::Member { object, field } => {
521                        assert_eq!(field, "d");
522                        assert_eq!(*object, Expr::Var("c".into()));
523                    }
524                    _ => panic!("right not a member"),
525                }
526            }
527            _ => panic!("top not add"),
528        }
529    }
530
531    #[test]
532    fn precedence_with_postfix_and_logical() {
533        let ast = parse("a.b(c) && d.e").unwrap();
534        match ast {
535            Expr::Binary { op: BinaryOp::And, left, right } => {
536                match *left {
537                    Expr::Call { callee, args } => {
538                        assert_eq!(args.len(), 1);
539                        assert_eq!(args[0], Expr::Var("c".into()));
540                        match *callee {
541                            Expr::Member { object, field } => {
542                                assert_eq!(field, "b");
543                                assert_eq!(*object, Expr::Var("a".into()));
544                            }
545                            _ => panic!("left not call(member)"),
546                        }
547                    }
548                    _ => panic!("left not call"),
549                }
550                match *right {
551                    Expr::Member { object, field } => {
552                        assert_eq!(field, "e");
553                        assert_eq!(*object, Expr::Var("d".into()));
554                    }
555                    _ => panic!("right not member"),
556                }
557            }
558            _ => panic!("top not &&"),
559        }
560    }
561
562    #[test]
563    fn parse_in_braces_allows_suffix() {
564        let res = parse_in_braces("'A'}-");
565        assert!(res.is_ok(), "parse_in_braces failed: {:?}", res);
566        let (expr, consumed) = res.unwrap();
567        assert_eq!(consumed, 4, "consumed should include string and '}}'");
568        assert_eq!(expr, Expr::Literal(Primitive::Str("A".into())));
569    }
570}