Skip to main content

haystack_core/expr/
eval.rs

1//! Expression evaluator with built-in functions.
2
3use std::collections::HashMap;
4
5use super::ast::*;
6use super::parser::{ExprError, parse_expr};
7use crate::kinds::{Kind, Number};
8
9/// Context providing variable values for expression evaluation.
10pub struct ExprContext {
11    vars: HashMap<String, Kind>,
12}
13
14impl ExprContext {
15    /// Create an empty context.
16    pub fn new() -> Self {
17        Self {
18            vars: HashMap::new(),
19        }
20    }
21
22    /// Set a variable value.
23    pub fn set(&mut self, name: impl Into<String>, val: Kind) {
24        self.vars.insert(name.into(), val);
25    }
26
27    /// Look up a variable by name.
28    pub fn get(&self, name: &str) -> Option<&Kind> {
29        self.vars.get(name)
30    }
31}
32
33impl Default for ExprContext {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39/// A parsed, ready-to-evaluate expression.
40pub struct Expr {
41    root: ExprNode,
42    variables: Vec<String>,
43}
44
45impl Expr {
46    /// Parse a source string into a ready-to-evaluate expression.
47    pub fn parse(source: &str) -> Result<Self, ExprError> {
48        let root = parse_expr(source)?;
49        let mut variables = Vec::new();
50        collect_variables(&root, &mut variables);
51        variables.sort();
52        variables.dedup();
53        Ok(Self { root, variables })
54    }
55
56    /// Evaluate the expression, returning `Kind::NA` on any failure.
57    pub fn eval(&self, ctx: &ExprContext) -> Kind {
58        eval_node(&self.root, ctx)
59    }
60
61    /// Evaluate and extract an f64, returning `NaN` on failure.
62    pub fn eval_number(&self, ctx: &ExprContext) -> f64 {
63        match self.eval(ctx) {
64            Kind::Number(n) => n.val,
65            _ => f64::NAN,
66        }
67    }
68
69    /// Evaluate and extract a bool, returning `false` on failure.
70    pub fn eval_bool(&self, ctx: &ExprContext) -> bool {
71        match self.eval(ctx) {
72            Kind::Bool(b) => b,
73            _ => false,
74        }
75    }
76
77    /// Return the sorted, deduplicated list of variables referenced in the expression.
78    pub fn variables(&self) -> &[String] {
79        &self.variables
80    }
81}
82
83/// Walk the AST and collect all `Variable` names.
84fn collect_variables(node: &ExprNode, out: &mut Vec<String>) {
85    match node {
86        ExprNode::Variable(name) => out.push(name.clone()),
87        ExprNode::Literal(_) => {}
88        ExprNode::BinaryOp { left, right, .. } => {
89            collect_variables(left, out);
90            collect_variables(right, out);
91        }
92        ExprNode::UnaryOp { operand, .. } => collect_variables(operand, out),
93        ExprNode::Comparison { left, right, .. } => {
94            collect_variables(left, out);
95            collect_variables(right, out);
96        }
97        ExprNode::Logical { left, right, .. } => {
98            collect_variables(left, out);
99            collect_variables(right, out);
100        }
101        ExprNode::FnCall { args, .. } => {
102            for arg in args {
103                collect_variables(arg, out);
104            }
105        }
106        ExprNode::Conditional {
107            cond,
108            then_expr,
109            else_expr,
110        } => {
111            collect_variables(cond, out);
112            collect_variables(then_expr, out);
113            collect_variables(else_expr, out);
114        }
115    }
116}
117
118/// Extract an f64 from a Kind, or return None.
119fn as_f64(k: &Kind) -> Option<f64> {
120    match k {
121        Kind::Number(n) => Some(n.val),
122        _ => None,
123    }
124}
125
126/// Extract a bool from a Kind, or return None.
127fn as_bool(k: &Kind) -> Option<bool> {
128    match k {
129        Kind::Bool(b) => Some(*b),
130        _ => None,
131    }
132}
133
134/// Wrap an f64 into a `Kind::Number`.
135fn num(val: f64) -> Kind {
136    Kind::Number(Number::unitless(val))
137}
138
139/// Evaluate a single AST node.
140fn eval_node(node: &ExprNode, ctx: &ExprContext) -> Kind {
141    match node {
142        ExprNode::Literal(k) => k.clone(),
143
144        ExprNode::Variable(name) => ctx.get(name).cloned().unwrap_or(Kind::NA),
145
146        ExprNode::BinaryOp { left, op, right } => {
147            let lv = eval_node(left, ctx);
148            let rv = eval_node(right, ctx);
149            let (Some(l), Some(r)) = (as_f64(&lv), as_f64(&rv)) else {
150                return Kind::NA;
151            };
152            let result = match op {
153                BinOp::Add => l + r,
154                BinOp::Sub => l - r,
155                BinOp::Mul => l * r,
156                BinOp::Div => {
157                    if r == 0.0 {
158                        return Kind::NA;
159                    }
160                    l / r
161                }
162                BinOp::Mod => {
163                    if r == 0.0 {
164                        return Kind::NA;
165                    }
166                    l % r
167                }
168            };
169            num(result)
170        }
171
172        ExprNode::UnaryOp { op, operand } => {
173            let val = eval_node(operand, ctx);
174            match op {
175                UnOp::Neg => {
176                    if let Some(v) = as_f64(&val) {
177                        num(-v)
178                    } else {
179                        Kind::NA
180                    }
181                }
182                UnOp::Not => {
183                    if let Some(b) = as_bool(&val) {
184                        Kind::Bool(!b)
185                    } else {
186                        Kind::NA
187                    }
188                }
189            }
190        }
191
192        ExprNode::Comparison { left, op, right } => {
193            let lv = eval_node(left, ctx);
194            let rv = eval_node(right, ctx);
195            let (Some(l), Some(r)) = (as_f64(&lv), as_f64(&rv)) else {
196                return Kind::NA;
197            };
198            let result = match op {
199                CmpOp::Eq => l == r,
200                CmpOp::Ne => l != r,
201                CmpOp::Lt => l < r,
202                CmpOp::Le => l <= r,
203                CmpOp::Gt => l > r,
204                CmpOp::Ge => l >= r,
205            };
206            Kind::Bool(result)
207        }
208
209        ExprNode::Logical { left, op, right } => {
210            let lv = eval_node(left, ctx);
211            let rv = eval_node(right, ctx);
212            let (Some(l), Some(r)) = (as_bool(&lv), as_bool(&rv)) else {
213                return Kind::NA;
214            };
215            let result = match op {
216                LogicOp::And => l && r,
217                LogicOp::Or => l || r,
218            };
219            Kind::Bool(result)
220        }
221
222        ExprNode::FnCall { name, args } => {
223            let evaluated: Vec<Kind> = args.iter().map(|a| eval_node(a, ctx)).collect();
224            eval_builtin(name, &evaluated)
225        }
226
227        ExprNode::Conditional {
228            cond,
229            then_expr,
230            else_expr,
231        } => {
232            let cv = eval_node(cond, ctx);
233            match as_bool(&cv) {
234                Some(true) => eval_node(then_expr, ctx),
235                Some(false) => eval_node(else_expr, ctx),
236                None => Kind::NA,
237            }
238        }
239    }
240}
241
242fn finite_num(val: f64) -> Kind {
243    if val.is_finite() {
244        Kind::Number(Number::unitless(val))
245    } else {
246        Kind::NA
247    }
248}
249
250/// Evaluate one of the 7 built-in functions.
251fn eval_builtin(name: &str, args: &[Kind]) -> Kind {
252    match name {
253        "abs" => {
254            if args.len() != 1 {
255                return Kind::NA;
256            }
257            as_f64(&args[0]).map_or(Kind::NA, |v| num(v.abs()))
258        }
259        "min" => {
260            if args.len() != 2 {
261                return Kind::NA;
262            }
263            let (Some(a), Some(b)) = (as_f64(&args[0]), as_f64(&args[1])) else {
264                return Kind::NA;
265            };
266            num(a.min(b))
267        }
268        "max" => {
269            if args.len() != 2 {
270                return Kind::NA;
271            }
272            let (Some(a), Some(b)) = (as_f64(&args[0]), as_f64(&args[1])) else {
273                return Kind::NA;
274            };
275            num(a.max(b))
276        }
277        "sqrt" => {
278            if args.len() != 1 {
279                return Kind::NA;
280            }
281            as_f64(&args[0]).map_or(Kind::NA, |v| finite_num(v.sqrt()))
282        }
283        "clamp" => {
284            if args.len() != 3 {
285                return Kind::NA;
286            }
287            let (Some(x), Some(lo), Some(hi)) =
288                (as_f64(&args[0]), as_f64(&args[1]), as_f64(&args[2]))
289            else {
290                return Kind::NA;
291            };
292            num(x.clamp(lo, hi))
293        }
294        "avg" => {
295            if args.is_empty() {
296                return Kind::NA;
297            }
298            let mut sum = 0.0;
299            for a in args {
300                match as_f64(a) {
301                    Some(v) => sum += v,
302                    None => return Kind::NA,
303                }
304            }
305            num(sum / args.len() as f64)
306        }
307        "between" => {
308            if args.len() != 3 {
309                return Kind::NA;
310            }
311            let (Some(x), Some(lo), Some(hi)) =
312                (as_f64(&args[0]), as_f64(&args[1]), as_f64(&args[2]))
313            else {
314                return Kind::NA;
315            };
316            Kind::Bool(x >= lo && x <= hi)
317        }
318        _ => Kind::NA,
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    fn ctx_with(pairs: &[(&str, Kind)]) -> ExprContext {
327        let mut ctx = ExprContext::new();
328        for (k, v) in pairs {
329            ctx.set(*k, v.clone());
330        }
331        ctx
332    }
333
334    fn n(v: f64) -> Kind {
335        Kind::Number(Number::unitless(v))
336    }
337
338    // ── Arithmetic ──────────────────────────────────────────
339
340    #[test]
341    fn eval_addition() {
342        let expr = Expr::parse("1 + 2").unwrap();
343        let result = expr.eval(&ExprContext::new());
344        assert!(matches!(result, Kind::Number(ref n) if (n.val - 3.0).abs() < 1e-10));
345    }
346
347    #[test]
348    fn eval_subtraction() {
349        let expr = Expr::parse("10 - 3").unwrap();
350        assert!(
351            matches!(expr.eval(&ExprContext::new()), Kind::Number(ref n) if (n.val - 7.0).abs() < 1e-10)
352        );
353    }
354
355    #[test]
356    fn eval_multiplication() {
357        let expr = Expr::parse("4 * 5").unwrap();
358        assert_eq!(expr.eval_number(&ExprContext::new()), 20.0);
359    }
360
361    #[test]
362    fn eval_division() {
363        let expr = Expr::parse("10 / 4").unwrap();
364        assert_eq!(expr.eval_number(&ExprContext::new()), 2.5);
365    }
366
367    #[test]
368    fn eval_modulo() {
369        let expr = Expr::parse("10 % 3").unwrap();
370        assert_eq!(expr.eval_number(&ExprContext::new()), 1.0);
371    }
372
373    #[test]
374    fn eval_precedence() {
375        let expr = Expr::parse("2 + 3 * 4").unwrap();
376        assert_eq!(expr.eval_number(&ExprContext::new()), 14.0);
377    }
378
379    #[test]
380    fn eval_parentheses() {
381        let expr = Expr::parse("(2 + 3) * 4").unwrap();
382        assert_eq!(expr.eval_number(&ExprContext::new()), 20.0);
383    }
384
385    #[test]
386    fn eval_negation() {
387        let expr = Expr::parse("-5").unwrap();
388        assert_eq!(expr.eval_number(&ExprContext::new()), -5.0);
389    }
390
391    // ── Comparison ──────────────────────────────────────────
392
393    #[test]
394    fn eval_eq_true() {
395        let expr = Expr::parse("5 == 5").unwrap();
396        assert!(expr.eval_bool(&ExprContext::new()));
397    }
398
399    #[test]
400    fn eval_eq_false() {
401        let expr = Expr::parse("5 == 6").unwrap();
402        assert!(!expr.eval_bool(&ExprContext::new()));
403    }
404
405    #[test]
406    fn eval_ne() {
407        let expr = Expr::parse("5 != 6").unwrap();
408        assert!(expr.eval_bool(&ExprContext::new()));
409    }
410
411    #[test]
412    fn eval_lt() {
413        let expr = Expr::parse("3 < 5").unwrap();
414        assert!(expr.eval_bool(&ExprContext::new()));
415    }
416
417    #[test]
418    fn eval_le() {
419        let expr = Expr::parse("5 <= 5").unwrap();
420        assert!(expr.eval_bool(&ExprContext::new()));
421    }
422
423    #[test]
424    fn eval_gt() {
425        let expr = Expr::parse("10 > 5").unwrap();
426        assert!(expr.eval_bool(&ExprContext::new()));
427    }
428
429    #[test]
430    fn eval_ge() {
431        let expr = Expr::parse("5 >= 5").unwrap();
432        assert!(expr.eval_bool(&ExprContext::new()));
433    }
434
435    // ── Logical ─────────────────────────────────────────────
436
437    #[test]
438    fn eval_and_true() {
439        let expr = Expr::parse("true and true").unwrap();
440        assert!(expr.eval_bool(&ExprContext::new()));
441    }
442
443    #[test]
444    fn eval_and_false() {
445        let expr = Expr::parse("true and false").unwrap();
446        assert!(!expr.eval_bool(&ExprContext::new()));
447    }
448
449    #[test]
450    fn eval_or_true() {
451        let expr = Expr::parse("false or true").unwrap();
452        assert!(expr.eval_bool(&ExprContext::new()));
453    }
454
455    #[test]
456    fn eval_or_false() {
457        let expr = Expr::parse("false or false").unwrap();
458        assert!(!expr.eval_bool(&ExprContext::new()));
459    }
460
461    #[test]
462    fn eval_not() {
463        let expr = Expr::parse("!true").unwrap();
464        assert!(!expr.eval_bool(&ExprContext::new()));
465    }
466
467    #[test]
468    fn eval_not_keyword() {
469        let expr = Expr::parse("not false").unwrap();
470        assert!(expr.eval_bool(&ExprContext::new()));
471    }
472
473    // ── Variables ───────────────────────────────────────────
474
475    #[test]
476    fn eval_variable() {
477        let expr = Expr::parse("$x + 1").unwrap();
478        let ctx = ctx_with(&[("x", n(10.0))]);
479        assert_eq!(expr.eval_number(&ctx), 11.0);
480    }
481
482    #[test]
483    fn eval_missing_variable_returns_na() {
484        let expr = Expr::parse("$missing").unwrap();
485        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
486    }
487
488    #[test]
489    fn eval_missing_variable_number_returns_nan() {
490        let expr = Expr::parse("$missing").unwrap();
491        assert!(expr.eval_number(&ExprContext::new()).is_nan());
492    }
493
494    #[test]
495    fn eval_missing_variable_bool_returns_false() {
496        let expr = Expr::parse("$missing").unwrap();
497        assert!(!expr.eval_bool(&ExprContext::new()));
498    }
499
500    #[test]
501    fn variables_collected() {
502        let expr = Expr::parse("$a + $b * $a").unwrap();
503        assert_eq!(expr.variables(), &["a", "b"]);
504    }
505
506    // ── Conditional ─────────────────────────────────────────
507
508    #[test]
509    fn eval_conditional_true() {
510        let expr = Expr::parse("if true then 1 else 0").unwrap();
511        assert_eq!(expr.eval_number(&ExprContext::new()), 1.0);
512    }
513
514    #[test]
515    fn eval_conditional_false() {
516        let expr = Expr::parse("if false then 1 else 0").unwrap();
517        assert_eq!(expr.eval_number(&ExprContext::new()), 0.0);
518    }
519
520    #[test]
521    fn eval_conditional_with_variable() {
522        let expr = Expr::parse("if $flag then $a else $b").unwrap();
523        let ctx = ctx_with(&[("flag", Kind::Bool(true)), ("a", n(42.0)), ("b", n(99.0))]);
524        assert_eq!(expr.eval_number(&ctx), 42.0);
525    }
526
527    // ── Built-in functions ──────────────────────────────────
528
529    #[test]
530    fn fn_abs() {
531        let expr = Expr::parse("abs(-7)").unwrap();
532        assert_eq!(expr.eval_number(&ExprContext::new()), 7.0);
533    }
534
535    #[test]
536    fn fn_abs_positive() {
537        let expr = Expr::parse("abs(3)").unwrap();
538        assert_eq!(expr.eval_number(&ExprContext::new()), 3.0);
539    }
540
541    #[test]
542    fn fn_min() {
543        let expr = Expr::parse("min(3, 7)").unwrap();
544        assert_eq!(expr.eval_number(&ExprContext::new()), 3.0);
545    }
546
547    #[test]
548    fn fn_max() {
549        let expr = Expr::parse("max(3, 7)").unwrap();
550        assert_eq!(expr.eval_number(&ExprContext::new()), 7.0);
551    }
552
553    #[test]
554    fn fn_sqrt() {
555        let expr = Expr::parse("sqrt(16)").unwrap();
556        assert_eq!(expr.eval_number(&ExprContext::new()), 4.0);
557    }
558
559    #[test]
560    fn fn_clamp() {
561        let expr = Expr::parse("clamp(15, 0, 10)").unwrap();
562        assert_eq!(expr.eval_number(&ExprContext::new()), 10.0);
563    }
564
565    #[test]
566    fn fn_clamp_within() {
567        let expr = Expr::parse("clamp(5, 0, 10)").unwrap();
568        assert_eq!(expr.eval_number(&ExprContext::new()), 5.0);
569    }
570
571    #[test]
572    fn fn_clamp_below() {
573        let expr = Expr::parse("clamp(-5, 0, 10)").unwrap();
574        assert_eq!(expr.eval_number(&ExprContext::new()), 0.0);
575    }
576
577    #[test]
578    fn fn_avg_two() {
579        let expr = Expr::parse("avg(4, 6)").unwrap();
580        assert_eq!(expr.eval_number(&ExprContext::new()), 5.0);
581    }
582
583    #[test]
584    fn fn_avg_three() {
585        let expr = Expr::parse("avg(2, 4, 6)").unwrap();
586        assert_eq!(expr.eval_number(&ExprContext::new()), 4.0);
587    }
588
589    #[test]
590    fn fn_between_inside() {
591        let expr = Expr::parse("between(5, 0, 10)").unwrap();
592        assert!(expr.eval_bool(&ExprContext::new()));
593    }
594
595    #[test]
596    fn fn_between_outside() {
597        let expr = Expr::parse("between(15, 0, 10)").unwrap();
598        assert!(!expr.eval_bool(&ExprContext::new()));
599    }
600
601    #[test]
602    fn fn_between_boundary() {
603        let expr = Expr::parse("between(0, 0, 10)").unwrap();
604        assert!(expr.eval_bool(&ExprContext::new()));
605    }
606
607    // ── Type mismatch → NA ──────────────────────────────────
608
609    #[test]
610    fn type_mismatch_add_str() {
611        let expr = Expr::parse(r#""hello" + 1"#).unwrap();
612        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
613    }
614
615    #[test]
616    fn type_mismatch_compare_str() {
617        let expr = Expr::parse(r#""hello" > 1"#).unwrap();
618        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
619    }
620
621    #[test]
622    fn type_mismatch_logical_number() {
623        let expr = Expr::parse("1 and 2").unwrap();
624        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
625    }
626
627    #[test]
628    fn type_mismatch_neg_bool() {
629        let expr = Expr::parse("-true").unwrap();
630        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
631    }
632
633    #[test]
634    fn type_mismatch_not_number() {
635        let expr = Expr::parse("!5").unwrap();
636        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
637    }
638
639    #[test]
640    fn fn_wrong_arity() {
641        let expr = Expr::parse("abs(1, 2)").unwrap();
642        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
643    }
644
645    #[test]
646    fn fn_unknown_returns_na() {
647        let expr = Expr::parse("unknown(1)").unwrap();
648        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
649    }
650
651    #[test]
652    fn test_division_by_zero() {
653        let expr = Expr::parse("$x / $y").unwrap();
654        let mut ctx = ExprContext::new();
655        ctx.set("x", Kind::Number(Number::unitless(10.0)));
656        ctx.set("y", Kind::Number(Number::unitless(0.0)));
657        assert_eq!(expr.eval(&ctx), Kind::NA);
658    }
659
660    #[test]
661    fn test_modulo_by_zero() {
662        let expr = Expr::parse("$x % $y").unwrap();
663        let mut ctx = ExprContext::new();
664        ctx.set("x", Kind::Number(Number::unitless(10.0)));
665        ctx.set("y", Kind::Number(Number::unitless(0.0)));
666        assert_eq!(expr.eval(&ctx), Kind::NA);
667    }
668
669    #[test]
670    fn test_sqrt_negative() {
671        let expr = Expr::parse("sqrt($x)").unwrap();
672        let mut ctx = ExprContext::new();
673        ctx.set("x", Kind::Number(Number::unitless(-1.0)));
674        assert_eq!(expr.eval(&ctx), Kind::NA);
675    }
676
677    #[test]
678    fn fn_avg_no_args_returns_na() {
679        let expr = Expr::parse("avg()").unwrap();
680        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
681    }
682
683    #[test]
684    fn conditional_non_bool_returns_na() {
685        let expr = Expr::parse("if 1 then 2 else 3").unwrap();
686        assert!(matches!(expr.eval(&ExprContext::new()), Kind::NA));
687    }
688
689    // ── ExprContext ─────────────────────────────────────────
690
691    #[test]
692    fn context_set_and_get() {
693        let mut ctx = ExprContext::new();
694        ctx.set("x", n(42.0));
695        assert!(matches!(ctx.get("x"), Some(Kind::Number(n)) if (n.val - 42.0).abs() < 1e-10));
696        assert!(ctx.get("y").is_none());
697    }
698
699    #[test]
700    fn context_default() {
701        let ctx = ExprContext::default();
702        assert!(ctx.get("anything").is_none());
703    }
704
705    // ── Complex expressions ─────────────────────────────────
706
707    #[test]
708    fn complex_expression() {
709        // clamp($temp, 0, 100) > 50 and $enabled
710        let expr = Expr::parse("clamp($temp, 0, 100) > 50 and $enabled").unwrap();
711        let ctx = ctx_with(&[("temp", n(75.0)), ("enabled", Kind::Bool(true))]);
712        assert!(expr.eval_bool(&ctx));
713    }
714
715    #[test]
716    fn complex_conditional_expression() {
717        let expr = Expr::parse("if $x > 10 then $x * 2 else $x + 1").unwrap();
718        let ctx = ctx_with(&[("x", n(15.0))]);
719        assert_eq!(expr.eval_number(&ctx), 30.0);
720    }
721
722    #[test]
723    fn nested_function_expression() {
724        let expr = Expr::parse("max(abs(-3), sqrt(4))").unwrap();
725        assert_eq!(expr.eval_number(&ExprContext::new()), 3.0);
726    }
727}