rust_rule_engine/backward/
expression.rs

1//! Expression AST for backward chaining queries
2//!
3//! This module provides a robust Abstract Syntax Tree (AST) implementation for parsing
4//! and evaluating goal expressions in backward chaining queries. It uses a recursive
5//! descent parser to build a typed expression tree that can be evaluated against facts.
6//!
7//! # Features
8//!
9//! - **Field references** - Access fact values using dot notation (e.g., `User.Name`, `Order.Total`)
10//! - **Literal values** - Support for boolean, numeric, and string literals
11//! - **Comparison operators** - `==`, `!=`, `>`, `<`, `>=`, `<=`
12//! - **Logical operators** - `&&` (AND), `||` (OR), `!` (NOT)
13//! - **Parenthesized expressions** - Group sub-expressions with parentheses
14//! - **Variable binding** - Placeholder variables for unification (e.g., `?x`, `?name`)
15//! - **Operator precedence** - Proper precedence: `||` < `&&` < comparisons
16//!
17//! # Expression Syntax
18//!
19//! ```text
20//! Expression Grammar:
21//!   expr        ::= or_expr
22//!   or_expr     ::= and_expr ('||' and_expr)*
23//!   and_expr    ::= comparison ('&&' comparison)*
24//!   comparison  ::= primary (COMP_OP primary)?
25//!   primary     ::= '!' primary
26//!                 | '(' expr ')'
27//!                 | field
28//!                 | literal
29//!                 | variable
30//!
31//!   COMP_OP     ::= '==' | '!=' | '>' | '<' | '>=' | '<='
32//!   field       ::= IDENT ('.' IDENT)*
33//!   literal     ::= boolean | number | string
34//!   variable    ::= '?' IDENT
35//! ```
36//!
37//! # Example
38//!
39//! ```rust
40//! use rust_rule_engine::backward::expression::{Expression, ExpressionParser};
41//! use rust_rule_engine::{Facts, Value};
42//!
43//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
44//! // Parse a complex expression
45//! let expr = ExpressionParser::parse(
46//!     "User.IsVIP == true && Order.Amount > 1000"
47//! )?;
48//!
49//! // Create facts
50//! let mut facts = Facts::new();
51//! facts.set("User.IsVIP", Value::Boolean(true));
52//! facts.set("Order.Amount", Value::Number(1500.0));
53//!
54//! // Evaluate expression
55//! let satisfied = expr.is_satisfied(&facts);
56//! assert!(satisfied);
57//!
58//! // Extract referenced fields
59//! let fields = expr.extract_fields();
60//! assert!(fields.contains(&"User.IsVIP".to_string()));
61//! assert!(fields.contains(&"Order.Amount".to_string()));
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! # Supported Operators
67//!
68//! ## Comparison Operators (left-to-right, equal precedence)
69//! - `==` - Equal to
70//! - `!=` - Not equal to
71//! - `>` - Greater than
72//! - `<` - Less than
73//! - `>=` - Greater than or equal to
74//! - `<=` - Less than or equal to
75//!
76//! ## Logical Operators (precedence: low to high)
77//! - `||` - Logical OR (lowest precedence)
78//! - `&&` - Logical AND (medium precedence)
79//! - `!` - Logical NOT (highest precedence)
80//!
81//! # Error Handling
82//!
83//! The parser returns descriptive errors for invalid syntax:
84//! - Unclosed parentheses
85//! - Unterminated strings
86//! - Invalid operators
87//! - Unexpected tokens
88
89use crate::errors::{Result, RuleEngineError};
90use crate::types::{Operator, Value};
91use crate::Facts;
92
93/// Expression AST node
94#[derive(Debug, Clone, PartialEq)]
95pub enum Expression {
96    /// Field reference (e.g., "User.IsVIP", "Order.Amount")
97    Field(String),
98
99    /// Literal value (e.g., true, false, 42, "hello")
100    Literal(Value),
101
102    /// Binary comparison (e.g., "X == Y", "A > B")
103    Comparison {
104        left: Box<Expression>,
105        operator: Operator,
106        right: Box<Expression>,
107    },
108
109    /// Logical AND operation (e.g., "A && B")
110    And {
111        left: Box<Expression>,
112        right: Box<Expression>,
113    },
114
115    /// Logical OR operation (e.g., "A || B")
116    Or {
117        left: Box<Expression>,
118        right: Box<Expression>,
119    },
120
121    /// Negation (e.g., "!X")
122    Not(Box<Expression>),
123
124    /// Variable (for future unification support, e.g., "?X", "?Customer")
125    Variable(String),
126}
127
128impl Expression {
129    /// Evaluate expression against facts
130    pub fn evaluate(&self, facts: &Facts) -> Result<Value> {
131        match self {
132            Expression::Field(name) => facts
133                .get(name)
134                .or_else(|| facts.get_nested(name))
135                .ok_or_else(|| {
136                    RuleEngineError::ExecutionError(format!("Field not found: {}", name))
137                }),
138
139            Expression::Literal(value) => Ok(value.clone()),
140
141            Expression::Comparison {
142                left,
143                operator,
144                right,
145            } => {
146                // Special handling for NotEqual when field doesn't exist
147                // If field doesn't exist, treat as Null
148                let left_val = left.evaluate(facts).unwrap_or(Value::Null);
149                let right_val = right.evaluate(facts).unwrap_or(Value::Null);
150
151                let result = operator.evaluate(&left_val, &right_val);
152                Ok(Value::Boolean(result))
153            }
154
155            Expression::And { left, right } => {
156                let left_val = left.evaluate(facts)?;
157                if !left_val.to_bool() {
158                    return Ok(Value::Boolean(false));
159                }
160                let right_val = right.evaluate(facts)?;
161                Ok(Value::Boolean(right_val.to_bool()))
162            }
163
164            Expression::Or { left, right } => {
165                let left_val = left.evaluate(facts)?;
166                if left_val.to_bool() {
167                    return Ok(Value::Boolean(true));
168                }
169                let right_val = right.evaluate(facts)?;
170                Ok(Value::Boolean(right_val.to_bool()))
171            }
172
173            Expression::Not(expr) => {
174                let value = expr.evaluate(facts)?;
175                Ok(Value::Boolean(!value.to_bool()))
176            }
177
178            Expression::Variable(var) => Err(RuleEngineError::ExecutionError(format!(
179                "Cannot evaluate unbound variable: {}",
180                var
181            ))),
182        }
183    }
184
185    /// Check if expression is satisfied (returns true/false)
186    pub fn is_satisfied(&self, facts: &Facts) -> bool {
187        self.evaluate(facts).map(|v| v.to_bool()).unwrap_or(false)
188    }
189
190    /// Extract all field references from expression
191    pub fn extract_fields(&self) -> Vec<String> {
192        let mut fields = Vec::new();
193        self.extract_fields_recursive(&mut fields);
194        fields
195    }
196
197    fn extract_fields_recursive(&self, fields: &mut Vec<String>) {
198        match self {
199            Expression::Field(name) => {
200                if !fields.contains(name) {
201                    fields.push(name.clone());
202                }
203            }
204            Expression::Comparison { left, right, .. } => {
205                left.extract_fields_recursive(fields);
206                right.extract_fields_recursive(fields);
207            }
208            Expression::And { left, right } | Expression::Or { left, right } => {
209                left.extract_fields_recursive(fields);
210                right.extract_fields_recursive(fields);
211            }
212            Expression::Not(expr) => {
213                expr.extract_fields_recursive(fields);
214            }
215            _ => {}
216        }
217    }
218
219    /// Convert to human-readable string
220    #[allow(clippy::inherent_to_string)]
221    pub fn to_string(&self) -> String {
222        match self {
223            Expression::Field(name) => name.clone(),
224            Expression::Literal(val) => format!("{:?}", val),
225            Expression::Comparison {
226                left,
227                operator,
228                right,
229            } => {
230                format!("{} {:?} {}", left.to_string(), operator, right.to_string())
231            }
232            Expression::And { left, right } => {
233                format!("({} && {})", left.to_string(), right.to_string())
234            }
235            Expression::Or { left, right } => {
236                format!("({} || {})", left.to_string(), right.to_string())
237            }
238            Expression::Not(expr) => {
239                format!("!{}", expr.to_string())
240            }
241            Expression::Variable(var) => var.clone(),
242        }
243    }
244}
245
246/// Expression parser using recursive descent parsing
247pub struct ExpressionParser {
248    input: Vec<char>,
249    position: usize,
250}
251
252impl ExpressionParser {
253    /// Create a new parser
254    pub fn new(input: &str) -> Self {
255        Self {
256            input: input.chars().collect(),
257            position: 0,
258        }
259    }
260
261    /// Parse expression from string
262    pub fn parse(input: &str) -> Result<Expression> {
263        let mut parser = Self::new(input.trim());
264        parser.parse_expression()
265    }
266
267    /// Parse full expression (handles ||)
268    fn parse_expression(&mut self) -> Result<Expression> {
269        let mut left = self.parse_and_expression()?;
270
271        while self.peek_operator("||") {
272            self.consume_operator("||");
273            let right = self.parse_and_expression()?;
274            left = Expression::Or {
275                left: Box::new(left),
276                right: Box::new(right),
277            };
278        }
279
280        Ok(left)
281    }
282
283    /// Parse AND expression (handles &&)
284    fn parse_and_expression(&mut self) -> Result<Expression> {
285        let mut left = self.parse_comparison()?;
286
287        while self.peek_operator("&&") {
288            self.consume_operator("&&");
289            let right = self.parse_comparison()?;
290            left = Expression::And {
291                left: Box::new(left),
292                right: Box::new(right),
293            };
294        }
295
296        Ok(left)
297    }
298
299    /// Parse comparison (e.g., "X == Y", "A > 5")
300    fn parse_comparison(&mut self) -> Result<Expression> {
301        let left = self.parse_primary()?;
302
303        // Check for comparison operators (check longer operators first)
304        let operator = if self.peek_operator("==") {
305            self.consume_operator("==");
306            Operator::Equal
307        } else if self.peek_operator("!=") {
308            self.consume_operator("!=");
309            Operator::NotEqual
310        } else if self.peek_operator(">=") {
311            self.consume_operator(">=");
312            Operator::GreaterThanOrEqual
313        } else if self.peek_operator("<=") {
314            self.consume_operator("<=");
315            Operator::LessThanOrEqual
316        } else if self.peek_operator(">") {
317            self.consume_operator(">");
318            Operator::GreaterThan
319        } else if self.peek_operator("<") {
320            self.consume_operator("<");
321            Operator::LessThan
322        } else {
323            // No comparison operator - return just the left side
324            return Ok(left);
325        };
326
327        let right = self.parse_primary()?;
328
329        Ok(Expression::Comparison {
330            left: Box::new(left),
331            operator,
332            right: Box::new(right),
333        })
334    }
335
336    /// Parse primary expression (field, literal, variable, or parenthesized)
337    fn parse_primary(&mut self) -> Result<Expression> {
338        self.skip_whitespace();
339
340        // Handle negation
341        if self.peek_char() == Some('!') {
342            self.consume_char();
343            let expr = self.parse_primary()?;
344            return Ok(Expression::Not(Box::new(expr)));
345        }
346
347        // Handle parentheses
348        if self.peek_char() == Some('(') {
349            self.consume_char();
350            let expr = self.parse_expression()?;
351            self.skip_whitespace();
352            if self.peek_char() != Some(')') {
353                return Err(RuleEngineError::ParseError {
354                    message: format!("Expected closing parenthesis at position {}", self.position),
355                });
356            }
357            self.consume_char();
358            return Ok(expr);
359        }
360
361        // Handle variables (?X, ?Customer)
362        if self.peek_char() == Some('?') {
363            self.consume_char();
364            let name = self.consume_identifier()?;
365            return Ok(Expression::Variable(format!("?{}", name)));
366        }
367
368        // Try to parse literal
369        if let Some(value) = self.try_parse_literal()? {
370            return Ok(Expression::Literal(value));
371        }
372
373        // Handle field reference
374        let field_name = self.consume_field_path()?;
375        Ok(Expression::Field(field_name))
376    }
377
378    fn consume_field_path(&mut self) -> Result<String> {
379        let mut path = String::new();
380
381        while let Some(ch) = self.peek_char() {
382            if ch.is_alphanumeric() || ch == '_' || ch == '.' {
383                path.push(ch);
384                self.consume_char();
385            } else {
386                break;
387            }
388        }
389
390        if path.is_empty() {
391            return Err(RuleEngineError::ParseError {
392                message: format!("Expected field name at position {}", self.position),
393            });
394        }
395
396        Ok(path)
397    }
398
399    fn consume_identifier(&mut self) -> Result<String> {
400        let mut ident = String::new();
401
402        while let Some(ch) = self.peek_char() {
403            if ch.is_alphanumeric() || ch == '_' {
404                ident.push(ch);
405                self.consume_char();
406            } else {
407                break;
408            }
409        }
410
411        if ident.is_empty() {
412            return Err(RuleEngineError::ParseError {
413                message: format!("Expected identifier at position {}", self.position),
414            });
415        }
416
417        Ok(ident)
418    }
419
420    fn try_parse_literal(&mut self) -> Result<Option<Value>> {
421        self.skip_whitespace();
422
423        // Boolean literals
424        if self.peek_word("true") {
425            self.consume_word("true");
426            return Ok(Some(Value::Boolean(true)));
427        }
428        if self.peek_word("false") {
429            self.consume_word("false");
430            return Ok(Some(Value::Boolean(false)));
431        }
432
433        // Null literal
434        if self.peek_word("null") {
435            self.consume_word("null");
436            return Ok(Some(Value::Null));
437        }
438
439        // String literals
440        if self.peek_char() == Some('"') {
441            self.consume_char();
442            let mut s = String::new();
443            let mut escaped = false;
444
445            while let Some(ch) = self.peek_char() {
446                if escaped {
447                    // Handle escape sequences
448                    let escaped_char = match ch {
449                        'n' => '\n',
450                        't' => '\t',
451                        'r' => '\r',
452                        '\\' => '\\',
453                        '"' => '"',
454                        _ => ch,
455                    };
456                    s.push(escaped_char);
457                    escaped = false;
458                    self.consume_char();
459                } else if ch == '\\' {
460                    escaped = true;
461                    self.consume_char();
462                } else if ch == '"' {
463                    self.consume_char();
464                    return Ok(Some(Value::String(s)));
465                } else {
466                    s.push(ch);
467                    self.consume_char();
468                }
469            }
470
471            return Err(RuleEngineError::ParseError {
472                message: format!("Unterminated string at position {}", self.position),
473            });
474        }
475
476        // Number literals
477        if let Some(ch) = self.peek_char() {
478            if ch.is_numeric() || ch == '-' {
479                let start_pos = self.position;
480                let mut num_str = String::new();
481                let mut has_dot = false;
482
483                while let Some(ch) = self.peek_char() {
484                    if ch.is_numeric() {
485                        num_str.push(ch);
486                        self.consume_char();
487                    } else if ch == '.' && !has_dot {
488                        has_dot = true;
489                        num_str.push(ch);
490                        self.consume_char();
491                    } else if ch == '-' && num_str.is_empty() {
492                        num_str.push(ch);
493                        self.consume_char();
494                    } else {
495                        break;
496                    }
497                }
498
499                if !num_str.is_empty() && num_str != "-" {
500                    if has_dot {
501                        if let Ok(n) = num_str.parse::<f64>() {
502                            return Ok(Some(Value::Number(n)));
503                        }
504                    } else if let Ok(i) = num_str.parse::<i64>() {
505                        return Ok(Some(Value::Number(i as f64)));
506                    }
507                }
508
509                // Failed to parse - reset position
510                self.position = start_pos;
511            }
512        }
513
514        Ok(None)
515    }
516
517    fn peek_char(&self) -> Option<char> {
518        if self.position < self.input.len() {
519            Some(self.input[self.position])
520        } else {
521            None
522        }
523    }
524
525    fn consume_char(&mut self) {
526        if self.position < self.input.len() {
527            self.position += 1;
528        }
529    }
530
531    fn peek_operator(&mut self, op: &str) -> bool {
532        self.skip_whitespace();
533        let remaining: String = self.input[self.position..].iter().collect();
534        remaining.starts_with(op)
535    }
536
537    fn consume_operator(&mut self, op: &str) {
538        self.skip_whitespace();
539        for _ in 0..op.len() {
540            self.consume_char();
541        }
542    }
543
544    fn peek_word(&mut self, word: &str) -> bool {
545        self.skip_whitespace();
546        let remaining: String = self.input[self.position..].iter().collect();
547
548        if remaining.starts_with(word) {
549            // Make sure it's a complete word (not prefix)
550            let next_pos = self.position + word.len();
551            if next_pos >= self.input.len() {
552                return true;
553            }
554            let next_char = self.input[next_pos];
555            !next_char.is_alphanumeric() && next_char != '_'
556        } else {
557            false
558        }
559    }
560
561    fn consume_word(&mut self, word: &str) {
562        self.skip_whitespace();
563        if self.peek_word(word) {
564            for _ in 0..word.len() {
565                self.consume_char();
566            }
567        }
568    }
569
570    fn skip_whitespace(&mut self) {
571        while let Some(ch) = self.peek_char() {
572            if ch.is_whitespace() {
573                self.consume_char();
574            } else {
575                break;
576            }
577        }
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    #[test]
586    fn test_parse_simple_field() {
587        let expr = ExpressionParser::parse("User.IsVIP").unwrap();
588        match expr {
589            Expression::Field(name) => {
590                assert_eq!(name, "User.IsVIP");
591            }
592            _ => panic!("Expected field expression"),
593        }
594    }
595
596    #[test]
597    fn test_parse_literal_boolean() {
598        let expr = ExpressionParser::parse("true").unwrap();
599        match expr {
600            Expression::Literal(Value::Boolean(true)) => {}
601            _ => panic!("Expected boolean literal"),
602        }
603    }
604
605    #[test]
606    fn test_parse_literal_number() {
607        let expr = ExpressionParser::parse("42.5").unwrap();
608        match expr {
609            Expression::Literal(Value::Number(n)) => {
610                assert!((n - 42.5).abs() < 0.001);
611            }
612            _ => panic!("Expected number literal"),
613        }
614    }
615
616    #[test]
617    fn test_parse_literal_string() {
618        let expr = ExpressionParser::parse(r#""hello world""#).unwrap();
619        match expr {
620            Expression::Literal(Value::String(s)) => {
621                assert_eq!(s, "hello world");
622            }
623            _ => panic!("Expected string literal"),
624        }
625    }
626
627    #[test]
628    fn test_parse_simple_comparison() {
629        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
630        match expr {
631            Expression::Comparison { operator, .. } => {
632                assert_eq!(operator, Operator::Equal);
633            }
634            _ => panic!("Expected comparison"),
635        }
636    }
637
638    #[test]
639    fn test_parse_all_comparison_operators() {
640        let operators = vec![
641            ("a == b", Operator::Equal),
642            ("a != b", Operator::NotEqual),
643            ("a > b", Operator::GreaterThan),
644            ("a >= b", Operator::GreaterThanOrEqual),
645            ("a < b", Operator::LessThan),
646            ("a <= b", Operator::LessThanOrEqual),
647        ];
648
649        for (input, expected_op) in operators {
650            let expr = ExpressionParser::parse(input).unwrap();
651            match expr {
652                Expression::Comparison { operator, .. } => {
653                    assert_eq!(operator, expected_op, "Failed for: {}", input);
654                }
655                _ => panic!("Expected comparison for: {}", input),
656            }
657        }
658    }
659
660    #[test]
661    fn test_parse_logical_and() {
662        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
663        match expr {
664            Expression::And { .. } => {}
665            _ => panic!("Expected logical AND, got: {:?}", expr),
666        }
667    }
668
669    #[test]
670    fn test_parse_logical_or() {
671        let expr = ExpressionParser::parse("a == true || b == true").unwrap();
672        match expr {
673            Expression::Or { .. } => {}
674            _ => panic!("Expected logical OR"),
675        }
676    }
677
678    #[test]
679    fn test_parse_negation() {
680        let expr = ExpressionParser::parse("!User.IsBanned").unwrap();
681        match expr {
682            Expression::Not(_) => {}
683            _ => panic!("Expected negation"),
684        }
685    }
686
687    #[test]
688    fn test_parse_parentheses() {
689        let expr = ExpressionParser::parse("(a == true || b == true) && c == true").unwrap();
690        match expr {
691            Expression::And { left, .. } => match *left {
692                Expression::Or { .. } => {}
693                _ => panic!("Expected OR inside AND"),
694            },
695            _ => panic!("Expected AND"),
696        }
697    }
698
699    #[test]
700    fn test_parse_variable() {
701        let expr = ExpressionParser::parse("?X == true").unwrap();
702        match expr {
703            Expression::Comparison { left, .. } => match *left {
704                Expression::Variable(var) => {
705                    assert_eq!(var, "?X");
706                }
707                _ => panic!("Expected variable"),
708            },
709            _ => panic!("Expected comparison"),
710        }
711    }
712
713    #[test]
714    fn test_evaluate_simple() {
715        let facts = Facts::new();
716        facts.set("User.IsVIP", Value::Boolean(true));
717
718        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
719        let result = expr.evaluate(&facts).unwrap();
720
721        assert_eq!(result, Value::Boolean(true));
722    }
723
724    #[test]
725    fn test_evaluate_comparison() {
726        let facts = Facts::new();
727        facts.set("Order.Amount", Value::Number(1500.0));
728
729        let expr = ExpressionParser::parse("Order.Amount > 1000").unwrap();
730        let result = expr.evaluate(&facts).unwrap();
731
732        assert_eq!(result, Value::Boolean(true));
733    }
734
735    #[test]
736    fn test_evaluate_logical_and() {
737        let facts = Facts::new();
738        facts.set("User.IsVIP", Value::Boolean(true));
739        facts.set("Order.Amount", Value::Number(1500.0));
740
741        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
742        let result = expr.evaluate(&facts).unwrap();
743
744        assert_eq!(result, Value::Boolean(true));
745    }
746
747    #[test]
748    fn test_evaluate_logical_or() {
749        let facts = Facts::new();
750        facts.set("a", Value::Boolean(false));
751        facts.set("b", Value::Boolean(true));
752
753        let expr = ExpressionParser::parse("a == true || b == true").unwrap();
754        let result = expr.evaluate(&facts).unwrap();
755
756        assert_eq!(result, Value::Boolean(true));
757    }
758
759    #[test]
760    fn test_is_satisfied() {
761        let facts = Facts::new();
762        facts.set("User.IsVIP", Value::Boolean(true));
763
764        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
765        assert!(expr.is_satisfied(&facts));
766    }
767
768    #[test]
769    fn test_extract_fields() {
770        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
771        let fields = expr.extract_fields();
772
773        assert_eq!(fields.len(), 2);
774        assert!(fields.contains(&"User.IsVIP".to_string()));
775        assert!(fields.contains(&"Order.Amount".to_string()));
776    }
777
778    #[test]
779    fn test_parse_error_unclosed_parenthesis() {
780        let result = ExpressionParser::parse("(a == true");
781        assert!(result.is_err());
782    }
783
784    #[test]
785    fn test_parse_error_unterminated_string() {
786        let result = ExpressionParser::parse(r#""hello"#);
787        assert!(result.is_err());
788    }
789
790    #[test]
791    fn test_parse_complex_expression() {
792        let expr = ExpressionParser::parse(
793            "(User.IsVIP == true && Order.Amount > 1000) || (User.Points >= 100 && Order.Discount < 0.5)"
794        ).unwrap();
795
796        // Just check it parses without panicking
797        match expr {
798            Expression::Or { .. } => {}
799            _ => panic!("Expected OR at top level"),
800        }
801    }
802
803    #[test]
804    fn test_to_string() {
805        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
806        let s = expr.to_string();
807        assert!(s.contains("User.IsVIP"));
808        assert!(s.contains("true"));
809    }
810}