1use super::ast::*;
4use crate::kinds::{Kind, Number};
5
6#[derive(Debug)]
8pub struct ExprError {
9 pub msg: String,
11 pub pos: usize,
13}
14
15impl std::fmt::Display for ExprError {
16 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
17 write!(f, "expr error at {}: {}", self.pos, self.msg)
18 }
19}
20
21impl std::error::Error for ExprError {}
22
23struct Parser<'a> {
25 source: &'a str,
26 pos: usize,
27 depth: usize,
28}
29
30impl<'a> Parser<'a> {
31 fn new(source: &'a str) -> Self {
32 Self {
33 source,
34 pos: 0,
35 depth: 0,
36 }
37 }
38
39 fn err(&self, msg: impl Into<String>) -> ExprError {
40 ExprError {
41 msg: msg.into(),
42 pos: self.pos,
43 }
44 }
45
46 fn enter(&mut self) -> Result<(), ExprError> {
47 self.depth += 1;
48 if self.depth > MAX_EXPR_DEPTH {
49 Err(self.err("expression exceeds maximum nesting depth"))
50 } else {
51 Ok(())
52 }
53 }
54
55 fn leave(&mut self) {
56 self.depth -= 1;
57 }
58
59 fn skip_ws(&mut self) {
60 while self.pos < self.source.len() {
61 let b = self.source.as_bytes()[self.pos];
62 if b == b' ' || b == b'\t' || b == b'\n' || b == b'\r' {
63 self.pos += 1;
64 } else {
65 break;
66 }
67 }
68 }
69
70 fn at_end(&self) -> bool {
71 self.pos >= self.source.len()
72 }
73
74 fn peek(&self) -> Option<u8> {
75 if self.pos < self.source.len() {
76 Some(self.source.as_bytes()[self.pos])
77 } else {
78 None
79 }
80 }
81
82 fn consume(&mut self, ch: u8) -> Result<(), ExprError> {
83 self.skip_ws();
84 if self.peek() == Some(ch) {
85 self.pos += 1;
86 Ok(())
87 } else {
88 Err(self.err(format!("expected '{}'", ch as char)))
89 }
90 }
91
92 fn starts_with(&self, s: &str) -> bool {
93 self.source[self.pos..].starts_with(s)
94 }
95
96 fn keyword(&self, kw: &str) -> bool {
99 if !self.starts_with(kw) {
100 return false;
101 }
102 let after = self.pos + kw.len();
103 if after >= self.source.len() {
104 return true;
105 }
106 let b = self.source.as_bytes()[after];
107 !b.is_ascii_alphanumeric() && b != b'_'
108 }
109
110 fn consume_keyword(&mut self, kw: &str) -> Result<(), ExprError> {
111 self.skip_ws();
112 if self.keyword(kw) {
113 self.pos += kw.len();
114 Ok(())
115 } else {
116 Err(self.err(format!("expected '{kw}'")))
117 }
118 }
119
120 fn read_ident(&mut self) -> Result<String, ExprError> {
121 self.skip_ws();
122 let start = self.pos;
123 while self.pos < self.source.len() {
124 let b = self.source.as_bytes()[self.pos];
125 if b.is_ascii_alphanumeric() || b == b'_' {
126 self.pos += 1;
127 } else {
128 break;
129 }
130 }
131 if self.pos == start {
132 return Err(self.err("expected identifier"));
133 }
134 Ok(self.source[start..self.pos].to_string())
135 }
136
137 fn parse_expr(&mut self) -> Result<ExprNode, ExprError> {
140 self.enter()?;
141 let node = self.parse_logic_or()?;
142 self.leave();
143 Ok(node)
144 }
145
146 fn parse_logic_or(&mut self) -> Result<ExprNode, ExprError> {
147 let mut left = self.parse_logic_and()?;
148 loop {
149 self.skip_ws();
150 if self.keyword("or") {
151 self.pos += 2;
152 let right = self.parse_logic_and()?;
153 left = ExprNode::Logical {
154 left: Box::new(left),
155 op: LogicOp::Or,
156 right: Box::new(right),
157 };
158 } else {
159 break;
160 }
161 }
162 Ok(left)
163 }
164
165 fn parse_logic_and(&mut self) -> Result<ExprNode, ExprError> {
166 let mut left = self.parse_comparison()?;
167 loop {
168 self.skip_ws();
169 if self.keyword("and") {
170 self.pos += 3;
171 let right = self.parse_comparison()?;
172 left = ExprNode::Logical {
173 left: Box::new(left),
174 op: LogicOp::And,
175 right: Box::new(right),
176 };
177 } else {
178 break;
179 }
180 }
181 Ok(left)
182 }
183
184 fn parse_comparison(&mut self) -> Result<ExprNode, ExprError> {
185 let left = self.parse_additive()?;
186 self.skip_ws();
187 let op = if self.starts_with("!=") {
188 self.pos += 2;
189 Some(CmpOp::Ne)
190 } else if self.starts_with("==") {
191 self.pos += 2;
192 Some(CmpOp::Eq)
193 } else if self.starts_with("<=") {
194 self.pos += 2;
195 Some(CmpOp::Le)
196 } else if self.starts_with(">=") {
197 self.pos += 2;
198 Some(CmpOp::Ge)
199 } else if self.peek() == Some(b'<') {
200 self.pos += 1;
201 Some(CmpOp::Lt)
202 } else if self.peek() == Some(b'>') {
203 self.pos += 1;
204 Some(CmpOp::Gt)
205 } else {
206 None
207 };
208 if let Some(op) = op {
209 let right = self.parse_additive()?;
210 Ok(ExprNode::Comparison {
211 left: Box::new(left),
212 op,
213 right: Box::new(right),
214 })
215 } else {
216 Ok(left)
217 }
218 }
219
220 fn parse_additive(&mut self) -> Result<ExprNode, ExprError> {
221 let mut left = self.parse_multiplicative()?;
222 loop {
223 self.skip_ws();
224 match self.peek() {
225 Some(b'+') => {
226 self.pos += 1;
227 let right = self.parse_multiplicative()?;
228 left = ExprNode::BinaryOp {
229 left: Box::new(left),
230 op: BinOp::Add,
231 right: Box::new(right),
232 };
233 }
234 Some(b'-') => {
235 self.pos += 1;
236 let right = self.parse_multiplicative()?;
237 left = ExprNode::BinaryOp {
238 left: Box::new(left),
239 op: BinOp::Sub,
240 right: Box::new(right),
241 };
242 }
243 _ => break,
244 }
245 }
246 Ok(left)
247 }
248
249 fn parse_multiplicative(&mut self) -> Result<ExprNode, ExprError> {
250 let mut left = self.parse_unary()?;
251 loop {
252 self.skip_ws();
253 match self.peek() {
254 Some(b'*') => {
255 self.pos += 1;
256 let right = self.parse_unary()?;
257 left = ExprNode::BinaryOp {
258 left: Box::new(left),
259 op: BinOp::Mul,
260 right: Box::new(right),
261 };
262 }
263 Some(b'/') => {
264 self.pos += 1;
265 let right = self.parse_unary()?;
266 left = ExprNode::BinaryOp {
267 left: Box::new(left),
268 op: BinOp::Div,
269 right: Box::new(right),
270 };
271 }
272 Some(b'%') => {
273 self.pos += 1;
274 let right = self.parse_unary()?;
275 left = ExprNode::BinaryOp {
276 left: Box::new(left),
277 op: BinOp::Mod,
278 right: Box::new(right),
279 };
280 }
281 _ => break,
282 }
283 }
284 Ok(left)
285 }
286
287 fn parse_unary(&mut self) -> Result<ExprNode, ExprError> {
288 self.skip_ws();
289 if self.peek() == Some(b'-') {
290 self.pos += 1;
291 let operand = self.parse_unary()?;
292 return Ok(ExprNode::UnaryOp {
293 op: UnOp::Neg,
294 operand: Box::new(operand),
295 });
296 }
297 if self.peek() == Some(b'!') {
298 self.pos += 1;
299 let operand = self.parse_unary()?;
300 return Ok(ExprNode::UnaryOp {
301 op: UnOp::Not,
302 operand: Box::new(operand),
303 });
304 }
305 if self.keyword("not") {
306 self.pos += 3;
307 let operand = self.parse_unary()?;
308 return Ok(ExprNode::UnaryOp {
309 op: UnOp::Not,
310 operand: Box::new(operand),
311 });
312 }
313 self.parse_call()
314 }
315
316 fn parse_call(&mut self) -> Result<ExprNode, ExprError> {
317 self.skip_ws();
318 let start = self.pos;
319
320 if self.pos < self.source.len() {
324 let b = self.source.as_bytes()[self.pos];
325 if (b.is_ascii_alphabetic() || b == b'_')
326 && !self.keyword("true")
327 && !self.keyword("false")
328 && !self.keyword("null")
329 && !self.keyword("if")
330 && !self.keyword("not")
331 {
332 let name = self.read_ident()?;
333 self.skip_ws();
334 if self.peek() == Some(b'(') {
335 self.pos += 1;
336 let mut args = Vec::new();
337 self.skip_ws();
338 if self.peek() != Some(b')') {
339 args.push(self.parse_expr()?);
340 loop {
341 self.skip_ws();
342 if self.peek() == Some(b',') {
343 self.pos += 1;
344 args.push(self.parse_expr()?);
345 } else {
346 break;
347 }
348 }
349 }
350 self.consume(b')')?;
351 return Ok(ExprNode::FnCall { name, args });
352 }
353 self.pos = start;
355 }
356 }
357
358 self.parse_primary()
359 }
360
361 fn parse_primary(&mut self) -> Result<ExprNode, ExprError> {
362 self.skip_ws();
363
364 if self.at_end() {
365 return Err(self.err("unexpected end of expression"));
366 }
367
368 let b = self.source.as_bytes()[self.pos];
370 if b.is_ascii_digit()
371 || (b == b'.'
372 && self.pos + 1 < self.source.len()
373 && self.source.as_bytes()[self.pos + 1].is_ascii_digit())
374 {
375 return self.parse_number();
376 }
377
378 if b == b'"' {
380 return self.parse_string();
381 }
382
383 if self.keyword("true") {
385 self.pos += 4;
386 return Ok(ExprNode::Literal(Kind::Bool(true)));
387 }
388 if self.keyword("false") {
389 self.pos += 5;
390 return Ok(ExprNode::Literal(Kind::Bool(false)));
391 }
392
393 if self.keyword("null") {
395 self.pos += 4;
396 return Ok(ExprNode::Literal(Kind::Null));
397 }
398
399 if b == b'$' {
401 self.pos += 1;
402 let name = self.read_ident()?;
403 return Ok(ExprNode::Variable(name));
404 }
405
406 if b == b'(' {
408 self.pos += 1;
409 let node = self.parse_expr()?;
410 self.consume(b')')?;
411 return Ok(node);
412 }
413
414 if self.keyword("if") {
416 self.pos += 2;
417 let cond = self.parse_expr()?;
418 self.consume_keyword("then")?;
419 let then_expr = self.parse_expr()?;
420 self.consume_keyword("else")?;
421 let else_expr = self.parse_expr()?;
422 return Ok(ExprNode::Conditional {
423 cond: Box::new(cond),
424 then_expr: Box::new(then_expr),
425 else_expr: Box::new(else_expr),
426 });
427 }
428
429 Err(self.err(format!("unexpected character '{}'", b as char)))
430 }
431
432 fn parse_number(&mut self) -> Result<ExprNode, ExprError> {
433 let start = self.pos;
434 while self.pos < self.source.len() && self.source.as_bytes()[self.pos].is_ascii_digit() {
435 self.pos += 1;
436 }
437 if self.pos < self.source.len() && self.source.as_bytes()[self.pos] == b'.' {
438 self.pos += 1;
439 while self.pos < self.source.len() && self.source.as_bytes()[self.pos].is_ascii_digit()
440 {
441 self.pos += 1;
442 }
443 }
444 let s = &self.source[start..self.pos];
445 let val: f64 = s
446 .parse()
447 .map_err(|_| self.err(format!("invalid number '{s}'")))?;
448 Ok(ExprNode::Literal(Kind::Number(Number::unitless(val))))
449 }
450
451 fn parse_string(&mut self) -> Result<ExprNode, ExprError> {
452 self.pos += 1; let start = self.pos;
454 while self.pos < self.source.len() && self.source.as_bytes()[self.pos] != b'"' {
455 if self.source.as_bytes()[self.pos] == b'\\' {
456 self.pos += 1;
457 if self.pos >= self.source.len() {
458 return Err(self.err("unterminated string escape"));
459 }
460 match self.source.as_bytes()[self.pos] {
461 b'"' | b'\\' | b'n' | b't' | b'r' => {}
462 ch => {
463 return Err(self.err(format!("invalid escape sequence: \\{}", ch as char)));
464 }
465 }
466 }
467 self.pos += 1;
468 }
469 if self.pos >= self.source.len() {
470 return Err(self.err("unterminated string"));
471 }
472 let s = self.source[start..self.pos].to_string();
473 self.pos += 1; Ok(ExprNode::Literal(Kind::Str(s)))
475 }
476}
477
478pub fn parse_expr(source: &str) -> Result<ExprNode, ExprError> {
480 if source.len() > MAX_EXPR_SOURCE {
481 return Err(ExprError {
482 msg: format!("expression source exceeds maximum length of {MAX_EXPR_SOURCE} bytes"),
483 pos: 0,
484 });
485 }
486 let mut parser = Parser::new(source);
487 let node = parser.parse_expr()?;
488 parser.skip_ws();
489 if !parser.at_end() {
490 return Err(parser.err("unexpected trailing input"));
491 }
492 Ok(node)
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn parse_number_literal() {
501 let node = parse_expr("42").unwrap();
502 assert!(matches!(node, ExprNode::Literal(Kind::Number(_))));
503 }
504
505 #[test]
506 fn parse_float_literal() {
507 let node = parse_expr("3.14").unwrap();
508 if let ExprNode::Literal(Kind::Number(n)) = &node {
509 assert!(n.val > 3.13 && n.val < 3.15);
511 } else {
512 panic!("expected number literal");
513 }
514 }
515
516 #[test]
517 fn parse_string_literal() {
518 let node = parse_expr(r#""hello""#).unwrap();
519 assert!(matches!(node, ExprNode::Literal(Kind::Str(s)) if s == "hello"));
520 }
521
522 #[test]
523 fn parse_bool_true() {
524 let node = parse_expr("true").unwrap();
525 assert!(matches!(node, ExprNode::Literal(Kind::Bool(true))));
526 }
527
528 #[test]
529 fn parse_bool_false() {
530 let node = parse_expr("false").unwrap();
531 assert!(matches!(node, ExprNode::Literal(Kind::Bool(false))));
532 }
533
534 #[test]
535 fn parse_null() {
536 let node = parse_expr("null").unwrap();
537 assert!(matches!(node, ExprNode::Literal(Kind::Null)));
538 }
539
540 #[test]
541 fn parse_variable() {
542 let node = parse_expr("$temp").unwrap();
543 assert!(matches!(node, ExprNode::Variable(ref s) if s == "temp"));
544 }
545
546 #[test]
547 fn parse_addition() {
548 let node = parse_expr("1 + 2").unwrap();
549 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Add, .. }));
550 }
551
552 #[test]
553 fn parse_arithmetic_precedence() {
554 let node = parse_expr("1 + 2 * 3").unwrap();
556 if let ExprNode::BinaryOp {
557 op: BinOp::Add,
558 right,
559 ..
560 } = &node
561 {
562 assert!(matches!(
563 right.as_ref(),
564 ExprNode::BinaryOp { op: BinOp::Mul, .. }
565 ));
566 } else {
567 panic!("expected Add at top");
568 }
569 }
570
571 #[test]
572 fn parse_subtraction() {
573 let node = parse_expr("5 - 3").unwrap();
574 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Sub, .. }));
575 }
576
577 #[test]
578 fn parse_division() {
579 let node = parse_expr("10 / 2").unwrap();
580 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Div, .. }));
581 }
582
583 #[test]
584 fn parse_modulo() {
585 let node = parse_expr("10 % 3").unwrap();
586 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Mod, .. }));
587 }
588
589 #[test]
590 fn parse_unary_neg() {
591 let node = parse_expr("-5").unwrap();
592 assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Neg, .. }));
593 }
594
595 #[test]
596 fn parse_unary_not_bang() {
597 let node = parse_expr("!true").unwrap();
598 assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Not, .. }));
599 }
600
601 #[test]
602 fn parse_unary_not_keyword() {
603 let node = parse_expr("not false").unwrap();
604 assert!(matches!(node, ExprNode::UnaryOp { op: UnOp::Not, .. }));
605 }
606
607 #[test]
608 fn parse_comparison_eq() {
609 let node = parse_expr("$x == 5").unwrap();
610 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Eq, .. }));
611 }
612
613 #[test]
614 fn parse_comparison_ne() {
615 let node = parse_expr("$x != 5").unwrap();
616 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Ne, .. }));
617 }
618
619 #[test]
620 fn parse_comparison_lt() {
621 let node = parse_expr("$x < 10").unwrap();
622 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Lt, .. }));
623 }
624
625 #[test]
626 fn parse_comparison_le() {
627 let node = parse_expr("$x <= 10").unwrap();
628 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Le, .. }));
629 }
630
631 #[test]
632 fn parse_comparison_gt() {
633 let node = parse_expr("$x > 0").unwrap();
634 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Gt, .. }));
635 }
636
637 #[test]
638 fn parse_comparison_ge() {
639 let node = parse_expr("$x >= 0").unwrap();
640 assert!(matches!(node, ExprNode::Comparison { op: CmpOp::Ge, .. }));
641 }
642
643 #[test]
644 fn parse_logical_and() {
645 let node = parse_expr("true and false").unwrap();
646 assert!(matches!(
647 node,
648 ExprNode::Logical {
649 op: LogicOp::And,
650 ..
651 }
652 ));
653 }
654
655 #[test]
656 fn parse_logical_or() {
657 let node = parse_expr("true or false").unwrap();
658 assert!(matches!(
659 node,
660 ExprNode::Logical {
661 op: LogicOp::Or,
662 ..
663 }
664 ));
665 }
666
667 #[test]
668 fn parse_fn_call_one_arg() {
669 let node = parse_expr("abs(-5)").unwrap();
670 if let ExprNode::FnCall { name, args } = &node {
671 assert_eq!(name, "abs");
672 assert_eq!(args.len(), 1);
673 } else {
674 panic!("expected FnCall");
675 }
676 }
677
678 #[test]
679 fn parse_fn_call_two_args() {
680 let node = parse_expr("min(1, 2)").unwrap();
681 if let ExprNode::FnCall { name, args } = &node {
682 assert_eq!(name, "min");
683 assert_eq!(args.len(), 2);
684 } else {
685 panic!("expected FnCall");
686 }
687 }
688
689 #[test]
690 fn parse_fn_call_no_args() {
691 let node = parse_expr("foo()").unwrap();
692 if let ExprNode::FnCall { name, args } = &node {
693 assert_eq!(name, "foo");
694 assert!(args.is_empty());
695 } else {
696 panic!("expected FnCall");
697 }
698 }
699
700 #[test]
701 fn parse_conditional() {
702 let node = parse_expr("if true then 1 else 0").unwrap();
703 assert!(matches!(node, ExprNode::Conditional { .. }));
704 }
705
706 #[test]
707 fn parse_parenthesised() {
708 let node = parse_expr("(1 + 2) * 3").unwrap();
709 assert!(matches!(node, ExprNode::BinaryOp { op: BinOp::Mul, .. }));
710 }
711
712 #[test]
713 fn parse_complex_expression() {
714 let node = parse_expr("$a + $b * 2 > 10 and $c != 0").unwrap();
715 assert!(matches!(
716 node,
717 ExprNode::Logical {
718 op: LogicOp::And,
719 ..
720 }
721 ));
722 }
723
724 #[test]
725 fn error_empty_input() {
726 let err = parse_expr("").unwrap_err();
727 assert!(err.msg.contains("unexpected end"));
728 }
729
730 #[test]
731 fn error_trailing_input() {
732 let err = parse_expr("1 2").unwrap_err();
733 assert!(err.msg.contains("trailing"));
734 }
735
736 #[test]
737 fn error_unterminated_string() {
738 let err = parse_expr(r#""hello"#).unwrap_err();
739 assert!(err.msg.contains("unterminated"));
740 }
741
742 #[test]
743 fn error_source_too_long() {
744 let long = "1+".repeat(MAX_EXPR_SOURCE);
745 let err = parse_expr(&long).unwrap_err();
746 assert!(err.msg.contains("maximum length"));
747 }
748
749 #[test]
750 fn error_depth_exceeded() {
751 let open: String = "(".repeat(MAX_EXPR_DEPTH + 10);
753 let close: String = ")".repeat(MAX_EXPR_DEPTH + 10);
754 let src = format!("{open}1{close}");
755 let err = parse_expr(&src).unwrap_err();
756 assert!(err.msg.contains("depth"));
757 }
758
759 #[test]
760 fn error_display() {
761 let err = ExprError {
762 msg: "bad".into(),
763 pos: 5,
764 };
765 assert_eq!(err.to_string(), "expr error at 5: bad");
766 }
767
768 #[test]
769 fn parse_nested_fn_calls() {
770 let node = parse_expr("max(abs(-1), min(2, 3))").unwrap();
771 if let ExprNode::FnCall { name, args } = &node {
772 assert_eq!(name, "max");
773 assert_eq!(args.len(), 2);
774 } else {
775 panic!("expected FnCall");
776 }
777 }
778
779 #[test]
780 fn parse_logical_precedence() {
781 let node = parse_expr("true or false and true").unwrap();
783 if let ExprNode::Logical {
784 op: LogicOp::Or,
785 right,
786 ..
787 } = &node
788 {
789 assert!(matches!(
790 right.as_ref(),
791 ExprNode::Logical {
792 op: LogicOp::And,
793 ..
794 }
795 ));
796 } else {
797 panic!("expected Or at top");
798 }
799 }
800}