1use std::{cmp::Ordering, fmt::Display};
25
26use crate::compiler::grammar::expr::parser::ID_EXTERNAL;
27use crate::Event;
28use crate::{compiler::Number, runtime::Variable, Context};
29
30use crate::compiler::grammar::expr::{BinaryOperator, Constant, Expression, UnaryOperator};
31
32impl<'x> Context<'x> {
33    pub(crate) fn eval_expression(&mut self, expr: &[Expression]) -> Result<Variable, Event> {
34        let mut exprs = expr.iter().skip(self.expr_pos);
35        while let Some(expr) = exprs.next() {
36            self.expr_pos += 1;
37            match expr {
38                Expression::Variable(v) => {
39                    self.expr_stack.push(self.variable(v).unwrap_or_default());
40                }
41                Expression::Constant(val) => {
42                    self.expr_stack.push(Variable::from(val));
43                }
44                Expression::UnaryOperator(op) => {
45                    let value = self.expr_stack.pop().unwrap_or_default();
46                    self.expr_stack.push(match op {
47                        UnaryOperator::Not => value.op_not(),
48                        UnaryOperator::Minus => value.op_minus(),
49                    });
50                }
51                Expression::BinaryOperator(op) => {
52                    let right = self.expr_stack.pop().unwrap_or_default();
53                    let left = self.expr_stack.pop().unwrap_or_default();
54                    self.expr_stack.push(match op {
55                        BinaryOperator::Add => left.op_add(right),
56                        BinaryOperator::Subtract => left.op_subtract(right),
57                        BinaryOperator::Multiply => left.op_multiply(right),
58                        BinaryOperator::Divide => left.op_divide(right),
59                        BinaryOperator::And => left.op_and(right),
60                        BinaryOperator::Or => left.op_or(right),
61                        BinaryOperator::Xor => left.op_xor(right),
62                        BinaryOperator::Eq => left.op_eq(right),
63                        BinaryOperator::Ne => left.op_ne(right),
64                        BinaryOperator::Lt => left.op_lt(right),
65                        BinaryOperator::Le => left.op_le(right),
66                        BinaryOperator::Gt => left.op_gt(right),
67                        BinaryOperator::Ge => left.op_ge(right),
68                    });
69                }
70                Expression::Function { id, num_args } => {
71                    let num_args = *num_args as usize;
72
73                    if let Some(fnc) = self.runtime.functions.get(*id as usize) {
74                        let mut arguments = vec![Variable::Integer(0); num_args];
75                        for arg_num in 0..num_args {
76                            arguments[num_args - arg_num - 1] =
77                                self.expr_stack.pop().unwrap_or_default();
78                        }
79                        self.expr_stack.push((fnc)(self, arguments));
80                    } else {
81                        let mut arguments = vec![Variable::Integer(0); num_args];
82                        for arg_num in 0..num_args {
83                            arguments[num_args - arg_num - 1] =
84                                self.expr_stack.pop().unwrap_or_default();
85                        }
86                        self.pos -= 1; return Err(Event::Function {
88                            id: ID_EXTERNAL - *id,
89                            arguments,
90                        });
91                    }
92                }
93                Expression::JmpIf { val, pos } => {
94                    if self.expr_stack.last().map_or(false, |v| v.to_bool()) == *val {
95                        self.expr_pos += *pos as usize;
96                        for _ in 0..*pos {
97                            exprs.next();
98                        }
99                    }
100                }
101                Expression::ArrayAccess => {
102                    let index = self.expr_stack.pop().unwrap_or_default().to_usize();
103                    let array = self.expr_stack.pop().unwrap_or_default().into_array();
104                    self.expr_stack
105                        .push(array.get(index).cloned().unwrap_or_default());
106                }
107                Expression::ArrayBuild(num_items) => {
108                    let num_items = *num_items as usize;
109                    let mut items = vec![Variable::Integer(0); num_items];
110                    for arg_num in 0..num_items {
111                        items[num_items - arg_num - 1] = self.expr_stack.pop().unwrap_or_default();
112                    }
113                    self.expr_stack.push(Variable::Array(items.into()));
114                }
115            }
116        }
117
118        let result = self.expr_stack.pop().unwrap_or_default();
119        self.expr_stack.clear();
120        self.expr_pos = 0;
121        Ok(result)
122    }
123}
124
125impl Variable {
126    pub fn op_add(self, other: Variable) -> Variable {
127        match (self, other) {
128            (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_add(b)),
129            (Variable::Float(a), Variable::Float(b)) => Variable::Float(a + b),
130            (Variable::Integer(i), Variable::Float(f))
131            | (Variable::Float(f), Variable::Integer(i)) => Variable::Float(i as f64 + f),
132            (Variable::Array(a), Variable::Array(b)) => {
133                Variable::Array(a.iter().chain(b.iter()).cloned().collect::<Vec<_>>().into())
134            }
135            (Variable::Array(a), b) => a.iter().cloned().chain([b]).collect::<Vec<_>>().into(),
136            (a, Variable::Array(b)) => [a]
137                .into_iter()
138                .chain(b.iter().cloned())
139                .collect::<Vec<_>>()
140                .into(),
141            (Variable::String(a), b) => {
142                if !a.is_empty() {
143                    Variable::String(format!("{}{}", a, b).into())
144                } else {
145                    b
146                }
147            }
148            (a, Variable::String(b)) => {
149                if !b.is_empty() {
150                    Variable::String(format!("{}{}", a, b).into())
151                } else {
152                    a
153                }
154            }
155        }
156    }
157
158    pub fn op_subtract(self, other: Variable) -> Variable {
159        match (self, other) {
160            (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_sub(b)),
161            (Variable::Float(a), Variable::Float(b)) => Variable::Float(a - b),
162            (Variable::Integer(a), Variable::Float(b)) => Variable::Float(a as f64 - b),
163            (Variable::Float(a), Variable::Integer(b)) => Variable::Float(a - b as f64),
164            (Variable::Array(a), b) | (b, Variable::Array(a)) => Variable::Array(
165                a.iter()
166                    .filter(|v| *v != &b)
167                    .cloned()
168                    .collect::<Vec<_>>()
169                    .into(),
170            ),
171            (a, b) => a.parse_number().op_subtract(b.parse_number()),
172        }
173    }
174
175    pub fn op_multiply(self, other: Variable) -> Variable {
176        match (self, other) {
177            (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_mul(b)),
178            (Variable::Float(a), Variable::Float(b)) => Variable::Float(a * b),
179            (Variable::Integer(i), Variable::Float(f))
180            | (Variable::Float(f), Variable::Integer(i)) => Variable::Float(i as f64 * f),
181            (a, b) => a.parse_number().op_multiply(b.parse_number()),
182        }
183    }
184
185    pub fn op_divide(self, other: Variable) -> Variable {
186        match (self, other) {
187            (Variable::Integer(a), Variable::Integer(b)) => {
188                Variable::Float(if b != 0 { a as f64 / b as f64 } else { 0.0 })
189            }
190            (Variable::Float(a), Variable::Float(b)) => {
191                Variable::Float(if b != 0.0 { a / b } else { 0.0 })
192            }
193            (Variable::Integer(a), Variable::Float(b)) => {
194                Variable::Float(if b != 0.0 { a as f64 / b } else { 0.0 })
195            }
196            (Variable::Float(a), Variable::Integer(b)) => {
197                Variable::Float(if b != 0 { a / b as f64 } else { 0.0 })
198            }
199            (a, b) => a.parse_number().op_divide(b.parse_number()),
200        }
201    }
202
203    pub fn op_and(self, other: Variable) -> Variable {
204        Variable::Integer(i64::from(self.to_bool() & other.to_bool()))
205    }
206
207    pub fn op_or(self, other: Variable) -> Variable {
208        Variable::Integer(i64::from(self.to_bool() | other.to_bool()))
209    }
210
211    pub fn op_xor(self, other: Variable) -> Variable {
212        Variable::Integer(i64::from(self.to_bool() ^ other.to_bool()))
213    }
214
215    pub fn op_eq(self, other: Variable) -> Variable {
216        Variable::Integer(i64::from(self == other))
217    }
218
219    pub fn op_ne(self, other: Variable) -> Variable {
220        Variable::Integer(i64::from(self != other))
221    }
222
223    pub fn op_lt(self, other: Variable) -> Variable {
224        Variable::Integer(i64::from(self < other))
225    }
226
227    pub fn op_le(self, other: Variable) -> Variable {
228        Variable::Integer(i64::from(self <= other))
229    }
230
231    pub fn op_gt(self, other: Variable) -> Variable {
232        Variable::Integer(i64::from(self > other))
233    }
234
235    pub fn op_ge(self, other: Variable) -> Variable {
236        Variable::Integer(i64::from(self >= other))
237    }
238
239    pub fn op_not(self) -> Variable {
240        Variable::Integer(i64::from(!self.to_bool()))
241    }
242
243    pub fn op_minus(self) -> Variable {
244        match self {
245            Variable::Integer(n) => Variable::Integer(-n),
246            Variable::Float(n) => Variable::Float(-n),
247            _ => self.parse_number().op_minus(),
248        }
249    }
250
251    pub fn parse_number(&self) -> Variable {
252        match self {
253            Variable::String(s) if !s.is_empty() => {
254                if let Ok(n) = s.parse::<i64>() {
255                    Variable::Integer(n)
256                } else if let Ok(n) = s.parse::<f64>() {
257                    Variable::Float(n)
258                } else {
259                    Variable::Integer(0)
260                }
261            }
262            Variable::Integer(n) => Variable::Integer(*n),
263            Variable::Float(n) => Variable::Float(*n),
264            Variable::Array(l) => Variable::Integer(l.is_empty() as i64),
265            _ => Variable::Integer(0),
266        }
267    }
268
269    pub fn to_bool(&self) -> bool {
270        match self {
271            Variable::Float(f) => *f != 0.0,
272            Variable::Integer(n) => *n != 0,
273            Variable::String(s) => !s.is_empty(),
274            Variable::Array(a) => !a.is_empty(),
275        }
276    }
277}
278
279impl PartialEq for Variable {
280    fn eq(&self, other: &Self) -> bool {
281        match (self, other) {
282            (Self::Integer(a), Self::Integer(b)) => a == b,
283            (Self::Float(a), Self::Float(b)) => a == b,
284            (Self::Integer(a), Self::Float(b)) | (Self::Float(b), Self::Integer(a)) => {
285                *a as f64 == *b
286            }
287            (Self::String(a), Self::String(b)) => a == b,
288            (Self::String(_), Self::Integer(_) | Self::Float(_)) => &self.parse_number() == other,
289            (Self::Integer(_) | Self::Float(_), Self::String(_)) => self == &other.parse_number(),
290            (Self::Array(a), Self::Array(b)) => a == b,
291            _ => false,
292        }
293    }
294}
295
296impl Eq for Variable {}
297
298#[allow(clippy::non_canonical_partial_ord_impl)]
299impl PartialOrd for Variable {
300    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
301        match (self, other) {
302            (Self::Integer(a), Self::Integer(b)) => a.partial_cmp(b),
303            (Self::Float(a), Self::Float(b)) => a.partial_cmp(b),
304            (Self::Integer(a), Self::Float(b)) => (*a as f64).partial_cmp(b),
305            (Self::Float(a), Self::Integer(b)) => a.partial_cmp(&(*b as f64)),
306            (Self::String(a), Self::String(b)) => a.partial_cmp(b),
307            (Self::String(_), Self::Integer(_) | Self::Float(_)) => {
308                self.parse_number().partial_cmp(other)
309            }
310            (Self::Integer(_) | Self::Float(_), Self::String(_)) => {
311                self.partial_cmp(&other.parse_number())
312            }
313            (Self::Array(a), Self::Array(b)) => a.partial_cmp(b),
314            (Self::Array(_) | Self::String(_), _) => Ordering::Greater.into(),
315            (_, Self::Array(_)) => Ordering::Less.into(),
316        }
317    }
318}
319
320impl Ord for Variable {
321    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
322        self.partial_cmp(other).unwrap_or(Ordering::Greater)
323    }
324}
325
326impl Display for Variable {
327    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328        match self {
329            Variable::String(v) => v.fmt(f),
330            Variable::Integer(v) => v.fmt(f),
331            Variable::Float(v) => v.fmt(f),
332            Variable::Array(v) => {
333                for (i, v) in v.iter().enumerate() {
334                    if i > 0 {
335                        f.write_str("\n")?;
336                    }
337                    v.fmt(f)?;
338                }
339                Ok(())
340            }
341        }
342    }
343}
344
345impl Number {
346    pub fn is_non_zero(&self) -> bool {
347        match self {
348            Number::Integer(n) => *n != 0,
349            Number::Float(n) => *n != 0.0,
350        }
351    }
352}
353
354impl Default for Number {
355    fn default() -> Self {
356        Number::Integer(0)
357    }
358}
359
360trait IntoBool {
361    fn into_bool(self) -> bool;
362}
363
364impl IntoBool for f64 {
365    #[inline(always)]
366    fn into_bool(self) -> bool {
367        self != 0.0
368    }
369}
370
371impl IntoBool for i64 {
372    #[inline(always)]
373    fn into_bool(self) -> bool {
374        self != 0
375    }
376}
377
378impl From<bool> for Number {
379    #[inline(always)]
380    fn from(b: bool) -> Self {
381        Number::Integer(i64::from(b))
382    }
383}
384
385impl From<i64> for Number {
386    #[inline(always)]
387    fn from(n: i64) -> Self {
388        Number::Integer(n)
389    }
390}
391
392impl From<f64> for Number {
393    #[inline(always)]
394    fn from(n: f64) -> Self {
395        Number::Float(n)
396    }
397}
398
399impl From<i32> for Number {
400    #[inline(always)]
401    fn from(n: i32) -> Self {
402        Number::Integer(n as i64)
403    }
404}
405
406impl<'x> From<&'x Constant> for Variable {
407    fn from(value: &'x Constant) -> Self {
408        match value {
409            Constant::Integer(i) => Variable::Integer(*i),
410            Constant::Float(f) => Variable::Float(*f),
411            Constant::String(s) => Variable::String(s.clone()),
412        }
413    }
414}
415
416#[cfg(test)]
417mod test {
418    use ahash::{HashMap, HashMapExt};
419
420    use crate::{
421        compiler::{
422            grammar::expr::{
423                parser::ExpressionParser, tokenizer::Tokenizer, BinaryOperator, Expression, Token,
424                UnaryOperator,
425            },
426            VariableType,
427        },
428        runtime::Variable,
429    };
430
431    use evalexpr::*;
432
433    pub trait EvalExpression {
434        fn eval(&self, variables: &HashMap<String, Variable>) -> Option<Variable>;
435    }
436
437    impl EvalExpression for Vec<Expression> {
438        fn eval(&self, variables: &HashMap<String, Variable>) -> Option<Variable> {
439            let mut stack = Vec::with_capacity(self.len());
440            let mut exprs = self.iter();
441
442            while let Some(expr) = exprs.next() {
443                match expr {
444                    Expression::Variable(VariableType::Global(v)) => {
445                        stack.push(variables.get(v)?.clone());
446                    }
447                    Expression::Constant(val) => {
448                        stack.push(Variable::from(val));
449                    }
450                    Expression::UnaryOperator(op) => {
451                        let value = stack.pop()?;
452                        stack.push(match op {
453                            UnaryOperator::Not => value.op_not(),
454                            UnaryOperator::Minus => value.op_minus(),
455                        });
456                    }
457                    Expression::BinaryOperator(op) => {
458                        let right = stack.pop()?;
459                        let left = stack.pop()?;
460                        stack.push(match op {
461                            BinaryOperator::Add => left.op_add(right),
462                            BinaryOperator::Subtract => left.op_subtract(right),
463                            BinaryOperator::Multiply => left.op_multiply(right),
464                            BinaryOperator::Divide => left.op_divide(right),
465                            BinaryOperator::And => left.op_and(right),
466                            BinaryOperator::Or => left.op_or(right),
467                            BinaryOperator::Xor => left.op_xor(right),
468                            BinaryOperator::Eq => left.op_eq(right),
469                            BinaryOperator::Ne => left.op_ne(right),
470                            BinaryOperator::Lt => left.op_lt(right),
471                            BinaryOperator::Le => left.op_le(right),
472                            BinaryOperator::Gt => left.op_gt(right),
473                            BinaryOperator::Ge => left.op_ge(right),
474                        });
475                    }
476                    Expression::JmpIf { val, pos } => {
477                        if stack.last()?.to_bool() == *val {
478                            for _ in 0..*pos {
479                                exprs.next();
480                            }
481                        }
482                    }
483                    _ => unreachable!("Invalid expression"),
484                }
485            }
486            stack.pop()
487        }
488    }
489
490    #[test]
491    fn eval_expression() {
492        let mut variables = HashMap::from_iter([
493            ("A".to_string(), Variable::Integer(0)),
494            ("B".to_string(), Variable::Integer(0)),
495            ("C".to_string(), Variable::Integer(0)),
496            ("D".to_string(), Variable::Integer(0)),
497            ("E".to_string(), Variable::Integer(0)),
498            ("F".to_string(), Variable::Integer(0)),
499            ("G".to_string(), Variable::Integer(0)),
500            ("H".to_string(), Variable::Integer(0)),
501            ("I".to_string(), Variable::Integer(0)),
502            ("J".to_string(), Variable::Integer(0)),
503        ]);
504        let num_vars = variables.len();
505
506        for expr in [
507            "A + B",
508            "A * B",
509            "A / B",
510            "A - B",
511            "-A",
512            "A == B",
513            "A != B",
514            "A > B",
515            "A < B",
516            "A >= B",
517            "A <= B",
518            "A + B * C - D / E",
519            "A + B + C - D - E",
520            "(A + B) * (C - D) / E",
521            "A - B + C * D / E * F - G",
522            "A + B * C - D / E",
523            "(A + B) * (C - D) / E",
524            "A - B + C / D * E",
525            "(A + B) / (C - D) + E",
526            "A * (B + C) - D / E",
527            "A / (B - C + D) * E",
528            "(A + B) * C - D / (E + F)",
529            "A * B - C + D / E",
530            "A + B - C * D / E",
531            "(A * B + C) / D - E",
532            "A - B / C + D * E",
533            "A + B * (C - D) / E",
534            "A * B / C + (D - E)",
535            "(A - B) * C / D + E",
536            "A * (B / C) - D + E",
537            "(A + B) / (C + D) * E",
538            "A - B * C / D + E",
539            "A + (B - C) * D / E",
540            "(A + B) * (C / D) - E",
541            "A - B / (C * D) + E",
542            "(A + B) > (C - D) && E <= F",
543            "A * B == C / D || E - F != G + H",
544            "A / B >= C * D && E + F < G - H",
545            "(A * B - C) != (D / E + F) && G > H",
546            "A - B < C && D + E >= F * G",
547            "(A * B) > C && (D / E) < F || G == H",
548            "(A + B) <= (C - D) || E > F && G != H",
549            "A * B != C + D || E - F == G / H",
550            "A >= B * C && D < E - F || G != H + I",
551            "(A / B + C) > D && E * F <= G - H",
552            "A * (B - C) == D && E / F > G + H",
553            "(A - B + C) != D || E * F >= G && H < I",
554            "A < B / C && D + E * F == G - H",
555            "(A + B * C) <= D && E > F / G",
556            "(A * B - C) > D || E <= F + G && H != I",
557            "A != B / C && D == E * F - G",
558            "A <= B + C - D && E / F > G * H",
559            "(A - B * C) < D || E >= F + G && H != I",
560            "(A + B) / C == D && E - F < G * H",
561            "A * B != C && D >= E + F / G || H < I",
562            "!(A * B != C) && !(D >= E + F / G) || !(H < I)",
563            "-A - B - (- C - D) - E - (-F)",
564        ] {
565            println!("Testing {}", expr);
566            for (pos, v) in variables.values_mut().enumerate() {
567                *v = Variable::Integer(pos as i64 + 1);
568            }
569
570            assert_expr(expr, &variables);
571
572            for (pos, v) in variables.values_mut().enumerate() {
573                *v = Variable::Integer((num_vars - pos) as i64);
574            }
575
576            assert_expr(expr, &variables);
577        }
578
579        for expr in [
580            "true && false",
581            "!true || false",
582            "true && !false",
583            "!(true && false)",
584            "true || true && false",
585            "!false && (true || false)",
586            "!(true || !false) && true",
587            "!(!true && !false)",
588            "true || false && !true",
589            "!(true && true) || !false",
590            "!(!true || !false) && (!false) && !(!true)",
591        ] {
592            let pexp = parse_expression(expr.replace("true", "1").replace("false", "0").as_str());
593            let result = pexp.eval(&HashMap::new()).unwrap();
594
595            match (eval(expr).expect(expr), result) {
598                (Value::Float(a), Variable::Float(b)) if a == b => (),
599                (Value::Float(a), Variable::Integer(b)) if a == b as f64 => (),
600                (Value::Boolean(a), Variable::Integer(b)) if a == (b != 0) => (),
601                (a, b) => {
602                    panic!("{} => {:?} != {:?}", expr, a, b)
603                }
604            }
605        }
606    }
607
608    fn assert_expr(expr: &str, variables: &HashMap<String, Variable>) {
609        let e = parse_expression(expr);
610
611        let result = e.eval(variables).unwrap();
612
613        let mut str_expr = expr.to_string();
614        let mut str_expr_float = expr.to_string();
615        for (k, v) in variables {
616            let v = v.to_string();
617
618            if v.contains('.') {
619                str_expr_float = str_expr_float.replace(k, &v);
620            } else {
621                str_expr_float = str_expr_float.replace(k, &format!("{}.0", v));
622            }
623            str_expr = str_expr.replace(k, &v);
624        }
625
626        assert_eq!(
627            parse_expression(&str_expr)
628                .eval(&HashMap::new())
629                .unwrap()
630                .to_number()
631                .to_float(),
632            result.to_number().to_float()
633        );
634
635        assert_eq!(
636            parse_expression(&str_expr_float)
637                .eval(&HashMap::new())
638                .unwrap()
639                .to_number()
640                .to_float(),
641            result.to_number().to_float()
642        );
643
644        match (
647            eval(&str_expr_float)
648                .map(|v| {
649                    if matches!(&v, Value::Float(f) if f.is_infinite()) {
651                        Value::Float(0.0)
652                    } else {
653                        v
654                    }
655                })
656                .expect(&str_expr),
657            result,
658        ) {
659            (Value::Float(a), Variable::Float(b)) if a == b => (),
660            (Value::Float(a), Variable::Integer(b)) if a == b as f64 => (),
661            (Value::Boolean(a), Variable::Integer(b)) if a == (b != 0) => (),
662            (a, b) => {
663                panic!("{} => {:?} != {:?}", str_expr, a, b)
664            }
665        }
666    }
667
668    fn parse_expression(expr: &str) -> Vec<Expression> {
669        ExpressionParser::from_tokenizer(Tokenizer::new(expr, |var_name: &str, _: bool| {
670            Ok::<_, String>(Token::Variable(VariableType::Global(var_name.to_string())))
671        }))
672        .parse()
673        .unwrap()
674        .output
675    }
676}