1use std::{cmp::Ordering, fmt::Display};
8
9use crate::compiler::grammar::expr::parser::ID_EXTERNAL;
10use crate::Event;
11use crate::{compiler::Number, runtime::Variable, Context};
12
13use crate::compiler::grammar::expr::{BinaryOperator, Constant, Expression, UnaryOperator};
14
15impl Context<'_> {
16 pub(crate) fn eval_expression(&mut self, expr: &[Expression]) -> Result<Variable, Event> {
17 let mut exprs = expr.iter().skip(self.expr_pos);
18 while let Some(expr) = exprs.next() {
19 self.expr_pos += 1;
20 match expr {
21 Expression::Variable(v) => {
22 self.expr_stack.push(self.variable(v).unwrap_or_default());
23 }
24 Expression::Constant(val) => {
25 self.expr_stack.push(Variable::from(val));
26 }
27 Expression::UnaryOperator(op) => {
28 let value = self.expr_stack.pop().unwrap_or_default();
29 self.expr_stack.push(match op {
30 UnaryOperator::Not => value.op_not(),
31 UnaryOperator::Minus => value.op_minus(),
32 });
33 }
34 Expression::BinaryOperator(op) => {
35 let right = self.expr_stack.pop().unwrap_or_default();
36 let left = self.expr_stack.pop().unwrap_or_default();
37 self.expr_stack.push(match op {
38 BinaryOperator::Add => left.op_add(right),
39 BinaryOperator::Subtract => left.op_subtract(right),
40 BinaryOperator::Multiply => left.op_multiply(right),
41 BinaryOperator::Divide => left.op_divide(right),
42 BinaryOperator::And => left.op_and(right),
43 BinaryOperator::Or => left.op_or(right),
44 BinaryOperator::Xor => left.op_xor(right),
45 BinaryOperator::Eq => left.op_eq(right),
46 BinaryOperator::Ne => left.op_ne(right),
47 BinaryOperator::Lt => left.op_lt(right),
48 BinaryOperator::Le => left.op_le(right),
49 BinaryOperator::Gt => left.op_gt(right),
50 BinaryOperator::Ge => left.op_ge(right),
51 });
52 }
53 Expression::Function { id, num_args } => {
54 let num_args = *num_args as usize;
55
56 if let Some(fnc) = self.runtime.functions.get(*id as usize) {
57 let mut arguments = vec![Variable::Integer(0); num_args];
58 for arg_num in 0..num_args {
59 arguments[num_args - arg_num - 1] =
60 self.expr_stack.pop().unwrap_or_default();
61 }
62 self.expr_stack.push((fnc)(self, arguments));
63 } else {
64 let mut arguments = vec![Variable::Integer(0); num_args];
65 for arg_num in 0..num_args {
66 arguments[num_args - arg_num - 1] =
67 self.expr_stack.pop().unwrap_or_default();
68 }
69 self.pos -= 1; return Err(Event::Function {
71 id: ID_EXTERNAL - *id,
72 arguments,
73 });
74 }
75 }
76 Expression::JmpIf { val, pos } => {
77 if self.expr_stack.last().is_some_and(|v| v.to_bool()) == *val {
78 self.expr_pos += *pos as usize;
79 for _ in 0..*pos {
80 exprs.next();
81 }
82 }
83 }
84 Expression::ArrayAccess => {
85 let index = self.expr_stack.pop().unwrap_or_default().to_usize();
86 let array = self.expr_stack.pop().unwrap_or_default().into_array();
87 self.expr_stack
88 .push(array.get(index).cloned().unwrap_or_default());
89 }
90 Expression::ArrayBuild(num_items) => {
91 let num_items = *num_items as usize;
92 let mut items = vec![Variable::Integer(0); num_items];
93 for arg_num in 0..num_items {
94 items[num_items - arg_num - 1] = self.expr_stack.pop().unwrap_or_default();
95 }
96 self.expr_stack.push(Variable::Array(items.into()));
97 }
98 }
99 }
100
101 let result = self.expr_stack.pop().unwrap_or_default();
102 self.expr_stack.clear();
103 self.expr_pos = 0;
104 Ok(result)
105 }
106}
107
108impl Variable {
109 pub fn op_add(self, other: Variable) -> Variable {
110 match (self, other) {
111 (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_add(b)),
112 (Variable::Float(a), Variable::Float(b)) => Variable::Float(a + b),
113 (Variable::Integer(i), Variable::Float(f))
114 | (Variable::Float(f), Variable::Integer(i)) => Variable::Float(i as f64 + f),
115 (Variable::Array(a), Variable::Array(b)) => {
116 Variable::Array(a.iter().chain(b.iter()).cloned().collect::<Vec<_>>().into())
117 }
118 (Variable::Array(a), b) => a.iter().cloned().chain([b]).collect::<Vec<_>>().into(),
119 (a, Variable::Array(b)) => [a]
120 .into_iter()
121 .chain(b.iter().cloned())
122 .collect::<Vec<_>>()
123 .into(),
124 (Variable::String(a), b) => {
125 if !a.is_empty() {
126 Variable::String(format!("{}{}", a, b).into())
127 } else {
128 b
129 }
130 }
131 (a, Variable::String(b)) => {
132 if !b.is_empty() {
133 Variable::String(format!("{}{}", a, b).into())
134 } else {
135 a
136 }
137 }
138 }
139 }
140
141 pub fn op_subtract(self, other: Variable) -> Variable {
142 match (self, other) {
143 (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_sub(b)),
144 (Variable::Float(a), Variable::Float(b)) => Variable::Float(a - b),
145 (Variable::Integer(a), Variable::Float(b)) => Variable::Float(a as f64 - b),
146 (Variable::Float(a), Variable::Integer(b)) => Variable::Float(a - b as f64),
147 (Variable::Array(a), b) | (b, Variable::Array(a)) => Variable::Array(
148 a.iter()
149 .filter(|v| *v != &b)
150 .cloned()
151 .collect::<Vec<_>>()
152 .into(),
153 ),
154 (a, b) => a.parse_number().op_subtract(b.parse_number()),
155 }
156 }
157
158 pub fn op_multiply(self, other: Variable) -> Variable {
159 match (self, other) {
160 (Variable::Integer(a), Variable::Integer(b)) => Variable::Integer(a.saturating_mul(b)),
161 (Variable::Float(a), Variable::Float(b)) => Variable::Float(a * b),
162 (Variable::Integer(i), Variable::Float(f))
163 | (Variable::Float(f), Variable::Integer(i)) => Variable::Float(i as f64 * f),
164 (a, b) => a.parse_number().op_multiply(b.parse_number()),
165 }
166 }
167
168 pub fn op_divide(self, other: Variable) -> Variable {
169 match (self, other) {
170 (Variable::Integer(a), Variable::Integer(b)) => {
171 Variable::Float(if b != 0 { a as f64 / b as f64 } else { 0.0 })
172 }
173 (Variable::Float(a), Variable::Float(b)) => {
174 Variable::Float(if b != 0.0 { a / b } else { 0.0 })
175 }
176 (Variable::Integer(a), Variable::Float(b)) => {
177 Variable::Float(if b != 0.0 { a as f64 / b } else { 0.0 })
178 }
179 (Variable::Float(a), Variable::Integer(b)) => {
180 Variable::Float(if b != 0 { a / b as f64 } else { 0.0 })
181 }
182 (a, b) => a.parse_number().op_divide(b.parse_number()),
183 }
184 }
185
186 pub fn op_and(self, other: Variable) -> Variable {
187 Variable::Integer(i64::from(self.to_bool() & other.to_bool()))
188 }
189
190 pub fn op_or(self, other: Variable) -> Variable {
191 Variable::Integer(i64::from(self.to_bool() | other.to_bool()))
192 }
193
194 pub fn op_xor(self, other: Variable) -> Variable {
195 Variable::Integer(i64::from(self.to_bool() ^ other.to_bool()))
196 }
197
198 pub fn op_eq(self, other: Variable) -> Variable {
199 Variable::Integer(i64::from(self == other))
200 }
201
202 pub fn op_ne(self, other: Variable) -> Variable {
203 Variable::Integer(i64::from(self != other))
204 }
205
206 pub fn op_lt(self, other: Variable) -> Variable {
207 Variable::Integer(i64::from(self < other))
208 }
209
210 pub fn op_le(self, other: Variable) -> Variable {
211 Variable::Integer(i64::from(self <= other))
212 }
213
214 pub fn op_gt(self, other: Variable) -> Variable {
215 Variable::Integer(i64::from(self > other))
216 }
217
218 pub fn op_ge(self, other: Variable) -> Variable {
219 Variable::Integer(i64::from(self >= other))
220 }
221
222 pub fn op_not(self) -> Variable {
223 Variable::Integer(i64::from(!self.to_bool()))
224 }
225
226 pub fn op_minus(self) -> Variable {
227 match self {
228 Variable::Integer(n) => Variable::Integer(-n),
229 Variable::Float(n) => Variable::Float(-n),
230 _ => self.parse_number().op_minus(),
231 }
232 }
233
234 pub fn parse_number(&self) -> Variable {
235 match self {
236 Variable::String(s) if !s.is_empty() => {
237 if let Ok(n) = s.parse::<i64>() {
238 Variable::Integer(n)
239 } else if let Ok(n) = s.parse::<f64>() {
240 Variable::Float(n)
241 } else {
242 Variable::Integer(0)
243 }
244 }
245 Variable::Integer(n) => Variable::Integer(*n),
246 Variable::Float(n) => Variable::Float(*n),
247 Variable::Array(l) => Variable::Integer(l.is_empty() as i64),
248 _ => Variable::Integer(0),
249 }
250 }
251
252 pub fn to_bool(&self) -> bool {
253 match self {
254 Variable::Float(f) => *f != 0.0,
255 Variable::Integer(n) => *n != 0,
256 Variable::String(s) => !s.is_empty(),
257 Variable::Array(a) => !a.is_empty(),
258 }
259 }
260}
261
262impl PartialEq for Variable {
263 fn eq(&self, other: &Self) -> bool {
264 match (self, other) {
265 (Self::Integer(a), Self::Integer(b)) => a == b,
266 (Self::Float(a), Self::Float(b)) => a == b,
267 (Self::Integer(a), Self::Float(b)) | (Self::Float(b), Self::Integer(a)) => {
268 *a as f64 == *b
269 }
270 (Self::String(a), Self::String(b)) => a == b,
271 (Self::String(_), Self::Integer(_) | Self::Float(_)) => &self.parse_number() == other,
272 (Self::Integer(_) | Self::Float(_), Self::String(_)) => self == &other.parse_number(),
273 (Self::Array(a), Self::Array(b)) => a == b,
274 _ => false,
275 }
276 }
277}
278
279impl Eq for Variable {}
280
281#[allow(clippy::non_canonical_partial_ord_impl)]
282impl PartialOrd for Variable {
283 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
284 match (self, other) {
285 (Self::Integer(a), Self::Integer(b)) => a.partial_cmp(b),
286 (Self::Float(a), Self::Float(b)) => a.partial_cmp(b),
287 (Self::Integer(a), Self::Float(b)) => (*a as f64).partial_cmp(b),
288 (Self::Float(a), Self::Integer(b)) => a.partial_cmp(&(*b as f64)),
289 (Self::String(a), Self::String(b)) => a.partial_cmp(b),
290 (Self::String(_), Self::Integer(_) | Self::Float(_)) => {
291 self.parse_number().partial_cmp(other)
292 }
293 (Self::Integer(_) | Self::Float(_), Self::String(_)) => {
294 self.partial_cmp(&other.parse_number())
295 }
296 (Self::Array(a), Self::Array(b)) => a.partial_cmp(b),
297 (Self::Array(_) | Self::String(_), _) => Ordering::Greater.into(),
298 (_, Self::Array(_)) => Ordering::Less.into(),
299 }
300 }
301}
302
303impl Ord for Variable {
304 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
305 self.partial_cmp(other).unwrap_or(Ordering::Greater)
306 }
307}
308
309impl Display for Variable {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 match self {
312 Variable::String(v) => v.fmt(f),
313 Variable::Integer(v) => v.fmt(f),
314 Variable::Float(v) => v.fmt(f),
315 Variable::Array(v) => {
316 for (i, v) in v.iter().enumerate() {
317 if i > 0 {
318 f.write_str("\n")?;
319 }
320 v.fmt(f)?;
321 }
322 Ok(())
323 }
324 }
325 }
326}
327
328impl Number {
329 pub fn is_non_zero(&self) -> bool {
330 match self {
331 Number::Integer(n) => *n != 0,
332 Number::Float(n) => *n != 0.0,
333 }
334 }
335}
336
337impl Default for Number {
338 fn default() -> Self {
339 Number::Integer(0)
340 }
341}
342
343impl From<bool> for Number {
344 #[inline(always)]
345 fn from(b: bool) -> Self {
346 Number::Integer(i64::from(b))
347 }
348}
349
350impl From<i64> for Number {
351 #[inline(always)]
352 fn from(n: i64) -> Self {
353 Number::Integer(n)
354 }
355}
356
357impl From<f64> for Number {
358 #[inline(always)]
359 fn from(n: f64) -> Self {
360 Number::Float(n)
361 }
362}
363
364impl From<i32> for Number {
365 #[inline(always)]
366 fn from(n: i32) -> Self {
367 Number::Integer(n as i64)
368 }
369}
370
371impl<'x> From<&'x Constant> for Variable {
372 fn from(value: &'x Constant) -> Self {
373 match value {
374 Constant::Integer(i) => Variable::Integer(*i),
375 Constant::Float(f) => Variable::Float(*f),
376 Constant::String(s) => Variable::String(s.clone()),
377 }
378 }
379}
380
381#[cfg(test)]
382mod test {
383 use ahash::{HashMap, HashMapExt};
384
385 use crate::{
386 compiler::{
387 grammar::expr::{
388 parser::ExpressionParser, tokenizer::Tokenizer, BinaryOperator, Expression, Token,
389 UnaryOperator,
390 },
391 VariableType,
392 },
393 runtime::Variable,
394 };
395
396 use evalexpr::*;
397
398 pub trait EvalExpression {
399 fn eval(&self, variables: &HashMap<String, Variable>) -> Option<Variable>;
400 }
401
402 impl EvalExpression for Vec<Expression> {
403 fn eval(&self, variables: &HashMap<String, Variable>) -> Option<Variable> {
404 let mut stack = Vec::with_capacity(self.len());
405 let mut exprs = self.iter();
406
407 while let Some(expr) = exprs.next() {
408 match expr {
409 Expression::Variable(VariableType::Global(v)) => {
410 stack.push(variables.get(v)?.clone());
411 }
412 Expression::Constant(val) => {
413 stack.push(Variable::from(val));
414 }
415 Expression::UnaryOperator(op) => {
416 let value = stack.pop()?;
417 stack.push(match op {
418 UnaryOperator::Not => value.op_not(),
419 UnaryOperator::Minus => value.op_minus(),
420 });
421 }
422 Expression::BinaryOperator(op) => {
423 let right = stack.pop()?;
424 let left = stack.pop()?;
425 stack.push(match op {
426 BinaryOperator::Add => left.op_add(right),
427 BinaryOperator::Subtract => left.op_subtract(right),
428 BinaryOperator::Multiply => left.op_multiply(right),
429 BinaryOperator::Divide => left.op_divide(right),
430 BinaryOperator::And => left.op_and(right),
431 BinaryOperator::Or => left.op_or(right),
432 BinaryOperator::Xor => left.op_xor(right),
433 BinaryOperator::Eq => left.op_eq(right),
434 BinaryOperator::Ne => left.op_ne(right),
435 BinaryOperator::Lt => left.op_lt(right),
436 BinaryOperator::Le => left.op_le(right),
437 BinaryOperator::Gt => left.op_gt(right),
438 BinaryOperator::Ge => left.op_ge(right),
439 });
440 }
441 Expression::JmpIf { val, pos } => {
442 if stack.last()?.to_bool() == *val {
443 for _ in 0..*pos {
444 exprs.next();
445 }
446 }
447 }
448 _ => unreachable!("Invalid expression"),
449 }
450 }
451 stack.pop()
452 }
453 }
454
455 #[test]
456 fn eval_expression() {
457 let mut variables = HashMap::from_iter([
458 ("A".to_string(), Variable::Integer(0)),
459 ("B".to_string(), Variable::Integer(0)),
460 ("C".to_string(), Variable::Integer(0)),
461 ("D".to_string(), Variable::Integer(0)),
462 ("E".to_string(), Variable::Integer(0)),
463 ("F".to_string(), Variable::Integer(0)),
464 ("G".to_string(), Variable::Integer(0)),
465 ("H".to_string(), Variable::Integer(0)),
466 ("I".to_string(), Variable::Integer(0)),
467 ("J".to_string(), Variable::Integer(0)),
468 ]);
469 let num_vars = variables.len();
470
471 for expr in [
472 "A + B",
473 "A * B",
474 "A / B",
475 "A - B",
476 "-A",
477 "A == B",
478 "A != B",
479 "A > B",
480 "A < B",
481 "A >= B",
482 "A <= B",
483 "A + B * C - D / E",
484 "A + B + C - D - E",
485 "(A + B) * (C - D) / E",
486 "A - B + C * D / E * F - G",
487 "A + B * C - D / E",
488 "(A + B) * (C - D) / E",
489 "A - B + C / D * E",
490 "(A + B) / (C - D) + E",
491 "A * (B + C) - D / E",
492 "A / (B - C + D) * E",
493 "(A + B) * C - D / (E + F)",
494 "A * B - C + D / E",
495 "A + B - C * D / E",
496 "(A * B + C) / D - E",
497 "A - B / C + D * E",
498 "A + B * (C - D) / E",
499 "A * B / C + (D - E)",
500 "(A - B) * C / D + E",
501 "A * (B / C) - D + E",
502 "(A + B) / (C + D) * E",
503 "A - B * C / D + E",
504 "A + (B - C) * D / E",
505 "(A + B) * (C / D) - E",
506 "A - B / (C * D) + E",
507 "(A + B) > (C - D) && E <= F",
508 "A * B == C / D || E - F != G + H",
509 "A / B >= C * D && E + F < G - H",
510 "(A * B - C) != (D / E + F) && G > H",
511 "A - B < C && D + E >= F * G",
512 "(A * B) > C && (D / E) < F || G == H",
513 "(A + B) <= (C - D) || E > F && G != H",
514 "A * B != C + D || E - F == G / H",
515 "A >= B * C && D < E - F || G != H + I",
516 "(A / B + C) > D && E * F <= G - H",
517 "A * (B - C) == D && E / F > G + H",
518 "(A - B + C) != D || E * F >= G && H < I",
519 "A < B / C && D + E * F == G - H",
520 "(A + B * C) <= D && E > F / G",
521 "(A * B - C) > D || E <= F + G && H != I",
522 "A != B / C && D == E * F - G",
523 "A <= B + C - D && E / F > G * H",
524 "(A - B * C) < D || E >= F + G && H != I",
525 "(A + B) / C == D && E - F < G * H",
526 "A * B != C && D >= E + F / G || H < I",
527 "!(A * B != C) && !(D >= E + F / G) || !(H < I)",
528 "-A - B - (- C - D) - E - (-F)",
529 ] {
530 println!("Testing {}", expr);
531 for (pos, v) in variables.values_mut().enumerate() {
532 *v = Variable::Integer(pos as i64 + 1);
533 }
534
535 assert_expr(expr, &variables);
536
537 for (pos, v) in variables.values_mut().enumerate() {
538 *v = Variable::Integer((num_vars - pos) as i64);
539 }
540
541 assert_expr(expr, &variables);
542 }
543
544 for expr in [
545 "true && false",
546 "!true || false",
547 "true && !false",
548 "!(true && false)",
549 "true || true && false",
550 "!false && (true || false)",
551 "!(true || !false) && true",
552 "!(!true && !false)",
553 "true || false && !true",
554 "!(true && true) || !false",
555 "!(!true || !false) && (!false) && !(!true)",
556 ] {
557 let pexp = parse_expression(expr.replace("true", "1").replace("false", "0").as_str());
558 let result = pexp.eval(&HashMap::new()).unwrap();
559
560 match (eval(expr).expect(expr), result) {
563 (Value::Float(a), Variable::Float(b)) if a == b => (),
564 (Value::Float(a), Variable::Integer(b)) if a == b as f64 => (),
565 (Value::Boolean(a), Variable::Integer(b)) if a == (b != 0) => (),
566 (a, b) => {
567 panic!("{} => {:?} != {:?}", expr, a, b)
568 }
569 }
570 }
571 }
572
573 fn assert_expr(expr: &str, variables: &HashMap<String, Variable>) {
574 let e = parse_expression(expr);
575
576 let result = e.eval(variables).unwrap();
577
578 let mut str_expr = expr.to_string();
579 let mut str_expr_float = expr.to_string();
580 for (k, v) in variables {
581 let v = v.to_string();
582
583 if v.contains('.') {
584 str_expr_float = str_expr_float.replace(k, &v);
585 } else {
586 str_expr_float = str_expr_float.replace(k, &format!("{}.0", v));
587 }
588 str_expr = str_expr.replace(k, &v);
589 }
590
591 assert_eq!(
592 parse_expression(&str_expr)
593 .eval(&HashMap::new())
594 .unwrap()
595 .to_number()
596 .to_float(),
597 result.to_number().to_float()
598 );
599
600 assert_eq!(
601 parse_expression(&str_expr_float)
602 .eval(&HashMap::new())
603 .unwrap()
604 .to_number()
605 .to_float(),
606 result.to_number().to_float()
607 );
608
609 match (
612 eval(&str_expr_float)
613 .map(|v| {
614 if matches!(&v, Value::Float(f) if f.is_infinite()) {
616 Value::Float(0.0)
617 } else {
618 v
619 }
620 })
621 .expect(&str_expr),
622 result,
623 ) {
624 (Value::Float(a), Variable::Float(b)) if a == b => (),
625 (Value::Float(a), Variable::Integer(b)) if a == b as f64 => (),
626 (Value::Boolean(a), Variable::Integer(b)) if a == (b != 0) => (),
627 (a, b) => {
628 panic!("{} => {:?} != {:?}", str_expr, a, b)
629 }
630 }
631 }
632
633 fn parse_expression(expr: &str) -> Vec<Expression> {
634 ExpressionParser::from_tokenizer(Tokenizer::new(expr, |var_name: &str, _: bool| {
635 Ok::<_, String>(Token::Variable(VariableType::Global(var_name.to_string())))
636 }))
637 .parse()
638 .unwrap()
639 .output
640 }
641}