1use super::ast::*;
4use crate::kinds::{Kind, Number};
5
6#[derive(Debug)]
8pub struct ExprError {
9 pub msg: String,
11 pub pos: usize,
13}
14
15impl std::fmt::Display for ExprError {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 write!(f, "expr error at {}: {}", self.pos, self.msg)
18 }
19}
20
21impl std::error::Error for ExprError {}
22
23struct Parser<'a> {
25 source: &'a str,
26 pos: usize,
27 depth: usize,
28}
29
30impl<'a> Parser<'a> {
31 fn new(source: &'a str) -> Self {
32 Self {
33 source,
34 pos: 0,
35 depth: 0,
36 }
37 }
38
39 fn err(&self, msg: impl Into<String>) -> ExprError {
40 ExprError {
41 msg: msg.into(),
42 pos: self.pos,
43 }
44 }
45
46 fn enter(&mut self) -> Result<(), ExprError> {
47 self.depth += 1;
48 if self.depth > MAX_EXPR_DEPTH {
49 Err(self.err("expression exceeds maximum nesting depth"))
50 } else {
51 Ok(())
52 }
53 }
54
55 fn leave(&mut self) {
56 self.depth -= 1;
57 }
58
59 fn skip_ws(&mut self) {
60 while self.pos < self.source.len() {
61 let b = self.source.as_bytes()[self.pos];
62 if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
63 self.pos += 1;
64 } else {
65 break;
66 }
67 }
68 }
69
70 fn at_end(&self) -> bool {
71 self.pos >= self.source.len()
72 }
73
74 fn peek(&self) -> Option<u8> {
75 if self.pos < self.source.len() {
76 Some(self.source.as_bytes()[self.pos])
77 } else {
78 None
79 }
80 }
81
82 fn consume(&mut self, ch: u8) -> Result<(), ExprError> {
83 self.skip_ws();
84 if self.peek() == Some(ch) {
85 self.pos += 1;
86 Ok(())
87 } else {
88 Err(self.err(format!("expected '{}'", ch as char)))
89 }
90 }
91
92 fn starts_with(&self, s: &str) -> bool {
93 self.source[self.pos..].starts_with(s)
94 }
95
96 fn keyword(&self, kw: &str) -> bool {
99 if !self.starts_with(kw) {
100 return false;
101 }
102 let after = self.pos + kw.len();
103 if after >= self.source.len() {
104 return true;
105 }
106 let b = self.source.as_bytes()[after];
107 !b.is_ascii_alphanumeric() && b != b'_'
108 }
109
110 fn consume_keyword(&mut self, kw: &str) -> Result<(), ExprError> {
111 self.skip_ws();
112 if self.keyword(kw) {
113 self.pos += kw.len();
114 Ok(())
115 } else {
116 Err(self.err(format!("expected '{kw}'")))
117 }
118 }
119
120 fn read_ident(&mut self) -> Result<String, ExprError> {
121 self.skip_ws();
122 let start = self.pos;
123 while self.pos < self.source.len() {
124 let b = self.source.as_bytes()[self.pos];
125 if b.is_ascii_alphanumeric() || b == b'_' {
126 self.pos += 1;
127 } else {
128 break;
129 }
130 }
131 if self.pos == start {
132 return Err(self.err("expected identifier"));
133 }
134 Ok(self.source[start..self.pos].to_string())
135 }
136
137 fn parse_expr(&mut self) -> Result<ExprNode, ExprError> {
140 self.enter()?;
141 let node = self.parse_logic_or()?;
142 self.leave();
143 Ok(node)
144 }
145
146 fn parse_logic_or(&mut self) -> Result<ExprNode, ExprError> {
147 let mut left = self.parse_logic_and()?;
148 loop {
149 self.skip_ws();
150 if self.keyword("or") {
151 self.pos += 2;
152 let right = self.parse_logic_and()?;
153 left = ExprNode::Logical {
154 left: Box::new(left),
155 op: LogicOp::Or,
156 right: Box::new(right),
157 };
158 } else {
159 break;
160 }
161 }
162 Ok(left)
163 }
164
165 fn parse_logic_and(&mut self) -> Result<ExprNode, ExprError> {
166 let mut left = self.parse_comparison()?;
167 loop {
168 self.skip_ws();
169 if self.keyword("and") {
170 self.pos += 3;
171 let right = self.parse_comparison()?;
172 left = ExprNode::Logical {
173 left: Box::new(left),
174 op: LogicOp::And,
175 right: Box::new(right),
176 };
177 } else {
178 break;
179 }
180 }
181 Ok(left)
182 }
183
184 fn parse_comparison(&mut self) -> Result<ExprNode, ExprError> {
185 let left = self.parse_additive()?;
186 self.skip_ws();
187 let op = if self.starts_with("!=") {
188 self.pos += 2;
189 Some(CmpOp::Ne)
190 } else if self.starts_with("==") {
191 self.pos += 2;
192 Some(CmpOp::Eq)
193 } else if self.starts_with("<=") {
194 self.pos += 2;
195 Some(CmpOp::Le)
196 } else if self.starts_with(">=") {
197 self.pos += 2;
198 Some(CmpOp::Ge)
199 } else if self.peek() == Some(b'<') {
200 self.pos += 1;
201 Some(CmpOp::Lt)
202 } else if self.peek() == Some(b'>') {
203 self.pos += 1;
204 Some(CmpOp::Gt)
205 } else {
206 None
207 };
208 if let Some(op) = op {
209 let right = self.parse_additive()?;
210 Ok(ExprNode::Comparison {
211 left: Box::new(left),
212 op,
213 right: Box::new(right),
214 })
215 } else {
216 Ok(left)
217 }
218 }
219
220 fn parse_additive(&mut self) -> Result<ExprNode, ExprError> {
221 let mut left = self.parse_multiplicative()?;
222 loop {
223 self.skip_ws();
224 match self.peek() {
225 Some(b'+') => {
226 self.pos += 1;
227 let right = self.parse_multiplicative()?;
228 left = ExprNode::BinaryOp {
229 left: Box::new(left),
230 op: BinOp::Add,
231 right: Box::new(right),
232 };
233 }
234 Some(b'-') => {
235 self.pos += 1;
236 let right = self.parse_multiplicative()?;
237 left = ExprNode::BinaryOp {
238 left: Box::new(left),
239 op: BinOp::Sub,
240 right: Box::new(right),
241 };
242 }
243 _ => break,
244 }
245 }
246 Ok(left)
247 }
248
249 fn parse_multiplicative(&mut self) -> Result<ExprNode, ExprError> {
250 let mut left = self.parse_unary()?;
251 loop {
252 self.skip_ws();
253 match self.peek() {
254 Some(b'*') => {
255 self.pos += 1;
256 let right = self.parse_unary()?;
257 left = ExprNode::BinaryOp {
258 left: Box::new(left),
259 op: BinOp::Mul,
260 right: Box::new(right),
261 };
262 }
263 Some(b'/') => {
264 self.pos += 1;
265 let right = self.parse_unary()?;
266 left = ExprNode::BinaryOp {
267 left: Box::new(left),
268 op: BinOp::Div,
269 right: Box::new(right),
270 };
271 }
272 Some(b'%') => {
273 self.pos += 1;
274 let right = self.parse_unary()?;
275 left = ExprNode::BinaryOp {
276 left: Box::new(left),
277 op: BinOp::Mod,
278 right: Box::new(right),
279 };
280 }
281 _ => break,
282 }
283 }
284 Ok(left)
285 }
286
287 fn parse_unary(&mut self) -> Result<ExprNode, ExprError> {
288 self.skip_ws();
289 if self.peek() == Some(b'-') {
290 self.pos += 1;
291 let operand = self.parse_unary()?;
292 return Ok(ExprNode::UnaryOp {
293 op: UnOp::Neg,
294 operand: Box::new(operand),
295 });
296 }
297 if self.peek() == Some(b'!') {
298 self.pos += 1;
299 let operand = self.parse_unary()?;
300 return Ok(ExprNode::UnaryOp {
301 op: UnOp::Not,
302 operand: Box::new(operand),
303 });
304 }
305 if self.keyword("not") {
306 self.pos += 3;
307 let operand = self.parse_unary()?;
308 return Ok(ExprNode::UnaryOp {
309 op: UnOp::Not,
310 operand: Box::new(operand),
311 });
312 }
313 self.parse_call()
314 }
315
316 fn parse_call(&mut self) -> Result<ExprNode, ExprError> {
317 self.skip_ws();
318 let start = self.pos;
319
320 if self.pos < self.source.len() {
324 let b = self.source.as_bytes()[self.pos];
325 if (b.is_ascii_alphabetic() || b == b'_')
326 && !self.keyword("true")
327 && !self.keyword("false")
328 && !self.keyword("null")
329 && !self.keyword("if")
330 && !self.keyword("not")
331 {
332 let name = self.read_ident()?;
333 self.skip_ws();
334 if self.peek() == Some(b'(') {
335 self.pos += 1;
336 let mut args = Vec::new();
337 self.skip_ws();
338 if self.peek() != Some(b')') {
339 args.push(self.parse_expr()?);
340 loop {
341 self.skip_ws();
342 if self.peek() == Some(b',') {
343 self.pos += 1;
344 args.push(self.parse_expr()?);
345 } else {
346 break;
347 }
348 }
349 }
350 self.consume(b')')?;
351 return Ok(ExprNode::FnCall { name, args });
352 }
353 self.pos = start;
355 }
356 }
357
358 self.parse_primary()
359 }
360
361 fn parse_primary(&mut self) -> Result<ExprNode, ExprError> {
362 self.skip_ws();
363
364 if self.at_end() {
365 return Err(self.err("unexpected end of expression"));
366 }
367
368 let b = self.source.as_bytes()[self.pos];
370 if b.is_ascii_digit()
371 || (b == b'.'
372 && self.pos + 1 < self.source.len()
373 && self.source.as_bytes()[self.pos + 1].is_ascii_digit())
374 {
375 return self.parse_number();
376 }
377
378 if b == b'"' {
380 return self.parse_string();
381 }
382
383 if self.keyword("true") {
385 self.pos += 4;
386 return Ok(ExprNode::Literal(Kind::Bool(true)));
387 }
388 if self.keyword("false") {
389 self.pos += 5;
390 return Ok(ExprNode::Literal(Kind::Bool(false)));
391 }
392
393 if self.keyword("null") {
395 self.pos += 4;
396 return Ok(ExprNode::Literal(Kind::Null));
397 }
398
399 if b == b'$' {
401 self.pos += 1;
402 let name = self.read_ident()?;
403 return Ok(ExprNode::Variable(name));
404 }
405
406 if b == b'(' {
408 self.pos += 1;
409 let node = self.parse_expr()?;
410 self.consume(b')')?;
411 return Ok(node);
412 }
413
414 if self.keyword("if") {
416 self.pos += 2;
417 let cond = self.parse_expr()?;
418 self.consume_keyword("then")?;
419 let then_expr = self.parse_expr()?;
420 self.consume_keyword("else")?;
421 let else_expr = self.parse_expr()?;
422 return Ok(ExprNode::Conditional {
423 cond: Box::new(cond),
424 then_expr: Box::new(then_expr),
425 else_expr: Box::new(else_expr),
426 });
427 }
428
429 Err(self.err(format!("unexpected character '{}'", b as char)))
430 }
431
432 fn parse_number(&mut self) -> Result<ExprNode, ExprError> {
433 let start = self.pos;
434 while self.pos < self.source.len() && self.source.as_bytes()[self.pos].is_ascii_digit() {
435 self.pos += 1;
436 }
437 if self.pos < self.source.len() && self.source.as_bytes()[self.pos] == b'.' {
438 self.pos += 1;
439 while self.pos < self.source.len() && self.source.as_bytes()[self.pos].is_ascii_digit()
440 {
441 self.pos += 1;
442 }
443 }
444 let s = &self.source[start..self.pos];
445 let val: f64 = s
446 .parse()
447 .map_err(|_| self.err(format!("invalid number '{s}'")))?;
448 Ok(ExprNode::Literal(Kind::Number(Number::unitless(val))))
449 }
450
451 fn parse_string(&mut self) -> Result<ExprNode, ExprError> {
452 self.pos += 1; let start = self.pos;
454 while self.pos < self.source.len() && self.source.as_bytes()[self.pos] != b'"' {
455 if self.source.as_bytes()[self.pos] == b'\\' {
456 self.pos += 1;
457 if self.pos >= self.source.len() {
458 return Err(self.err("unterminated string escape"));
459 }
460 match self.source.as_bytes()[self.pos] {
461 b'"' | b'\\' | b'n' | b't' | b'r' => {}
462 ch => {
463 return Err(self.err(format!("invalid escape sequence: \\{}", ch as char)));
464 }
465 }
466 }
467 self.pos += 1;
468 }
469 if self.pos >= self.source.len() {
470 return Err(self.err("unterminated string"));
471 }
472 let s = self.source[start..self.pos].to_string();
473 self.pos += 1; Ok(ExprNode::Literal(Kind::Str(s)))
475 }
476}
477
478pub fn parse_expr(source: &str) -> Result<ExprNode, ExprError> {
480 if source.len() > MAX_EXPR_SOURCE {
481 return Err(ExprError {
482 msg: format!("expression source exceeds maximum length of {MAX_EXPR_SOURCE} bytes"),
483 pos: 0,
484 });
485 }
486 let mut parser = Parser::new(source);
487 let node = parser.parse_expr()?;
488 parser.skip_ws();
489 if !parser.at_end() {
490 return Err(parser.err("unexpected trailing input"));
491 }
492 Ok(node)
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn parse_number_literal() {
501 let node = parse_expr("42").unwrap();
502 assert!(matches!(node, ExprNode::Literal(Kind::Number(_))));
503 }
504
505 #[test]
506 fn parse_float_literal() {
507 let node = parse_expr("3.14").unwrap();
508 if let ExprNode::Literal(Kind::Number(n)) = &node {
509 assert!((n.val - 3.14).abs() < 1e-10);
510 } else {
511 panic!("expected number literal");
512 }
513 }
514
515 #[test]
516 fn parse_string_literal() {
517 let node = parse_expr(r#""hello""#).unwrap();
518 assert!(matches!(node, ExprNode::Literal(Kind::Str(s)) if s == "hello"));
519 }
520
521 #[test]
522 fn parse_bool_true() {
523 let node = parse_expr("true").unwrap();
524 assert!(matches!(node, ExprNode::Literal(Kind::Bool(true))));
525 }
526
527 #[test]
528 fn parse_bool_false() {
529 let node = parse_expr("false").unwrap();
530 assert!(matches!(node, ExprNode::Literal(Kind::Bool(false))));
531 }
532
533 #[test]
534 fn parse_null() {
535 let node = parse_expr("null").unwrap();
536 assert!(matches!(node, ExprNode::Literal(Kind::Null)));
537 }
538
539 #[test]
540 fn parse_variable() {
541 let node = parse_expr("$temp").unwrap();
542 assert!(matches!(node, ExprNode::Variable(ref s) if s == "temp"));
543 }
544
545 #[test]
546 fn parse_addition() {
547 let node = parse_expr("1 + 2").unwrap();
548 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Add, .. }));
549 }
550
551 #[test]
552 fn parse_arithmetic_precedence() {
553 let node = parse_expr("1 + 2 * 3").unwrap();
555 if let ExprNode::BinaryOp {
556 op: BinOp::Add,
557 right,
558 ..
559 } = &node
560 {
561 assert!(matches!(
562 right.as_ref(),
563 ExprNode::BinaryOp { op: BinOp::Mul, .. }
564 ));
565 } else {
566 panic!("expected Add at top");
567 }
568 }
569
570 #[test]
571 fn parse_subtraction() {
572 let node = parse_expr("5 - 3").unwrap();
573 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Sub, .. }));
574 }
575
576 #[test]
577 fn parse_division() {
578 let node = parse_expr("10 / 2").unwrap();
579 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Div, .. }));
580 }
581
582 #[test]
583 fn parse_modulo() {
584 let node = parse_expr("10 % 3").unwrap();
585 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Mod, .. }));
586 }
587
588 #[test]
589 fn parse_unary_neg() {
590 let node = parse_expr("-5").unwrap();
591 assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Neg, .. }));
592 }
593
594 #[test]
595 fn parse_unary_not_bang() {
596 let node = parse_expr("!true").unwrap();
597 assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Not, .. }));
598 }
599
600 #[test]
601 fn parse_unary_not_keyword() {
602 let node = parse_expr("not false").unwrap();
603 assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Not, .. }));
604 }
605
606 #[test]
607 fn parse_comparison_eq() {
608 let node = parse_expr("$x == 5").unwrap();
609 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Eq, .. }));
610 }
611
612 #[test]
613 fn parse_comparison_ne() {
614 let node = parse_expr("$x != 5").unwrap();
615 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Ne, .. }));
616 }
617
618 #[test]
619 fn parse_comparison_lt() {
620 let node = parse_expr("$x < 10").unwrap();
621 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Lt, .. }));
622 }
623
624 #[test]
625 fn parse_comparison_le() {
626 let node = parse_expr("$x <= 10").unwrap();
627 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Le, .. }));
628 }
629
630 #[test]
631 fn parse_comparison_gt() {
632 let node = parse_expr("$x > 0").unwrap();
633 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Gt, .. }));
634 }
635
636 #[test]
637 fn parse_comparison_ge() {
638 let node = parse_expr("$x >= 0").unwrap();
639 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Ge, .. }));
640 }
641
642 #[test]
643 fn parse_logical_and() {
644 let node = parse_expr("true and false").unwrap();
645 assert!(matches!(
646 node,
647 ExprNode::Logical {
648 op: LogicOp::And,
649 ..
650 }
651 ));
652 }
653
654 #[test]
655 fn parse_logical_or() {
656 let node = parse_expr("true or false").unwrap();
657 assert!(matches!(
658 node,
659 ExprNode::Logical {
660 op: LogicOp::Or,
661 ..
662 }
663 ));
664 }
665
666 #[test]
667 fn parse_fn_call_one_arg() {
668 let node = parse_expr("abs(-5)").unwrap();
669 if let ExprNode::FnCall { name, args } = &node {
670 assert_eq!(name, "abs");
671 assert_eq!(args.len(), 1);
672 } else {
673 panic!("expected FnCall");
674 }
675 }
676
677 #[test]
678 fn parse_fn_call_two_args() {
679 let node = parse_expr("min(1, 2)").unwrap();
680 if let ExprNode::FnCall { name, args } = &node {
681 assert_eq!(name, "min");
682 assert_eq!(args.len(), 2);
683 } else {
684 panic!("expected FnCall");
685 }
686 }
687
688 #[test]
689 fn parse_fn_call_no_args() {
690 let node = parse_expr("foo()").unwrap();
691 if let ExprNode::FnCall { name, args } = &node {
692 assert_eq!(name, "foo");
693 assert!(args.is_empty());
694 } else {
695 panic!("expected FnCall");
696 }
697 }
698
699 #[test]
700 fn parse_conditional() {
701 let node = parse_expr("if true then 1 else 0").unwrap();
702 assert!(matches!(node, ExprNode::Conditional { .. }));
703 }
704
705 #[test]
706 fn parse_parenthesised() {
707 let node = parse_expr("(1 + 2) * 3").unwrap();
708 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Mul, .. }));
709 }
710
711 #[test]
712 fn parse_complex_expression() {
713 let node = parse_expr("$a + $b * 2 > 10 and $c != 0").unwrap();
714 assert!(matches!(
715 node,
716 ExprNode::Logical {
717 op: LogicOp::And,
718 ..
719 }
720 ));
721 }
722
723 #[test]
724 fn error_empty_input() {
725 let err = parse_expr("").unwrap_err();
726 assert!(err.msg.contains("unexpected end"));
727 }
728
729 #[test]
730 fn error_trailing_input() {
731 let err = parse_expr("1 2").unwrap_err();
732 assert!(err.msg.contains("trailing"));
733 }
734
735 #[test]
736 fn error_unterminated_string() {
737 let err = parse_expr(r#""hello"#).unwrap_err();
738 assert!(err.msg.contains("unterminated"));
739 }
740
741 #[test]
742 fn error_source_too_long() {
743 let long = "1+".repeat(MAX_EXPR_SOURCE);
744 let err = parse_expr(&long).unwrap_err();
745 assert!(err.msg.contains("maximum length"));
746 }
747
748 #[test]
749 fn error_depth_exceeded() {
750 let open: String = "(".repeat(MAX_EXPR_DEPTH + 10);
752 let close: String = ")".repeat(MAX_EXPR_DEPTH + 10);
753 let src = format!("{open}1{close}");
754 let err = parse_expr(&src).unwrap_err();
755 assert!(err.msg.contains("depth"));
756 }
757
758 #[test]
759 fn error_display() {
760 let err = ExprError {
761 msg: "bad".into(),
762 pos: 5,
763 };
764 assert_eq!(err.to_string(), "expr error at 5: bad");
765 }
766
767 #[test]
768 fn parse_nested_fn_calls() {
769 let node = parse_expr("max(abs(-1), min(2, 3))").unwrap();
770 if let ExprNode::FnCall { name, args } = &node {
771 assert_eq!(name, "max");
772 assert_eq!(args.len(), 2);
773 } else {
774 panic!("expected FnCall");
775 }
776 }
777
778 #[test]
779 fn parse_logical_precedence() {
780 let node = parse_expr("true or false and true").unwrap();
782 if let ExprNode::Logical {
783 op: LogicOp::Or,
784 right,
785 ..
786 } = &node
787 {
788 assert!(matches!(
789 right.as_ref(),
790 ExprNode::Logical {
791 op: LogicOp::And,
792 ..
793 }
794 ));
795 } else {
796 panic!("expected Or at top");
797 }
798 }
799}