Skip to main content

haystack_core/expr/
parser.rs

1//! Hand-written recursive descent parser for expressions.
2
3use super::ast::*;
4use crate::kinds::{Kind, Number};
5
6/// Error produced when parsing an expression fails.
7#[derive(Debug)]
8pub struct ExprError {
9    /// Human-readable error message.
10    pub msg: String,
11    /// Byte position in the source where the error occurred.
12    pub pos: usize,
13}
14
15impl std::fmt::Display for ExprError {
16    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17        write!(f, "expr error at {}: {}", self.pos, self.msg)
18    }
19}
20
21impl std::error::Error for ExprError {}
22
23/// Internal parser state.
24struct Parser<'a> {
25    source: &'a str,
26    pos: usize,
27    depth: usize,
28}
29
30impl<'a> Parser<'a> {
31    fn new(source: &'a str) -> Self {
32        Self {
33            source,
34            pos: 0,
35            depth: 0,
36        }
37    }
38
39    fn err(&self, msg: impl Into<String>) -> ExprError {
40        ExprError {
41            msg: msg.into(),
42            pos: self.pos,
43        }
44    }
45
46    fn enter(&mut self) -> Result<(), ExprError> {
47        self.depth += 1;
48        if self.depth > MAX_EXPR_DEPTH {
49            Err(self.err("expression exceeds maximum nesting depth"))
50        } else {
51            Ok(())
52        }
53    }
54
55    fn leave(&mut self) {
56        self.depth -= 1;
57    }
58
59    fn skip_ws(&mut self) {
60        while self.pos < self.source.len() {
61            let b = self.source.as_bytes()[self.pos];
62            if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
63                self.pos += 1;
64            } else {
65                break;
66            }
67        }
68    }
69
70    fn at_end(&self) -> bool {
71        self.pos >= self.source.len()
72    }
73
74    fn peek(&self) -> Option<u8> {
75        if self.pos < self.source.len() {
76            Some(self.source.as_bytes()[self.pos])
77        } else {
78            None
79        }
80    }
81
82    fn consume(&mut self, ch: u8) -> Result<(), ExprError> {
83        self.skip_ws();
84        if self.peek() == Some(ch) {
85            self.pos += 1;
86            Ok(())
87        } else {
88            Err(self.err(format!("expected '{}'", ch as char)))
89        }
90    }
91
92    fn starts_with(&self, s: &str) -> bool {
93        self.source[self.pos..].starts_with(s)
94    }
95
96    /// Check if identifier keyword `kw` starts at current position and is
97    /// followed by a non-identifier character.
98    fn keyword(&self, kw: &str) -> bool {
99        if !self.starts_with(kw) {
100            return false;
101        }
102        let after = self.pos + kw.len();
103        if after >= self.source.len() {
104            return true;
105        }
106        let b = self.source.as_bytes()[after];
107        !b.is_ascii_alphanumeric() && b != b'_'
108    }
109
110    fn consume_keyword(&mut self, kw: &str) -> Result<(), ExprError> {
111        self.skip_ws();
112        if self.keyword(kw) {
113            self.pos += kw.len();
114            Ok(())
115        } else {
116            Err(self.err(format!("expected '{kw}'")))
117        }
118    }
119
120    fn read_ident(&mut self) -> Result<String, ExprError> {
121        self.skip_ws();
122        let start = self.pos;
123        while self.pos < self.source.len() {
124            let b = self.source.as_bytes()[self.pos];
125            if b.is_ascii_alphanumeric() || b == b'_' {
126                self.pos += 1;
127            } else {
128                break;
129            }
130        }
131        if self.pos == start {
132            return Err(self.err("expected identifier"));
133        }
134        Ok(self.source[start..self.pos].to_string())
135    }
136
137    // ── Grammar rules ──────────────────────────────────────────────
138
139    fn parse_expr(&mut self) -> Result<ExprNode, ExprError> {
140        self.enter()?;
141        let node = self.parse_logic_or()?;
142        self.leave();
143        Ok(node)
144    }
145
146    fn parse_logic_or(&mut self) -> Result<ExprNode, ExprError> {
147        let mut left = self.parse_logic_and()?;
148        loop {
149            self.skip_ws();
150            if self.keyword("or") {
151                self.pos += 2;
152                let right = self.parse_logic_and()?;
153                left = ExprNode::Logical {
154                    left: Box::new(left),
155                    op: LogicOp::Or,
156                    right: Box::new(right),
157                };
158            } else {
159                break;
160            }
161        }
162        Ok(left)
163    }
164
165    fn parse_logic_and(&mut self) -> Result<ExprNode, ExprError> {
166        let mut left = self.parse_comparison()?;
167        loop {
168            self.skip_ws();
169            if self.keyword("and") {
170                self.pos += 3;
171                let right = self.parse_comparison()?;
172                left = ExprNode::Logical {
173                    left: Box::new(left),
174                    op: LogicOp::And,
175                    right: Box::new(right),
176                };
177            } else {
178                break;
179            }
180        }
181        Ok(left)
182    }
183
184    fn parse_comparison(&mut self) -> Result<ExprNode, ExprError> {
185        let left = self.parse_additive()?;
186        self.skip_ws();
187        let op = if self.starts_with("!=") {
188            self.pos += 2;
189            Some(CmpOp::Ne)
190        } else if self.starts_with("==") {
191            self.pos += 2;
192            Some(CmpOp::Eq)
193        } else if self.starts_with("<=") {
194            self.pos += 2;
195            Some(CmpOp::Le)
196        } else if self.starts_with(">=") {
197            self.pos += 2;
198            Some(CmpOp::Ge)
199        } else if self.peek() == Some(b'<') {
200            self.pos += 1;
201            Some(CmpOp::Lt)
202        } else if self.peek() == Some(b'>') {
203            self.pos += 1;
204            Some(CmpOp::Gt)
205        } else {
206            None
207        };
208        if let Some(op) = op {
209            let right = self.parse_additive()?;
210            Ok(ExprNode::Comparison {
211                left: Box::new(left),
212                op,
213                right: Box::new(right),
214            })
215        } else {
216            Ok(left)
217        }
218    }
219
220    fn parse_additive(&mut self) -> Result<ExprNode, ExprError> {
221        let mut left = self.parse_multiplicative()?;
222        loop {
223            self.skip_ws();
224            match self.peek() {
225                Some(b'+') => {
226                    self.pos += 1;
227                    let right = self.parse_multiplicative()?;
228                    left = ExprNode::BinaryOp {
229                        left: Box::new(left),
230                        op: BinOp::Add,
231                        right: Box::new(right),
232                    };
233                }
234                Some(b'-') => {
235                    self.pos += 1;
236                    let right = self.parse_multiplicative()?;
237                    left = ExprNode::BinaryOp {
238                        left: Box::new(left),
239                        op: BinOp::Sub,
240                        right: Box::new(right),
241                    };
242                }
243                _ => break,
244            }
245        }
246        Ok(left)
247    }
248
249    fn parse_multiplicative(&mut self) -> Result<ExprNode, ExprError> {
250        let mut left = self.parse_unary()?;
251        loop {
252            self.skip_ws();
253            match self.peek() {
254                Some(b'*') => {
255                    self.pos += 1;
256                    let right = self.parse_unary()?;
257                    left = ExprNode::BinaryOp {
258                        left: Box::new(left),
259                        op: BinOp::Mul,
260                        right: Box::new(right),
261                    };
262                }
263                Some(b'/') => {
264                    self.pos += 1;
265                    let right = self.parse_unary()?;
266                    left = ExprNode::BinaryOp {
267                        left: Box::new(left),
268                        op: BinOp::Div,
269                        right: Box::new(right),
270                    };
271                }
272                Some(b'%') => {
273                    self.pos += 1;
274                    let right = self.parse_unary()?;
275                    left = ExprNode::BinaryOp {
276                        left: Box::new(left),
277                        op: BinOp::Mod,
278                        right: Box::new(right),
279                    };
280                }
281                _ => break,
282            }
283        }
284        Ok(left)
285    }
286
287    fn parse_unary(&mut self) -> Result<ExprNode, ExprError> {
288        self.skip_ws();
289        if self.peek() == Some(b'-') {
290            self.pos += 1;
291            let operand = self.parse_unary()?;
292            return Ok(ExprNode::UnaryOp {
293                op: UnOp::Neg,
294                operand: Box::new(operand),
295            });
296        }
297        if self.peek() == Some(b'!') {
298            self.pos += 1;
299            let operand = self.parse_unary()?;
300            return Ok(ExprNode::UnaryOp {
301                op: UnOp::Not,
302                operand: Box::new(operand),
303            });
304        }
305        if self.keyword("not") {
306            self.pos += 3;
307            let operand = self.parse_unary()?;
308            return Ok(ExprNode::UnaryOp {
309                op: UnOp::Not,
310                operand: Box::new(operand),
311            });
312        }
313        self.parse_call()
314    }
315
316    fn parse_call(&mut self) -> Result<ExprNode, ExprError> {
317        self.skip_ws();
318        let start = self.pos;
319
320        // Try to parse an identifier that could be a function name.
321        // We need to check if the next char starts an identifier (alpha/underscore)
322        // and is NOT a keyword (true/false/null/if/not) followed by '('.
323        if self.pos < self.source.len() {
324            let b = self.source.as_bytes()[self.pos];
325            if (b.is_ascii_alphabetic() || b == b'_')
326                && !self.keyword("true")
327                && !self.keyword("false")
328                && !self.keyword("null")
329                && !self.keyword("if")
330                && !self.keyword("not")
331            {
332                let name = self.read_ident()?;
333                self.skip_ws();
334                if self.peek() == Some(b'(') {
335                    self.pos += 1;
336                    let mut args = Vec::new();
337                    self.skip_ws();
338                    if self.peek() != Some(b')') {
339                        args.push(self.parse_expr()?);
340                        loop {
341                            self.skip_ws();
342                            if self.peek() == Some(b',') {
343                                self.pos += 1;
344                                args.push(self.parse_expr()?);
345                            } else {
346                                break;
347                            }
348                        }
349                    }
350                    self.consume(b')')?;
351                    return Ok(ExprNode::FnCall { name, args });
352                }
353                // Not a function call — backtrack.
354                self.pos = start;
355            }
356        }
357
358        self.parse_primary()
359    }
360
361    fn parse_primary(&mut self) -> Result<ExprNode, ExprError> {
362        self.skip_ws();
363
364        if self.at_end() {
365            return Err(self.err("unexpected end of expression"));
366        }
367
368        // Number literal
369        let b = self.source.as_bytes()[self.pos];
370        if b.is_ascii_digit()
371            || (b == b'.'
372                && self.pos + 1 < self.source.len()
373                && self.source.as_bytes()[self.pos + 1].is_ascii_digit())
374        {
375            return self.parse_number();
376        }
377
378        // String literal
379        if b == b'"' {
380            return self.parse_string();
381        }
382
383        // true / false
384        if self.keyword("true") {
385            self.pos += 4;
386            return Ok(ExprNode::Literal(Kind::Bool(true)));
387        }
388        if self.keyword("false") {
389            self.pos += 5;
390            return Ok(ExprNode::Literal(Kind::Bool(false)));
391        }
392
393        // null
394        if self.keyword("null") {
395            self.pos += 4;
396            return Ok(ExprNode::Literal(Kind::Null));
397        }
398
399        // Variable: $ident
400        if b == b'$' {
401            self.pos += 1;
402            let name = self.read_ident()?;
403            return Ok(ExprNode::Variable(name));
404        }
405
406        // Parenthesised expression
407        if b == b'(' {
408            self.pos += 1;
409            let node = self.parse_expr()?;
410            self.consume(b')')?;
411            return Ok(node);
412        }
413
414        // Conditional: if expr then expr else expr
415        if self.keyword("if") {
416            self.pos += 2;
417            let cond = self.parse_expr()?;
418            self.consume_keyword("then")?;
419            let then_expr = self.parse_expr()?;
420            self.consume_keyword("else")?;
421            let else_expr = self.parse_expr()?;
422            return Ok(ExprNode::Conditional {
423                cond: Box::new(cond),
424                then_expr: Box::new(then_expr),
425                else_expr: Box::new(else_expr),
426            });
427        }
428
429        Err(self.err(format!("unexpected character '{}'", b as char)))
430    }
431
432    fn parse_number(&mut self) -> Result<ExprNode, ExprError> {
433        let start = self.pos;
434        while self.pos < self.source.len() && self.source.as_bytes()[self.pos].is_ascii_digit() {
435            self.pos += 1;
436        }
437        if self.pos < self.source.len() && self.source.as_bytes()[self.pos] == b'.' {
438            self.pos += 1;
439            while self.pos < self.source.len() && self.source.as_bytes()[self.pos].is_ascii_digit()
440            {
441                self.pos += 1;
442            }
443        }
444        let s = &self.source[start..self.pos];
445        let val: f64 = s
446            .parse()
447            .map_err(|_| self.err(format!("invalid number '{s}'")))?;
448        Ok(ExprNode::Literal(Kind::Number(Number::unitless(val))))
449    }
450
451    fn parse_string(&mut self) -> Result<ExprNode, ExprError> {
452        self.pos += 1; // skip opening "
453        let start = self.pos;
454        while self.pos < self.source.len() && self.source.as_bytes()[self.pos] != b'"' {
455            if self.source.as_bytes()[self.pos] == b'\\' {
456                self.pos += 1;
457                if self.pos >= self.source.len() {
458                    return Err(self.err("unterminated string escape"));
459                }
460                match self.source.as_bytes()[self.pos] {
461                    b'"' | b'\\' | b'n' | b't' | b'r' => {}
462                    ch => {
463                        return Err(self.err(format!("invalid escape sequence: \\{}", ch as char)));
464                    }
465                }
466            }
467            self.pos += 1;
468        }
469        if self.pos >= self.source.len() {
470            return Err(self.err("unterminated string"));
471        }
472        let s = self.source[start..self.pos].to_string();
473        self.pos += 1; // skip closing "
474        Ok(ExprNode::Literal(Kind::Str(s)))
475    }
476}
477
478/// Parse an expression source string into an AST.
479pub fn parse_expr(source: &str) -> Result<ExprNode, ExprError> {
480    if source.len() > MAX_EXPR_SOURCE {
481        return Err(ExprError {
482            msg: format!("expression source exceeds maximum length of {MAX_EXPR_SOURCE} bytes"),
483            pos: 0,
484        });
485    }
486    let mut parser = Parser::new(source);
487    let node = parser.parse_expr()?;
488    parser.skip_ws();
489    if !parser.at_end() {
490        return Err(parser.err("unexpected trailing input"));
491    }
492    Ok(node)
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn parse_number_literal() {
501        let node = parse_expr("42").unwrap();
502        assert!(matches!(node, ExprNode::Literal(Kind::Number(_))));
503    }
504
505    #[test]
506    fn parse_float_literal() {
507        let node = parse_expr("3.14").unwrap();
508        if let ExprNode::Literal(Kind::Number(n)) = &node {
509            // Test that the parser produces a value close to 3.14
510            assert!(n.val > 3.13 && n.val < 3.15);
511        } else {
512            panic!("expected number literal");
513        }
514    }
515
516    #[test]
517    fn parse_string_literal() {
518        let node = parse_expr(r#""hello""#).unwrap();
519        assert!(matches!(node, ExprNode::Literal(Kind::Str(s)) if s == "hello"));
520    }
521
522    #[test]
523    fn parse_bool_true() {
524        let node = parse_expr("true").unwrap();
525        assert!(matches!(node, ExprNode::Literal(Kind::Bool(true))));
526    }
527
528    #[test]
529    fn parse_bool_false() {
530        let node = parse_expr("false").unwrap();
531        assert!(matches!(node, ExprNode::Literal(Kind::Bool(false))));
532    }
533
534    #[test]
535    fn parse_null() {
536        let node = parse_expr("null").unwrap();
537        assert!(matches!(node, ExprNode::Literal(Kind::Null)));
538    }
539
540    #[test]
541    fn parse_variable() {
542        let node = parse_expr("$temp").unwrap();
543        assert!(matches!(node, ExprNode::Variable(ref s) if s == "temp"));
544    }
545
546    #[test]
547    fn parse_addition() {
548        let node = parse_expr("1 + 2").unwrap();
549        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Add, .. }));
550    }
551
552    #[test]
553    fn parse_arithmetic_precedence() {
554        // 1 + 2 * 3 should be 1 + (2 * 3)
555        let node = parse_expr("1 + 2 * 3").unwrap();
556        if let ExprNode::BinaryOp {
557            op: BinOp::Add,
558            right,
559            ..
560        } = &node
561        {
562            assert!(matches!(
563                right.as_ref(),
564                ExprNode::BinaryOp { op: BinOp::Mul, .. }
565            ));
566        } else {
567            panic!("expected Add at top");
568        }
569    }
570
571    #[test]
572    fn parse_subtraction() {
573        let node = parse_expr("5 - 3").unwrap();
574        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Sub, .. }));
575    }
576
577    #[test]
578    fn parse_division() {
579        let node = parse_expr("10 / 2").unwrap();
580        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Div, .. }));
581    }
582
583    #[test]
584    fn parse_modulo() {
585        let node = parse_expr("10 % 3").unwrap();
586        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Mod, .. }));
587    }
588
589    #[test]
590    fn parse_unary_neg() {
591        let node = parse_expr("-5").unwrap();
592        assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Neg, .. }));
593    }
594
595    #[test]
596    fn parse_unary_not_bang() {
597        let node = parse_expr("!true").unwrap();
598        assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Not, .. }));
599    }
600
601    #[test]
602    fn parse_unary_not_keyword() {
603        let node = parse_expr("not false").unwrap();
604        assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Not, .. }));
605    }
606
607    #[test]
608    fn parse_comparison_eq() {
609        let node = parse_expr("$x == 5").unwrap();
610        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Eq, .. }));
611    }
612
613    #[test]
614    fn parse_comparison_ne() {
615        let node = parse_expr("$x != 5").unwrap();
616        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Ne, .. }));
617    }
618
619    #[test]
620    fn parse_comparison_lt() {
621        let node = parse_expr("$x < 10").unwrap();
622        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Lt, .. }));
623    }
624
625    #[test]
626    fn parse_comparison_le() {
627        let node = parse_expr("$x <= 10").unwrap();
628        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Le, .. }));
629    }
630
631    #[test]
632    fn parse_comparison_gt() {
633        let node = parse_expr("$x > 0").unwrap();
634        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Gt, .. }));
635    }
636
637    #[test]
638    fn parse_comparison_ge() {
639        let node = parse_expr("$x >= 0").unwrap();
640        assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Ge, .. }));
641    }
642
643    #[test]
644    fn parse_logical_and() {
645        let node = parse_expr("true and false").unwrap();
646        assert!(matches!(
647            node,
648            ExprNode::Logical {
649                op: LogicOp::And,
650                ..
651            }
652        ));
653    }
654
655    #[test]
656    fn parse_logical_or() {
657        let node = parse_expr("true or false").unwrap();
658        assert!(matches!(
659            node,
660            ExprNode::Logical {
661                op: LogicOp::Or,
662                ..
663            }
664        ));
665    }
666
667    #[test]
668    fn parse_fn_call_one_arg() {
669        let node = parse_expr("abs(-5)").unwrap();
670        if let ExprNode::FnCall { name, args } = &node {
671            assert_eq!(name, "abs");
672            assert_eq!(args.len(), 1);
673        } else {
674            panic!("expected FnCall");
675        }
676    }
677
678    #[test]
679    fn parse_fn_call_two_args() {
680        let node = parse_expr("min(1, 2)").unwrap();
681        if let ExprNode::FnCall { name, args } = &node {
682            assert_eq!(name, "min");
683            assert_eq!(args.len(), 2);
684        } else {
685            panic!("expected FnCall");
686        }
687    }
688
689    #[test]
690    fn parse_fn_call_no_args() {
691        let node = parse_expr("foo()").unwrap();
692        if let ExprNode::FnCall { name, args } = &node {
693            assert_eq!(name, "foo");
694            assert!(args.is_empty());
695        } else {
696            panic!("expected FnCall");
697        }
698    }
699
700    #[test]
701    fn parse_conditional() {
702        let node = parse_expr("if true then 1 else 0").unwrap();
703        assert!(matches!(node, ExprNode::Conditional { .. }));
704    }
705
706    #[test]
707    fn parse_parenthesised() {
708        let node = parse_expr("(1 + 2) * 3").unwrap();
709        assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Mul, .. }));
710    }
711
712    #[test]
713    fn parse_complex_expression() {
714        let node = parse_expr("$a + $b * 2 > 10 and $c != 0").unwrap();
715        assert!(matches!(
716            node,
717            ExprNode::Logical {
718                op: LogicOp::And,
719                ..
720            }
721        ));
722    }
723
724    #[test]
725    fn error_empty_input() {
726        let err = parse_expr("").unwrap_err();
727        assert!(err.msg.contains("unexpected end"));
728    }
729
730    #[test]
731    fn error_trailing_input() {
732        let err = parse_expr("1 2").unwrap_err();
733        assert!(err.msg.contains("trailing"));
734    }
735
736    #[test]
737    fn error_unterminated_string() {
738        let err = parse_expr(r#""hello"#).unwrap_err();
739        assert!(err.msg.contains("unterminated"));
740    }
741
742    #[test]
743    fn error_source_too_long() {
744        let long = "1+".repeat(MAX_EXPR_SOURCE);
745        let err = parse_expr(&long).unwrap_err();
746        assert!(err.msg.contains("maximum length"));
747    }
748
749    #[test]
750    fn error_depth_exceeded() {
751        // Build deeply nested parens: ((((...))))
752        let open: String = "(".repeat(MAX_EXPR_DEPTH + 10);
753        let close: String = ")".repeat(MAX_EXPR_DEPTH + 10);
754        let src = format!("{open}1{close}");
755        let err = parse_expr(&src).unwrap_err();
756        assert!(err.msg.contains("depth"));
757    }
758
759    #[test]
760    fn error_display() {
761        let err = ExprError {
762            msg: "bad".into(),
763            pos: 5,
764        };
765        assert_eq!(err.to_string(), "expr error at 5: bad");
766    }
767
768    #[test]
769    fn parse_nested_fn_calls() {
770        let node = parse_expr("max(abs(-1), min(2, 3))").unwrap();
771        if let ExprNode::FnCall { name, args } = &node {
772            assert_eq!(name, "max");
773            assert_eq!(args.len(), 2);
774        } else {
775            panic!("expected FnCall");
776        }
777    }
778
779    #[test]
780    fn parse_logical_precedence() {
781        // `a or b and c` should be `a or (b and c)` since and binds tighter
782        let node = parse_expr("true or false and true").unwrap();
783        if let ExprNode::Logical {
784            op: LogicOp::Or,
785            right,
786            ..
787        } = &node
788        {
789            assert!(matches!(
790                right.as_ref(),
791                ExprNode::Logical {
792                    op: LogicOp::And,
793                    ..
794                }
795            ));
796        } else {
797            panic!("expected Or at top");
798        }
799    }
800}