1use std::{cmp::Ordering, fmt::Display};
25
26use crate::compiler::grammar::expr::parser::ID_EXTERNAL;
27use crate::Event;
28use crate::{compiler::Number, runtime::Variable, Context};
29
30use crate::compiler::grammar::expr::{BinaryOperator, Constant, Expression, UnaryOperator};
31
32impl<'x> Context<'x> {
33 pub(crate) fn eval_expression(&mut self, expr: &[Expression]) -> Result<Variable, Event> {
34 let mut exprs = expr.iter().skip(self.expr_pos);
35 while let Some(expr) = exprs.next() {
36 self.expr_pos += 1;
37 match expr {
38 Expression::Variable(v) => {
39 self.expr_stack.push(self.variable(v).unwrap_or_default());
40 }
41 Expression::Constant(val) => {
42 self.expr_stack.push(Variable::from(val));
43 }
44 Expression::UnaryOperator(op) => {
45 let value = self.expr_stack.pop().unwrap_or_default();
46 self.expr_stack.push(match op {
47 UnaryOperator::Not => value.op_not(),
48 UnaryOperator::Minus => value.op_minus(),
49 });
50 }
51 Expression::BinaryOperator(op) => {
52 let right = self.expr_stack.pop().unwrap_or_default();
53 let left = self.expr_stack.pop().unwrap_or_default();
54 self.expr_stack.push(match op {
55 BinaryOperator::Add => left.op_add(right),
56 BinaryOperator::Subtract => left.op_subtract(right),
57 BinaryOperator::Multiply => left.op_multiply(right),
58 BinaryOperator::Divide => left.op_divide(right),
59 BinaryOperator::And => left.op_and(right),
60 BinaryOperator::Or => left.op_or(right),
61 BinaryOperator::Xor => left.op_xor(right),
62 BinaryOperator::Eq => left.op_eq(right),
63 BinaryOperator::Ne => left.op_ne(right),
64 BinaryOperator::Lt => left.op_lt(right),
65 BinaryOperator::Le => left.op_le(right),
66 BinaryOperator::Gt => left.op_gt(right),
67 BinaryOperator::Ge => left.op_ge(right),
68 });
69 }
70 Expression::Function { id, num_args } => {
71 let num_args = *num_args as usize;
72
73 if let Some(fnc) = self.runtime.functions.get(*id as usize) {
74 let mut arguments = vec![Variable::Integer(0); num_args];
75 for arg_num in 0..num_args {
76 arguments[num_args - arg_num - 1] =
77 self.expr_stack.pop().unwrap_or_default();
78 }
79 self.expr_stack.push((fnc)(self, arguments));
80 } else {
81 let mut arguments = vec![Variable::Integer(0); num_args];
82 for arg_num in 0..num_args {
83 arguments[num_args - arg_num - 1] =
84 self.expr_stack.pop().unwrap_or_default();
85 }
86 self.pos -= 1; return Err(Event::Function {
88 id: ID_EXTERNAL - *id,
89 arguments,
90 });
91 }
92 }
93 Expression::JmpIf { val, pos } => {
94 if self.expr_stack.last().map_or(false, |v| v.to_bool()) == *val {
95 self.expr_pos += *pos as usize;
96 for _ in 0..*pos {
97 exprs.next();
98 }
99 }
100 }
101 Expression::ArrayAccess => {
102 let index = self.expr_stack.pop().unwrap_or_default().to_usize();
103 let array = self.expr_stack.pop().unwrap_or_default().into_array();
104 self.expr_stack
105 .push(array.get(index).cloned().unwrap_or_default());
106 }
107 Expression::ArrayBuild(num_items) => {
108 let num_items = *num_items as usize;
109 let mut items = vec![Variable::Integer(0); num_items];
110 for arg_num in 0..num_items {
111 items[num_items - arg_num - 1] = self.expr_stack.pop().unwrap_or_default();
112 }
113 self.expr_stack.push(Variable::Array(items.into()));
114 }
115 }
116 }
117
118 let result = self.expr_stack.pop().unwrap_or_default();
119 self.expr_stack.clear();
120 self.expr_pos = 0;
121 Ok(result)
122 }
123}
124
125impl Variable {
126 pub fn op_add(self, other: Variable) -> Variable {
127 match (self, other) {
128 (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_add(b)),
129 (Variable::Float(a), Variable::Float(b)) => Variable::Float(a + b),
130 (Variable::Integer(i), Variable::Float(f))
131 | (Variable::Float(f), Variable::Integer(i)) => Variable::Float(i as f64 + f),
132 (Variable::Array(a), Variable::Array(b)) => {
133 Variable::Array(a.iter().chain(b.iter()).cloned().collect::<Vec<_>>().into())
134 }
135 (Variable::Array(a), b) => a.iter().cloned().chain([b]).collect::<Vec<_>>().into(),
136 (a, Variable::Array(b)) => [a]
137 .into_iter()
138 .chain(b.iter().cloned())
139 .collect::<Vec<_>>()
140 .into(),
141 (Variable::String(a), b) => {
142 if !a.is_empty() {
143 Variable::String(format!("{}{}", a, b).into())
144 } else {
145 b
146 }
147 }
148 (a, Variable::String(b)) => {
149 if !b.is_empty() {
150 Variable::String(format!("{}{}", a, b).into())
151 } else {
152 a
153 }
154 }
155 }
156 }
157
158 pub fn op_subtract(self, other: Variable) -> Variable {
159 match (self, other) {
160 (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_sub(b)),
161 (Variable::Float(a), Variable::Float(b)) => Variable::Float(a - b),
162 (Variable::Integer(a), Variable::Float(b)) => Variable::Float(a as f64 - b),
163 (Variable::Float(a), Variable::Integer(b)) => Variable::Float(a - b as f64),
164 (Variable::Array(a), b) | (b, Variable::Array(a)) => Variable::Array(
165 a.iter()
166 .filter(|v| *v != &b)
167 .cloned()
168 .collect::<Vec<_>>()
169 .into(),
170 ),
171 (a, b) => a.parse_number().op_subtract(b.parse_number()),
172 }
173 }
174
175 pub fn op_multiply(self, other: Variable) -> Variable {
176 match (self, other) {
177 (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_mul(b)),
178 (Variable::Float(a), Variable::Float(b)) => Variable::Float(a * b),
179 (Variable::Integer(i), Variable::Float(f))
180 | (Variable::Float(f), Variable::Integer(i)) => Variable::Float(i as f64 * f),
181 (a, b) => a.parse_number().op_multiply(b.parse_number()),
182 }
183 }
184
185 pub fn op_divide(self, other: Variable) -> Variable {
186 match (self, other) {
187 (Variable::Integer(a), Variable::Integer(b)) => {
188 Variable::Float(if b != 0 { a as f64 / b as f64 } else { 0.0 })
189 }
190 (Variable::Float(a), Variable::Float(b)) => {
191 Variable::Float(if b != 0.0 { a / b } else { 0.0 })
192 }
193 (Variable::Integer(a), Variable::Float(b)) => {
194 Variable::Float(if b != 0.0 { a as f64 / b } else { 0.0 })
195 }
196 (Variable::Float(a), Variable::Integer(b)) => {
197 Variable::Float(if b != 0 { a / b as f64 } else { 0.0 })
198 }
199 (a, b) => a.parse_number().op_divide(b.parse_number()),
200 }
201 }
202
203 pub fn op_and(self, other: Variable) -> Variable {
204 Variable::Integer(i64::from(self.to_bool() & other.to_bool()))
205 }
206
207 pub fn op_or(self, other: Variable) -> Variable {
208 Variable::Integer(i64::from(self.to_bool() | other.to_bool()))
209 }
210
211 pub fn op_xor(self, other: Variable) -> Variable {
212 Variable::Integer(i64::from(self.to_bool() ^ other.to_bool()))
213 }
214
215 pub fn op_eq(self, other: Variable) -> Variable {
216 Variable::Integer(i64::from(self == other))
217 }
218
219 pub fn op_ne(self, other: Variable) -> Variable {
220 Variable::Integer(i64::from(self != other))
221 }
222
223 pub fn op_lt(self, other: Variable) -> Variable {
224 Variable::Integer(i64::from(self < other))
225 }
226
227 pub fn op_le(self, other: Variable) -> Variable {
228 Variable::Integer(i64::from(self <= other))
229 }
230
231 pub fn op_gt(self, other: Variable) -> Variable {
232 Variable::Integer(i64::from(self > other))
233 }
234
235 pub fn op_ge(self, other: Variable) -> Variable {
236 Variable::Integer(i64::from(self >= other))
237 }
238
239 pub fn op_not(self) -> Variable {
240 Variable::Integer(i64::from(!self.to_bool()))
241 }
242
243 pub fn op_minus(self) -> Variable {
244 match self {
245 Variable::Integer(n) => Variable::Integer(-n),
246 Variable::Float(n) => Variable::Float(-n),
247 _ => self.parse_number().op_minus(),
248 }
249 }
250
251 pub fn parse_number(&self) -> Variable {
252 match self {
253 Variable::String(s) if !s.is_empty() => {
254 if let Ok(n) = s.parse::<i64>() {
255 Variable::Integer(n)
256 } else if let Ok(n) = s.parse::<f64>() {
257 Variable::Float(n)
258 } else {
259 Variable::Integer(0)
260 }
261 }
262 Variable::Integer(n) => Variable::Integer(*n),
263 Variable::Float(n) => Variable::Float(*n),
264 Variable::Array(l) => Variable::Integer(l.is_empty() as i64),
265 _ => Variable::Integer(0),
266 }
267 }
268
269 pub fn to_bool(&self) -> bool {
270 match self {
271 Variable::Float(f) => *f != 0.0,
272 Variable::Integer(n) => *n != 0,
273 Variable::String(s) => !s.is_empty(),
274 Variable::Array(a) => !a.is_empty(),
275 }
276 }
277}
278
279impl PartialEq for Variable {
280 fn eq(&self, other: &Self) -> bool {
281 match (self, other) {
282 (Self::Integer(a), Self::Integer(b)) => a == b,
283 (Self::Float(a), Self::Float(b)) => a == b,
284 (Self::Integer(a), Self::Float(b)) | (Self::Float(b), Self::Integer(a)) => {
285 *a as f64 == *b
286 }
287 (Self::String(a), Self::String(b)) => a == b,
288 (Self::String(_), Self::Integer(_) | Self::Float(_)) => &self.parse_number() == other,
289 (Self::Integer(_) | Self::Float(_), Self::String(_)) => self == &other.parse_number(),
290 (Self::Array(a), Self::Array(b)) => a == b,
291 _ => false,
292 }
293 }
294}
295
296impl Eq for Variable {}
297
298#[allow(clippy::non_canonical_partial_ord_impl)]
299impl PartialOrd for Variable {
300 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
301 match (self, other) {
302 (Self::Integer(a), Self::Integer(b)) => a.partial_cmp(b),
303 (Self::Float(a), Self::Float(b)) => a.partial_cmp(b),
304 (Self::Integer(a), Self::Float(b)) => (*a as f64).partial_cmp(b),
305 (Self::Float(a), Self::Integer(b)) => a.partial_cmp(&(*b as f64)),
306 (Self::String(a), Self::String(b)) => a.partial_cmp(b),
307 (Self::String(_), Self::Integer(_) | Self::Float(_)) => {
308 self.parse_number().partial_cmp(other)
309 }
310 (Self::Integer(_) | Self::Float(_), Self::String(_)) => {
311 self.partial_cmp(&other.parse_number())
312 }
313 (Self::Array(a), Self::Array(b)) => a.partial_cmp(b),
314 (Self::Array(_) | Self::String(_), _) => Ordering::Greater.into(),
315 (_, Self::Array(_)) => Ordering::Less.into(),
316 }
317 }
318}
319
320impl Ord for Variable {
321 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
322 self.partial_cmp(other).unwrap_or(Ordering::Greater)
323 }
324}
325
326impl Display for Variable {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 match self {
329 Variable::String(v) => v.fmt(f),
330 Variable::Integer(v) => v.fmt(f),
331 Variable::Float(v) => v.fmt(f),
332 Variable::Array(v) => {
333 for (i, v) in v.iter().enumerate() {
334 if i > 0 {
335 f.write_str("\n")?;
336 }
337 v.fmt(f)?;
338 }
339 Ok(())
340 }
341 }
342 }
343}
344
345impl Number {
346 pub fn is_non_zero(&self) -> bool {
347 match self {
348 Number::Integer(n) => *n != 0,
349 Number::Float(n) => *n != 0.0,
350 }
351 }
352}
353
354impl Default for Number {
355 fn default() -> Self {
356 Number::Integer(0)
357 }
358}
359
360trait IntoBool {
361 fn into_bool(self) -> bool;
362}
363
364impl IntoBool for f64 {
365 #[inline(always)]
366 fn into_bool(self) -> bool {
367 self != 0.0
368 }
369}
370
371impl IntoBool for i64 {
372 #[inline(always)]
373 fn into_bool(self) -> bool {
374 self != 0
375 }
376}
377
378impl From<bool> for Number {
379 #[inline(always)]
380 fn from(b: bool) -> Self {
381 Number::Integer(i64::from(b))
382 }
383}
384
385impl From<i64> for Number {
386 #[inline(always)]
387 fn from(n: i64) -> Self {
388 Number::Integer(n)
389 }
390}
391
392impl From<f64> for Number {
393 #[inline(always)]
394 fn from(n: f64) -> Self {
395 Number::Float(n)
396 }
397}
398
399impl From<i32> for Number {
400 #[inline(always)]
401 fn from(n: i32) -> Self {
402 Number::Integer(n as i64)
403 }
404}
405
406impl<'x> From<&'x Constant> for Variable {
407 fn from(value: &'x Constant) -> Self {
408 match value {
409 Constant::Integer(i) => Variable::Integer(*i),
410 Constant::Float(f) => Variable::Float(*f),
411 Constant::String(s) => Variable::String(s.clone()),
412 }
413 }
414}
415
416#[cfg(test)]
417mod test {
418 use ahash::{HashMap, HashMapExt};
419
420 use crate::{
421 compiler::{
422 grammar::expr::{
423 parser::ExpressionParser, tokenizer::Tokenizer, BinaryOperator, Expression, Token,
424 UnaryOperator,
425 },
426 VariableType,
427 },
428 runtime::Variable,
429 };
430
431 use evalexpr::*;
432
433 pub trait EvalExpression {
434 fn eval(&self, variables: &HashMap<String, Variable>) -> Option<Variable>;
435 }
436
437 impl EvalExpression for Vec<Expression> {
438 fn eval(&self, variables: &HashMap<String, Variable>) -> Option<Variable> {
439 let mut stack = Vec::with_capacity(self.len());
440 let mut exprs = self.iter();
441
442 while let Some(expr) = exprs.next() {
443 match expr {
444 Expression::Variable(VariableType::Global(v)) => {
445 stack.push(variables.get(v)?.clone());
446 }
447 Expression::Constant(val) => {
448 stack.push(Variable::from(val));
449 }
450 Expression::UnaryOperator(op) => {
451 let value = stack.pop()?;
452 stack.push(match op {
453 UnaryOperator::Not => value.op_not(),
454 UnaryOperator::Minus => value.op_minus(),
455 });
456 }
457 Expression::BinaryOperator(op) => {
458 let right = stack.pop()?;
459 let left = stack.pop()?;
460 stack.push(match op {
461 BinaryOperator::Add => left.op_add(right),
462 BinaryOperator::Subtract => left.op_subtract(right),
463 BinaryOperator::Multiply => left.op_multiply(right),
464 BinaryOperator::Divide => left.op_divide(right),
465 BinaryOperator::And => left.op_and(right),
466 BinaryOperator::Or => left.op_or(right),
467 BinaryOperator::Xor => left.op_xor(right),
468 BinaryOperator::Eq => left.op_eq(right),
469 BinaryOperator::Ne => left.op_ne(right),
470 BinaryOperator::Lt => left.op_lt(right),
471 BinaryOperator::Le => left.op_le(right),
472 BinaryOperator::Gt => left.op_gt(right),
473 BinaryOperator::Ge => left.op_ge(right),
474 });
475 }
476 Expression::JmpIf { val, pos } => {
477 if stack.last()?.to_bool() == *val {
478 for _ in 0..*pos {
479 exprs.next();
480 }
481 }
482 }
483 _ => unreachable!("Invalid expression"),
484 }
485 }
486 stack.pop()
487 }
488 }
489
490 #[test]
491 fn eval_expression() {
492 let mut variables = HashMap::from_iter([
493 ("A".to_string(), Variable::Integer(0)),
494 ("B".to_string(), Variable::Integer(0)),
495 ("C".to_string(), Variable::Integer(0)),
496 ("D".to_string(), Variable::Integer(0)),
497 ("E".to_string(), Variable::Integer(0)),
498 ("F".to_string(), Variable::Integer(0)),
499 ("G".to_string(), Variable::Integer(0)),
500 ("H".to_string(), Variable::Integer(0)),
501 ("I".to_string(), Variable::Integer(0)),
502 ("J".to_string(), Variable::Integer(0)),
503 ]);
504 let num_vars = variables.len();
505
506 for expr in [
507 "A + B",
508 "A * B",
509 "A / B",
510 "A - B",
511 "-A",
512 "A == B",
513 "A != B",
514 "A > B",
515 "A < B",
516 "A >= B",
517 "A <= B",
518 "A + B * C - D / E",
519 "A + B + C - D - E",
520 "(A + B) * (C - D) / E",
521 "A - B + C * D / E * F - G",
522 "A + B * C - D / E",
523 "(A + B) * (C - D) / E",
524 "A - B + C / D * E",
525 "(A + B) / (C - D) + E",
526 "A * (B + C) - D / E",
527 "A / (B - C + D) * E",
528 "(A + B) * C - D / (E + F)",
529 "A * B - C + D / E",
530 "A + B - C * D / E",
531 "(A * B + C) / D - E",
532 "A - B / C + D * E",
533 "A + B * (C - D) / E",
534 "A * B / C + (D - E)",
535 "(A - B) * C / D + E",
536 "A * (B / C) - D + E",
537 "(A + B) / (C + D) * E",
538 "A - B * C / D + E",
539 "A + (B - C) * D / E",
540 "(A + B) * (C / D) - E",
541 "A - B / (C * D) + E",
542 "(A + B) > (C - D) && E <= F",
543 "A * B == C / D || E - F != G + H",
544 "A / B >= C * D && E + F < G - H",
545 "(A * B - C) != (D / E + F) && G > H",
546 "A - B < C && D + E >= F * G",
547 "(A * B) > C && (D / E) < F || G == H",
548 "(A + B) <= (C - D) || E > F && G != H",
549 "A * B != C + D || E - F == G / H",
550 "A >= B * C && D < E - F || G != H + I",
551 "(A / B + C) > D && E * F <= G - H",
552 "A * (B - C) == D && E / F > G + H",
553 "(A - B + C) != D || E * F >= G && H < I",
554 "A < B / C && D + E * F == G - H",
555 "(A + B * C) <= D && E > F / G",
556 "(A * B - C) > D || E <= F + G && H != I",
557 "A != B / C && D == E * F - G",
558 "A <= B + C - D && E / F > G * H",
559 "(A - B * C) < D || E >= F + G && H != I",
560 "(A + B) / C == D && E - F < G * H",
561 "A * B != C && D >= E + F / G || H < I",
562 "!(A * B != C) && !(D >= E + F / G) || !(H < I)",
563 "-A - B - (- C - D) - E - (-F)",
564 ] {
565 println!("Testing {}", expr);
566 for (pos, v) in variables.values_mut().enumerate() {
567 *v = Variable::Integer(pos as i64 + 1);
568 }
569
570 assert_expr(expr, &variables);
571
572 for (pos, v) in variables.values_mut().enumerate() {
573 *v = Variable::Integer((num_vars - pos) as i64);
574 }
575
576 assert_expr(expr, &variables);
577 }
578
579 for expr in [
580 "true && false",
581 "!true || false",
582 "true && !false",
583 "!(true && false)",
584 "true || true && false",
585 "!false && (true || false)",
586 "!(true || !false) && true",
587 "!(!true && !false)",
588 "true || false && !true",
589 "!(true && true) || !false",
590 "!(!true || !false) && (!false) && !(!true)",
591 ] {
592 let pexp = parse_expression(expr.replace("true", "1").replace("false", "0").as_str());
593 let result = pexp.eval(&HashMap::new()).unwrap();
594
595 match (eval(expr).expect(expr), result) {
598 (Value::Float(a), Variable::Float(b)) if a == b => (),
599 (Value::Float(a), Variable::Integer(b)) if a == b as f64 => (),
600 (Value::Boolean(a), Variable::Integer(b)) if a == (b != 0) => (),
601 (a, b) => {
602 panic!("{} => {:?} != {:?}", expr, a, b)
603 }
604 }
605 }
606 }
607
608 fn assert_expr(expr: &str, variables: &HashMap<String, Variable>) {
609 let e = parse_expression(expr);
610
611 let result = e.eval(variables).unwrap();
612
613 let mut str_expr = expr.to_string();
614 let mut str_expr_float = expr.to_string();
615 for (k, v) in variables {
616 let v = v.to_string();
617
618 if v.contains('.') {
619 str_expr_float = str_expr_float.replace(k, &v);
620 } else {
621 str_expr_float = str_expr_float.replace(k, &format!("{}.0", v));
622 }
623 str_expr = str_expr.replace(k, &v);
624 }
625
626 assert_eq!(
627 parse_expression(&str_expr)
628 .eval(&HashMap::new())
629 .unwrap()
630 .to_number()
631 .to_float(),
632 result.to_number().to_float()
633 );
634
635 assert_eq!(
636 parse_expression(&str_expr_float)
637 .eval(&HashMap::new())
638 .unwrap()
639 .to_number()
640 .to_float(),
641 result.to_number().to_float()
642 );
643
644 match (
647 eval(&str_expr_float)
648 .map(|v| {
649 if matches!(&v, Value::Float(f) if f.is_infinite()) {
651 Value::Float(0.0)
652 } else {
653 v
654 }
655 })
656 .expect(&str_expr),
657 result,
658 ) {
659 (Value::Float(a), Variable::Float(b)) if a == b => (),
660 (Value::Float(a), Variable::Integer(b)) if a == b as f64 => (),
661 (Value::Boolean(a), Variable::Integer(b)) if a == (b != 0) => (),
662 (a, b) => {
663 panic!("{} => {:?} != {:?}", str_expr, a, b)
664 }
665 }
666 }
667
668 fn parse_expression(expr: &str) -> Vec<Expression> {
669 ExpressionParser::from_tokenizer(Tokenizer::new(expr, |var_name: &str, _: bool| {
670 Ok::<_, String>(Token::Variable(VariableType::Global(var_name.to_string())))
671 }))
672 .parse()
673 .unwrap()
674 .output
675 }
676}