Skip to main content

rust_rule_engine/rete/
alpha.rs

1//! AlphaNode: checks single condition on a fact
2
3use super::facts::{FactValue, TypedFacts};
4
5#[derive(Debug, Clone)]
6pub struct AlphaNode {
7    pub field: String,
8    pub operator: String,
9    pub value: String,
10}
11
12impl AlphaNode {
13    /// Match with string-based facts (backward compatible)
14    pub fn matches(&self, fact_field: &str, fact_value: &str) -> bool {
15        if self.field != fact_field {
16            return false;
17        }
18        match self.operator.as_str() {
19            "==" => fact_value == self.value,
20            "!=" => fact_value != self.value,
21            ">" => parse_num(fact_value) > parse_num(&self.value),
22            "<" => parse_num(fact_value) < parse_num(&self.value),
23            ">=" => parse_num(fact_value) >= parse_num(&self.value),
24            "<=" => parse_num(fact_value) <= parse_num(&self.value),
25            "contains" => fact_value.contains(&self.value),
26            "startsWith" => fact_value.starts_with(&self.value),
27            "endsWith" => fact_value.ends_with(&self.value),
28            "matches" => wildcard_match(fact_value, &self.value),
29            _ => false,
30        }
31    }
32
33    /// Match with typed facts (new!)
34    pub fn matches_typed(&self, facts: &TypedFacts) -> bool {
35        // Check if this is a Test CE (arithmetic expression)
36        // Test CE fields look like "test(User.Age % 3 == 0)"
37        if self.field.starts_with("test(") && self.field.ends_with(')') {
38            // Extract the expression: "test(expr)" -> "expr"
39            let expr = &self.field[5..self.field.len() - 1];
40
41            // Try to evaluate as arithmetic expression
42            if let Some(result) = self.evaluate_arithmetic_rete(expr, facts) {
43                // Compare result with expected value
44                let expected_value = self.parse_value_string(&self.value);
45                return match (&result, &expected_value) {
46                    (FactValue::Boolean(r), FactValue::Boolean(e)) => r == e,
47                    _ => false,
48                };
49            }
50            return false;
51        }
52
53        // Check if the value is a variable reference (field name in facts)
54        // This enables variable-to-variable comparison like "L1 > L1Min"
55        let expected_value = if let Some(var_value) = facts.get(&self.value) {
56            // Value is a field name - use the field's value for comparison
57            var_value.clone()
58        } else {
59            // Value is a literal - parse it
60            self.parse_value_string(&self.value)
61        };
62
63        facts.evaluate_condition(&self.field, &self.operator, &expected_value)
64    }
65
66    /// Parse value string into FactValue
67    fn parse_value_string(&self, s: &str) -> FactValue {
68        // Check for array literal: [value1,value2,value3]
69        if s.starts_with('[') && s.ends_with(']') {
70            let inner = &s[1..s.len() - 1];
71            if inner.is_empty() {
72                return FactValue::Array(vec![]);
73            }
74
75            let elements: Vec<FactValue> = inner
76                .split(',')
77                .map(|elem| {
78                    let trimmed = elem.trim();
79                    // Parse each element recursively
80                    if let Ok(i) = trimmed.parse::<i64>() {
81                        FactValue::Integer(i)
82                    } else if let Ok(f) = trimmed.parse::<f64>() {
83                        FactValue::Float(f)
84                    } else if let Ok(b) = trimmed.parse::<bool>() {
85                        FactValue::Boolean(b)
86                    } else if trimmed == "null" {
87                        FactValue::Null
88                    } else {
89                        FactValue::String(trimmed.to_string())
90                    }
91                })
92                .collect();
93            return FactValue::Array(elements);
94        }
95
96        // Try to parse as different types
97        if let Ok(i) = s.parse::<i64>() {
98            FactValue::Integer(i)
99        } else if let Ok(f) = s.parse::<f64>() {
100            FactValue::Float(f)
101        } else if let Ok(b) = s.parse::<bool>() {
102            FactValue::Boolean(b)
103        } else if s == "null" {
104            FactValue::Null
105        } else {
106            FactValue::String(s.to_string())
107        }
108    }
109
110    /// Evaluate arithmetic expression for RETE
111    /// Handles expressions like "User.Age % 3 == 0", "Price * 2 > 100"
112    fn evaluate_arithmetic_rete(&self, expr: &str, facts: &TypedFacts) -> Option<FactValue> {
113        // Split by comparison operators
114        let ops = ["==", "!=", ">=", "<=", ">", "<"];
115        for op in &ops {
116            if let Some(pos) = expr.find(op) {
117                let left = expr[..pos].trim();
118                let right = expr[pos + op.len()..].trim();
119
120                // Evaluate left side (arithmetic expression)
121                let left_val = Self::evaluate_arithmetic_expr(left, facts)?;
122
123                // Evaluate right side
124                let right_val = if let Some(val) = facts.get(right) {
125                    val.clone()
126                } else if let Ok(i) = right.parse::<i64>() {
127                    FactValue::Integer(i)
128                } else if let Ok(f) = right.parse::<f64>() {
129                    FactValue::Float(f)
130                } else {
131                    return None;
132                };
133
134                // Compare values
135                let result = left_val.compare(op, &right_val);
136                return Some(FactValue::Boolean(result));
137            }
138        }
139        None
140    }
141
142    /// Evaluate arithmetic expression (handles +, -, *, /, %)
143    fn evaluate_arithmetic_expr(expr: &str, facts: &TypedFacts) -> Option<FactValue> {
144        let expr = expr.trim();
145
146        // Try arithmetic operators in order of precedence (reverse)
147        let ops = ["+", "-", "*", "/", "%"];
148
149        for op in &ops {
150            if let Some(pos) = expr.rfind(op) {
151                // Skip if operator is at the start (negative number)
152                if pos == 0 {
153                    continue;
154                }
155
156                let left = expr[..pos].trim();
157                let right = expr[pos + 1..].trim();
158
159                let left_val = if let Some(val) = facts.get(left) {
160                    val.as_number()?
161                } else if let Ok(f) = left.parse::<f64>() {
162                    f
163                } else {
164                    // Recursive evaluation
165                    Self::evaluate_arithmetic_expr(left, facts)?.as_number()?
166                };
167
168                let right_val = if let Some(val) = facts.get(right) {
169                    val.as_number()?
170                } else if let Ok(f) = right.parse::<f64>() {
171                    f
172                } else {
173                    Self::evaluate_arithmetic_expr(right, facts)?.as_number()?
174                };
175
176                let result = match *op {
177                    "+" => left_val + right_val,
178                    "-" => left_val - right_val,
179                    "*" => left_val * right_val,
180                    "/" => {
181                        if right_val != 0.0 {
182                            left_val / right_val
183                        } else {
184                            return None;
185                        }
186                    }
187                    "%" => left_val % right_val,
188                    _ => return None,
189                };
190
191                // Return Integer if result is whole number, otherwise Float
192                if result.fract() == 0.0 {
193                    return Some(FactValue::Integer(result as i64));
194                } else {
195                    return Some(FactValue::Float(result));
196                }
197            }
198        }
199
200        // Base case: just a field reference or literal
201        if let Some(val) = facts.get(expr) {
202            Some(val.clone())
203        } else if let Ok(i) = expr.parse::<i64>() {
204            Some(FactValue::Integer(i))
205        } else if let Ok(f) = expr.parse::<f64>() {
206            Some(FactValue::Float(f))
207        } else {
208            None
209        }
210    }
211
212    /// Create with typed value
213    pub fn with_typed_value(field: String, operator: String, value: FactValue) -> Self {
214        Self {
215            field,
216            operator,
217            value: value.as_string(),
218        }
219    }
220}
221
222fn parse_num(s: &str) -> f64 {
223    s.parse::<f64>().unwrap_or(0.0)
224}
225
226/// Simple wildcard pattern matching (for backward compatibility)
227fn wildcard_match(text: &str, pattern: &str) -> bool {
228    let text_chars: Vec<char> = text.chars().collect();
229    let pattern_chars: Vec<char> = pattern.chars().collect();
230    wildcard_match_impl(&text_chars, &pattern_chars, 0, 0)
231}
232
233fn wildcard_match_impl(text: &[char], pattern: &[char], ti: usize, pi: usize) -> bool {
234    if pi == pattern.len() {
235        return ti == text.len();
236    }
237
238    if pattern[pi] == '*' {
239        for i in ti..=text.len() {
240            if wildcard_match_impl(text, pattern, i, pi + 1) {
241                return true;
242            }
243        }
244        false
245    } else if ti < text.len() && (pattern[pi] == '?' || pattern[pi] == text[ti]) {
246        wildcard_match_impl(text, pattern, ti + 1, pi + 1)
247    } else {
248        false
249    }
250}