rust_rule_engine/rete/
alpha.rs

1//! AlphaNode: checks single condition on a fact
2
3use super::facts::{FactValue, TypedFacts};
4
5#[derive(Debug, Clone)]
6pub struct AlphaNode {
7    pub field: String,
8    pub operator: String,
9    pub value: String,
10}
11
12impl AlphaNode {
13    /// Match with string-based facts (backward compatible)
14    pub fn matches(&self, fact_field: &str, fact_value: &str) -> bool {
15        if self.field != fact_field {
16            return false;
17        }
18        match self.operator.as_str() {
19            "==" => fact_value == self.value,
20            "!=" => fact_value != self.value,
21            ">" => parse_num(fact_value) > parse_num(&self.value),
22            "<" => parse_num(fact_value) < parse_num(&self.value),
23            ">=" => parse_num(fact_value) >= parse_num(&self.value),
24            "<=" => parse_num(fact_value) <= parse_num(&self.value),
25            "contains" => fact_value.contains(&self.value),
26            "startsWith" => fact_value.starts_with(&self.value),
27            "endsWith" => fact_value.ends_with(&self.value),
28            "matches" => wildcard_match(fact_value, &self.value),
29            _ => false,
30        }
31    }
32
33    /// Match with typed facts (new!)
34    pub fn matches_typed(&self, facts: &TypedFacts) -> bool {
35        // Check if this is a Test CE (arithmetic expression)
36        // Test CE fields look like "test(User.Age % 3 == 0)"
37        if self.field.starts_with("test(") && self.field.ends_with(')') {
38            // Extract the expression: "test(expr)" -> "expr"
39            let expr = &self.field[5..self.field.len() - 1];
40
41            // Try to evaluate as arithmetic expression
42            if let Some(result) = self.evaluate_arithmetic_rete(expr, facts) {
43                // Compare result with expected value
44                let expected_value = self.parse_value_string(&self.value);
45                return match (&result, &expected_value) {
46                    (FactValue::Boolean(r), FactValue::Boolean(e)) => r == e,
47                    _ => false,
48                };
49            }
50            return false;
51        }
52
53        // Check if the value is a variable reference (field name in facts)
54        // This enables variable-to-variable comparison like "L1 > L1Min"
55        let expected_value = if let Some(var_value) = facts.get(&self.value) {
56            // Value is a field name - use the field's value for comparison
57            var_value.clone()
58        } else {
59            // Value is a literal - parse it
60            self.parse_value_string(&self.value)
61        };
62
63        facts.evaluate_condition(&self.field, &self.operator, &expected_value)
64    }
65
66    /// Parse value string into FactValue
67    fn parse_value_string(&self, s: &str) -> FactValue {
68        // Try to parse as different types
69        if let Ok(i) = s.parse::<i64>() {
70            FactValue::Integer(i)
71        } else if let Ok(f) = s.parse::<f64>() {
72            FactValue::Float(f)
73        } else if let Ok(b) = s.parse::<bool>() {
74            FactValue::Boolean(b)
75        } else if s == "null" {
76            FactValue::Null
77        } else {
78            FactValue::String(s.to_string())
79        }
80    }
81
82    /// Evaluate arithmetic expression for RETE
83    /// Handles expressions like "User.Age % 3 == 0", "Price * 2 > 100"
84    fn evaluate_arithmetic_rete(&self, expr: &str, facts: &TypedFacts) -> Option<FactValue> {
85        // Split by comparison operators
86        let ops = ["==", "!=", ">=", "<=", ">", "<"];
87        for op in &ops {
88            if let Some(pos) = expr.find(op) {
89                let left = expr[..pos].trim();
90                let right = expr[pos + op.len()..].trim();
91
92                // Evaluate left side (arithmetic expression)
93                let left_val = Self::evaluate_arithmetic_expr(left, facts)?;
94
95                // Evaluate right side
96                let right_val = if let Some(val) = facts.get(right) {
97                    val.clone()
98                } else if let Ok(i) = right.parse::<i64>() {
99                    FactValue::Integer(i)
100                } else if let Ok(f) = right.parse::<f64>() {
101                    FactValue::Float(f)
102                } else {
103                    return None;
104                };
105
106                // Compare values
107                let result = left_val.compare(op, &right_val);
108                return Some(FactValue::Boolean(result));
109            }
110        }
111        None
112    }
113
114    /// Evaluate arithmetic expression (handles +, -, *, /, %)
115    fn evaluate_arithmetic_expr(expr: &str, facts: &TypedFacts) -> Option<FactValue> {
116        let expr = expr.trim();
117
118        // Try arithmetic operators in order of precedence (reverse)
119        let ops = ["+", "-", "*", "/", "%"];
120
121        for op in &ops {
122            if let Some(pos) = expr.rfind(op) {
123                // Skip if operator is at the start (negative number)
124                if pos == 0 {
125                    continue;
126                }
127
128                let left = expr[..pos].trim();
129                let right = expr[pos + 1..].trim();
130
131                let left_val = if let Some(val) = facts.get(left) {
132                    val.as_number()?
133                } else if let Ok(f) = left.parse::<f64>() {
134                    f
135                } else {
136                    // Recursive evaluation
137                    Self::evaluate_arithmetic_expr(left, facts)?.as_number()?
138                };
139
140                let right_val = if let Some(val) = facts.get(right) {
141                    val.as_number()?
142                } else if let Ok(f) = right.parse::<f64>() {
143                    f
144                } else {
145                    Self::evaluate_arithmetic_expr(right, facts)?.as_number()?
146                };
147
148                let result = match *op {
149                    "+" => left_val + right_val,
150                    "-" => left_val - right_val,
151                    "*" => left_val * right_val,
152                    "/" => {
153                        if right_val != 0.0 {
154                            left_val / right_val
155                        } else {
156                            return None;
157                        }
158                    }
159                    "%" => left_val % right_val,
160                    _ => return None,
161                };
162
163                // Return Integer if result is whole number, otherwise Float
164                if result.fract() == 0.0 {
165                    return Some(FactValue::Integer(result as i64));
166                } else {
167                    return Some(FactValue::Float(result));
168                }
169            }
170        }
171
172        // Base case: just a field reference or literal
173        if let Some(val) = facts.get(expr) {
174            Some(val.clone())
175        } else if let Ok(i) = expr.parse::<i64>() {
176            Some(FactValue::Integer(i))
177        } else if let Ok(f) = expr.parse::<f64>() {
178            Some(FactValue::Float(f))
179        } else {
180            None
181        }
182    }
183
184    /// Create with typed value
185    pub fn with_typed_value(field: String, operator: String, value: FactValue) -> Self {
186        Self {
187            field,
188            operator,
189            value: value.as_string(),
190        }
191    }
192}
193
194fn parse_num(s: &str) -> f64 {
195    s.parse::<f64>().unwrap_or(0.0)
196}
197
198/// Simple wildcard pattern matching (for backward compatibility)
199fn wildcard_match(text: &str, pattern: &str) -> bool {
200    let text_chars: Vec<char> = text.chars().collect();
201    let pattern_chars: Vec<char> = pattern.chars().collect();
202    wildcard_match_impl(&text_chars, &pattern_chars, 0, 0)
203}
204
205fn wildcard_match_impl(text: &[char], pattern: &[char], ti: usize, pi: usize) -> bool {
206    if pi == pattern.len() {
207        return ti == text.len();
208    }
209
210    if pattern[pi] == '*' {
211        for i in ti..=text.len() {
212            if wildcard_match_impl(text, pattern, i, pi + 1) {
213                return true;
214            }
215        }
216        false
217    } else if ti < text.len() && (pattern[pi] == '?' || pattern[pi] == text[ti]) {
218        wildcard_match_impl(text, pattern, ti + 1, pi + 1)
219    } else {
220        false
221    }
222}