rust_rule_engine/backward/
expression.rs

1/// Expression AST for backward chaining queries
2///
3/// This module provides a proper Abstract Syntax Tree (AST) for parsing
4/// and evaluating backward chaining query expressions.
5///
6/// # Example
7/// ```ignore
8/// use rust_rule_engine::backward::expression::{Expression, ExpressionParser};
9///
10/// let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000")?;
11/// let result = expr.evaluate(&facts)?;
12/// ```
13
14use crate::types::{Value, Operator};
15use crate::errors::{Result, RuleEngineError};
16use crate::Facts;
17
18/// Expression AST node
19#[derive(Debug, Clone, PartialEq)]
20pub enum Expression {
21    /// Field reference (e.g., "User.IsVIP", "Order.Amount")
22    Field(String),
23
24    /// Literal value (e.g., true, false, 42, "hello")
25    Literal(Value),
26
27    /// Binary comparison (e.g., "X == Y", "A > B")
28    Comparison {
29        left: Box<Expression>,
30        operator: Operator,
31        right: Box<Expression>,
32    },
33
34    /// Logical AND operation (e.g., "A && B")
35    And {
36        left: Box<Expression>,
37        right: Box<Expression>,
38    },
39
40    /// Logical OR operation (e.g., "A || B")
41    Or {
42        left: Box<Expression>,
43        right: Box<Expression>,
44    },
45
46    /// Negation (e.g., "!X")
47    Not(Box<Expression>),
48
49    /// Variable (for future unification support, e.g., "?X", "?Customer")
50    Variable(String),
51}
52
53impl Expression {
54    /// Evaluate expression against facts
55    pub fn evaluate(&self, facts: &Facts) -> Result<Value> {
56        match self {
57            Expression::Field(name) => {
58                facts.get(name)
59                    .or_else(|| facts.get_nested(name))
60                    .ok_or_else(|| RuleEngineError::ExecutionError(
61                        format!("Field not found: {}", name)
62                    ))
63            }
64
65            Expression::Literal(value) => Ok(value.clone()),
66
67            Expression::Comparison { left, operator, right } => {
68                // Special handling for NotEqual when field doesn't exist
69                // If field doesn't exist, treat as Null
70                let left_val = left.evaluate(facts).unwrap_or(Value::Null);
71                let right_val = right.evaluate(facts).unwrap_or(Value::Null);
72
73                let result = operator.evaluate(&left_val, &right_val);
74                Ok(Value::Boolean(result))
75            }
76
77            Expression::And { left, right } => {
78                let left_val = left.evaluate(facts)?;
79                if !left_val.to_bool() {
80                    return Ok(Value::Boolean(false));
81                }
82                let right_val = right.evaluate(facts)?;
83                Ok(Value::Boolean(right_val.to_bool()))
84            }
85
86            Expression::Or { left, right } => {
87                let left_val = left.evaluate(facts)?;
88                if left_val.to_bool() {
89                    return Ok(Value::Boolean(true));
90                }
91                let right_val = right.evaluate(facts)?;
92                Ok(Value::Boolean(right_val.to_bool()))
93            }
94
95            Expression::Not(expr) => {
96                let value = expr.evaluate(facts)?;
97                Ok(Value::Boolean(!value.to_bool()))
98            }
99
100            Expression::Variable(var) => {
101                Err(RuleEngineError::ExecutionError(
102                    format!("Cannot evaluate unbound variable: {}", var)
103                ))
104            }
105        }
106    }
107
108    /// Check if expression is satisfied (returns true/false)
109    pub fn is_satisfied(&self, facts: &Facts) -> bool {
110        self.evaluate(facts)
111            .map(|v| v.to_bool())
112            .unwrap_or(false)
113    }
114
115    /// Extract all field references from expression
116    pub fn extract_fields(&self) -> Vec<String> {
117        let mut fields = Vec::new();
118        self.extract_fields_recursive(&mut fields);
119        fields
120    }
121
122    fn extract_fields_recursive(&self, fields: &mut Vec<String>) {
123        match self {
124            Expression::Field(name) => {
125                if !fields.contains(name) {
126                    fields.push(name.clone());
127                }
128            }
129            Expression::Comparison { left, right, .. } => {
130                left.extract_fields_recursive(fields);
131                right.extract_fields_recursive(fields);
132            }
133            Expression::And { left, right } | Expression::Or { left, right } => {
134                left.extract_fields_recursive(fields);
135                right.extract_fields_recursive(fields);
136            }
137            Expression::Not(expr) => {
138                expr.extract_fields_recursive(fields);
139            }
140            _ => {}
141        }
142    }
143
144    /// Convert to human-readable string
145    pub fn to_string(&self) -> String {
146        match self {
147            Expression::Field(name) => name.clone(),
148            Expression::Literal(val) => format!("{:?}", val),
149            Expression::Comparison { left, operator, right } => {
150                format!("{} {:?} {}", left.to_string(), operator, right.to_string())
151            }
152            Expression::And { left, right } => {
153                format!("({} && {})", left.to_string(), right.to_string())
154            }
155            Expression::Or { left, right } => {
156                format!("({} || {})", left.to_string(), right.to_string())
157            }
158            Expression::Not(expr) => {
159                format!("!{}", expr.to_string())
160            }
161            Expression::Variable(var) => var.clone(),
162        }
163    }
164}
165
166/// Expression parser using recursive descent parsing
167pub struct ExpressionParser {
168    input: Vec<char>,
169    position: usize,
170}
171
172impl ExpressionParser {
173    /// Create a new parser
174    pub fn new(input: &str) -> Self {
175        Self {
176            input: input.chars().collect(),
177            position: 0,
178        }
179    }
180
181    /// Parse expression from string
182    pub fn parse(input: &str) -> Result<Expression> {
183        let mut parser = Self::new(input.trim());
184        parser.parse_expression()
185    }
186
187    /// Parse full expression (handles ||)
188    fn parse_expression(&mut self) -> Result<Expression> {
189        let mut left = self.parse_and_expression()?;
190
191        while self.peek_operator("||") {
192            self.consume_operator("||");
193            let right = self.parse_and_expression()?;
194            left = Expression::Or {
195                left: Box::new(left),
196                right: Box::new(right),
197            };
198        }
199
200        Ok(left)
201    }
202
203    /// Parse AND expression (handles &&)
204    fn parse_and_expression(&mut self) -> Result<Expression> {
205        let mut left = self.parse_comparison()?;
206
207        while self.peek_operator("&&") {
208            self.consume_operator("&&");
209            let right = self.parse_comparison()?;
210            left = Expression::And {
211                left: Box::new(left),
212                right: Box::new(right),
213            };
214        }
215
216        Ok(left)
217    }
218
219    /// Parse comparison (e.g., "X == Y", "A > 5")
220    fn parse_comparison(&mut self) -> Result<Expression> {
221        let left = self.parse_primary()?;
222
223        // Check for comparison operators (check longer operators first)
224        let operator = if self.peek_operator("==") {
225            self.consume_operator("==");
226            Operator::Equal
227        } else if self.peek_operator("!=") {
228            self.consume_operator("!=");
229            Operator::NotEqual
230        } else if self.peek_operator(">=") {
231            self.consume_operator(">=");
232            Operator::GreaterThanOrEqual
233        } else if self.peek_operator("<=") {
234            self.consume_operator("<=");
235            Operator::LessThanOrEqual
236        } else if self.peek_operator(">") {
237            self.consume_operator(">");
238            Operator::GreaterThan
239        } else if self.peek_operator("<") {
240            self.consume_operator("<");
241            Operator::LessThan
242        } else {
243            // No comparison operator - return just the left side
244            return Ok(left);
245        };
246
247        let right = self.parse_primary()?;
248
249        Ok(Expression::Comparison {
250            left: Box::new(left),
251            operator,
252            right: Box::new(right),
253        })
254    }
255
256    /// Parse primary expression (field, literal, variable, or parenthesized)
257    fn parse_primary(&mut self) -> Result<Expression> {
258        self.skip_whitespace();
259
260        // Handle negation
261        if self.peek_char() == Some('!') {
262            self.consume_char();
263            let expr = self.parse_primary()?;
264            return Ok(Expression::Not(Box::new(expr)));
265        }
266
267        // Handle parentheses
268        if self.peek_char() == Some('(') {
269            self.consume_char();
270            let expr = self.parse_expression()?;
271            self.skip_whitespace();
272            if self.peek_char() != Some(')') {
273                return Err(RuleEngineError::ParseError {
274                    message: format!("Expected closing parenthesis at position {}", self.position),
275                });
276            }
277            self.consume_char();
278            return Ok(expr);
279        }
280
281        // Handle variables (?X, ?Customer)
282        if self.peek_char() == Some('?') {
283            self.consume_char();
284            let name = self.consume_identifier()?;
285            return Ok(Expression::Variable(format!("?{}", name)));
286        }
287
288        // Try to parse literal
289        if let Some(value) = self.try_parse_literal()? {
290            return Ok(Expression::Literal(value));
291        }
292
293        // Handle field reference
294        let field_name = self.consume_field_path()?;
295        Ok(Expression::Field(field_name))
296    }
297
298    fn consume_field_path(&mut self) -> Result<String> {
299        let mut path = String::new();
300
301        while let Some(ch) = self.peek_char() {
302            if ch.is_alphanumeric() || ch == '_' || ch == '.' {
303                path.push(ch);
304                self.consume_char();
305            } else {
306                break;
307            }
308        }
309
310        if path.is_empty() {
311            return Err(RuleEngineError::ParseError {
312                message: format!("Expected field name at position {}", self.position),
313            });
314        }
315
316        Ok(path)
317    }
318
319    fn consume_identifier(&mut self) -> Result<String> {
320        let mut ident = String::new();
321
322        while let Some(ch) = self.peek_char() {
323            if ch.is_alphanumeric() || ch == '_' {
324                ident.push(ch);
325                self.consume_char();
326            } else {
327                break;
328            }
329        }
330
331        if ident.is_empty() {
332            return Err(RuleEngineError::ParseError {
333                message: format!("Expected identifier at position {}", self.position),
334            });
335        }
336
337        Ok(ident)
338    }
339
340    fn try_parse_literal(&mut self) -> Result<Option<Value>> {
341        self.skip_whitespace();
342
343        // Boolean literals
344        if self.peek_word("true") {
345            self.consume_word("true");
346            return Ok(Some(Value::Boolean(true)));
347        }
348        if self.peek_word("false") {
349            self.consume_word("false");
350            return Ok(Some(Value::Boolean(false)));
351        }
352
353        // Null literal
354        if self.peek_word("null") {
355            self.consume_word("null");
356            return Ok(Some(Value::Null));
357        }
358
359        // String literals
360        if self.peek_char() == Some('"') {
361            self.consume_char();
362            let mut s = String::new();
363            let mut escaped = false;
364
365            while let Some(ch) = self.peek_char() {
366                if escaped {
367                    // Handle escape sequences
368                    let escaped_char = match ch {
369                        'n' => '\n',
370                        't' => '\t',
371                        'r' => '\r',
372                        '\\' => '\\',
373                        '"' => '"',
374                        _ => ch,
375                    };
376                    s.push(escaped_char);
377                    escaped = false;
378                    self.consume_char();
379                } else if ch == '\\' {
380                    escaped = true;
381                    self.consume_char();
382                } else if ch == '"' {
383                    self.consume_char();
384                    return Ok(Some(Value::String(s)));
385                } else {
386                    s.push(ch);
387                    self.consume_char();
388                }
389            }
390
391            return Err(RuleEngineError::ParseError {
392                message: format!("Unterminated string at position {}", self.position),
393            });
394        }
395
396        // Number literals
397        if let Some(ch) = self.peek_char() {
398            if ch.is_numeric() || ch == '-' {
399                let start_pos = self.position;
400                let mut num_str = String::new();
401                let mut has_dot = false;
402
403                while let Some(ch) = self.peek_char() {
404                    if ch.is_numeric() {
405                        num_str.push(ch);
406                        self.consume_char();
407                    } else if ch == '.' && !has_dot {
408                        has_dot = true;
409                        num_str.push(ch);
410                        self.consume_char();
411                    } else if ch == '-' && num_str.is_empty() {
412                        num_str.push(ch);
413                        self.consume_char();
414                    } else {
415                        break;
416                    }
417                }
418
419                if !num_str.is_empty() && num_str != "-" {
420                    if has_dot {
421                        if let Ok(n) = num_str.parse::<f64>() {
422                            return Ok(Some(Value::Number(n)));
423                        }
424                    } else if let Ok(i) = num_str.parse::<i64>() {
425                        return Ok(Some(Value::Integer(i)));
426                    }
427                }
428
429                // Failed to parse - reset position
430                self.position = start_pos;
431            }
432        }
433
434        Ok(None)
435    }
436
437    fn peek_char(&self) -> Option<char> {
438        if self.position < self.input.len() {
439            Some(self.input[self.position])
440        } else {
441            None
442        }
443    }
444
445    fn consume_char(&mut self) {
446        if self.position < self.input.len() {
447            self.position += 1;
448        }
449    }
450
451    fn peek_operator(&mut self, op: &str) -> bool {
452        self.skip_whitespace();
453        let remaining: String = self.input[self.position..].iter().collect();
454        remaining.starts_with(op)
455    }
456
457    fn consume_operator(&mut self, op: &str) {
458        self.skip_whitespace();
459        for _ in 0..op.len() {
460            self.consume_char();
461        }
462    }
463
464    fn peek_word(&mut self, word: &str) -> bool {
465        self.skip_whitespace();
466        let remaining: String = self.input[self.position..].iter().collect();
467
468        if remaining.starts_with(word) {
469            // Make sure it's a complete word (not prefix)
470            let next_pos = self.position + word.len();
471            if next_pos >= self.input.len() {
472                return true;
473            }
474            let next_char = self.input[next_pos];
475            !next_char.is_alphanumeric() && next_char != '_'
476        } else {
477            false
478        }
479    }
480
481    fn consume_word(&mut self, word: &str) {
482        self.skip_whitespace();
483        if self.peek_word(word) {
484            for _ in 0..word.len() {
485                self.consume_char();
486            }
487        }
488    }
489
490    fn skip_whitespace(&mut self) {
491        while let Some(ch) = self.peek_char() {
492            if ch.is_whitespace() {
493                self.consume_char();
494            } else {
495                break;
496            }
497        }
498    }
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn test_parse_simple_field() {
507        let expr = ExpressionParser::parse("User.IsVIP").unwrap();
508        match expr {
509            Expression::Field(name) => {
510                assert_eq!(name, "User.IsVIP");
511            }
512            _ => panic!("Expected field expression"),
513        }
514    }
515
516    #[test]
517    fn test_parse_literal_boolean() {
518        let expr = ExpressionParser::parse("true").unwrap();
519        match expr {
520            Expression::Literal(Value::Boolean(true)) => {}
521            _ => panic!("Expected boolean literal"),
522        }
523    }
524
525    #[test]
526    fn test_parse_literal_number() {
527        let expr = ExpressionParser::parse("42.5").unwrap();
528        match expr {
529            Expression::Literal(Value::Number(n)) => {
530                assert!((n - 42.5).abs() < 0.001);
531            }
532            _ => panic!("Expected number literal"),
533        }
534    }
535
536    #[test]
537    fn test_parse_literal_string() {
538        let expr = ExpressionParser::parse(r#""hello world""#).unwrap();
539        match expr {
540            Expression::Literal(Value::String(s)) => {
541                assert_eq!(s, "hello world");
542            }
543            _ => panic!("Expected string literal"),
544        }
545    }
546
547    #[test]
548    fn test_parse_simple_comparison() {
549        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
550        match expr {
551            Expression::Comparison { operator, .. } => {
552                assert_eq!(operator, Operator::Equal);
553            }
554            _ => panic!("Expected comparison"),
555        }
556    }
557
558    #[test]
559    fn test_parse_all_comparison_operators() {
560        let operators = vec![
561            ("a == b", Operator::Equal),
562            ("a != b", Operator::NotEqual),
563            ("a > b", Operator::GreaterThan),
564            ("a >= b", Operator::GreaterThanOrEqual),
565            ("a < b", Operator::LessThan),
566            ("a <= b", Operator::LessThanOrEqual),
567        ];
568
569        for (input, expected_op) in operators {
570            let expr = ExpressionParser::parse(input).unwrap();
571            match expr {
572                Expression::Comparison { operator, .. } => {
573                    assert_eq!(operator, expected_op, "Failed for: {}", input);
574                }
575                _ => panic!("Expected comparison for: {}", input),
576            }
577        }
578    }
579
580    #[test]
581    fn test_parse_logical_and() {
582        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
583        match expr {
584            Expression::And { .. } => {}
585            _ => panic!("Expected logical AND, got: {:?}", expr),
586        }
587    }
588
589    #[test]
590    fn test_parse_logical_or() {
591        let expr = ExpressionParser::parse("a == true || b == true").unwrap();
592        match expr {
593            Expression::Or { .. } => {}
594            _ => panic!("Expected logical OR"),
595        }
596    }
597
598    #[test]
599    fn test_parse_negation() {
600        let expr = ExpressionParser::parse("!User.IsBanned").unwrap();
601        match expr {
602            Expression::Not(_) => {}
603            _ => panic!("Expected negation"),
604        }
605    }
606
607    #[test]
608    fn test_parse_parentheses() {
609        let expr = ExpressionParser::parse("(a == true || b == true) && c == true").unwrap();
610        match expr {
611            Expression::And { left, .. } => {
612                match *left {
613                    Expression::Or { .. } => {}
614                    _ => panic!("Expected OR inside AND"),
615                }
616            }
617            _ => panic!("Expected AND"),
618        }
619    }
620
621    #[test]
622    fn test_parse_variable() {
623        let expr = ExpressionParser::parse("?X == true").unwrap();
624        match expr {
625            Expression::Comparison { left, .. } => {
626                match *left {
627                    Expression::Variable(var) => {
628                        assert_eq!(var, "?X");
629                    }
630                    _ => panic!("Expected variable"),
631                }
632            }
633            _ => panic!("Expected comparison"),
634        }
635    }
636
637    #[test]
638    fn test_evaluate_simple() {
639        let mut facts = Facts::new();
640        facts.set("User.IsVIP", Value::Boolean(true));
641
642        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
643        let result = expr.evaluate(&facts).unwrap();
644
645        assert_eq!(result, Value::Boolean(true));
646    }
647
648    #[test]
649    fn test_evaluate_comparison() {
650        let mut facts = Facts::new();
651        facts.set("Order.Amount", Value::Number(1500.0));
652
653        let expr = ExpressionParser::parse("Order.Amount > 1000").unwrap();
654        let result = expr.evaluate(&facts).unwrap();
655
656        assert_eq!(result, Value::Boolean(true));
657    }
658
659    #[test]
660    fn test_evaluate_logical_and() {
661        let mut facts = Facts::new();
662        facts.set("User.IsVIP", Value::Boolean(true));
663        facts.set("Order.Amount", Value::Number(1500.0));
664
665        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
666        let result = expr.evaluate(&facts).unwrap();
667
668        assert_eq!(result, Value::Boolean(true));
669    }
670
671    #[test]
672    fn test_evaluate_logical_or() {
673        let mut facts = Facts::new();
674        facts.set("a", Value::Boolean(false));
675        facts.set("b", Value::Boolean(true));
676
677        let expr = ExpressionParser::parse("a == true || b == true").unwrap();
678        let result = expr.evaluate(&facts).unwrap();
679
680        assert_eq!(result, Value::Boolean(true));
681    }
682
683    #[test]
684    fn test_is_satisfied() {
685        let mut facts = Facts::new();
686        facts.set("User.IsVIP", Value::Boolean(true));
687
688        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
689        assert!(expr.is_satisfied(&facts));
690    }
691
692    #[test]
693    fn test_extract_fields() {
694        let expr = ExpressionParser::parse("User.IsVIP == true && Order.Amount > 1000").unwrap();
695        let fields = expr.extract_fields();
696
697        assert_eq!(fields.len(), 2);
698        assert!(fields.contains(&"User.IsVIP".to_string()));
699        assert!(fields.contains(&"Order.Amount".to_string()));
700    }
701
702    #[test]
703    fn test_parse_error_unclosed_parenthesis() {
704        let result = ExpressionParser::parse("(a == true");
705        assert!(result.is_err());
706    }
707
708    #[test]
709    fn test_parse_error_unterminated_string() {
710        let result = ExpressionParser::parse(r#""hello"#);
711        assert!(result.is_err());
712    }
713
714    #[test]
715    fn test_parse_complex_expression() {
716        let expr = ExpressionParser::parse(
717            "(User.IsVIP == true && Order.Amount > 1000) || (User.Points >= 100 && Order.Discount < 0.5)"
718        ).unwrap();
719
720        // Just check it parses without panicking
721        match expr {
722            Expression::Or { .. } => {}
723            _ => panic!("Expected OR at top level"),
724        }
725    }
726
727    #[test]
728    fn test_to_string() {
729        let expr = ExpressionParser::parse("User.IsVIP == true").unwrap();
730        let s = expr.to_string();
731        assert!(s.contains("User.IsVIP"));
732        assert!(s.contains("true"));
733    }
734}