1use std::collections::HashMap;
4
5use super::ast::*;
6use super::parser::{ExprError, parse_expr};
7use crate::kinds::{Kind, Number};
8
9pub struct ExprContext {
11 vars: HashMap<String, Kind>,
12}
13
14impl ExprContext {
15 pub fn new() -> Self {
17 Self {
18 vars: HashMap::new(),
19 }
20 }
21
22 pub fn set(&mut self, name: impl Into<String>, val: Kind) {
24 self.vars.insert(name.into(), val);
25 }
26
27 pub fn get(&self, name: &str) -> Option<&Kind> {
29 self.vars.get(name)
30 }
31}
32
33impl Default for ExprContext {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39pub struct Expr {
41 root: ExprNode,
42 variables: Vec<String>,
43}
44
45impl Expr {
46 pub fn parse(source: &str) -> Result<Self, ExprError> {
48 let root = parse_expr(source)?;
49 let mut variables = Vec::new();
50 collect_variables(&root, &mut variables);
51 variables.sort();
52 variables.dedup();
53 Ok(Self { root, variables })
54 }
55
56 pub fn eval(&self, ctx: &ExprContext) -> Kind {
58 eval_node(&self.root, ctx)
59 }
60
61 pub fn eval_number(&self, ctx: &ExprContext) -> f64 {
63 match self.eval(ctx) {
64 Kind::Number(n) => n.val,
65 _ => f64::NAN,
66 }
67 }
68
69 pub fn eval_bool(&self, ctx: &ExprContext) -> bool {
71 match self.eval(ctx) {
72 Kind::Bool(b) => b,
73 _ => false,
74 }
75 }
76
77 pub fn variables(&self) -> &[String] {
79 &self.variables
80 }
81}
82
83fn collect_variables(node: &ExprNode, out: &mut Vec<String>) {
85 match node {
86 ExprNode::Variable(name) => out.push(name.clone()),
87 ExprNode::Literal(_) => {}
88 ExprNode::BinaryOp { left, right, .. } => {
89 collect_variables(left, out);
90 collect_variables(right, out);
91 }
92 ExprNode::UnaryOp { operand, .. } => collect_variables(operand, out),
93 ExprNode::Comparison { left, right, .. } => {
94 collect_variables(left, out);
95 collect_variables(right, out);
96 }
97 ExprNode::Logical { left, right, .. } => {
98 collect_variables(left, out);
99 collect_variables(right, out);
100 }
101 ExprNode::FnCall { args, .. } => {
102 for arg in args {
103 collect_variables(arg, out);
104 }
105 }
106 ExprNode::Conditional {
107 cond,
108 then_expr,
109 else_expr,
110 } => {
111 collect_variables(cond, out);
112 collect_variables(then_expr, out);
113 collect_variables(else_expr, out);
114 }
115 }
116}
117
118fn as_f64(k: &Kind) -> Option<f64> {
120 match k {
121 Kind::Number(n) => Some(n.val),
122 _ => None,
123 }
124}
125
126fn as_bool(k: &Kind) -> Option<bool> {
128 match k {
129 Kind::Bool(b) => Some(*b),
130 _ => None,
131 }
132}
133
134fn num(val: f64) -> Kind {
136 Kind::Number(Number::unitless(val))
137}
138
139fn eval_node(node: &ExprNode, ctx: &ExprContext) -> Kind {
141 match node {
142 ExprNode::Literal(k) => k.clone(),
143
144 ExprNode::Variable(name) => ctx.get(name).cloned().unwrap_or(Kind::NA),
145
146 ExprNode::BinaryOp { left, op, right } => {
147 let lv = eval_node(left, ctx);
148 let rv = eval_node(right, ctx);
149 let (Some(l), Some(r)) = (as_f64(&lv), as_f64(&rv)) else {
150 return Kind::NA;
151 };
152 let result = match op {
153 BinOp::Add => l + r,
154 BinOp::Sub => l - r,
155 BinOp::Mul => l * r,
156 BinOp::Div => {
157 if r == 0.0 {
158 return Kind::NA;
159 }
160 l / r
161 }
162 BinOp::Mod => {
163 if r == 0.0 {
164 return Kind::NA;
165 }
166 l % r
167 }
168 };
169 num(result)
170 }
171
172 ExprNode::UnaryOp { op, operand } => {
173 let val = eval_node(operand, ctx);
174 match op {
175 UnOp::Neg => {
176 if let Some(v) = as_f64(&val) {
177 num(-v)
178 } else {
179 Kind::NA
180 }
181 }
182 UnOp::Not => {
183 if let Some(b) = as_bool(&val) {
184 Kind::Bool(!b)
185 } else {
186 Kind::NA
187 }
188 }
189 }
190 }
191
192 ExprNode::Comparison { left, op, right } => {
193 let lv = eval_node(left, ctx);
194 let rv = eval_node(right, ctx);
195 let (Some(l), Some(r)) = (as_f64(&lv), as_f64(&rv)) else {
196 return Kind::NA;
197 };
198 let result = match op {
199 CmpOp::Eq => l == r,
200 CmpOp::Ne => l != r,
201 CmpOp::Lt => l < r,
202 CmpOp::Le => l <= r,
203 CmpOp::Gt => l > r,
204 CmpOp::Ge => l >= r,
205 };
206 Kind::Bool(result)
207 }
208
209 ExprNode::Logical { left, op, right } => {
210 let lv = eval_node(left, ctx);
211 let rv = eval_node(right, ctx);
212 let (Some(l), Some(r)) = (as_bool(&lv), as_bool(&rv)) else {
213 return Kind::NA;
214 };
215 let result = match op {
216 LogicOp::And => l && r,
217 LogicOp::Or => l || r,
218 };
219 Kind::Bool(result)
220 }
221
222 ExprNode::FnCall { name, args } => {
223 let evaluated: Vec<Kind> = args.iter().map(|a| eval_node(a, ctx)).collect();
224 eval_builtin(name, &evaluated)
225 }
226
227 ExprNode::Conditional {
228 cond,
229 then_expr,
230 else_expr,
231 } => {
232 let cv = eval_node(cond, ctx);
233 match as_bool(&cv) {
234 Some(true) => eval_node(then_expr, ctx),
235 Some(false) => eval_node(else_expr, ctx),
236 None => Kind::NA,
237 }
238 }
239 }
240}
241
242fn finite_num(val: f64) -> Kind {
243 if val.is_finite() {
244 Kind::Number(Number::unitless(val))
245 } else {
246 Kind::NA
247 }
248}
249
250fn eval_builtin(name: &str, args: &[Kind]) -> Kind {
252 match name {
253 "abs" => {
254 if args.len() != 1 {
255 return Kind::NA;
256 }
257 as_f64(&args[0]).map_or(Kind::NA, |v| num(v.abs()))
258 }
259 "min" => {
260 if args.len() != 2 {
261 return Kind::NA;
262 }
263 let (Some(a), Some(b)) = (as_f64(&args[0]), as_f64(&args[1])) else {
264 return Kind::NA;
265 };
266 num(a.min(b))
267 }
268 "max" => {
269 if args.len() != 2 {
270 return Kind::NA;
271 }
272 let (Some(a), Some(b)) = (as_f64(&args[0]), as_f64(&args[1])) else {
273 return Kind::NA;
274 };
275 num(a.max(b))
276 }
277 "sqrt" => {
278 if args.len() != 1 {
279 return Kind::NA;
280 }
281 as_f64(&args[0]).map_or(Kind::NA, |v| finite_num(v.sqrt()))
282 }
283 "clamp" => {
284 if args.len() != 3 {
285 return Kind::NA;
286 }
287 let (Some(x), Some(lo), Some(hi)) =
288 (as_f64(&args[0]), as_f64(&args[1]), as_f64(&args[2]))
289 else {
290 return Kind::NA;
291 };
292 num(x.clamp(lo, hi))
293 }
294 "avg" => {
295 if args.is_empty() {
296 return Kind::NA;
297 }
298 let mut sum = 0.0;
299 for a in args {
300 match as_f64(a) {
301 Some(v) => sum += v,
302 None => return Kind::NA,
303 }
304 }
305 num(sum / args.len() as f64)
306 }
307 "between" => {
308 if args.len() != 3 {
309 return Kind::NA;
310 }
311 let (Some(x), Some(lo), Some(hi)) =
312 (as_f64(&args[0]), as_f64(&args[1]), as_f64(&args[2]))
313 else {
314 return Kind::NA;
315 };
316 Kind::Bool(x >= lo && x <= hi)
317 }
318 _ => Kind::NA,
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn ctx_with(pairs: &[(&str, Kind)]) -> ExprContext {
327 let mut ctx = ExprContext::new();
328 for (k, v) in pairs {
329 ctx.set(*k, v.clone());
330 }
331 ctx
332 }
333
334 fn n(v: f64) -> Kind {
335 Kind::Number(Number::unitless(v))
336 }
337
338 #[test]
341 fn eval_addition() {
342 let expr = Expr::parse("1 + 2").unwrap();
343 let result = expr.eval(&ExprContext::new());
344 assert!(matches!(result, Kind::Number(ref n) if (n.val - 3.0).abs() < 1e-10));
345 }
346
347 #[test]
348 fn eval_subtraction() {
349 let expr = Expr::parse("10 - 3").unwrap();
350 assert!(
351 matches!(expr.eval(&ExprContext::new()), Kind::Number(ref n) if (n.val - 7.0).abs() < 1e-10)
352 );
353 }
354
355 #[test]
356 fn eval_multiplication() {
357 let expr = Expr::parse("4 * 5").unwrap();
358 assert_eq!(expr.eval_number(&ExprContext::new()), 20.0);
359 }
360
361 #[test]
362 fn eval_division() {
363 let expr = Expr::parse("10 / 4").unwrap();
364 assert_eq!(expr.eval_number(&ExprContext::new()), 2.5);
365 }
366
367 #[test]
368 fn eval_modulo() {
369 let expr = Expr::parse("10 % 3").unwrap();
370 assert_eq!(expr.eval_number(&ExprContext::new()), 1.0);
371 }
372
373 #[test]
374 fn eval_precedence() {
375 let expr = Expr::parse("2 + 3 * 4").unwrap();
376 assert_eq!(expr.eval_number(&ExprContext::new()), 14.0);
377 }
378
379 #[test]
380 fn eval_parentheses() {
381 let expr = Expr::parse("(2 + 3) * 4").unwrap();
382 assert_eq!(expr.eval_number(&ExprContext::new()), 20.0);
383 }
384
385 #[test]
386 fn eval_negation() {
387 let expr = Expr::parse("-5").unwrap();
388 assert_eq!(expr.eval_number(&ExprContext::new()), -5.0);
389 }
390
391 #[test]
394 fn eval_eq_true() {
395 let expr = Expr::parse("5 == 5").unwrap();
396 assert!(expr.eval_bool(&ExprContext::new()));
397 }
398
399 #[test]
400 fn eval_eq_false() {
401 let expr = Expr::parse("5 == 6").unwrap();
402 assert!(!expr.eval_bool(&ExprContext::new()));
403 }
404
405 #[test]
406 fn eval_ne() {
407 let expr = Expr::parse("5 != 6").unwrap();
408 assert!(expr.eval_bool(&ExprContext::new()));
409 }
410
411 #[test]
412 fn eval_lt() {
413 let expr = Expr::parse("3 < 5").unwrap();
414 assert!(expr.eval_bool(&ExprContext::new()));
415 }
416
417 #[test]
418 fn eval_le() {
419 let expr = Expr::parse("5 <= 5").unwrap();
420 assert!(expr.eval_bool(&ExprContext::new()));
421 }
422
423 #[test]
424 fn eval_gt() {
425 let expr = Expr::parse("10 > 5").unwrap();
426 assert!(expr.eval_bool(&ExprContext::new()));
427 }
428
429 #[test]
430 fn eval_ge() {
431 let expr = Expr::parse("5 >= 5").unwrap();
432 assert!(expr.eval_bool(&ExprContext::new()));
433 }
434
435 #[test]
438 fn eval_and_true() {
439 let expr = Expr::parse("true and true").unwrap();
440 assert!(expr.eval_bool(&ExprContext::new()));
441 }
442
443 #[test]
444 fn eval_and_false() {
445 let expr = Expr::parse("true and false").unwrap();
446 assert!(!expr.eval_bool(&ExprContext::new()));
447 }
448
449 #[test]
450 fn eval_or_true() {
451 let expr = Expr::parse("false or true").unwrap();
452 assert!(expr.eval_bool(&ExprContext::new()));
453 }
454
455 #[test]
456 fn eval_or_false() {
457 let expr = Expr::parse("false or false").unwrap();
458 assert!(!expr.eval_bool(&ExprContext::new()));
459 }
460
461 #[test]
462 fn eval_not() {
463 let expr = Expr::parse("!true").unwrap();
464 assert!(!expr.eval_bool(&ExprContext::new()));
465 }
466
467 #[test]
468 fn eval_not_keyword() {
469 let expr = Expr::parse("not false").unwrap();
470 assert!(expr.eval_bool(&ExprContext::new()));
471 }
472
473 #[test]
476 fn eval_variable() {
477 let expr = Expr::parse("$x + 1").unwrap();
478 let ctx = ctx_with(&[("x", n(10.0))]);
479 assert_eq!(expr.eval_number(&ctx), 11.0);
480 }
481
482 #[test]
483 fn eval_missing_variable_returns_na() {
484 let expr = Expr::parse("$missing").unwrap();
485 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
486 }
487
488 #[test]
489 fn eval_missing_variable_number_returns_nan() {
490 let expr = Expr::parse("$missing").unwrap();
491 assert!(expr.eval_number(&ExprContext::new()).is_nan());
492 }
493
494 #[test]
495 fn eval_missing_variable_bool_returns_false() {
496 let expr = Expr::parse("$missing").unwrap();
497 assert!(!expr.eval_bool(&ExprContext::new()));
498 }
499
500 #[test]
501 fn variables_collected() {
502 let expr = Expr::parse("$a + $b * $a").unwrap();
503 assert_eq!(expr.variables(), &["a", "b"]);
504 }
505
506 #[test]
509 fn eval_conditional_true() {
510 let expr = Expr::parse("if true then 1 else 0").unwrap();
511 assert_eq!(expr.eval_number(&ExprContext::new()), 1.0);
512 }
513
514 #[test]
515 fn eval_conditional_false() {
516 let expr = Expr::parse("if false then 1 else 0").unwrap();
517 assert_eq!(expr.eval_number(&ExprContext::new()), 0.0);
518 }
519
520 #[test]
521 fn eval_conditional_with_variable() {
522 let expr = Expr::parse("if $flag then $a else $b").unwrap();
523 let ctx = ctx_with(&[("flag", Kind::Bool(true)), ("a", n(42.0)), ("b", n(99.0))]);
524 assert_eq!(expr.eval_number(&ctx), 42.0);
525 }
526
527 #[test]
530 fn fn_abs() {
531 let expr = Expr::parse("abs(-7)").unwrap();
532 assert_eq!(expr.eval_number(&ExprContext::new()), 7.0);
533 }
534
535 #[test]
536 fn fn_abs_positive() {
537 let expr = Expr::parse("abs(3)").unwrap();
538 assert_eq!(expr.eval_number(&ExprContext::new()), 3.0);
539 }
540
541 #[test]
542 fn fn_min() {
543 let expr = Expr::parse("min(3, 7)").unwrap();
544 assert_eq!(expr.eval_number(&ExprContext::new()), 3.0);
545 }
546
547 #[test]
548 fn fn_max() {
549 let expr = Expr::parse("max(3, 7)").unwrap();
550 assert_eq!(expr.eval_number(&ExprContext::new()), 7.0);
551 }
552
553 #[test]
554 fn fn_sqrt() {
555 let expr = Expr::parse("sqrt(16)").unwrap();
556 assert_eq!(expr.eval_number(&ExprContext::new()), 4.0);
557 }
558
559 #[test]
560 fn fn_clamp() {
561 let expr = Expr::parse("clamp(15, 0, 10)").unwrap();
562 assert_eq!(expr.eval_number(&ExprContext::new()), 10.0);
563 }
564
565 #[test]
566 fn fn_clamp_within() {
567 let expr = Expr::parse("clamp(5, 0, 10)").unwrap();
568 assert_eq!(expr.eval_number(&ExprContext::new()), 5.0);
569 }
570
571 #[test]
572 fn fn_clamp_below() {
573 let expr = Expr::parse("clamp(-5, 0, 10)").unwrap();
574 assert_eq!(expr.eval_number(&ExprContext::new()), 0.0);
575 }
576
577 #[test]
578 fn fn_avg_two() {
579 let expr = Expr::parse("avg(4, 6)").unwrap();
580 assert_eq!(expr.eval_number(&ExprContext::new()), 5.0);
581 }
582
583 #[test]
584 fn fn_avg_three() {
585 let expr = Expr::parse("avg(2, 4, 6)").unwrap();
586 assert_eq!(expr.eval_number(&ExprContext::new()), 4.0);
587 }
588
589 #[test]
590 fn fn_between_inside() {
591 let expr = Expr::parse("between(5, 0, 10)").unwrap();
592 assert!(expr.eval_bool(&ExprContext::new()));
593 }
594
595 #[test]
596 fn fn_between_outside() {
597 let expr = Expr::parse("between(15, 0, 10)").unwrap();
598 assert!(!expr.eval_bool(&ExprContext::new()));
599 }
600
601 #[test]
602 fn fn_between_boundary() {
603 let expr = Expr::parse("between(0, 0, 10)").unwrap();
604 assert!(expr.eval_bool(&ExprContext::new()));
605 }
606
607 #[test]
610 fn type_mismatch_add_str() {
611 let expr = Expr::parse(r#""hello" + 1"#).unwrap();
612 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
613 }
614
615 #[test]
616 fn type_mismatch_compare_str() {
617 let expr = Expr::parse(r#""hello" > 1"#).unwrap();
618 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
619 }
620
621 #[test]
622 fn type_mismatch_logical_number() {
623 let expr = Expr::parse("1 and 2").unwrap();
624 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
625 }
626
627 #[test]
628 fn type_mismatch_neg_bool() {
629 let expr = Expr::parse("-true").unwrap();
630 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
631 }
632
633 #[test]
634 fn type_mismatch_not_number() {
635 let expr = Expr::parse("!5").unwrap();
636 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
637 }
638
639 #[test]
640 fn fn_wrong_arity() {
641 let expr = Expr::parse("abs(1, 2)").unwrap();
642 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
643 }
644
645 #[test]
646 fn fn_unknown_returns_na() {
647 let expr = Expr::parse("unknown(1)").unwrap();
648 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
649 }
650
651 #[test]
652 fn test_division_by_zero() {
653 let expr = Expr::parse("$x / $y").unwrap();
654 let mut ctx = ExprContext::new();
655 ctx.set("x", Kind::Number(Number::unitless(10.0)));
656 ctx.set("y", Kind::Number(Number::unitless(0.0)));
657 assert_eq!(expr.eval(&ctx), Kind::NA);
658 }
659
660 #[test]
661 fn test_modulo_by_zero() {
662 let expr = Expr::parse("$x % $y").unwrap();
663 let mut ctx = ExprContext::new();
664 ctx.set("x", Kind::Number(Number::unitless(10.0)));
665 ctx.set("y", Kind::Number(Number::unitless(0.0)));
666 assert_eq!(expr.eval(&ctx), Kind::NA);
667 }
668
669 #[test]
670 fn test_sqrt_negative() {
671 let expr = Expr::parse("sqrt($x)").unwrap();
672 let mut ctx = ExprContext::new();
673 ctx.set("x", Kind::Number(Number::unitless(-1.0)));
674 assert_eq!(expr.eval(&ctx), Kind::NA);
675 }
676
677 #[test]
678 fn fn_avg_no_args_returns_na() {
679 let expr = Expr::parse("avg()").unwrap();
680 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
681 }
682
683 #[test]
684 fn conditional_non_bool_returns_na() {
685 let expr = Expr::parse("if 1 then 2 else 3").unwrap();
686 assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
687 }
688
689 #[test]
692 fn context_set_and_get() {
693 let mut ctx = ExprContext::new();
694 ctx.set("x", n(42.0));
695 assert!(matches!(ctx.get("x"), Some(Kind::Number(n)) if (n.val - 42.0).abs() < 1e-10));
696 assert!(ctx.get("y").is_none());
697 }
698
699 #[test]
700 fn context_default() {
701 let ctx = ExprContext::default();
702 assert!(ctx.get("anything").is_none());
703 }
704
705 #[test]
708 fn complex_expression() {
709 let expr = Expr::parse("clamp($temp, 0, 100) > 50 and $enabled").unwrap();
711 let ctx = ctx_with(&[("temp", n(75.0)), ("enabled", Kind::Bool(true))]);
712 assert!(expr.eval_bool(&ctx));
713 }
714
715 #[test]
716 fn complex_conditional_expression() {
717 let expr = Expr::parse("if $x > 10 then $x * 2 else $x + 1").unwrap();
718 let ctx = ctx_with(&[("x", n(15.0))]);
719 assert_eq!(expr.eval_number(&ctx), 30.0);
720 }
721
722 #[test]
723 fn nested_function_expression() {
724 let expr = Expr::parse("max(abs(-3), sqrt(4))").unwrap();
725 assert_eq!(expr.eval_number(&ExprContext::new()), 3.0);
726 }
727}