rust_rule_engine/rete/
alpha.rs

1//! AlphaNode: checks single condition on a fact
2
3use super::facts::{FactValue, TypedFacts};
4
5#[derive(Debug, Clone)]
6pub struct AlphaNode {
7    pub field: String,
8    pub operator: String,
9    pub value: String,
10}
11
12
13impl AlphaNode {
14    /// Match with string-based facts (backward compatible)
15    pub fn matches(&self, fact_field: &str, fact_value: &str) -> bool {
16        if self.field != fact_field {
17            return false;
18        }
19        match self.operator.as_str() {
20            "==" => fact_value == self.value,
21            "!=" => fact_value != self.value,
22            ">" => parse_num(fact_value) > parse_num(&self.value),
23            "<" => parse_num(fact_value) < parse_num(&self.value),
24            ">=" => parse_num(fact_value) >= parse_num(&self.value),
25            "<=" => parse_num(fact_value) <= parse_num(&self.value),
26            "contains" => fact_value.contains(&self.value),
27            "startsWith" => fact_value.starts_with(&self.value),
28            "endsWith" => fact_value.ends_with(&self.value),
29            "matches" => wildcard_match(fact_value, &self.value),
30            _ => false,
31        }
32    }
33
34    /// Match with typed facts (new!)
35    pub fn matches_typed(&self, facts: &TypedFacts) -> bool {
36        // Check if this is a Test CE (arithmetic expression)
37        // Test CE fields look like "test(User.Age % 3 == 0)"
38        if self.field.starts_with("test(") && self.field.ends_with(')') {
39            // Extract the expression: "test(expr)" -> "expr"
40            let expr = &self.field[5..self.field.len() - 1];
41            
42            // Try to evaluate as arithmetic expression
43            if let Some(result) = self.evaluate_arithmetic_rete(expr, facts) {
44                // Compare result with expected value
45                let expected_value = self.parse_value_string(&self.value);
46                return match (&result, &expected_value) {
47                    (FactValue::Boolean(r), FactValue::Boolean(e)) => r == e,
48                    _ => false,
49                };
50            }
51            return false;
52        }
53        
54        // Check if the value is a variable reference (field name in facts)
55        // This enables variable-to-variable comparison like "L1 > L1Min"
56        let expected_value = if let Some(var_value) = facts.get(&self.value) {
57            // Value is a field name - use the field's value for comparison
58            var_value.clone()
59        } else {
60            // Value is a literal - parse it
61            self.parse_value_string(&self.value)
62        };
63
64        facts.evaluate_condition(&self.field, &self.operator, &expected_value)
65    }
66
67    /// Parse value string into FactValue
68    fn parse_value_string(&self, s: &str) -> FactValue {
69        // Try to parse as different types
70        if let Ok(i) = s.parse::<i64>() {
71            FactValue::Integer(i)
72        } else if let Ok(f) = s.parse::<f64>() {
73            FactValue::Float(f)
74        } else if let Ok(b) = s.parse::<bool>() {
75            FactValue::Boolean(b)
76        } else if s == "null" {
77            FactValue::Null
78        } else {
79            FactValue::String(s.to_string())
80        }
81    }
82
83    /// Evaluate arithmetic expression for RETE
84    /// Handles expressions like "User.Age % 3 == 0", "Price * 2 > 100"
85    fn evaluate_arithmetic_rete(&self, expr: &str, facts: &TypedFacts) -> Option<FactValue> {
86        // Split by comparison operators
87        let ops = ["==", "!=", ">=", "<=", ">", "<"];
88        for op in &ops {
89            if let Some(pos) = expr.find(op) {
90                let left = expr[..pos].trim();
91                let right = expr[pos + op.len()..].trim();
92                
93                // Evaluate left side (arithmetic expression)
94                let left_val = self.evaluate_arithmetic_expr(left, facts)?;
95                
96                // Evaluate right side
97                let right_val = if let Some(val) = facts.get(right) {
98                    val.clone()
99                } else if let Ok(i) = right.parse::<i64>() {
100                    FactValue::Integer(i)
101                } else if let Ok(f) = right.parse::<f64>() {
102                    FactValue::Float(f)
103                } else {
104                    return None;
105                };
106                
107                // Compare values
108                let result = left_val.compare(op, &right_val);
109                return Some(FactValue::Boolean(result));
110            }
111        }
112        None
113    }
114
115    /// Evaluate arithmetic expression (handles +, -, *, /, %)
116    fn evaluate_arithmetic_expr(&self, expr: &str, facts: &TypedFacts) -> Option<FactValue> {
117        let expr = expr.trim();
118        
119        // Try arithmetic operators in order of precedence (reverse)
120        let ops = ["+", "-", "*", "/", "%"];
121        
122        for op in &ops {
123            if let Some(pos) = expr.rfind(op) {
124                // Skip if operator is at the start (negative number)
125                if pos == 0 {
126                    continue;
127                }
128                
129                let left = expr[..pos].trim();
130                let right = expr[pos + 1..].trim();
131                
132                let left_val = if let Some(val) = facts.get(left) {
133                    val.as_number()?
134                } else if let Ok(f) = left.parse::<f64>() {
135                    f
136                } else {
137                    // Recursive evaluation
138                    self.evaluate_arithmetic_expr(left, facts)?.as_number()?
139                };
140                
141                let right_val = if let Some(val) = facts.get(right) {
142                    val.as_number()?
143                } else if let Ok(f) = right.parse::<f64>() {
144                    f
145                } else {
146                    self.evaluate_arithmetic_expr(right, facts)?.as_number()?
147                };
148                
149                let result = match *op {
150                    "+" => left_val + right_val,
151                    "-" => left_val - right_val,
152                    "*" => left_val * right_val,
153                    "/" => if right_val != 0.0 { left_val / right_val } else { return None; },
154                    "%" => left_val % right_val,
155                    _ => return None,
156                };
157                
158                // Return Integer if result is whole number, otherwise Float
159                if result.fract() == 0.0 {
160                    return Some(FactValue::Integer(result as i64));
161                } else {
162                    return Some(FactValue::Float(result));
163                }
164            }
165        }
166        
167        // Base case: just a field reference or literal
168        if let Some(val) = facts.get(expr) {
169            Some(val.clone())
170        } else if let Ok(i) = expr.parse::<i64>() {
171            Some(FactValue::Integer(i))
172        } else if let Ok(f) = expr.parse::<f64>() {
173            Some(FactValue::Float(f))
174        } else {
175            None
176        }
177    }
178
179    /// Create with typed value
180    pub fn with_typed_value(field: String, operator: String, value: FactValue) -> Self {
181        Self {
182            field,
183            operator,
184            value: value.as_string(),
185        }
186    }
187}
188
189fn parse_num(s: &str) -> f64 {
190    s.parse::<f64>().unwrap_or(0.0)
191}
192
193/// Simple wildcard pattern matching (for backward compatibility)
194fn wildcard_match(text: &str, pattern: &str) -> bool {
195    let text_chars: Vec<char> = text.chars().collect();
196    let pattern_chars: Vec<char> = pattern.chars().collect();
197    wildcard_match_impl(&text_chars, &pattern_chars, 0, 0)
198}
199
200fn wildcard_match_impl(text: &[char], pattern: &[char], ti: usize, pi: usize) -> bool {
201    if pi == pattern.len() {
202        return ti == text.len();
203    }
204
205    if pattern[pi] == '*' {
206        for i in ti..=text.len() {
207            if wildcard_match_impl(text, pattern, i, pi + 1) {
208                return true;
209            }
210        }
211        false
212    } else if ti < text.len() && (pattern[pi] == '?' || pattern[pi] == text[ti]) {
213        wildcard_match_impl(text, pattern, ti + 1, pi + 1)
214    } else {
215        false
216    }
217}