Skip to main content

haystack_core/expr/
parser.rs

1//! Hand-written recursive descent parser for expressions.
2
3use super::ast::*;
4use crate::kinds::{Kind, Number};
5
6/// Error produced when parsing an expression fails.
7#[derive(Debug)]
8pub struct ExprError {
9    /// Human-readable error message.
10    pub msg: String,
11    /// Byte position in the source where the error occurred.
12    pub pos: usize,
13}
14
15impl std::fmt::Display for ExprError {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        write!(f, "expr error at {}: {}", self.pos, self.msg)
18    }
19}
20
21impl std::error::Error for ExprError {}
22
23/// Internal parser state.
24struct Parser<'a> {
25    source: &'a str,
26    pos: usize,
27    depth: usize,
28}
29
30impl<'a> Parser<'a> {
31    fn new(source: &'a str) -> Self {
32        Self {
33            source,
34            pos: 0,
35            depth: 0,
36        }
37    }
38
39    fn err(&self, msg: impl Into<String>) -> ExprError {
40        ExprError {
41            msg: msg.into(),
42            pos: self.pos,
43        }
44    }
45
46    fn enter(&mut self) -> Result<(), ExprError> {
47        self.depth += 1;
48        if self.depth > MAX_EXPR_DEPTH {
49            Err(self.err("expression exceeds maximum nesting depth"))
50        } else {
51            Ok(())
52        }
53    }
54
55    fn leave(&mut self) {
56        self.depth -= 1;
57    }
58
59    fn skip_ws(&mut self) {
60        while self.pos < self.source.len() {
61            let b = self.source.as_bytes()[self.pos];
62            if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
63                self.pos += 1;
64            } else {
65                break;
66            }
67        }
68    }
69
70    fn at_end(&self) -> bool {
71        self.pos >= self.source.len()
72    }
73
74    fn peek(&self) -> Option<u8> {
75        if self.pos < self.source.len() {
76            Some(self.source.as_bytes()[self.pos])
77        } else {
78            None
79        }
80    }
81
82    fn consume(&mut self, ch: u8) -> Result<(), ExprError> {
83        self.skip_ws();
84        if self.peek() == Some(ch) {
85            self.pos += 1;
86            Ok(())
87        } else {
88            Err(self.err(format!("expected '{}'", ch as char)))
89        }
90    }
91
92    fn starts_with(&self, s: &str) -> bool {
93        self.source[self.pos..].starts_with(s)
94    }
95
96    /// Check if identifier keyword `kw` starts at current position and is
97    /// followed by a non-identifier character.
98    fn keyword(&self, kw: &str) -> bool {
99        if !self.starts_with(kw) {
100            return false;
101        }
102        let after = self.pos + kw.len();
103        if after >= self.source.len() {
104            return true;
105        }
106        let b = self.source.as_bytes()[after];
107        !b.is_ascii_alphanumeric() && b != b'_'
108    }
109
110    fn consume_keyword(&mut self, kw: &str) -> Result<(), ExprError> {
111        self.skip_ws();
112        if self.keyword(kw) {
113            self.pos += kw.len();
114            Ok(())
115        } else {
116            Err(self.err(format!("expected '{kw}'")))
117        }
118    }
119
120    fn read_ident(&mut self) -> Result<String, ExprError> {
121        self.skip_ws();
122        let start = self.pos;
123        while self.pos < self.source.len() {
124            let b = self.source.as_bytes()[self.pos];
125            if b.is_ascii_alphanumeric() || b == b'_' {
126                self.pos += 1;
127            } else {
128                break;
129            }
130        }
131        if self.pos == start {
132            return Err(self.err("expected identifier"));
133        }
134        Ok(self.source[start..self.pos].to_string())
135    }
136
137    // ── Grammar rules ──────────────────────────────────────────────
138
139    fn parse_expr(&mut self) -> Result<ExprNode, ExprError> {
140        self.enter()?;
141        let node = self.parse_logic_or()?;
142        self.leave();
143        Ok(node)
144    }
145
146    fn parse_logic_or(&mut self) -> Result<ExprNode, ExprError> {
147        let mut left = self.parse_logic_and()?;
148        loop {
149            self.skip_ws();
150            if self.keyword("or") {
151                self.pos += 2;
152                let right = self.parse_logic_and()?;
153                left = ExprNode::Logical {
154                    left: Box::new(left),
155                    op: LogicOp::Or,
156                    right: Box::new(right),
157                };
158            } else {
159                break;
160            }
161        }
162        Ok(left)
163    }
164
165    fn parse_logic_and(&mut self) -> Result<ExprNode, ExprError> {
166        let mut left = self.parse_comparison()?;
167        loop {
168            self.skip_ws();
169            if self.keyword("and") {
170                self.pos += 3;
171                let right = self.parse_comparison()?;
172                left = ExprNode::Logical {
173                    left: Box::new(left),
174                    op: LogicOp::And,
175                    right: Box::new(right),
176                };
177            } else {
178                break;
179            }
180        }
181        Ok(left)
182    }
183
184    fn parse_comparison(&mut self) -> Result<ExprNode, ExprError> {
185        let left = self.parse_additive()?;
186        self.skip_ws();
187        let op = if self.starts_with("!=") {
188            self.pos += 2;
189            Some(CmpOp::Ne)
190        } else if self.starts_with("==") {
191            self.pos += 2;
192            Some(CmpOp::Eq)
193        } else if self.starts_with("<=") {
194            self.pos += 2;
195            Some(CmpOp::Le)
196        } else if self.starts_with(">=") {
197            self.pos += 2;
198            Some(CmpOp::Ge)
199        } else if self.peek() == Some(b'<') {
200            self.pos += 1;
201            Some(CmpOp::Lt)
202        } else if self.peek() == Some(b'>') {
203            self.pos += 1;
204            Some(CmpOp::Gt)
205        } else {
206            None
207        };
208        if let Some(op) = op {
209            let right = self.parse_additive()?;
210            Ok(ExprNode::Comparison {
211                left: Box::new(left),
212                op,
213                right: Box::new(right),
214            })
215        } else {
216            Ok(left)
217        }
218    }
219
220    fn parse_additive(&mut self) -> Result<ExprNode, ExprError> {
221        let mut left = self.parse_multiplicative()?;
222        loop {
223            self.skip_ws();
224            match self.peek() {
225                Some(b'+') => {
226                    self.pos += 1;
227                    let right = self.parse_multiplicative()?;
228                    left = ExprNode::BinaryOp {
229                        left: Box::new(left),
230                        op: BinOp::Add,
231                        right: Box::new(right),
232                    };
233                }
234                Some(b'-') => {
235                    self.pos += 1;
236                    let right = self.parse_multiplicative()?;
237                    left = ExprNode::BinaryOp {
238                        left: Box::new(left),
239                        op: BinOp::Sub,
240                        right: Box::new(right),
241                    };
242                }
243                _ => break,
244            }
245        }
246        Ok(left)
247    }
248
249    fn parse_multiplicative(&mut self) -> Result<ExprNode, ExprError> {
250        let mut left = self.parse_unary()?;
251        loop {
252            self.skip_ws();
253            match self.peek() {
254                Some(b'*') => {
255                    self.pos += 1;
256                    let right = self.parse_unary()?;
257                    left = ExprNode::BinaryOp {
258                        left: Box::new(left),
259                        op: BinOp::Mul,
260                        right: Box::new(right),
261                    };
262                }
263                Some(b'/') => {
264                    self.pos += 1;
265                    let right = self.parse_unary()?;
266                    left = ExprNode::BinaryOp {
267                        left: Box::new(left),
268                        op: BinOp::Div,
269                        right: Box::new(right),
270                    };
271                }
272                Some(b'%') => {
273                    self.pos += 1;
274                    let right = self.parse_unary()?;
275                    left = ExprNode::BinaryOp {
276                        left: Box::new(left),
277                        op: BinOp::Mod,
278                        right: Box::new(right),
279                    };
280                }
281                _ => break,
282            }
283        }
284        Ok(left)
285    }
286
287    fn parse_unary(&mut self) -> Result<ExprNode, ExprError> {
288        self.skip_ws();
289        if self.peek() == Some(b'-') {
290            self.pos += 1;
291            let operand = self.parse_unary()?;
292            return Ok(ExprNode::UnaryOp {
293                op: UnOp::Neg,
294                operand: Box::new(operand),
295            });
296        }
297        if self.peek() == Some(b'!') {
298            self.pos += 1;
299            let operand = self.parse_unary()?;
300            return Ok(ExprNode::UnaryOp {
301                op: UnOp::Not,
302                operand: Box::new(operand),
303            });
304        }
305        if self.keyword("not") {
306            self.pos += 3;
307            let operand = self.parse_unary()?;
308            return Ok(ExprNode::UnaryOp {
309                op: UnOp::Not,
310                operand: Box::new(operand),
311            });
312        }
313        self.parse_call()
314    }
315
316    fn parse_call(&mut self) -> Result<ExprNode, ExprError> {
317        self.skip_ws();
318        let start = self.pos;
319
320        // Try to parse an identifier that could be a function name.
321        // We need to check if the next char starts an identifier (alpha/underscore)
322        // and is NOT a keyword (true/false/null/if/not) followed by '('.
323        if self.pos < self.source.len() {
324            let b = self.source.as_bytes()[self.pos];
325            if (b.is_ascii_alphabetic() || b == b'_')
326                && !self.keyword("true")
327                && !self.keyword("false")
328                && !self.keyword("null")
329                && !self.keyword("if")
330                && !self.keyword("not")
331            {
332                let name = self.read_ident()?;
333                self.skip_ws();
334                if self.peek() == Some(b'(') {
335                    self.pos += 1;
336                    let mut args = Vec::new();
337                    self.skip_ws();
338                    if self.peek() != Some(b')') {
339                        args.push(self.parse_expr()?);
340                        loop {
341                            self.skip_ws();
342                            if self.peek() == Some(b',') {
343                                self.pos += 1;
344                                args.push(self.parse_expr()?);
345                            } else {
346                                break;
347                            }
348                        }
349                    }
350                    self.consume(b')')?;
351                    return Ok(ExprNode::FnCall { name, args });
352                }
353                // Not a function call — backtrack.
354                self.pos = start;
355            }
356        }
357
358        self.parse_primary()
359    }
360
361    fn parse_primary(&mut self) -> Result<ExprNode, ExprError> {
362        self.skip_ws();
363
364        if self.at_end() {
365            return Err(self.err("unexpected end of expression"));
366        }
367
368        // Number literal
369        let b = self.source.as_bytes()[self.pos];
370        if b.is_ascii_digit()
371            || (b == b'.'
372                && self.pos + 1 < self.source.len()
373                && self.source.as_bytes()[self.pos + 1].is_ascii_digit())
374        {
375            return self.parse_number();
376        }
377
378        // String literal
379        if b == b'"' {
380            return self.parse_string();
381        }
382
383        // true / false
384        if self.keyword("true") {
385            self.pos += 4;
386            return Ok(ExprNode::Literal(Kind::Bool(true)));
387        }
388        if self.keyword("false") {
389            self.pos += 5;
390            return Ok(ExprNode::Literal(Kind::Bool(false)));
391        }
392
393        // null
394        if self.keyword("null") {
395            self.pos += 4;
396            return Ok(ExprNode::Literal(Kind::Null));
397        }
398
399        // Variable: $ident
400        if b == b'$' {
401            self.pos += 1;
402            let name = self.read_ident()?;
403            return Ok(ExprNode::Variable(name));
404        }
405
406        // Parenthesised expression
407        if b == b'(' {
408            self.pos += 1;
409            let node = self.parse_expr()?;
410            self.consume(b')')?;
411            return Ok(node);
412        }
413
414        // Conditional: if expr then expr else expr
415        if self.keyword("if") {
416            self.pos += 2;
417            let cond = self.parse_expr()?;
418            self.consume_keyword("then")?;
419            let then_expr = self.parse_expr()?;
420            self.consume_keyword("else")?;
421            let else_expr = self.parse_expr()?;
422            return Ok(ExprNode::Conditional {
423                cond: Box::new(cond),
424                then_expr: Box::new(then_expr),
425                else_expr: Box::new(else_expr),
426            });
427        }
428
429        Err(self.err(format!("unexpected character '{}'", b as char)))
430    }
431
432    fn parse_number(&mut self) -> Result<ExprNode, ExprError> {
433        let start = self.pos;
434        while self.pos < self.source.len() && self.source.as_bytes()[self.pos].is_ascii_digit() {
435            self.pos += 1;
436        }
437        if self.pos < self.source.len() && self.source.as_bytes()[self.pos] == b'.' {
438            self.pos += 1;
439            while self.pos < self.source.len() && self.source.as_bytes()[self.pos].is_ascii_digit()
440            {
441                self.pos += 1;
442            }
443        }
444        let s = &self.source[start..self.pos];
445        let val: f64 = s
446            .parse()
447            .map_err(|_| self.err(format!("invalid number '{s}'")))?;
448        Ok(ExprNode::Literal(Kind::Number(Number::unitless(val))))
449    }
450
451    fn parse_string(&mut self) -> Result<ExprNode, ExprError> {
452        self.pos += 1; // skip opening "
453        let start = self.pos;
454        while self.pos < self.source.len() && self.source.as_bytes()[self.pos] != b'"' {
455            if self.source.as_bytes()[self.pos] == b'\\' {
456                self.pos += 1;
457                if self.pos >= self.source.len() {
458                    return Err(self.err("unterminated string escape"));
459                }
460                match self.source.as_bytes()[self.pos] {
461                    b'"' | b'\\' | b'n' | b't' | b'r' => {}
462                    ch => {
463                        return Err(self.err(format!("invalid escape sequence: \\{}", ch as char)));
464                    }
465                }
466            }
467            self.pos += 1;
468        }
469        if self.pos >= self.source.len() {
470            return Err(self.err("unterminated string"));
471        }
472        let s = self.source[start..self.pos].to_string();
473        self.pos += 1; // skip closing "
474        Ok(ExprNode::Literal(Kind::Str(s)))
475    }
476}
477
478/// Parse an expression source string into an AST.
479pub fn parse_expr(source: &str) -> Result<ExprNode, ExprError> {
480    if source.len() > MAX_EXPR_SOURCE {
481        return Err(ExprError {
482            msg: format!("expression source exceeds maximum length of {MAX_EXPR_SOURCE} bytes"),
483            pos: 0,
484        });
485    }
486    let mut parser = Parser::new(source);
487    let node = parser.parse_expr()?;
488    parser.skip_ws();
489    if !parser.at_end() {
490        return Err(parser.err("unexpected trailing input"));
491    }
492    Ok(node)
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn parse_number_literal() {
501        let node = parse_expr("42").unwrap();
502        assert!(matches!(node, ExprNode::Literal(Kind::Number(_))));
503    }
504
505    #[test]
506    fn parse_float_literal() {
507        let node = parse_expr("3.14").unwrap();
508        if let ExprNode::Literal(Kind::Number(n)) = &node {
509            assert!((n.val - 3.14).abs() < 1e-10);
510        } else {
511            panic!("expected number literal");
512        }
513    }
514
515    #[test]
516    fn parse_string_literal() {
517        let node = parse_expr(r#""hello""#).unwrap();
518        assert!(matches!(node, ExprNode::Literal(Kind::Str(s)) if s == "hello"));
519    }
520
521    #[test]
522    fn parse_bool_true() {
523        let node = parse_expr("true").unwrap();
524        assert!(matches!(node, ExprNode::Literal(Kind::Bool(true))));
525    }
526
527    #[test]
528    fn parse_bool_false() {
529        let node = parse_expr("false").unwrap();
530        assert!(matches!(node, ExprNode::Literal(Kind::Bool(false))));
531    }
532
533    #[test]
534    fn parse_null() {
535        let node = parse_expr("null").unwrap();
536        assert!(matches!(node, ExprNode::Literal(Kind::Null)));
537    }
538
539    #[test]
540    fn parse_variable() {
541        let node = parse_expr("$temp").unwrap();
542        assert!(matches!(node, ExprNode::Variable(ref s) if s == "temp"));
543    }
544
545    #[test]
546    fn parse_addition() {
547        let node = parse_expr("1 + 2").unwrap();
548        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Add, .. }));
549    }
550
551    #[test]
552    fn parse_arithmetic_precedence() {
553        // 1 + 2 * 3 should be 1 + (2 * 3)
554        let node = parse_expr("1 + 2 * 3").unwrap();
555        if let ExprNode::BinaryOp {
556            op: BinOp::Add,
557            right,
558            ..
559        } = &node
560        {
561            assert!(matches!(
562                right.as_ref(),
563                ExprNode::BinaryOp { op: BinOp::Mul, .. }
564            ));
565        } else {
566            panic!("expected Add at top");
567        }
568    }
569
570    #[test]
571    fn parse_subtraction() {
572        let node = parse_expr("5 - 3").unwrap();
573        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Sub, .. }));
574    }
575
576    #[test]
577    fn parse_division() {
578        let node = parse_expr("10 / 2").unwrap();
579        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Div, .. }));
580    }
581
582    #[test]
583    fn parse_modulo() {
584        let node = parse_expr("10 % 3").unwrap();
585        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Mod, .. }));
586    }
587
588    #[test]
589    fn parse_unary_neg() {
590        let node = parse_expr("-5").unwrap();
591        assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Neg, .. }));
592    }
593
594    #[test]
595    fn parse_unary_not_bang() {
596        let node = parse_expr("!true").unwrap();
597        assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Not, .. }));
598    }
599
600    #[test]
601    fn parse_unary_not_keyword() {
602        let node = parse_expr("not false").unwrap();
603        assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Not, .. }));
604    }
605
606    #[test]
607    fn parse_comparison_eq() {
608        let node = parse_expr("$x == 5").unwrap();
609        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Eq, .. }));
610    }
611
612    #[test]
613    fn parse_comparison_ne() {
614        let node = parse_expr("$x != 5").unwrap();
615        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Ne, .. }));
616    }
617
618    #[test]
619    fn parse_comparison_lt() {
620        let node = parse_expr("$x < 10").unwrap();
621        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Lt, .. }));
622    }
623
624    #[test]
625    fn parse_comparison_le() {
626        let node = parse_expr("$x <= 10").unwrap();
627        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Le, .. }));
628    }
629
630    #[test]
631    fn parse_comparison_gt() {
632        let node = parse_expr("$x > 0").unwrap();
633        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Gt, .. }));
634    }
635
636    #[test]
637    fn parse_comparison_ge() {
638        let node = parse_expr("$x >= 0").unwrap();
639        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Ge, .. }));
640    }
641
642    #[test]
643    fn parse_logical_and() {
644        let node = parse_expr("true and false").unwrap();
645        assert!(matches!(
646            node,
647            ExprNode::Logical {
648                op: LogicOp::And,
649                ..
650            }
651        ));
652    }
653
654    #[test]
655    fn parse_logical_or() {
656        let node = parse_expr("true or false").unwrap();
657        assert!(matches!(
658            node,
659            ExprNode::Logical {
660                op: LogicOp::Or,
661                ..
662            }
663        ));
664    }
665
666    #[test]
667    fn parse_fn_call_one_arg() {
668        let node = parse_expr("abs(-5)").unwrap();
669        if let ExprNode::FnCall { name, args } = &node {
670            assert_eq!(name, "abs");
671            assert_eq!(args.len(), 1);
672        } else {
673            panic!("expected FnCall");
674        }
675    }
676
677    #[test]
678    fn parse_fn_call_two_args() {
679        let node = parse_expr("min(1, 2)").unwrap();
680        if let ExprNode::FnCall { name, args } = &node {
681            assert_eq!(name, "min");
682            assert_eq!(args.len(), 2);
683        } else {
684            panic!("expected FnCall");
685        }
686    }
687
688    #[test]
689    fn parse_fn_call_no_args() {
690        let node = parse_expr("foo()").unwrap();
691        if let ExprNode::FnCall { name, args } = &node {
692            assert_eq!(name, "foo");
693            assert!(args.is_empty());
694        } else {
695            panic!("expected FnCall");
696        }
697    }
698
699    #[test]
700    fn parse_conditional() {
701        let node = parse_expr("if true then 1 else 0").unwrap();
702        assert!(matches!(node, ExprNode::Conditional { .. }));
703    }
704
705    #[test]
706    fn parse_parenthesised() {
707        let node = parse_expr("(1 + 2) * 3").unwrap();
708        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Mul, .. }));
709    }
710
711    #[test]
712    fn parse_complex_expression() {
713        let node = parse_expr("$a + $b * 2 > 10 and $c != 0").unwrap();
714        assert!(matches!(
715            node,
716            ExprNode::Logical {
717                op: LogicOp::And,
718                ..
719            }
720        ));
721    }
722
723    #[test]
724    fn error_empty_input() {
725        let err = parse_expr("").unwrap_err();
726        assert!(err.msg.contains("unexpected end"));
727    }
728
729    #[test]
730    fn error_trailing_input() {
731        let err = parse_expr("1 2").unwrap_err();
732        assert!(err.msg.contains("trailing"));
733    }
734
735    #[test]
736    fn error_unterminated_string() {
737        let err = parse_expr(r#""hello"#).unwrap_err();
738        assert!(err.msg.contains("unterminated"));
739    }
740
741    #[test]
742    fn error_source_too_long() {
743        let long = "1+".repeat(MAX_EXPR_SOURCE);
744        let err = parse_expr(&long).unwrap_err();
745        assert!(err.msg.contains("maximum length"));
746    }
747
748    #[test]
749    fn error_depth_exceeded() {
750        // Build deeply nested parens: ((((...))))
751        let open: String = "(".repeat(MAX_EXPR_DEPTH + 10);
752        let close: String = ")".repeat(MAX_EXPR_DEPTH + 10);
753        let src = format!("{open}1{close}");
754        let err = parse_expr(&src).unwrap_err();
755        assert!(err.msg.contains("depth"));
756    }
757
758    #[test]
759    fn error_display() {
760        let err = ExprError {
761            msg: "bad".into(),
762            pos: 5,
763        };
764        assert_eq!(err.to_string(), "expr error at 5: bad");
765    }
766
767    #[test]
768    fn parse_nested_fn_calls() {
769        let node = parse_expr("max(abs(-1), min(2, 3))").unwrap();
770        if let ExprNode::FnCall { name, args } = &node {
771            assert_eq!(name, "max");
772            assert_eq!(args.len(), 2);
773        } else {
774            panic!("expected FnCall");
775        }
776    }
777
778    #[test]
779    fn parse_logical_precedence() {
780        // `a or b and c` should be `a or (b and c)` since and binds tighter
781        let node = parse_expr("true or false and true").unwrap();
782        if let ExprNode::Logical {
783            op: LogicOp::Or,
784            right,
785            ..
786        } = &node
787        {
788            assert!(matches!(
789                right.as_ref(),
790                ExprNode::Logical {
791                    op: LogicOp::And,
792                    ..
793                }
794            ));
795        } else {
796            panic!("expected Or at top");
797        }
798    }
799}