rust_rule_engine/engine/
condition_evaluator.rs

1//! Shared condition evaluation logic for both forward and backward chaining
2//!
3//! This module provides a unified interface for evaluating rule conditions
4//! that can be used by both RustRuleEngine (forward chaining) and
5//! BackwardEngine (backward chaining).
6
7use crate::engine::rule::{Condition, ConditionExpression, ConditionGroup};
8use crate::errors::{Result, RuleEngineError};
9use crate::types::{Operator, Value};
10use crate::Facts;
11use std::collections::HashMap;
12
13/// Type for custom function implementations
14pub type CustomFunction = Box<dyn Fn(&[Value], &Facts) -> Result<Value> + Send + Sync>;
15
16/// Shared condition evaluator that works for both forward and backward chaining
17pub struct ConditionEvaluator {
18    /// Custom functions registered by user (optional - for forward chaining)
19    custom_functions: Option<HashMap<String, CustomFunction>>,
20
21    /// Whether to use built-in hardcoded functions (for backward chaining)
22    use_builtin_functions: bool,
23}
24
25impl ConditionEvaluator {
26    /// Create new evaluator with custom functions (for forward chaining)
27    pub fn with_custom_functions(custom_functions: HashMap<String, CustomFunction>) -> Self {
28        Self {
29            custom_functions: Some(custom_functions),
30            use_builtin_functions: false,
31        }
32    }
33
34    /// Create new evaluator with built-in functions (for backward chaining)
35    pub fn with_builtin_functions() -> Self {
36        Self {
37            custom_functions: None,
38            use_builtin_functions: true,
39        }
40    }
41
42    /// Evaluate condition group
43    pub fn evaluate_conditions(
44        &self,
45        group: &ConditionGroup,
46        facts: &Facts,
47    ) -> Result<bool> {
48        match group {
49            ConditionGroup::Single(condition) => {
50                self.evaluate_condition(condition, facts)
51            }
52
53            ConditionGroup::Compound { left, operator, right } => {
54                let left_result = self.evaluate_conditions(left, facts)?;
55
56                // Short-circuit evaluation
57                match operator {
58                    crate::types::LogicalOperator::And => {
59                        if !left_result {
60                            return Ok(false);
61                        }
62                        self.evaluate_conditions(right, facts)
63                    }
64                    crate::types::LogicalOperator::Or => {
65                        if left_result {
66                            return Ok(true);
67                        }
68                        self.evaluate_conditions(right, facts)
69                    }
70                    crate::types::LogicalOperator::Not => {
71                        Err(RuleEngineError::ExecutionError(
72                            "NOT operator should not appear in compound conditions".to_string()
73                        ))
74                    }
75                }
76            }
77
78            ConditionGroup::Not(inner) => {
79                let result = self.evaluate_conditions(inner, facts)?;
80                Ok(!result)
81            }
82
83            ConditionGroup::Exists(conditions) => {
84                // Simplified exists for backward chaining
85                self.evaluate_conditions(conditions, facts)
86            }
87
88            ConditionGroup::Forall(conditions) => {
89                // Simplified forall for backward chaining
90                self.evaluate_conditions(conditions, facts)
91            }
92
93            ConditionGroup::Accumulate { .. } => {
94                // Accumulate needs special handling - not fully supported yet
95                Ok(true)
96            }
97        }
98    }
99
100    /// Evaluate a single condition
101    pub fn evaluate_condition(&self, condition: &Condition, facts: &Facts) -> Result<bool> {
102        match &condition.expression {
103            ConditionExpression::Field(field_name) => {
104                // Get field value
105                if let Some(value) = facts.get_nested(field_name).or_else(|| facts.get(field_name)) {
106                    Ok(condition.operator.evaluate(&value, &condition.value))
107                } else {
108                    // Field not found
109                    // For some operators like NotEqual, this might be true
110                    match condition.operator {
111                        Operator::NotEqual => {
112                            // null != value is true
113                            Ok(true)
114                        }
115                        _ => Ok(false),
116                    }
117                }
118            }
119
120            ConditionExpression::FunctionCall { name, args } => {
121                self.evaluate_function_call(name, args, condition, facts)
122            }
123
124            ConditionExpression::Test { name, args } => {
125                self.evaluate_test_expression(name, args, facts)
126            }
127
128            ConditionExpression::MultiField { field, operation, variable } => {
129                self.evaluate_multifield(field, operation, variable, condition, facts)
130            }
131        }
132    }
133
134    /// Evaluate function call
135    fn evaluate_function_call(
136        &self,
137        function_name: &str,
138        args: &[String],
139        condition: &Condition,
140        facts: &Facts,
141    ) -> Result<bool> {
142        // Try custom functions first (if available)
143        if let Some(custom_fns) = &self.custom_functions {
144            if let Some(function) = custom_fns.get(function_name) {
145                // Resolve arguments from facts
146                let arg_values: Vec<Value> = args
147                    .iter()
148                    .map(|arg| {
149                        facts
150                            .get_nested(arg)
151                            .or_else(|| facts.get(arg))
152                            .unwrap_or_else(|| self.parse_literal_value(arg).unwrap_or(Value::String(arg.clone())))
153                    })
154                    .collect();
155
156                // Call the function
157                match function(&arg_values, facts) {
158                    Ok(result_value) => {
159                        return Ok(condition.operator.evaluate(&result_value, &condition.value));
160                    }
161                    Err(_) => return Ok(false),
162                }
163            }
164        }
165
166        // Fall back to built-in functions if enabled
167        if self.use_builtin_functions {
168            return self.evaluate_builtin_function(function_name, args, condition, facts);
169        }
170
171        // Function not found
172        Ok(false)
173    }
174
175    /// Evaluate built-in functions (hardcoded for backward chaining)
176    fn evaluate_builtin_function(
177        &self,
178        function_name: &str,
179        args: &[String],
180        condition: &Condition,
181        facts: &Facts,
182    ) -> Result<bool> {
183        // Get function arguments
184        let mut arg_values = Vec::new();
185        for arg in args {
186            if let Some(value) = facts.get(arg).or_else(|| facts.get_nested(arg)) {
187                arg_values.push(value);
188            } else {
189                // Try to parse as literal
190                if let Ok(val) = self.parse_literal_value(arg) {
191                    arg_values.push(val);
192                } else {
193                    // Argument not available - cannot evaluate
194                    return Ok(false);
195                }
196            }
197        }
198
199        match function_name {
200            "len" | "length" | "size" => {
201                if arg_values.len() == 1 {
202                    let len = match &arg_values[0] {
203                        Value::String(s) => s.len() as f64,
204                        Value::Array(arr) => arr.len() as f64,
205                        _ => return Ok(false),
206                    };
207
208                    Ok(condition.operator.evaluate(&Value::Number(len), &condition.value))
209                } else {
210                    Ok(false)
211                }
212            }
213
214            "isEmpty" | "is_empty" => {
215                if arg_values.len() == 1 {
216                    let is_empty = match &arg_values[0] {
217                        Value::String(s) => s.is_empty(),
218                        Value::Array(arr) => arr.is_empty(),
219                        Value::Null => true,
220                        _ => false,
221                    };
222
223                    Ok(condition.operator.evaluate(&Value::Boolean(is_empty), &condition.value))
224                } else {
225                    Ok(false)
226                }
227            }
228
229            "contains" => {
230                if arg_values.len() == 2 {
231                    let contains = match (&arg_values[0], &arg_values[1]) {
232                        (Value::String(s), Value::String(substr)) => s.contains(substr.as_str()),
233                        (Value::Array(arr), val) => arr.contains(val),
234                        _ => false,
235                    };
236
237                    Ok(condition.operator.evaluate(&Value::Boolean(contains), &condition.value))
238                } else {
239                    Ok(false)
240                }
241            }
242
243            _ => {
244                // Unknown function
245                Ok(false)
246            }
247        }
248    }
249
250    /// Evaluate test expression
251    fn evaluate_test_expression(
252        &self,
253        function_name: &str,
254        args: &[String],
255        facts: &Facts,
256    ) -> Result<bool> {
257        // Try custom functions first
258        if let Some(custom_fns) = &self.custom_functions {
259            if let Some(function) = custom_fns.get(function_name) {
260                let arg_values: Vec<Value> = args
261                    .iter()
262                    .map(|arg| {
263                        facts
264                            .get_nested(arg)
265                            .or_else(|| facts.get(arg))
266                            .unwrap_or(Value::String(arg.clone()))
267                    })
268                    .collect();
269
270                match function(&arg_values, facts) {
271                    Ok(result_value) => return Ok(result_value.to_bool()),
272                    Err(_) => return Ok(false),
273                }
274            }
275        }
276
277        // Built-in test expressions
278        if self.use_builtin_functions {
279            return self.evaluate_builtin_test(function_name, args, facts);
280        }
281
282        Ok(false)
283    }
284
285    /// Evaluate built-in test expressions
286    fn evaluate_builtin_test(
287        &self,
288        function_name: &str,
289        args: &[String],
290        facts: &Facts,
291    ) -> Result<bool> {
292        match function_name {
293            "exists" => {
294                // Check if field exists
295                if args.len() == 1 {
296                    Ok(facts.get(&args[0]).is_some() || facts.get_nested(&args[0]).is_some())
297                } else {
298                    Ok(false)
299                }
300            }
301
302            "notExists" | "not_exists" => {
303                // Check if field does not exist
304                if args.len() == 1 {
305                    Ok(facts.get(&args[0]).is_none() && facts.get_nested(&args[0]).is_none())
306                } else {
307                    Ok(false)
308                }
309            }
310
311            _ => {
312                // Unknown test function
313                Ok(false)
314            }
315        }
316    }
317
318    /// Evaluate multi-field operation
319    fn evaluate_multifield(
320        &self,
321        field: &str,
322        operation: &str,
323        _variable: &Option<String>,
324        condition: &Condition,
325        facts: &Facts,
326    ) -> Result<bool> {
327        // Get field value
328        let field_value = facts.get(field).or_else(|| facts.get_nested(field));
329
330        match operation {
331            "collect" => {
332                // Collect all values - just check if field exists
333                Ok(field_value.is_some())
334            }
335
336            "count" => {
337                // Count elements
338                let count = if let Some(value) = field_value {
339                    match value {
340                        Value::Array(arr) => arr.len() as f64,
341                        _ => 1.0,
342                    }
343                } else {
344                    0.0
345                };
346
347                Ok(condition.operator.evaluate(&Value::Number(count), &condition.value))
348            }
349
350            "first" => {
351                // Get first element
352                if let Some(Value::Array(arr)) = field_value {
353                    Ok(!arr.is_empty())
354                } else {
355                    Ok(false)
356                }
357            }
358
359            "last" => {
360                // Get last element
361                if let Some(Value::Array(arr)) = field_value {
362                    Ok(!arr.is_empty())
363                } else {
364                    Ok(false)
365                }
366            }
367
368            "empty" | "isEmpty" => {
369                // Check if empty
370                let is_empty = if let Some(value) = field_value {
371                    match value {
372                        Value::Array(arr) => arr.is_empty(),
373                        Value::String(s) => s.is_empty(),
374                        Value::Null => true,
375                        _ => false,
376                    }
377                } else {
378                    true
379                };
380
381                Ok(is_empty)
382            }
383
384            "not_empty" | "notEmpty" => {
385                // Check if not empty
386                let is_not_empty = if let Some(value) = field_value {
387                    match value {
388                        Value::Array(arr) => !arr.is_empty(),
389                        Value::String(s) => !s.is_empty(),
390                        Value::Null => false,
391                        _ => true,
392                    }
393                } else {
394                    false
395                };
396
397                Ok(is_not_empty)
398            }
399
400            "contains" => {
401                // Check if array contains value
402                if let Some(Value::Array(arr)) = field_value {
403                    Ok(arr.contains(&condition.value))
404                } else {
405                    Ok(false)
406                }
407            }
408
409            _ => {
410                // Unknown operation
411                Ok(false)
412            }
413        }
414    }
415
416    /// Parse literal value from string
417    fn parse_literal_value(&self, s: &str) -> Result<Value> {
418        // Try boolean
419        if s == "true" {
420            return Ok(Value::Boolean(true));
421        }
422        if s == "false" {
423            return Ok(Value::Boolean(false));
424        }
425        if s == "null" {
426            return Ok(Value::Null);
427        }
428
429        // Try number
430        if let Ok(n) = s.parse::<f64>() {
431            return Ok(Value::Number(n));
432        }
433
434        // Try integer
435        if let Ok(i) = s.parse::<i64>() {
436            return Ok(Value::Integer(i));
437        }
438
439        // String
440        Ok(Value::String(s.to_string()))
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_builtin_function_len() {
450        let evaluator = ConditionEvaluator::with_builtin_functions();
451        let mut facts = Facts::new();
452        facts.set("User.Name", Value::String("John".to_string()));
453
454        let condition = Condition::with_function(
455            "len".to_string(),
456            vec!["User.Name".to_string()],
457            Operator::GreaterThan,
458            Value::Number(3.0),
459        );
460
461        let result = evaluator.evaluate_condition(&condition, &facts).unwrap();
462        assert!(result); // "John".len() = 4 > 3
463    }
464
465    #[test]
466    fn test_builtin_test_exists() {
467        let evaluator = ConditionEvaluator::with_builtin_functions();
468        let mut facts = Facts::new();
469        facts.set("User.Email", Value::String("test@example.com".to_string()));
470
471        let result = evaluator.evaluate_builtin_test("exists", &["User.Email".to_string()], &facts).unwrap();
472        assert!(result);
473
474        let result2 = evaluator.evaluate_builtin_test("exists", &["User.Missing".to_string()], &facts).unwrap();
475        assert!(!result2);
476    }
477
478    #[test]
479    fn test_multifield_count() {
480        let evaluator = ConditionEvaluator::with_builtin_functions();
481        let mut facts = Facts::new();
482        facts.set("User.Orders", Value::Array(vec![
483            Value::Number(1.0),
484            Value::Number(2.0),
485            Value::Number(3.0),
486        ]));
487
488        let condition = Condition {
489            field: "User.Orders".to_string(),
490            expression: ConditionExpression::MultiField {
491                field: "User.Orders".to_string(),
492                operation: "count".to_string(),
493                variable: None,
494            },
495            operator: Operator::Equal,
496            value: Value::Number(3.0),
497        };
498
499        let result = evaluator.evaluate_condition(&condition, &facts).unwrap();
500        assert!(result);
501    }
502}