rust_rule_engine/engine/
condition_evaluator.rs

1//! Shared condition evaluation logic for both forward and backward chaining
2//!
3//! This module provides a unified interface for evaluating rule conditions
4//! that can be used by both RustRuleEngine (forward chaining) and
5//! BackwardEngine (backward chaining).
6
7#![allow(deprecated)]
8
9use crate::engine::rule::{Condition, ConditionExpression, ConditionGroup};
10use crate::errors::{Result, RuleEngineError};
11use crate::types::{Operator, Value};
12use crate::Facts;
13use std::collections::HashMap;
14
15/// Type for custom function implementations
16pub type CustomFunction = Box<dyn Fn(&[Value], &Facts) -> Result<Value> + Send + Sync>;
17
18/// Shared condition evaluator that works for both forward and backward chaining
19pub struct ConditionEvaluator {
20    /// Custom functions registered by user (optional - for forward chaining)
21    custom_functions: Option<HashMap<String, CustomFunction>>,
22
23    /// Whether to use built-in hardcoded functions (for backward chaining)
24    use_builtin_functions: bool,
25}
26
27impl ConditionEvaluator {
28    /// Create new evaluator with custom functions (for forward chaining)
29    pub fn with_custom_functions(custom_functions: HashMap<String, CustomFunction>) -> Self {
30        Self {
31            custom_functions: Some(custom_functions),
32            use_builtin_functions: false,
33        }
34    }
35
36    /// Create new evaluator with built-in functions (for backward chaining)
37    pub fn with_builtin_functions() -> Self {
38        Self {
39            custom_functions: None,
40            use_builtin_functions: true,
41        }
42    }
43
44    /// Evaluate condition group
45    pub fn evaluate_conditions(&self, group: &ConditionGroup, facts: &Facts) -> Result<bool> {
46        match group {
47            ConditionGroup::Single(condition) => self.evaluate_condition(condition, facts),
48
49            ConditionGroup::Compound {
50                left,
51                operator,
52                right,
53            } => {
54                let left_result = self.evaluate_conditions(left, facts)?;
55
56                // Short-circuit evaluation
57                match operator {
58                    crate::types::LogicalOperator::And => {
59                        if !left_result {
60                            return Ok(false);
61                        }
62                        self.evaluate_conditions(right, facts)
63                    }
64                    crate::types::LogicalOperator::Or => {
65                        if left_result {
66                            return Ok(true);
67                        }
68                        self.evaluate_conditions(right, facts)
69                    }
70                    crate::types::LogicalOperator::Not => Err(RuleEngineError::ExecutionError(
71                        "NOT operator should not appear in compound conditions".to_string(),
72                    )),
73                }
74            }
75
76            ConditionGroup::Not(inner) => {
77                let result = self.evaluate_conditions(inner, facts)?;
78                Ok(!result)
79            }
80
81            ConditionGroup::Exists(conditions) => {
82                // Simplified exists for backward chaining
83                self.evaluate_conditions(conditions, facts)
84            }
85
86            ConditionGroup::Forall(conditions) => {
87                // Simplified forall for backward chaining
88                self.evaluate_conditions(conditions, facts)
89            }
90
91            ConditionGroup::Accumulate { .. } => {
92                // Accumulate needs special handling - not fully supported yet
93                Ok(true)
94            }
95        }
96    }
97
98    /// Evaluate a single condition
99    pub fn evaluate_condition(&self, condition: &Condition, facts: &Facts) -> Result<bool> {
100        match &condition.expression {
101            ConditionExpression::Field(field_name) => {
102                // Get field value
103                if let Some(value) = facts
104                    .get_nested(field_name)
105                    .or_else(|| facts.get(field_name))
106                {
107                    Ok(condition.operator.evaluate(&value, &condition.value))
108                } else {
109                    // Field not found
110                    // For some operators like NotEqual, this might be true
111                    match condition.operator {
112                        Operator::NotEqual => {
113                            // null != value is true
114                            Ok(true)
115                        }
116                        _ => Ok(false),
117                    }
118                }
119            }
120
121            ConditionExpression::FunctionCall { name, args } => {
122                self.evaluate_function_call(name, args, condition, facts)
123            }
124
125            ConditionExpression::Test { name, args } => {
126                self.evaluate_test_expression(name, args, facts)
127            }
128
129            ConditionExpression::MultiField {
130                field,
131                operation,
132                variable,
133            } => self.evaluate_multifield(field, operation, variable, condition, facts),
134        }
135    }
136
137    /// Evaluate function call
138    fn evaluate_function_call(
139        &self,
140        function_name: &str,
141        args: &[String],
142        condition: &Condition,
143        facts: &Facts,
144    ) -> Result<bool> {
145        // Try custom functions first (if available)
146        if let Some(custom_fns) = &self.custom_functions {
147            if let Some(function) = custom_fns.get(function_name) {
148                // Resolve arguments from facts
149                let arg_values: Vec<Value> = args
150                    .iter()
151                    .map(|arg| {
152                        facts
153                            .get_nested(arg)
154                            .or_else(|| facts.get(arg))
155                            .unwrap_or_else(|| {
156                                self.parse_literal_value(arg)
157                                    .unwrap_or(Value::String(arg.clone()))
158                            })
159                    })
160                    .collect();
161
162                // Call the function
163                match function(&arg_values, facts) {
164                    Ok(result_value) => {
165                        return Ok(condition.operator.evaluate(&result_value, &condition.value));
166                    }
167                    Err(_) => return Ok(false),
168                }
169            }
170        }
171
172        // Fall back to built-in functions if enabled
173        if self.use_builtin_functions {
174            return self.evaluate_builtin_function(function_name, args, condition, facts);
175        }
176
177        // Function not found
178        Ok(false)
179    }
180
181    /// Evaluate built-in functions (hardcoded for backward chaining)
182    fn evaluate_builtin_function(
183        &self,
184        function_name: &str,
185        args: &[String],
186        condition: &Condition,
187        facts: &Facts,
188    ) -> Result<bool> {
189        // Get function arguments
190        let mut arg_values = Vec::new();
191        for arg in args {
192            if let Some(value) = facts.get(arg).or_else(|| facts.get_nested(arg)) {
193                arg_values.push(value);
194            } else {
195                // Try to parse as literal
196                if let Ok(val) = self.parse_literal_value(arg) {
197                    arg_values.push(val);
198                } else {
199                    // Argument not available - cannot evaluate
200                    return Ok(false);
201                }
202            }
203        }
204
205        match function_name {
206            "len" | "length" | "size" => {
207                if arg_values.len() == 1 {
208                    let len = match &arg_values[0] {
209                        Value::String(s) => s.len() as f64,
210                        Value::Array(arr) => arr.len() as f64,
211                        _ => return Ok(false),
212                    };
213
214                    Ok(condition
215                        .operator
216                        .evaluate(&Value::Number(len), &condition.value))
217                } else {
218                    Ok(false)
219                }
220            }
221
222            "isEmpty" | "is_empty" => {
223                if arg_values.len() == 1 {
224                    let is_empty = match &arg_values[0] {
225                        Value::String(s) => s.is_empty(),
226                        Value::Array(arr) => arr.is_empty(),
227                        Value::Null => true,
228                        _ => false,
229                    };
230
231                    Ok(condition
232                        .operator
233                        .evaluate(&Value::Boolean(is_empty), &condition.value))
234                } else {
235                    Ok(false)
236                }
237            }
238
239            "contains" => {
240                if arg_values.len() == 2 {
241                    let contains = match (&arg_values[0], &arg_values[1]) {
242                        (Value::String(s), Value::String(substr)) => s.contains(substr.as_str()),
243                        (Value::Array(arr), val) => arr.contains(val),
244                        _ => false,
245                    };
246
247                    Ok(condition
248                        .operator
249                        .evaluate(&Value::Boolean(contains), &condition.value))
250                } else {
251                    Ok(false)
252                }
253            }
254
255            _ => {
256                // Unknown function
257                Ok(false)
258            }
259        }
260    }
261
262    /// Evaluate test expression
263    fn evaluate_test_expression(
264        &self,
265        function_name: &str,
266        args: &[String],
267        facts: &Facts,
268    ) -> Result<bool> {
269        // Try custom functions first
270        if let Some(custom_fns) = &self.custom_functions {
271            if let Some(function) = custom_fns.get(function_name) {
272                let arg_values: Vec<Value> = args
273                    .iter()
274                    .map(|arg| {
275                        facts
276                            .get_nested(arg)
277                            .or_else(|| facts.get(arg))
278                            .unwrap_or(Value::String(arg.clone()))
279                    })
280                    .collect();
281
282                match function(&arg_values, facts) {
283                    Ok(result_value) => return Ok(result_value.to_bool()),
284                    Err(_) => return Ok(false),
285                }
286            }
287        }
288
289        // Built-in test expressions
290        if self.use_builtin_functions {
291            return self.evaluate_builtin_test(function_name, args, facts);
292        }
293
294        Ok(false)
295    }
296
297    /// Evaluate built-in test expressions
298    fn evaluate_builtin_test(
299        &self,
300        function_name: &str,
301        args: &[String],
302        facts: &Facts,
303    ) -> Result<bool> {
304        match function_name {
305            "exists" => {
306                // Check if field exists
307                if args.len() == 1 {
308                    Ok(facts.get(&args[0]).is_some() || facts.get_nested(&args[0]).is_some())
309                } else {
310                    Ok(false)
311                }
312            }
313
314            "notExists" | "not_exists" => {
315                // Check if field does not exist
316                if args.len() == 1 {
317                    Ok(facts.get(&args[0]).is_none() && facts.get_nested(&args[0]).is_none())
318                } else {
319                    Ok(false)
320                }
321            }
322
323            _ => {
324                // Unknown test function
325                Ok(false)
326            }
327        }
328    }
329
330    /// Evaluate multi-field operation
331    fn evaluate_multifield(
332        &self,
333        field: &str,
334        operation: &str,
335        _variable: &Option<String>,
336        condition: &Condition,
337        facts: &Facts,
338    ) -> Result<bool> {
339        // Get field value
340        let field_value = facts.get(field).or_else(|| facts.get_nested(field));
341
342        match operation {
343            "collect" => {
344                // Collect all values - just check if field exists
345                Ok(field_value.is_some())
346            }
347
348            "count" => {
349                // Count elements
350                let count = if let Some(value) = field_value {
351                    match value {
352                        Value::Array(arr) => arr.len() as f64,
353                        _ => 1.0,
354                    }
355                } else {
356                    0.0
357                };
358
359                Ok(condition
360                    .operator
361                    .evaluate(&Value::Number(count), &condition.value))
362            }
363
364            "first" => {
365                // Get first element
366                if let Some(Value::Array(arr)) = field_value {
367                    Ok(!arr.is_empty())
368                } else {
369                    Ok(false)
370                }
371            }
372
373            "last" => {
374                // Get last element
375                if let Some(Value::Array(arr)) = field_value {
376                    Ok(!arr.is_empty())
377                } else {
378                    Ok(false)
379                }
380            }
381
382            "empty" | "isEmpty" => {
383                // Check if empty
384                let is_empty = if let Some(value) = field_value {
385                    match value {
386                        Value::Array(arr) => arr.is_empty(),
387                        Value::String(s) => s.is_empty(),
388                        Value::Null => true,
389                        _ => false,
390                    }
391                } else {
392                    true
393                };
394
395                Ok(is_empty)
396            }
397
398            "not_empty" | "notEmpty" => {
399                // Check if not empty
400                let is_not_empty = if let Some(value) = field_value {
401                    match value {
402                        Value::Array(arr) => !arr.is_empty(),
403                        Value::String(s) => !s.is_empty(),
404                        Value::Null => false,
405                        _ => true,
406                    }
407                } else {
408                    false
409                };
410
411                Ok(is_not_empty)
412            }
413
414            "contains" => {
415                // Check if array contains value
416                if let Some(Value::Array(arr)) = field_value {
417                    Ok(arr.contains(&condition.value))
418                } else {
419                    Ok(false)
420                }
421            }
422
423            _ => {
424                // Unknown operation
425                Ok(false)
426            }
427        }
428    }
429
430    /// Parse literal value from string
431    fn parse_literal_value(&self, s: &str) -> Result<Value> {
432        // Try boolean
433        if s == "true" {
434            return Ok(Value::Boolean(true));
435        }
436        if s == "false" {
437            return Ok(Value::Boolean(false));
438        }
439        if s == "null" {
440            return Ok(Value::Null);
441        }
442
443        // Try number
444        if let Ok(n) = s.parse::<f64>() {
445            return Ok(Value::Number(n));
446        }
447
448        // Try integer
449        if let Ok(i) = s.parse::<i64>() {
450            return Ok(Value::Integer(i));
451        }
452
453        // String
454        Ok(Value::String(s.to_string()))
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461
462    #[test]
463    fn test_builtin_function_len() {
464        let evaluator = ConditionEvaluator::with_builtin_functions();
465        let facts = Facts::new();
466        facts.set("User.Name", Value::String("John".to_string()));
467
468        let condition = Condition::with_function(
469            "len".to_string(),
470            vec!["User.Name".to_string()],
471            Operator::GreaterThan,
472            Value::Number(3.0),
473        );
474
475        let result = evaluator.evaluate_condition(&condition, &facts).unwrap();
476        assert!(result); // "John".len() = 4 > 3
477    }
478
479    #[test]
480    fn test_builtin_test_exists() {
481        let evaluator = ConditionEvaluator::with_builtin_functions();
482        let facts = Facts::new();
483        facts.set("User.Email", Value::String("test@example.com".to_string()));
484
485        let result = evaluator
486            .evaluate_builtin_test("exists", &["User.Email".to_string()], &facts)
487            .unwrap();
488        assert!(result);
489
490        let result2 = evaluator
491            .evaluate_builtin_test("exists", &["User.Missing".to_string()], &facts)
492            .unwrap();
493        assert!(!result2);
494    }
495
496    #[test]
497    fn test_multifield_count() {
498        let evaluator = ConditionEvaluator::with_builtin_functions();
499        let facts = Facts::new();
500        facts.set(
501            "User.Orders",
502            Value::Array(vec![
503                Value::Number(1.0),
504                Value::Number(2.0),
505                Value::Number(3.0),
506            ]),
507        );
508
509        let condition = Condition {
510            field: "User.Orders".to_string(),
511            expression: ConditionExpression::MultiField {
512                field: "User.Orders".to_string(),
513                operation: "count".to_string(),
514                variable: None,
515            },
516            operator: Operator::Equal,
517            value: Value::Number(3.0),
518        };
519
520        let result = evaluator.evaluate_condition(&condition, &facts).unwrap();
521        assert!(result);
522    }
523}