Skip to main content

rust_rule_engine/backward/
expression.rs

1//! Expression AST for backward chaining queries
2//!
3//! This module provides a robust Abstract Syntax Tree (AST) implementation for parsing
4//! and evaluating goal expressions in backward chaining queries. It uses a recursive
5//! descent parser to build a typed expression tree that can be evaluated against facts.
6//!
7//! # Features
8//!
9//! - **Field references** - Access fact values using dot notation (e.g., `User.Name`, `Order.Total`)
10//! - **Literal values** - Support for boolean, numeric, and string literals
11//! - **Comparison operators** - `==`, `!=`, `>`, `<`, `>=`, `<=`
12//! - **Logical operators** - `&&` (AND), `||` (OR), `!` (NOT)
13//! - **Parenthesized expressions** - Group sub-expressions with parentheses
14//! - **Variable binding** - Placeholder variables for unification (e.g., `?x`, `?name`)
15//! - **Operator precedence** - Proper precedence: `||` < `&&` < comparisons
16//!
17//! # Expression Syntax
18//!
19//! ```text
20//! Expression Grammar:
21//!   expr        ::= or_expr
22//!   or_expr     ::= and_expr ('||' and_expr)*
23//!   and_expr    ::= comparison ('&&' comparison)*
24//!   comparison  ::= primary (COMP_OP primary)?
25//!   primary     ::= '!' primary
26//!                 | '(' expr ')'
27//!                 | field
28//!                 | literal
29//!                 | variable
30//!
31//!   COMP_OP     ::= '==' | '!=' | '>' | '<' | '>=' | '<='
32//!   field       ::= IDENT ('.' IDENT)*
33//!   literal     ::= boolean | number | string
34//!   variable    ::= '?' IDENT
35//! ```
36//!
37//! # Example
38//!
39//! ```rust
40//! use rust_rule_engine::backward::expression::{Expression, ExpressionParser};
41//! use rust_rule_engine::{Facts, Value};
42//!
43//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
44//! // Parse a complex expression
45//! let expr = ExpressionParser::parse(
46//!     "User.IsVIP == true && Order.Amount > 1000"
47//! )?;
48//!
49//! // Create facts
50//! let mut facts = Facts::new();
51//! facts.set("User.IsVIP", Value::Boolean(true));
52//! facts.set("Order.Amount", Value::Number(1500.0));
53//!
54//! // Evaluate expression
55//! let satisfied = expr.is_satisfied(&facts);
56//! assert!(satisfied);
57//!
58//! // Extract referenced fields
59//! let fields = expr.extract_fields();
60//! assert!(fields.contains(&"User.IsVIP".to_string()));
61//! assert!(fields.contains(&"Order.Amount".to_string()));
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! # Supported Operators
67//!
68//! ## Comparison Operators (left-to-right, equal precedence)
69//! - `==` - Equal to
70//! - `!=` - Not equal to
71//! - `>` - Greater than
72//! - `<` - Less than
73//! - `>=` - Greater than or equal to
74//! - `<=` - Less than or equal to
75//!
76//! ## Logical Operators (precedence: low to high)
77//! - `||` - Logical OR (lowest precedence)
78//! - `&&` - Logical AND (medium precedence)
79//! - `!` - Logical NOT (highest precedence)
80//!
81//! # Error Handling
82//!
83//! The parser returns descriptive errors for invalid syntax:
84//! - Unclosed parentheses
85//! - Unterminated strings
86//! - Invalid operators
87//! - Unexpected tokens
88
89use crate::errors::{Result, RuleEngineError};
90use crate::types::{Operator, Value};
91use crate::Facts;
92
93/// Expression AST node
94#[derive(Debug, Clone, PartialEq)]
95pub enum Expression {
96    /// Field reference (e.g., "User.IsVIP", "Order.Amount")
97    Field(String),
98
99    /// Literal value (e.g., true, false, 42, "hello")
100    Literal(Value),
101
102    /// Binary comparison (e.g., "X == Y", "A > B")
103    Comparison {
104        left: Box<Expression>,
105        operator: Operator,
106        right: Box<Expression>,
107    },
108
109    /// Logical AND operation (e.g., "A && B")
110    And {
111        left: Box<Expression>,
112        right: Box<Expression>,
113    },
114
115    /// Logical OR operation (e.g., "A || B")
116    Or {
117        left: Box<Expression>,
118        right: Box<Expression>,
119    },
120
121    /// Negation (e.g., "!X")
122    Not(Box<Expression>),
123
124    /// Variable (for future unification support, e.g., "?X", "?Customer")
125    Variable(String),
126}
127
128impl Expression {
129    /// Evaluate expression against facts
130    pub fn evaluate(&self, facts: &Facts) -> Result<Value> {
131        match self {
132            Expression::Field(name) => facts
133                .get(name)
134                .or_else(|| facts.get_nested(name))
135                .ok_or_else(|| {
136                    RuleEngineError::ExecutionError(format!("Field not found: {}", name))
137                }),
138
139            Expression::Literal(value) => Ok(value.clone()),
140
141            Expression::Comparison {
142                left,
143                operator,
144                right,
145            } => {
146                // Special handling for NotEqual when field doesn't exist
147                // If field doesn't exist, treat as Null
148                let left_val = left.evaluate(facts).unwrap_or(Value::Null);
149                let right_val = right.evaluate(facts).unwrap_or(Value::Null);
150
151                let result = operator.evaluate(&left_val, &right_val);
152                Ok(Value::Boolean(result))
153            }
154
155            Expression::And { left, right } => {
156                let left_val = left.evaluate(facts)?;
157                if !left_val.to_bool() {
158                    return Ok(Value::Boolean(false));
159                }
160                let right_val = right.evaluate(facts)?;
161                Ok(Value::Boolean(right_val.to_bool()))
162            }
163
164            Expression::Or { left, right } => {
165                let left_val = left.evaluate(facts)?;
166                if left_val.to_bool() {
167                    return Ok(Value::Boolean(true));
168                }
169                let right_val = right.evaluate(facts)?;
170                Ok(Value::Boolean(right_val.to_bool()))
171            }
172
173            Expression::Not(expr) => {
174                let value = expr.evaluate(facts)?;
175                Ok(Value::Boolean(!value.to_bool()))
176            }
177
178            Expression::Variable(var) => Err(RuleEngineError::ExecutionError(format!(
179                "Cannot evaluate unbound variable: {}",
180                var
181            ))),
182        }
183    }
184
185    /// Check if expression is satisfied (returns true/false)
186    pub fn is_satisfied(&self, facts: &Facts) -> bool {
187        self.evaluate(facts).map(|v| v.to_bool()).unwrap_or(false)
188    }
189
190    /// Extract all field references from expression
191    pub fn extract_fields(&self) -> Vec<String> {
192        let mut fields = Vec::new();
193        self.extract_fields_recursive(&mut fields);
194        fields
195    }
196
197    fn extract_fields_recursive(&self, fields: &mut Vec<String>) {
198        match self {
199            Expression::Field(name) if !fields.contains(name) => {
200                fields.push(name.clone());
201            }
202            Expression::Comparison { left, right, .. } => {
203                left.extract_fields_recursive(fields);
204                right.extract_fields_recursive(fields);
205            }
206            Expression::And { left, right } | Expression::Or { left, right } => {
207                left.extract_fields_recursive(fields);
208                right.extract_fields_recursive(fields);
209            }
210            Expression::Not(expr) => {
211                expr.extract_fields_recursive(fields);
212            }
213            _ => {}
214        }
215    }
216
217    /// Convert to human-readable string
218    #[allow(clippy::inherent_to_string)]
219    pub fn to_string(&self) -> String {
220        match self {
221            Expression::Field(name) => name.clone(),
222            Expression::Literal(val) => format!("{:?}", val),
223            Expression::Comparison {
224                left,
225                operator,
226                right,
227            } => {
228                format!("{} {:?} {}", left.to_string(), operator, right.to_string())
229            }
230            Expression::And { left, right } => {
231                format!("({} && {})", left.to_string(), right.to_string())
232            }
233            Expression::Or { left, right } => {
234                format!("({} || {})", left.to_string(), right.to_string())
235            }
236            Expression::Not(expr) => {
237                format!("!{}", expr.to_string())
238            }
239            Expression::Variable(var) => var.clone(),
240        }
241    }
242}
243
244/// Expression parser using recursive descent parsing
245pub struct ExpressionParser {
246    input: Vec<char>,
247    position: usize,
248}
249
250impl ExpressionParser {
251    /// Create a new parser
252    pub fn new(input: &str) -> Self {
253        Self {
254            input: input.chars().collect(),
255            position: 0,
256        }
257    }
258
259    /// Parse expression from string
260    pub fn parse(input: &str) -> Result<Expression> {
261        let mut parser = Self::new(input.trim());
262        parser.parse_expression()
263    }
264
265    /// Parse full expression (handles ||)
266    fn parse_expression(&mut self) -> Result<Expression> {
267        let mut left = self.parse_and_expression()?;
268
269        while self.peek_operator("||") {
270            self.consume_operator("||");
271            let right = self.parse_and_expression()?;
272            left = Expression::Or {
273                left: Box::new(left),
274                right: Box::new(right),
275            };
276        }
277
278        Ok(left)
279    }
280
281    /// Parse AND expression (handles &&)
282    fn parse_and_expression(&mut self) -> Result<Expression> {
283        let mut left = self.parse_comparison()?;
284
285        while self.peek_operator("&&") {
286            self.consume_operator("&&");
287            let right = self.parse_comparison()?;
288            left = Expression::And {
289                left: Box::new(left),
290                right: Box::new(right),
291            };
292        }
293
294        Ok(left)
295    }
296
297    /// Parse comparison (e.g., "X == Y", "A > 5")
298    fn parse_comparison(&mut self) -> Result<Expression> {
299        let left = self.parse_primary()?;
300
301        // Check for comparison operators (check longer operators first)
302        let operator = if self.peek_operator("==") {
303            self.consume_operator("==");
304            Operator::Equal
305        } else if self.peek_operator("!=") {
306            self.consume_operator("!=");
307            Operator::NotEqual
308        } else if self.peek_operator(">=") {
309            self.consume_operator(">=");
310            Operator::GreaterThanOrEqual
311        } else if self.peek_operator("<=") {
312            self.consume_operator("<=");
313            Operator::LessThanOrEqual
314        } else if self.peek_operator(">") {
315            self.consume_operator(">");
316            Operator::GreaterThan
317        } else if self.peek_operator("<") {
318            self.consume_operator("<");
319            Operator::LessThan
320        } else {
321            // No comparison operator - return just the left side
322            return Ok(left);
323        };
324
325        let right = self.parse_primary()?;
326
327        Ok(Expression::Comparison {
328            left: Box::new(left),
329            operator,
330            right: Box::new(right),
331        })
332    }
333
334    /// Parse primary expression (field, literal, variable, or parenthesized)
335    fn parse_primary(&mut self) -> Result<Expression> {
336        self.skip_whitespace();
337
338        // Handle negation
339        if self.peek_char() == Some('!') {
340            self.consume_char();
341            let expr = self.parse_primary()?;
342            return Ok(Expression::Not(Box::new(expr)));
343        }
344
345        // Handle parentheses
346        if self.peek_char() == Some('(') {
347            self.consume_char();
348            let expr = self.parse_expression()?;
349            self.skip_whitespace();
350            if self.peek_char() != Some(')') {
351                return Err(RuleEngineError::ParseError {
352                    message: format!("Expected closing parenthesis at position {}", self.position),
353                });
354            }
355            self.consume_char();
356            return Ok(expr);
357        }
358
359        // Handle variables (?X, ?Customer)
360        if self.peek_char() == Some('?') {
361            self.consume_char();
362            let name = self.consume_identifier()?;
363            return Ok(Expression::Variable(format!("?{}", name)));
364        }
365
366        // Try to parse literal
367        if let Some(value) = self.try_parse_literal()? {
368            return Ok(Expression::Literal(value));
369        }
370
371        // Handle field reference
372        let field_name = self.consume_field_path()?;
373        Ok(Expression::Field(field_name))
374    }
375
376    fn consume_field_path(&mut self) -> Result<String> {
377        let mut path = String::new();
378
379        while let Some(ch) = self.peek_char() {
380            if ch.is_alphanumeric() || ch == '_' || ch == '.' {
381                path.push(ch);
382                self.consume_char();
383            } else {
384                break;
385            }
386        }
387
388        if path.is_empty() {
389            return Err(RuleEngineError::ParseError {
390                message: format!("Expected field name at position {}", self.position),
391            });
392        }
393
394        Ok(path)
395    }
396
397    fn consume_identifier(&mut self) -> Result<String> {
398        let mut ident = String::new();
399
400        while let Some(ch) = self.peek_char() {
401            if ch.is_alphanumeric() || ch == '_' {
402                ident.push(ch);
403                self.consume_char();
404            } else {
405                break;
406            }
407        }
408
409        if ident.is_empty() {
410            return Err(RuleEngineError::ParseError {
411                message: format!("Expected identifier at position {}", self.position),
412            });
413        }
414
415        Ok(ident)
416    }
417
418    fn try_parse_literal(&mut self) -> Result<Option<Value>> {
419        self.skip_whitespace();
420
421        // Boolean literals
422        if self.peek_word("true") {
423            self.consume_word("true");
424            return Ok(Some(Value::Boolean(true)));
425        }
426        if self.peek_word("false") {
427            self.consume_word("false");
428            return Ok(Some(Value::Boolean(false)));
429        }
430
431        // Null literal
432        if self.peek_word("null") {
433            self.consume_word("null");
434            return Ok(Some(Value::Null));
435        }
436
437        // String literals
438        if self.peek_char() == Some('"') {
439            self.consume_char();
440            let mut s = String::new();
441            let mut escaped = false;
442
443            while let Some(ch) = self.peek_char() {
444                if escaped {
445                    // Handle escape sequences
446                    let escaped_char = match ch {
447                        'n' => '\n',
448                        't' => '\t',
449                        'r' => '\r',
450                        '\\' => '\\',
451                        '"' => '"',
452                        _ => ch,
453                    };
454                    s.push(escaped_char);
455                    escaped = false;
456                    self.consume_char();
457                } else if ch == '\\' {
458                    escaped = true;
459                    self.consume_char();
460                } else if ch == '"' {
461                    self.consume_char();
462                    return Ok(Some(Value::String(s)));
463                } else {
464                    s.push(ch);
465                    self.consume_char();
466                }
467            }
468
469            return Err(RuleEngineError::ParseError {
470                message: format!("Unterminated string at position {}", self.position),
471            });
472        }
473
474        // Number literals
475        if let Some(ch) = self.peek_char() {
476            if ch.is_numeric() || ch == '-' {
477                let start_pos = self.position;
478                let mut num_str = String::new();
479                let mut has_dot = false;
480
481                while let Some(ch) = self.peek_char() {
482                    if ch.is_numeric() {
483                        num_str.push(ch);
484                        self.consume_char();
485                    } else if ch == '.' && !has_dot {
486                        has_dot = true;
487                        num_str.push(ch);
488                        self.consume_char();
489                    } else if ch == '-' && num_str.is_empty() {
490                        num_str.push(ch);
491                        self.consume_char();
492                    } else {
493                        break;
494                    }
495                }
496
497                if !num_str.is_empty() && num_str != "-" {
498                    if has_dot {
499                        if let Ok(n) = num_str.parse::<f64>() {
500                            return Ok(Some(Value::Number(n)));
501                        }
502                    } else if let Ok(i) = num_str.parse::<i64>() {
503                        return Ok(Some(Value::Number(i as f64)));
504                    }
505                }
506
507                // Failed to parse - reset position
508                self.position = start_pos;
509            }
510        }
511
512        Ok(None)
513    }
514
515    fn peek_char(&self) -> Option<char> {
516        if self.position < self.input.len() {
517            Some(self.input[self.position])
518        } else {
519            None
520        }
521    }
522
523    fn consume_char(&mut self) {
524        if self.position < self.input.len() {
525            self.position += 1;
526        }
527    }
528
529    fn peek_operator(&mut self, op: &str) -> bool {
530        self.skip_whitespace();
531        let remaining: String = self.input[self.position..].iter().collect();
532        remaining.starts_with(op)
533    }
534
535    fn consume_operator(&mut self, op: &str) {
536        self.skip_whitespace();
537        for _ in 0..op.len() {
538            self.consume_char();
539        }
540    }
541
542    fn peek_word(&mut self, word: &str) -> bool {
543        self.skip_whitespace();
544        let remaining: String = self.input[self.position..].iter().collect();
545
546        if remaining.starts_with(word) {
547            // Make sure it's a complete word (not prefix)
548            let next_pos = self.position + word.len();
549            if next_pos >= self.input.len() {
550                return true;
551            }
552            let next_char = self.input[next_pos];
553            !next_char.is_alphanumeric() && next_char != '_'
554        } else {
555            false
556        }
557    }
558
559    fn consume_word(&mut self, word: &str) {
560        self.skip_whitespace();
561        if self.peek_word(word) {
562            for _ in 0..word.len() {
563                self.consume_char();
564            }
565        }
566    }
567
568    fn skip_whitespace(&mut self) {
569        while let Some(ch) = self.peek_char() {
570            if ch.is_whitespace() {
571                self.consume_char();
572            } else {
573                break;
574            }
575        }
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_parse_simple_field() {
585        let expr = ExpressionParser::parse("User.IsVIP").unwrap();
586        match expr {
587            Expression::Field(name) => {
588                assert_eq!(name, "User.IsVIP");
589            }
590            _ => panic!("Expected field expression"),
591        }
592    }
593
594    #[test]
595    fn test_parse_literal_boolean() {
596        let expr = ExpressionParser::parse("true").unwrap();
597        match expr {
598            Expression::Literal(Value::Boolean(true)) => {}
599            _ => panic!("Expected boolean literal"),
600        }
601    }
602
603    #[test]
604    fn test_parse_literal_number() {
605        let expr = ExpressionParser::parse("42.5").unwrap();
606        match expr {
607            Expression::Literal(Value::Number(n)) => {
608                assert!((n - 42.5).abs() < 0.001);
609            }
610            _ => panic!("Expected number literal"),
611        }
612    }
613
614    #[test]
615    fn test_parse_literal_string() {
616        let expr = ExpressionParser::parse(r#""hello world""#).unwrap();
617        match expr {
618            Expression::Literal(Value::String(s)) => {
619                assert_eq!(s, "hello world");
620            }
621            _ => panic!("Expected string literal"),
622        }
623    }
624
625    #[test]
626    fn test_parse_simple_comparison() {
627        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
628        match expr {
629            Expression::Comparison { operator, .. } => {
630                assert_eq!(operator, Operator::Equal);
631            }
632            _ => panic!("Expected comparison"),
633        }
634    }
635
636    #[test]
637    fn test_parse_all_comparison_operators() {
638        let operators = vec![
639            ("a == b", Operator::Equal),
640            ("a != b", Operator::NotEqual),
641            ("a > b", Operator::GreaterThan),
642            ("a >= b", Operator::GreaterThanOrEqual),
643            ("a < b", Operator::LessThan),
644            ("a <= b", Operator::LessThanOrEqual),
645        ];
646
647        for (input, expected_op) in operators {
648            let expr = ExpressionParser::parse(input).unwrap();
649            match expr {
650                Expression::Comparison { operator, .. } => {
651                    assert_eq!(operator, expected_op, "Failed for: {}", input);
652                }
653                _ => panic!("Expected comparison for: {}", input),
654            }
655        }
656    }
657
658    #[test]
659    fn test_parse_logical_and() {
660        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
661        match expr {
662            Expression::And { .. } => {}
663            _ => panic!("Expected logical AND, got: {:?}", expr),
664        }
665    }
666
667    #[test]
668    fn test_parse_logical_or() {
669        let expr = ExpressionParser::parse("a == true || b == true").unwrap();
670        match expr {
671            Expression::Or { .. } => {}
672            _ => panic!("Expected logical OR"),
673        }
674    }
675
676    #[test]
677    fn test_parse_negation() {
678        let expr = ExpressionParser::parse("!User.IsBanned").unwrap();
679        match expr {
680            Expression::Not(_) => {}
681            _ => panic!("Expected negation"),
682        }
683    }
684
685    #[test]
686    fn test_parse_parentheses() {
687        let expr = ExpressionParser::parse("(a == true || b == true) && c == true").unwrap();
688        match expr {
689            Expression::And { left, .. } => match *left {
690                Expression::Or { .. } => {}
691                _ => panic!("Expected OR inside AND"),
692            },
693            _ => panic!("Expected AND"),
694        }
695    }
696
697    #[test]
698    fn test_parse_variable() {
699        let expr = ExpressionParser::parse("?X == true").unwrap();
700        match expr {
701            Expression::Comparison { left, .. } => match *left {
702                Expression::Variable(var) => {
703                    assert_eq!(var, "?X");
704                }
705                _ => panic!("Expected variable"),
706            },
707            _ => panic!("Expected comparison"),
708        }
709    }
710
711    #[test]
712    fn test_evaluate_simple() {
713        let facts = Facts::new();
714        facts.set("User.IsVIP", Value::Boolean(true));
715
716        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
717        let result = expr.evaluate(&facts).unwrap();
718
719        assert_eq!(result, Value::Boolean(true));
720    }
721
722    #[test]
723    fn test_evaluate_comparison() {
724        let facts = Facts::new();
725        facts.set("Order.Amount", Value::Number(1500.0));
726
727        let expr = ExpressionParser::parse("Order.Amount > 1000").unwrap();
728        let result = expr.evaluate(&facts).unwrap();
729
730        assert_eq!(result, Value::Boolean(true));
731    }
732
733    #[test]
734    fn test_evaluate_logical_and() {
735        let facts = Facts::new();
736        facts.set("User.IsVIP", Value::Boolean(true));
737        facts.set("Order.Amount", Value::Number(1500.0));
738
739        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
740        let result = expr.evaluate(&facts).unwrap();
741
742        assert_eq!(result, Value::Boolean(true));
743    }
744
745    #[test]
746    fn test_evaluate_logical_or() {
747        let facts = Facts::new();
748        facts.set("a", Value::Boolean(false));
749        facts.set("b", Value::Boolean(true));
750
751        let expr = ExpressionParser::parse("a == true || b == true").unwrap();
752        let result = expr.evaluate(&facts).unwrap();
753
754        assert_eq!(result, Value::Boolean(true));
755    }
756
757    #[test]
758    fn test_is_satisfied() {
759        let facts = Facts::new();
760        facts.set("User.IsVIP", Value::Boolean(true));
761
762        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
763        assert!(expr.is_satisfied(&facts));
764    }
765
766    #[test]
767    fn test_extract_fields() {
768        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
769        let fields = expr.extract_fields();
770
771        assert_eq!(fields.len(), 2);
772        assert!(fields.contains(&"User.IsVIP".to_string()));
773        assert!(fields.contains(&"Order.Amount".to_string()));
774    }
775
776    #[test]
777    fn test_parse_error_unclosed_parenthesis() {
778        let result = ExpressionParser::parse("(a == true");
779        assert!(result.is_err());
780    }
781
782    #[test]
783    fn test_parse_error_unterminated_string() {
784        let result = ExpressionParser::parse(r#""hello"#);
785        assert!(result.is_err());
786    }
787
788    #[test]
789    fn test_parse_complex_expression() {
790        let expr = ExpressionParser::parse(
791            "(User.IsVIP == true && Order.Amount > 1000) || (User.Points >= 100 && Order.Discount < 0.5)"
792        ).unwrap();
793
794        // Just check it parses without panicking
795        match expr {
796            Expression::Or { .. } => {}
797            _ => panic!("Expected OR at top level"),
798        }
799    }
800
801    #[test]
802    fn test_to_string() {
803        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
804        let s = expr.to_string();
805        assert!(s.contains("User.IsVIP"));
806        assert!(s.contains("true"));
807    }
808}