rust_rule_engine/backward/
expression.rs1use crate::errors::{Result, RuleEngineError};
90use crate::types::{Operator, Value};
91use crate::Facts;
92
93#[derive(Debug, Clone, PartialEq)]
95pub enum Expression {
96 Field(String),
98
99 Literal(Value),
101
102 Comparison {
104 left: Box<Expression>,
105 operator: Operator,
106 right: Box<Expression>,
107 },
108
109 And {
111 left: Box<Expression>,
112 right: Box<Expression>,
113 },
114
115 Or {
117 left: Box<Expression>,
118 right: Box<Expression>,
119 },
120
121 Not(Box<Expression>),
123
124 Variable(String),
126}
127
128impl Expression {
129 pub fn evaluate(&self, facts: &Facts) -> Result<Value> {
131 match self {
132 Expression::Field(name) => facts
133 .get(name)
134 .or_else(|| facts.get_nested(name))
135 .ok_or_else(|| {
136 RuleEngineError::ExecutionError(format!("Field not found: {}", name))
137 }),
138
139 Expression::Literal(value) => Ok(value.clone()),
140
141 Expression::Comparison {
142 left,
143 operator,
144 right,
145 } => {
146 let left_val = left.evaluate(facts).unwrap_or(Value::Null);
149 let right_val = right.evaluate(facts).unwrap_or(Value::Null);
150
151 let result = operator.evaluate(&left_val, &right_val);
152 Ok(Value::Boolean(result))
153 }
154
155 Expression::And { left, right } => {
156 let left_val = left.evaluate(facts)?;
157 if !left_val.to_bool() {
158 return Ok(Value::Boolean(false));
159 }
160 let right_val = right.evaluate(facts)?;
161 Ok(Value::Boolean(right_val.to_bool()))
162 }
163
164 Expression::Or { left, right } => {
165 let left_val = left.evaluate(facts)?;
166 if left_val.to_bool() {
167 return Ok(Value::Boolean(true));
168 }
169 let right_val = right.evaluate(facts)?;
170 Ok(Value::Boolean(right_val.to_bool()))
171 }
172
173 Expression::Not(expr) => {
174 let value = expr.evaluate(facts)?;
175 Ok(Value::Boolean(!value.to_bool()))
176 }
177
178 Expression::Variable(var) => Err(RuleEngineError::ExecutionError(format!(
179 "Cannot evaluate unbound variable: {}",
180 var
181 ))),
182 }
183 }
184
185 pub fn is_satisfied(&self, facts: &Facts) -> bool {
187 self.evaluate(facts).map(|v| v.to_bool()).unwrap_or(false)
188 }
189
190 pub fn extract_fields(&self) -> Vec<String> {
192 let mut fields = Vec::new();
193 self.extract_fields_recursive(&mut fields);
194 fields
195 }
196
197 fn extract_fields_recursive(&self, fields: &mut Vec<String>) {
198 match self {
199 Expression::Field(name) => {
200 if !fields.contains(name) {
201 fields.push(name.clone());
202 }
203 }
204 Expression::Comparison { left, right, .. } => {
205 left.extract_fields_recursive(fields);
206 right.extract_fields_recursive(fields);
207 }
208 Expression::And { left, right } | Expression::Or { left, right } => {
209 left.extract_fields_recursive(fields);
210 right.extract_fields_recursive(fields);
211 }
212 Expression::Not(expr) => {
213 expr.extract_fields_recursive(fields);
214 }
215 _ => {}
216 }
217 }
218
219 #[allow(clippy::inherent_to_string)]
221 pub fn to_string(&self) -> String {
222 match self {
223 Expression::Field(name) => name.clone(),
224 Expression::Literal(val) => format!("{:?}", val),
225 Expression::Comparison {
226 left,
227 operator,
228 right,
229 } => {
230 format!("{} {:?} {}", left.to_string(), operator, right.to_string())
231 }
232 Expression::And { left, right } => {
233 format!("({} && {})", left.to_string(), right.to_string())
234 }
235 Expression::Or { left, right } => {
236 format!("({} || {})", left.to_string(), right.to_string())
237 }
238 Expression::Not(expr) => {
239 format!("!{}", expr.to_string())
240 }
241 Expression::Variable(var) => var.clone(),
242 }
243 }
244}
245
246pub struct ExpressionParser {
248 input: Vec<char>,
249 position: usize,
250}
251
252impl ExpressionParser {
253 pub fn new(input: &str) -> Self {
255 Self {
256 input: input.chars().collect(),
257 position: 0,
258 }
259 }
260
261 pub fn parse(input: &str) -> Result<Expression> {
263 let mut parser = Self::new(input.trim());
264 parser.parse_expression()
265 }
266
267 fn parse_expression(&mut self) -> Result<Expression> {
269 let mut left = self.parse_and_expression()?;
270
271 while self.peek_operator("||") {
272 self.consume_operator("||");
273 let right = self.parse_and_expression()?;
274 left = Expression::Or {
275 left: Box::new(left),
276 right: Box::new(right),
277 };
278 }
279
280 Ok(left)
281 }
282
283 fn parse_and_expression(&mut self) -> Result<Expression> {
285 let mut left = self.parse_comparison()?;
286
287 while self.peek_operator("&&") {
288 self.consume_operator("&&");
289 let right = self.parse_comparison()?;
290 left = Expression::And {
291 left: Box::new(left),
292 right: Box::new(right),
293 };
294 }
295
296 Ok(left)
297 }
298
299 fn parse_comparison(&mut self) -> Result<Expression> {
301 let left = self.parse_primary()?;
302
303 let operator = if self.peek_operator("==") {
305 self.consume_operator("==");
306 Operator::Equal
307 } else if self.peek_operator("!=") {
308 self.consume_operator("!=");
309 Operator::NotEqual
310 } else if self.peek_operator(">=") {
311 self.consume_operator(">=");
312 Operator::GreaterThanOrEqual
313 } else if self.peek_operator("<=") {
314 self.consume_operator("<=");
315 Operator::LessThanOrEqual
316 } else if self.peek_operator(">") {
317 self.consume_operator(">");
318 Operator::GreaterThan
319 } else if self.peek_operator("<") {
320 self.consume_operator("<");
321 Operator::LessThan
322 } else {
323 return Ok(left);
325 };
326
327 let right = self.parse_primary()?;
328
329 Ok(Expression::Comparison {
330 left: Box::new(left),
331 operator,
332 right: Box::new(right),
333 })
334 }
335
336 fn parse_primary(&mut self) -> Result<Expression> {
338 self.skip_whitespace();
339
340 if self.peek_char() == Some('!') {
342 self.consume_char();
343 let expr = self.parse_primary()?;
344 return Ok(Expression::Not(Box::new(expr)));
345 }
346
347 if self.peek_char() == Some('(') {
349 self.consume_char();
350 let expr = self.parse_expression()?;
351 self.skip_whitespace();
352 if self.peek_char() != Some(')') {
353 return Err(RuleEngineError::ParseError {
354 message: format!("Expected closing parenthesis at position {}", self.position),
355 });
356 }
357 self.consume_char();
358 return Ok(expr);
359 }
360
361 if self.peek_char() == Some('?') {
363 self.consume_char();
364 let name = self.consume_identifier()?;
365 return Ok(Expression::Variable(format!("?{}", name)));
366 }
367
368 if let Some(value) = self.try_parse_literal()? {
370 return Ok(Expression::Literal(value));
371 }
372
373 let field_name = self.consume_field_path()?;
375 Ok(Expression::Field(field_name))
376 }
377
378 fn consume_field_path(&mut self) -> Result<String> {
379 let mut path = String::new();
380
381 while let Some(ch) = self.peek_char() {
382 if ch.is_alphanumeric() || ch == '_' || ch == '.' {
383 path.push(ch);
384 self.consume_char();
385 } else {
386 break;
387 }
388 }
389
390 if path.is_empty() {
391 return Err(RuleEngineError::ParseError {
392 message: format!("Expected field name at position {}", self.position),
393 });
394 }
395
396 Ok(path)
397 }
398
399 fn consume_identifier(&mut self) -> Result<String> {
400 let mut ident = String::new();
401
402 while let Some(ch) = self.peek_char() {
403 if ch.is_alphanumeric() || ch == '_' {
404 ident.push(ch);
405 self.consume_char();
406 } else {
407 break;
408 }
409 }
410
411 if ident.is_empty() {
412 return Err(RuleEngineError::ParseError {
413 message: format!("Expected identifier at position {}", self.position),
414 });
415 }
416
417 Ok(ident)
418 }
419
420 fn try_parse_literal(&mut self) -> Result<Option<Value>> {
421 self.skip_whitespace();
422
423 if self.peek_word("true") {
425 self.consume_word("true");
426 return Ok(Some(Value::Boolean(true)));
427 }
428 if self.peek_word("false") {
429 self.consume_word("false");
430 return Ok(Some(Value::Boolean(false)));
431 }
432
433 if self.peek_word("null") {
435 self.consume_word("null");
436 return Ok(Some(Value::Null));
437 }
438
439 if self.peek_char() == Some('"') {
441 self.consume_char();
442 let mut s = String::new();
443 let mut escaped = false;
444
445 while let Some(ch) = self.peek_char() {
446 if escaped {
447 let escaped_char = match ch {
449 'n' => '\n',
450 't' => '\t',
451 'r' => '\r',
452 '\\' => '\\',
453 '"' => '"',
454 _ => ch,
455 };
456 s.push(escaped_char);
457 escaped = false;
458 self.consume_char();
459 } else if ch == '\\' {
460 escaped = true;
461 self.consume_char();
462 } else if ch == '"' {
463 self.consume_char();
464 return Ok(Some(Value::String(s)));
465 } else {
466 s.push(ch);
467 self.consume_char();
468 }
469 }
470
471 return Err(RuleEngineError::ParseError {
472 message: format!("Unterminated string at position {}", self.position),
473 });
474 }
475
476 if let Some(ch) = self.peek_char() {
478 if ch.is_numeric() || ch == '-' {
479 let start_pos = self.position;
480 let mut num_str = String::new();
481 let mut has_dot = false;
482
483 while let Some(ch) = self.peek_char() {
484 if ch.is_numeric() {
485 num_str.push(ch);
486 self.consume_char();
487 } else if ch == '.' && !has_dot {
488 has_dot = true;
489 num_str.push(ch);
490 self.consume_char();
491 } else if ch == '-' && num_str.is_empty() {
492 num_str.push(ch);
493 self.consume_char();
494 } else {
495 break;
496 }
497 }
498
499 if !num_str.is_empty() && num_str != "-" {
500 if has_dot {
501 if let Ok(n) = num_str.parse::<f64>() {
502 return Ok(Some(Value::Number(n)));
503 }
504 } else if let Ok(i) = num_str.parse::<i64>() {
505 return Ok(Some(Value::Number(i as f64)));
506 }
507 }
508
509 self.position = start_pos;
511 }
512 }
513
514 Ok(None)
515 }
516
517 fn peek_char(&self) -> Option<char> {
518 if self.position < self.input.len() {
519 Some(self.input[self.position])
520 } else {
521 None
522 }
523 }
524
525 fn consume_char(&mut self) {
526 if self.position < self.input.len() {
527 self.position += 1;
528 }
529 }
530
531 fn peek_operator(&mut self, op: &str) -> bool {
532 self.skip_whitespace();
533 let remaining: String = self.input[self.position..].iter().collect();
534 remaining.starts_with(op)
535 }
536
537 fn consume_operator(&mut self, op: &str) {
538 self.skip_whitespace();
539 for _ in 0..op.len() {
540 self.consume_char();
541 }
542 }
543
544 fn peek_word(&mut self, word: &str) -> bool {
545 self.skip_whitespace();
546 let remaining: String = self.input[self.position..].iter().collect();
547
548 if remaining.starts_with(word) {
549 let next_pos = self.position + word.len();
551 if next_pos >= self.input.len() {
552 return true;
553 }
554 let next_char = self.input[next_pos];
555 !next_char.is_alphanumeric() && next_char != '_'
556 } else {
557 false
558 }
559 }
560
561 fn consume_word(&mut self, word: &str) {
562 self.skip_whitespace();
563 if self.peek_word(word) {
564 for _ in 0..word.len() {
565 self.consume_char();
566 }
567 }
568 }
569
570 fn skip_whitespace(&mut self) {
571 while let Some(ch) = self.peek_char() {
572 if ch.is_whitespace() {
573 self.consume_char();
574 } else {
575 break;
576 }
577 }
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 #[test]
586 fn test_parse_simple_field() {
587 let expr = ExpressionParser::parse("User.IsVIP").unwrap();
588 match expr {
589 Expression::Field(name) => {
590 assert_eq!(name, "User.IsVIP");
591 }
592 _ => panic!("Expected field expression"),
593 }
594 }
595
596 #[test]
597 fn test_parse_literal_boolean() {
598 let expr = ExpressionParser::parse("true").unwrap();
599 match expr {
600 Expression::Literal(Value::Boolean(true)) => {}
601 _ => panic!("Expected boolean literal"),
602 }
603 }
604
605 #[test]
606 fn test_parse_literal_number() {
607 let expr = ExpressionParser::parse("42.5").unwrap();
608 match expr {
609 Expression::Literal(Value::Number(n)) => {
610 assert!((n - 42.5).abs() < 0.001);
611 }
612 _ => panic!("Expected number literal"),
613 }
614 }
615
616 #[test]
617 fn test_parse_literal_string() {
618 let expr = ExpressionParser::parse(r#""hello world""#).unwrap();
619 match expr {
620 Expression::Literal(Value::String(s)) => {
621 assert_eq!(s, "hello world");
622 }
623 _ => panic!("Expected string literal"),
624 }
625 }
626
627 #[test]
628 fn test_parse_simple_comparison() {
629 let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
630 match expr {
631 Expression::Comparison { operator, .. } => {
632 assert_eq!(operator, Operator::Equal);
633 }
634 _ => panic!("Expected comparison"),
635 }
636 }
637
638 #[test]
639 fn test_parse_all_comparison_operators() {
640 let operators = vec![
641 ("a == b", Operator::Equal),
642 ("a != b", Operator::NotEqual),
643 ("a > b", Operator::GreaterThan),
644 ("a >= b", Operator::GreaterThanOrEqual),
645 ("a < b", Operator::LessThan),
646 ("a <= b", Operator::LessThanOrEqual),
647 ];
648
649 for (input, expected_op) in operators {
650 let expr = ExpressionParser::parse(input).unwrap();
651 match expr {
652 Expression::Comparison { operator, .. } => {
653 assert_eq!(operator, expected_op, "Failed for: {}", input);
654 }
655 _ => panic!("Expected comparison for: {}", input),
656 }
657 }
658 }
659
660 #[test]
661 fn test_parse_logical_and() {
662 let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
663 match expr {
664 Expression::And { .. } => {}
665 _ => panic!("Expected logical AND, got: {:?}", expr),
666 }
667 }
668
669 #[test]
670 fn test_parse_logical_or() {
671 let expr = ExpressionParser::parse("a == true || b == true").unwrap();
672 match expr {
673 Expression::Or { .. } => {}
674 _ => panic!("Expected logical OR"),
675 }
676 }
677
678 #[test]
679 fn test_parse_negation() {
680 let expr = ExpressionParser::parse("!User.IsBanned").unwrap();
681 match expr {
682 Expression::Not(_) => {}
683 _ => panic!("Expected negation"),
684 }
685 }
686
687 #[test]
688 fn test_parse_parentheses() {
689 let expr = ExpressionParser::parse("(a == true || b == true) && c == true").unwrap();
690 match expr {
691 Expression::And { left, .. } => match *left {
692 Expression::Or { .. } => {}
693 _ => panic!("Expected OR inside AND"),
694 },
695 _ => panic!("Expected AND"),
696 }
697 }
698
699 #[test]
700 fn test_parse_variable() {
701 let expr = ExpressionParser::parse("?X == true").unwrap();
702 match expr {
703 Expression::Comparison { left, .. } => match *left {
704 Expression::Variable(var) => {
705 assert_eq!(var, "?X");
706 }
707 _ => panic!("Expected variable"),
708 },
709 _ => panic!("Expected comparison"),
710 }
711 }
712
713 #[test]
714 fn test_evaluate_simple() {
715 let facts = Facts::new();
716 facts.set("User.IsVIP", Value::Boolean(true));
717
718 let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
719 let result = expr.evaluate(&facts).unwrap();
720
721 assert_eq!(result, Value::Boolean(true));
722 }
723
724 #[test]
725 fn test_evaluate_comparison() {
726 let facts = Facts::new();
727 facts.set("Order.Amount", Value::Number(1500.0));
728
729 let expr = ExpressionParser::parse("Order.Amount > 1000").unwrap();
730 let result = expr.evaluate(&facts).unwrap();
731
732 assert_eq!(result, Value::Boolean(true));
733 }
734
735 #[test]
736 fn test_evaluate_logical_and() {
737 let facts = Facts::new();
738 facts.set("User.IsVIP", Value::Boolean(true));
739 facts.set("Order.Amount", Value::Number(1500.0));
740
741 let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
742 let result = expr.evaluate(&facts).unwrap();
743
744 assert_eq!(result, Value::Boolean(true));
745 }
746
747 #[test]
748 fn test_evaluate_logical_or() {
749 let facts = Facts::new();
750 facts.set("a", Value::Boolean(false));
751 facts.set("b", Value::Boolean(true));
752
753 let expr = ExpressionParser::parse("a == true || b == true").unwrap();
754 let result = expr.evaluate(&facts).unwrap();
755
756 assert_eq!(result, Value::Boolean(true));
757 }
758
759 #[test]
760 fn test_is_satisfied() {
761 let facts = Facts::new();
762 facts.set("User.IsVIP", Value::Boolean(true));
763
764 let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
765 assert!(expr.is_satisfied(&facts));
766 }
767
768 #[test]
769 fn test_extract_fields() {
770 let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
771 let fields = expr.extract_fields();
772
773 assert_eq!(fields.len(), 2);
774 assert!(fields.contains(&"User.IsVIP".to_string()));
775 assert!(fields.contains(&"Order.Amount".to_string()));
776 }
777
778 #[test]
779 fn test_parse_error_unclosed_parenthesis() {
780 let result = ExpressionParser::parse("(a == true");
781 assert!(result.is_err());
782 }
783
784 #[test]
785 fn test_parse_error_unterminated_string() {
786 let result = ExpressionParser::parse(r#""hello"#);
787 assert!(result.is_err());
788 }
789
790 #[test]
791 fn test_parse_complex_expression() {
792 let expr = ExpressionParser::parse(
793 "(User.IsVIP == true && Order.Amount > 1000) || (User.Points >= 100 && Order.Discount < 0.5)"
794 ).unwrap();
795
796 match expr {
798 Expression::Or { .. } => {}
799 _ => panic!("Expected OR at top level"),
800 }
801 }
802
803 #[test]
804 fn test_to_string() {
805 let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
806 let s = expr.to_string();
807 assert!(s.contains("User.IsVIP"));
808 assert!(s.contains("true"));
809 }
810}