rust_rule_engine/backward/
expression.rs

1//! Expression AST for backward chaining queries
2//!
3//! This module provides a robust Abstract Syntax Tree (AST) implementation for parsing
4//! and evaluating goal expressions in backward chaining queries. It uses a recursive
5//! descent parser to build a typed expression tree that can be evaluated against facts.
6//!
7//! # Features
8//!
9//! - **Field references** - Access fact values using dot notation (e.g., `User.Name`, `Order.Total`)
10//! - **Literal values** - Support for boolean, numeric, and string literals
11//! - **Comparison operators** - `==`, `!=`, `>`, `<`, `>=`, `<=`
12//! - **Logical operators** - `&&` (AND), `||` (OR), `!` (NOT)
13//! - **Parenthesized expressions** - Group sub-expressions with parentheses
14//! - **Variable binding** - Placeholder variables for unification (e.g., `?x`, `?name`)
15//! - **Operator precedence** - Proper precedence: `||` < `&&` < comparisons
16//!
17//! # Expression Syntax
18//!
19//! ```text
20//! Expression Grammar:
21//!   expr        ::= or_expr
22//!   or_expr     ::= and_expr ('||' and_expr)*
23//!   and_expr    ::= comparison ('&&' comparison)*
24//!   comparison  ::= primary (COMP_OP primary)?
25//!   primary     ::= '!' primary
26//!                 | '(' expr ')'
27//!                 | field
28//!                 | literal
29//!                 | variable
30//!
31//!   COMP_OP     ::= '==' | '!=' | '>' | '<' | '>=' | '<='
32//!   field       ::= IDENT ('.' IDENT)*
33//!   literal     ::= boolean | number | string
34//!   variable    ::= '?' IDENT
35//! ```
36//!
37//! # Example
38//!
39//! ```rust
40//! use rust_rule_engine::backward::expression::{Expression, ExpressionParser};
41//! use rust_rule_engine::{Facts, Value};
42//!
43//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
44//! // Parse a complex expression
45//! let expr = ExpressionParser::parse(
46//!     "User.IsVIP == true && Order.Amount > 1000"
47//! )?;
48//!
49//! // Create facts
50//! let mut facts = Facts::new();
51//! facts.set("User.IsVIP", Value::Boolean(true));
52//! facts.set("Order.Amount", Value::Number(1500.0));
53//!
54//! // Evaluate expression
55//! let satisfied = expr.is_satisfied(&facts);
56//! assert!(satisfied);
57//!
58//! // Extract referenced fields
59//! let fields = expr.extract_fields();
60//! assert!(fields.contains(&"User.IsVIP".to_string()));
61//! assert!(fields.contains(&"Order.Amount".to_string()));
62//! # Ok(())
63//! # }
64//! ```
65//!
66//! # Supported Operators
67//!
68//! ## Comparison Operators (left-to-right, equal precedence)
69//! - `==` - Equal to
70//! - `!=` - Not equal to
71//! - `>` - Greater than
72//! - `<` - Less than
73//! - `>=` - Greater than or equal to
74//! - `<=` - Less than or equal to
75//!
76//! ## Logical Operators (precedence: low to high)
77//! - `||` - Logical OR (lowest precedence)
78//! - `&&` - Logical AND (medium precedence)
79//! - `!` - Logical NOT (highest precedence)
80//!
81//! # Error Handling
82//!
83//! The parser returns descriptive errors for invalid syntax:
84//! - Unclosed parentheses
85//! - Unterminated strings
86//! - Invalid operators
87//! - Unexpected tokens
88
89use crate::types::{Value, Operator};
90use crate::errors::{Result, RuleEngineError};
91use crate::Facts;
92
93/// Expression AST node
94#[derive(Debug, Clone, PartialEq)]
95pub enum Expression {
96    /// Field reference (e.g., "User.IsVIP", "Order.Amount")
97    Field(String),
98
99    /// Literal value (e.g., true, false, 42, "hello")
100    Literal(Value),
101
102    /// Binary comparison (e.g., "X == Y", "A > B")
103    Comparison {
104        left: Box<Expression>,
105        operator: Operator,
106        right: Box<Expression>,
107    },
108
109    /// Logical AND operation (e.g., "A && B")
110    And {
111        left: Box<Expression>,
112        right: Box<Expression>,
113    },
114
115    /// Logical OR operation (e.g., "A || B")
116    Or {
117        left: Box<Expression>,
118        right: Box<Expression>,
119    },
120
121    /// Negation (e.g., "!X")
122    Not(Box<Expression>),
123
124    /// Variable (for future unification support, e.g., "?X", "?Customer")
125    Variable(String),
126}
127
128impl Expression {
129    /// Evaluate expression against facts
130    pub fn evaluate(&self, facts: &Facts) -> Result<Value> {
131        match self {
132            Expression::Field(name) => {
133                facts.get(name)
134                    .or_else(|| facts.get_nested(name))
135                    .ok_or_else(|| RuleEngineError::ExecutionError(
136                        format!("Field not found: {}", name)
137                    ))
138            }
139
140            Expression::Literal(value) => Ok(value.clone()),
141
142            Expression::Comparison { left, operator, right } => {
143                // Special handling for NotEqual when field doesn't exist
144                // If field doesn't exist, treat as Null
145                let left_val = left.evaluate(facts).unwrap_or(Value::Null);
146                let right_val = right.evaluate(facts).unwrap_or(Value::Null);
147
148                let result = operator.evaluate(&left_val, &right_val);
149                Ok(Value::Boolean(result))
150            }
151
152            Expression::And { left, right } => {
153                let left_val = left.evaluate(facts)?;
154                if !left_val.to_bool() {
155                    return Ok(Value::Boolean(false));
156                }
157                let right_val = right.evaluate(facts)?;
158                Ok(Value::Boolean(right_val.to_bool()))
159            }
160
161            Expression::Or { left, right } => {
162                let left_val = left.evaluate(facts)?;
163                if left_val.to_bool() {
164                    return Ok(Value::Boolean(true));
165                }
166                let right_val = right.evaluate(facts)?;
167                Ok(Value::Boolean(right_val.to_bool()))
168            }
169
170            Expression::Not(expr) => {
171                let value = expr.evaluate(facts)?;
172                Ok(Value::Boolean(!value.to_bool()))
173            }
174
175            Expression::Variable(var) => {
176                Err(RuleEngineError::ExecutionError(
177                    format!("Cannot evaluate unbound variable: {}", var)
178                ))
179            }
180        }
181    }
182
183    /// Check if expression is satisfied (returns true/false)
184    pub fn is_satisfied(&self, facts: &Facts) -> bool {
185        self.evaluate(facts)
186            .map(|v| v.to_bool())
187            .unwrap_or(false)
188    }
189
190    /// Extract all field references from expression
191    pub fn extract_fields(&self) -> Vec<String> {
192        let mut fields = Vec::new();
193        self.extract_fields_recursive(&mut fields);
194        fields
195    }
196
197    fn extract_fields_recursive(&self, fields: &mut Vec<String>) {
198        match self {
199            Expression::Field(name) => {
200                if !fields.contains(name) {
201                    fields.push(name.clone());
202                }
203            }
204            Expression::Comparison { left, right, .. } => {
205                left.extract_fields_recursive(fields);
206                right.extract_fields_recursive(fields);
207            }
208            Expression::And { left, right } | Expression::Or { left, right } => {
209                left.extract_fields_recursive(fields);
210                right.extract_fields_recursive(fields);
211            }
212            Expression::Not(expr) => {
213                expr.extract_fields_recursive(fields);
214            }
215            _ => {}
216        }
217    }
218
219    /// Convert to human-readable string
220    pub fn to_string(&self) -> String {
221        match self {
222            Expression::Field(name) => name.clone(),
223            Expression::Literal(val) => format!("{:?}", val),
224            Expression::Comparison { left, operator, right } => {
225                format!("{} {:?} {}", left.to_string(), operator, right.to_string())
226            }
227            Expression::And { left, right } => {
228                format!("({} && {})", left.to_string(), right.to_string())
229            }
230            Expression::Or { left, right } => {
231                format!("({} || {})", left.to_string(), right.to_string())
232            }
233            Expression::Not(expr) => {
234                format!("!{}", expr.to_string())
235            }
236            Expression::Variable(var) => var.clone(),
237        }
238    }
239}
240
241/// Expression parser using recursive descent parsing
242pub struct ExpressionParser {
243    input: Vec<char>,
244    position: usize,
245}
246
247impl ExpressionParser {
248    /// Create a new parser
249    pub fn new(input: &str) -> Self {
250        Self {
251            input: input.chars().collect(),
252            position: 0,
253        }
254    }
255
256    /// Parse expression from string
257    pub fn parse(input: &str) -> Result<Expression> {
258        let mut parser = Self::new(input.trim());
259        parser.parse_expression()
260    }
261
262    /// Parse full expression (handles ||)
263    fn parse_expression(&mut self) -> Result<Expression> {
264        let mut left = self.parse_and_expression()?;
265
266        while self.peek_operator("||") {
267            self.consume_operator("||");
268            let right = self.parse_and_expression()?;
269            left = Expression::Or {
270                left: Box::new(left),
271                right: Box::new(right),
272            };
273        }
274
275        Ok(left)
276    }
277
278    /// Parse AND expression (handles &&)
279    fn parse_and_expression(&mut self) -> Result<Expression> {
280        let mut left = self.parse_comparison()?;
281
282        while self.peek_operator("&&") {
283            self.consume_operator("&&");
284            let right = self.parse_comparison()?;
285            left = Expression::And {
286                left: Box::new(left),
287                right: Box::new(right),
288            };
289        }
290
291        Ok(left)
292    }
293
294    /// Parse comparison (e.g., "X == Y", "A > 5")
295    fn parse_comparison(&mut self) -> Result<Expression> {
296        let left = self.parse_primary()?;
297
298        // Check for comparison operators (check longer operators first)
299        let operator = if self.peek_operator("==") {
300            self.consume_operator("==");
301            Operator::Equal
302        } else if self.peek_operator("!=") {
303            self.consume_operator("!=");
304            Operator::NotEqual
305        } else if self.peek_operator(">=") {
306            self.consume_operator(">=");
307            Operator::GreaterThanOrEqual
308        } else if self.peek_operator("<=") {
309            self.consume_operator("<=");
310            Operator::LessThanOrEqual
311        } else if self.peek_operator(">") {
312            self.consume_operator(">");
313            Operator::GreaterThan
314        } else if self.peek_operator("<") {
315            self.consume_operator("<");
316            Operator::LessThan
317        } else {
318            // No comparison operator - return just the left side
319            return Ok(left);
320        };
321
322        let right = self.parse_primary()?;
323
324        Ok(Expression::Comparison {
325            left: Box::new(left),
326            operator,
327            right: Box::new(right),
328        })
329    }
330
331    /// Parse primary expression (field, literal, variable, or parenthesized)
332    fn parse_primary(&mut self) -> Result<Expression> {
333        self.skip_whitespace();
334
335        // Handle negation
336        if self.peek_char() == Some('!') {
337            self.consume_char();
338            let expr = self.parse_primary()?;
339            return Ok(Expression::Not(Box::new(expr)));
340        }
341
342        // Handle parentheses
343        if self.peek_char() == Some('(') {
344            self.consume_char();
345            let expr = self.parse_expression()?;
346            self.skip_whitespace();
347            if self.peek_char() != Some(')') {
348                return Err(RuleEngineError::ParseError {
349                    message: format!("Expected closing parenthesis at position {}", self.position),
350                });
351            }
352            self.consume_char();
353            return Ok(expr);
354        }
355
356        // Handle variables (?X, ?Customer)
357        if self.peek_char() == Some('?') {
358            self.consume_char();
359            let name = self.consume_identifier()?;
360            return Ok(Expression::Variable(format!("?{}", name)));
361        }
362
363        // Try to parse literal
364        if let Some(value) = self.try_parse_literal()? {
365            return Ok(Expression::Literal(value));
366        }
367
368        // Handle field reference
369        let field_name = self.consume_field_path()?;
370        Ok(Expression::Field(field_name))
371    }
372
373    fn consume_field_path(&mut self) -> Result<String> {
374        let mut path = String::new();
375
376        while let Some(ch) = self.peek_char() {
377            if ch.is_alphanumeric() || ch == '_' || ch == '.' {
378                path.push(ch);
379                self.consume_char();
380            } else {
381                break;
382            }
383        }
384
385        if path.is_empty() {
386            return Err(RuleEngineError::ParseError {
387                message: format!("Expected field name at position {}", self.position),
388            });
389        }
390
391        Ok(path)
392    }
393
394    fn consume_identifier(&mut self) -> Result<String> {
395        let mut ident = String::new();
396
397        while let Some(ch) = self.peek_char() {
398            if ch.is_alphanumeric() || ch == '_' {
399                ident.push(ch);
400                self.consume_char();
401            } else {
402                break;
403            }
404        }
405
406        if ident.is_empty() {
407            return Err(RuleEngineError::ParseError {
408                message: format!("Expected identifier at position {}", self.position),
409            });
410        }
411
412        Ok(ident)
413    }
414
415    fn try_parse_literal(&mut self) -> Result<Option<Value>> {
416        self.skip_whitespace();
417
418        // Boolean literals
419        if self.peek_word("true") {
420            self.consume_word("true");
421            return Ok(Some(Value::Boolean(true)));
422        }
423        if self.peek_word("false") {
424            self.consume_word("false");
425            return Ok(Some(Value::Boolean(false)));
426        }
427
428        // Null literal
429        if self.peek_word("null") {
430            self.consume_word("null");
431            return Ok(Some(Value::Null));
432        }
433
434        // String literals
435        if self.peek_char() == Some('"') {
436            self.consume_char();
437            let mut s = String::new();
438            let mut escaped = false;
439
440            while let Some(ch) = self.peek_char() {
441                if escaped {
442                    // Handle escape sequences
443                    let escaped_char = match ch {
444                        'n' => '\n',
445                        't' => '\t',
446                        'r' => '\r',
447                        '\\' => '\\',
448                        '"' => '"',
449                        _ => ch,
450                    };
451                    s.push(escaped_char);
452                    escaped = false;
453                    self.consume_char();
454                } else if ch == '\\' {
455                    escaped = true;
456                    self.consume_char();
457                } else if ch == '"' {
458                    self.consume_char();
459                    return Ok(Some(Value::String(s)));
460                } else {
461                    s.push(ch);
462                    self.consume_char();
463                }
464            }
465
466            return Err(RuleEngineError::ParseError {
467                message: format!("Unterminated string at position {}", self.position),
468            });
469        }
470
471        // Number literals
472        if let Some(ch) = self.peek_char() {
473            if ch.is_numeric() || ch == '-' {
474                let start_pos = self.position;
475                let mut num_str = String::new();
476                let mut has_dot = false;
477
478                while let Some(ch) = self.peek_char() {
479                    if ch.is_numeric() {
480                        num_str.push(ch);
481                        self.consume_char();
482                    } else if ch == '.' && !has_dot {
483                        has_dot = true;
484                        num_str.push(ch);
485                        self.consume_char();
486                    } else if ch == '-' && num_str.is_empty() {
487                        num_str.push(ch);
488                        self.consume_char();
489                    } else {
490                        break;
491                    }
492                }
493
494                if !num_str.is_empty() && num_str != "-" {
495                    if has_dot {
496                        if let Ok(n) = num_str.parse::<f64>() {
497                            return Ok(Some(Value::Number(n)));
498                        }
499                    } else if let Ok(i) = num_str.parse::<i64>() {
500                        return Ok(Some(Value::Number(i as f64)));
501                    }
502                }
503
504                // Failed to parse - reset position
505                self.position = start_pos;
506            }
507        }
508
509        Ok(None)
510    }
511
512    fn peek_char(&self) -> Option<char> {
513        if self.position < self.input.len() {
514            Some(self.input[self.position])
515        } else {
516            None
517        }
518    }
519
520    fn consume_char(&mut self) {
521        if self.position < self.input.len() {
522            self.position += 1;
523        }
524    }
525
526    fn peek_operator(&mut self, op: &str) -> bool {
527        self.skip_whitespace();
528        let remaining: String = self.input[self.position..].iter().collect();
529        remaining.starts_with(op)
530    }
531
532    fn consume_operator(&mut self, op: &str) {
533        self.skip_whitespace();
534        for _ in 0..op.len() {
535            self.consume_char();
536        }
537    }
538
539    fn peek_word(&mut self, word: &str) -> bool {
540        self.skip_whitespace();
541        let remaining: String = self.input[self.position..].iter().collect();
542
543        if remaining.starts_with(word) {
544            // Make sure it's a complete word (not prefix)
545            let next_pos = self.position + word.len();
546            if next_pos >= self.input.len() {
547                return true;
548            }
549            let next_char = self.input[next_pos];
550            !next_char.is_alphanumeric() && next_char != '_'
551        } else {
552            false
553        }
554    }
555
556    fn consume_word(&mut self, word: &str) {
557        self.skip_whitespace();
558        if self.peek_word(word) {
559            for _ in 0..word.len() {
560                self.consume_char();
561            }
562        }
563    }
564
565    fn skip_whitespace(&mut self) {
566        while let Some(ch) = self.peek_char() {
567            if ch.is_whitespace() {
568                self.consume_char();
569            } else {
570                break;
571            }
572        }
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_parse_simple_field() {
582        let expr = ExpressionParser::parse("User.IsVIP").unwrap();
583        match expr {
584            Expression::Field(name) => {
585                assert_eq!(name, "User.IsVIP");
586            }
587            _ => panic!("Expected field expression"),
588        }
589    }
590
591    #[test]
592    fn test_parse_literal_boolean() {
593        let expr = ExpressionParser::parse("true").unwrap();
594        match expr {
595            Expression::Literal(Value::Boolean(true)) => {}
596            _ => panic!("Expected boolean literal"),
597        }
598    }
599
600    #[test]
601    fn test_parse_literal_number() {
602        let expr = ExpressionParser::parse("42.5").unwrap();
603        match expr {
604            Expression::Literal(Value::Number(n)) => {
605                assert!((n - 42.5).abs() < 0.001);
606            }
607            _ => panic!("Expected number literal"),
608        }
609    }
610
611    #[test]
612    fn test_parse_literal_string() {
613        let expr = ExpressionParser::parse(r#""hello world""#).unwrap();
614        match expr {
615            Expression::Literal(Value::String(s)) => {
616                assert_eq!(s, "hello world");
617            }
618            _ => panic!("Expected string literal"),
619        }
620    }
621
622    #[test]
623    fn test_parse_simple_comparison() {
624        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
625        match expr {
626            Expression::Comparison { operator, .. } => {
627                assert_eq!(operator, Operator::Equal);
628            }
629            _ => panic!("Expected comparison"),
630        }
631    }
632
633    #[test]
634    fn test_parse_all_comparison_operators() {
635        let operators = vec![
636            ("a == b", Operator::Equal),
637            ("a != b", Operator::NotEqual),
638            ("a > b", Operator::GreaterThan),
639            ("a >= b", Operator::GreaterThanOrEqual),
640            ("a < b", Operator::LessThan),
641            ("a <= b", Operator::LessThanOrEqual),
642        ];
643
644        for (input, expected_op) in operators {
645            let expr = ExpressionParser::parse(input).unwrap();
646            match expr {
647                Expression::Comparison { operator, .. } => {
648                    assert_eq!(operator, expected_op, "Failed for: {}", input);
649                }
650                _ => panic!("Expected comparison for: {}", input),
651            }
652        }
653    }
654
655    #[test]
656    fn test_parse_logical_and() {
657        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
658        match expr {
659            Expression::And { .. } => {}
660            _ => panic!("Expected logical AND, got: {:?}", expr),
661        }
662    }
663
664    #[test]
665    fn test_parse_logical_or() {
666        let expr = ExpressionParser::parse("a == true || b == true").unwrap();
667        match expr {
668            Expression::Or { .. } => {}
669            _ => panic!("Expected logical OR"),
670        }
671    }
672
673    #[test]
674    fn test_parse_negation() {
675        let expr = ExpressionParser::parse("!User.IsBanned").unwrap();
676        match expr {
677            Expression::Not(_) => {}
678            _ => panic!("Expected negation"),
679        }
680    }
681
682    #[test]
683    fn test_parse_parentheses() {
684        let expr = ExpressionParser::parse("(a == true || b == true) && c == true").unwrap();
685        match expr {
686            Expression::And { left, .. } => {
687                match *left {
688                    Expression::Or { .. } => {}
689                    _ => panic!("Expected OR inside AND"),
690                }
691            }
692            _ => panic!("Expected AND"),
693        }
694    }
695
696    #[test]
697    fn test_parse_variable() {
698        let expr = ExpressionParser::parse("?X == true").unwrap();
699        match expr {
700            Expression::Comparison { left, .. } => {
701                match *left {
702                    Expression::Variable(var) => {
703                        assert_eq!(var, "?X");
704                    }
705                    _ => panic!("Expected variable"),
706                }
707            }
708            _ => panic!("Expected comparison"),
709        }
710    }
711
712    #[test]
713    fn test_evaluate_simple() {
714        let mut facts = Facts::new();
715        facts.set("User.IsVIP", Value::Boolean(true));
716
717        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
718        let result = expr.evaluate(&facts).unwrap();
719
720        assert_eq!(result, Value::Boolean(true));
721    }
722
723    #[test]
724    fn test_evaluate_comparison() {
725        let mut facts = Facts::new();
726        facts.set("Order.Amount", Value::Number(1500.0));
727
728        let expr = ExpressionParser::parse("Order.Amount > 1000").unwrap();
729        let result = expr.evaluate(&facts).unwrap();
730
731        assert_eq!(result, Value::Boolean(true));
732    }
733
734    #[test]
735    fn test_evaluate_logical_and() {
736        let mut facts = Facts::new();
737        facts.set("User.IsVIP", Value::Boolean(true));
738        facts.set("Order.Amount", Value::Number(1500.0));
739
740        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
741        let result = expr.evaluate(&facts).unwrap();
742
743        assert_eq!(result, Value::Boolean(true));
744    }
745
746    #[test]
747    fn test_evaluate_logical_or() {
748        let mut facts = Facts::new();
749        facts.set("a", Value::Boolean(false));
750        facts.set("b", Value::Boolean(true));
751
752        let expr = ExpressionParser::parse("a == true || b == true").unwrap();
753        let result = expr.evaluate(&facts).unwrap();
754
755        assert_eq!(result, Value::Boolean(true));
756    }
757
758    #[test]
759    fn test_is_satisfied() {
760        let mut facts = Facts::new();
761        facts.set("User.IsVIP", Value::Boolean(true));
762
763        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
764        assert!(expr.is_satisfied(&facts));
765    }
766
767    #[test]
768    fn test_extract_fields() {
769        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
770        let fields = expr.extract_fields();
771
772        assert_eq!(fields.len(), 2);
773        assert!(fields.contains(&"User.IsVIP".to_string()));
774        assert!(fields.contains(&"Order.Amount".to_string()));
775    }
776
777    #[test]
778    fn test_parse_error_unclosed_parenthesis() {
779        let result = ExpressionParser::parse("(a == true");
780        assert!(result.is_err());
781    }
782
783    #[test]
784    fn test_parse_error_unterminated_string() {
785        let result = ExpressionParser::parse(r#""hello"#);
786        assert!(result.is_err());
787    }
788
789    #[test]
790    fn test_parse_complex_expression() {
791        let expr = ExpressionParser::parse(
792            "(User.IsVIP == true && Order.Amount > 1000) || (User.Points >= 100 && Order.Discount < 0.5)"
793        ).unwrap();
794
795        // Just check it parses without panicking
796        match expr {
797            Expression::Or { .. } => {}
798            _ => panic!("Expected OR at top level"),
799        }
800    }
801
802    #[test]
803    fn test_to_string() {
804        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
805        let s = expr.to_string();
806        assert!(s.contains("User.IsVIP"));
807        assert!(s.contains("true"));
808    }
809}