rust_rule_engine/engine/
condition_evaluator.rs

1//! Shared condition evaluation logic for both forward and backward chaining
2//!
3//! This module provides a unified interface for evaluating rule conditions
4//! that can be used by both RustRuleEngine (forward chaining) and
5//! BackwardEngine (backward chaining).
6
7#![allow(deprecated)]
8
9use crate::engine::rule::{Condition, ConditionExpression, ConditionGroup};
10use crate::errors::{Result, RuleEngineError};
11use crate::types::{Operator, Value};
12use crate::Facts;
13use std::collections::HashMap;
14
15/// Type for custom function implementations
16pub type CustomFunction = Box<dyn Fn(&[Value], &Facts) -> Result<Value> + Send + Sync>;
17
18/// Shared condition evaluator that works for both forward and backward chaining
19pub struct ConditionEvaluator {
20    /// Custom functions registered by user (optional - for forward chaining)
21    custom_functions: Option<HashMap<String, CustomFunction>>,
22
23    /// Whether to use built-in hardcoded functions (for backward chaining)
24    use_builtin_functions: bool,
25}
26
27impl ConditionEvaluator {
28    /// Create new evaluator with custom functions (for forward chaining)
29    pub fn with_custom_functions(custom_functions: HashMap<String, CustomFunction>) -> Self {
30        Self {
31            custom_functions: Some(custom_functions),
32            use_builtin_functions: false,
33        }
34    }
35
36    /// Create new evaluator with built-in functions (for backward chaining)
37    pub fn with_builtin_functions() -> Self {
38        Self {
39            custom_functions: None,
40            use_builtin_functions: true,
41        }
42    }
43
44    /// Evaluate condition group
45    pub fn evaluate_conditions(&self, group: &ConditionGroup, facts: &Facts) -> Result<bool> {
46        match group {
47            ConditionGroup::Single(condition) => self.evaluate_condition(condition, facts),
48
49            ConditionGroup::Compound {
50                left,
51                operator,
52                right,
53            } => {
54                let left_result = self.evaluate_conditions(left, facts)?;
55
56                // Short-circuit evaluation
57                match operator {
58                    crate::types::LogicalOperator::And => {
59                        if !left_result {
60                            return Ok(false);
61                        }
62                        self.evaluate_conditions(right, facts)
63                    }
64                    crate::types::LogicalOperator::Or => {
65                        if left_result {
66                            return Ok(true);
67                        }
68                        self.evaluate_conditions(right, facts)
69                    }
70                    crate::types::LogicalOperator::Not => Err(RuleEngineError::ExecutionError(
71                        "NOT operator should not appear in compound conditions".to_string(),
72                    )),
73                }
74            }
75
76            ConditionGroup::Not(inner) => {
77                let result = self.evaluate_conditions(inner, facts)?;
78                Ok(!result)
79            }
80
81            ConditionGroup::Exists(conditions) => {
82                // Simplified exists for backward chaining
83                self.evaluate_conditions(conditions, facts)
84            }
85
86            ConditionGroup::Forall(conditions) => {
87                // Simplified forall for backward chaining
88                self.evaluate_conditions(conditions, facts)
89            }
90
91            ConditionGroup::Accumulate { .. } => {
92                // Accumulate needs special handling - not fully supported yet
93                Ok(true)
94            }
95
96            #[cfg(feature = "streaming")]
97            ConditionGroup::StreamPattern { .. } => {
98                // Stream patterns need special handling in streaming engine
99                // Not fully supported in backward chaining context
100                Ok(true)
101            }
102        }
103    }
104
105    /// Evaluate a single condition
106    pub fn evaluate_condition(&self, condition: &Condition, facts: &Facts) -> Result<bool> {
107        match &condition.expression {
108            ConditionExpression::Field(field_name) => {
109                // Get field value
110                if let Some(value) = facts
111                    .get_nested(field_name)
112                    .or_else(|| facts.get(field_name))
113                {
114                    Ok(condition.operator.evaluate(&value, &condition.value))
115                } else {
116                    // Field not found
117                    // For some operators like NotEqual, this might be true
118                    match condition.operator {
119                        Operator::NotEqual => {
120                            // null != value is true
121                            Ok(true)
122                        }
123                        _ => Ok(false),
124                    }
125                }
126            }
127
128            ConditionExpression::FunctionCall { name, args } => {
129                self.evaluate_function_call(name, args, condition, facts)
130            }
131
132            ConditionExpression::Test { name, args } => {
133                self.evaluate_test_expression(name, args, facts)
134            }
135
136            ConditionExpression::MultiField {
137                field,
138                operation,
139                variable,
140            } => self.evaluate_multifield(field, operation, variable, condition, facts),
141        }
142    }
143
144    /// Evaluate function call
145    fn evaluate_function_call(
146        &self,
147        function_name: &str,
148        args: &[String],
149        condition: &Condition,
150        facts: &Facts,
151    ) -> Result<bool> {
152        // Try custom functions first (if available)
153        if let Some(custom_fns) = &self.custom_functions {
154            if let Some(function) = custom_fns.get(function_name) {
155                // Resolve arguments from facts
156                let arg_values: Vec<Value> = args
157                    .iter()
158                    .map(|arg| {
159                        facts
160                            .get_nested(arg)
161                            .or_else(|| facts.get(arg))
162                            .unwrap_or_else(|| {
163                                self.parse_literal_value(arg)
164                                    .unwrap_or(Value::String(arg.clone()))
165                            })
166                    })
167                    .collect();
168
169                // Call the function
170                match function(&arg_values, facts) {
171                    Ok(result_value) => {
172                        return Ok(condition.operator.evaluate(&result_value, &condition.value));
173                    }
174                    Err(_) => return Ok(false),
175                }
176            }
177        }
178
179        // Fall back to built-in functions if enabled
180        if self.use_builtin_functions {
181            return self.evaluate_builtin_function(function_name, args, condition, facts);
182        }
183
184        // Function not found
185        Ok(false)
186    }
187
188    /// Evaluate built-in functions (hardcoded for backward chaining)
189    fn evaluate_builtin_function(
190        &self,
191        function_name: &str,
192        args: &[String],
193        condition: &Condition,
194        facts: &Facts,
195    ) -> Result<bool> {
196        // Get function arguments
197        let mut arg_values = Vec::new();
198        for arg in args {
199            if let Some(value) = facts.get(arg).or_else(|| facts.get_nested(arg)) {
200                arg_values.push(value);
201            } else {
202                // Try to parse as literal
203                if let Ok(val) = self.parse_literal_value(arg) {
204                    arg_values.push(val);
205                } else {
206                    // Argument not available - cannot evaluate
207                    return Ok(false);
208                }
209            }
210        }
211
212        match function_name {
213            "len" | "length" | "size" => {
214                if arg_values.len() == 1 {
215                    let len = match &arg_values[0] {
216                        Value::String(s) => s.len() as f64,
217                        Value::Array(arr) => arr.len() as f64,
218                        _ => return Ok(false),
219                    };
220
221                    Ok(condition
222                        .operator
223                        .evaluate(&Value::Number(len), &condition.value))
224                } else {
225                    Ok(false)
226                }
227            }
228
229            "isEmpty" | "is_empty" => {
230                if arg_values.len() == 1 {
231                    let is_empty = match &arg_values[0] {
232                        Value::String(s) => s.is_empty(),
233                        Value::Array(arr) => arr.is_empty(),
234                        Value::Null => true,
235                        _ => false,
236                    };
237
238                    Ok(condition
239                        .operator
240                        .evaluate(&Value::Boolean(is_empty), &condition.value))
241                } else {
242                    Ok(false)
243                }
244            }
245
246            "contains" => {
247                if arg_values.len() == 2 {
248                    let contains = match (&arg_values[0], &arg_values[1]) {
249                        (Value::String(s), Value::String(substr)) => s.contains(substr.as_str()),
250                        (Value::Array(arr), val) => arr.contains(val),
251                        _ => false,
252                    };
253
254                    Ok(condition
255                        .operator
256                        .evaluate(&Value::Boolean(contains), &condition.value))
257                } else {
258                    Ok(false)
259                }
260            }
261
262            _ => {
263                // Unknown function
264                Ok(false)
265            }
266        }
267    }
268
269    /// Evaluate test expression
270    fn evaluate_test_expression(
271        &self,
272        function_name: &str,
273        args: &[String],
274        facts: &Facts,
275    ) -> Result<bool> {
276        // Try custom functions first
277        if let Some(custom_fns) = &self.custom_functions {
278            if let Some(function) = custom_fns.get(function_name) {
279                let arg_values: Vec<Value> = args
280                    .iter()
281                    .map(|arg| {
282                        facts
283                            .get_nested(arg)
284                            .or_else(|| facts.get(arg))
285                            .unwrap_or(Value::String(arg.clone()))
286                    })
287                    .collect();
288
289                match function(&arg_values, facts) {
290                    Ok(result_value) => return Ok(result_value.to_bool()),
291                    Err(_) => return Ok(false),
292                }
293            }
294        }
295
296        // Built-in test expressions
297        if self.use_builtin_functions {
298            return self.evaluate_builtin_test(function_name, args, facts);
299        }
300
301        Ok(false)
302    }
303
304    /// Evaluate built-in test expressions
305    fn evaluate_builtin_test(
306        &self,
307        function_name: &str,
308        args: &[String],
309        facts: &Facts,
310    ) -> Result<bool> {
311        match function_name {
312            "exists" => {
313                // Check if field exists
314                if args.len() == 1 {
315                    Ok(facts.get(&args[0]).is_some() || facts.get_nested(&args[0]).is_some())
316                } else {
317                    Ok(false)
318                }
319            }
320
321            "notExists" | "not_exists" => {
322                // Check if field does not exist
323                if args.len() == 1 {
324                    Ok(facts.get(&args[0]).is_none() && facts.get_nested(&args[0]).is_none())
325                } else {
326                    Ok(false)
327                }
328            }
329
330            _ => {
331                // Unknown test function
332                Ok(false)
333            }
334        }
335    }
336
337    /// Evaluate multi-field operation
338    fn evaluate_multifield(
339        &self,
340        field: &str,
341        operation: &str,
342        _variable: &Option<String>,
343        condition: &Condition,
344        facts: &Facts,
345    ) -> Result<bool> {
346        // Get field value
347        let field_value = facts.get(field).or_else(|| facts.get_nested(field));
348
349        match operation {
350            "collect" => {
351                // Collect all values - just check if field exists
352                Ok(field_value.is_some())
353            }
354
355            "count" => {
356                // Count elements
357                let count = if let Some(value) = field_value {
358                    match value {
359                        Value::Array(arr) => arr.len() as f64,
360                        _ => 1.0,
361                    }
362                } else {
363                    0.0
364                };
365
366                Ok(condition
367                    .operator
368                    .evaluate(&Value::Number(count), &condition.value))
369            }
370
371            "first" => {
372                // Get first element
373                if let Some(Value::Array(arr)) = field_value {
374                    Ok(!arr.is_empty())
375                } else {
376                    Ok(false)
377                }
378            }
379
380            "last" => {
381                // Get last element
382                if let Some(Value::Array(arr)) = field_value {
383                    Ok(!arr.is_empty())
384                } else {
385                    Ok(false)
386                }
387            }
388
389            "empty" | "isEmpty" => {
390                // Check if empty
391                let is_empty = if let Some(value) = field_value {
392                    match value {
393                        Value::Array(arr) => arr.is_empty(),
394                        Value::String(s) => s.is_empty(),
395                        Value::Null => true,
396                        _ => false,
397                    }
398                } else {
399                    true
400                };
401
402                Ok(is_empty)
403            }
404
405            "not_empty" | "notEmpty" => {
406                // Check if not empty
407                let is_not_empty = if let Some(value) = field_value {
408                    match value {
409                        Value::Array(arr) => !arr.is_empty(),
410                        Value::String(s) => !s.is_empty(),
411                        Value::Null => false,
412                        _ => true,
413                    }
414                } else {
415                    false
416                };
417
418                Ok(is_not_empty)
419            }
420
421            "contains" => {
422                // Check if array contains value
423                if let Some(Value::Array(arr)) = field_value {
424                    Ok(arr.contains(&condition.value))
425                } else {
426                    Ok(false)
427                }
428            }
429
430            _ => {
431                // Unknown operation
432                Ok(false)
433            }
434        }
435    }
436
437    /// Parse literal value from string
438    fn parse_literal_value(&self, s: &str) -> Result<Value> {
439        // Try boolean
440        if s == "true" {
441            return Ok(Value::Boolean(true));
442        }
443        if s == "false" {
444            return Ok(Value::Boolean(false));
445        }
446        if s == "null" {
447            return Ok(Value::Null);
448        }
449
450        // Try number
451        if let Ok(n) = s.parse::<f64>() {
452            return Ok(Value::Number(n));
453        }
454
455        // Try integer
456        if let Ok(i) = s.parse::<i64>() {
457            return Ok(Value::Integer(i));
458        }
459
460        // String
461        Ok(Value::String(s.to_string()))
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_builtin_function_len() {
471        let evaluator = ConditionEvaluator::with_builtin_functions();
472        let facts = Facts::new();
473        facts.set("User.Name", Value::String("John".to_string()));
474
475        let condition = Condition::with_function(
476            "len".to_string(),
477            vec!["User.Name".to_string()],
478            Operator::GreaterThan,
479            Value::Number(3.0),
480        );
481
482        let result = evaluator.evaluate_condition(&condition, &facts).unwrap();
483        assert!(result); // "John".len() = 4 > 3
484    }
485
486    #[test]
487    fn test_builtin_test_exists() {
488        let evaluator = ConditionEvaluator::with_builtin_functions();
489        let facts = Facts::new();
490        facts.set("User.Email", Value::String("test@example.com".to_string()));
491
492        let result = evaluator
493            .evaluate_builtin_test("exists", &["User.Email".to_string()], &facts)
494            .unwrap();
495        assert!(result);
496
497        let result2 = evaluator
498            .evaluate_builtin_test("exists", &["User.Missing".to_string()], &facts)
499            .unwrap();
500        assert!(!result2);
501    }
502
503    #[test]
504    fn test_multifield_count() {
505        let evaluator = ConditionEvaluator::with_builtin_functions();
506        let facts = Facts::new();
507        facts.set(
508            "User.Orders",
509            Value::Array(vec![
510                Value::Number(1.0),
511                Value::Number(2.0),
512                Value::Number(3.0),
513            ]),
514        );
515
516        let condition = Condition {
517            field: "User.Orders".to_string(),
518            expression: ConditionExpression::MultiField {
519                field: "User.Orders".to_string(),
520                operation: "count".to_string(),
521                variable: None,
522            },
523            operator: Operator::Equal,
524            value: Value::Number(3.0),
525        };
526
527        let result = evaluator.evaluate_condition(&condition, &facts).unwrap();
528        assert!(result);
529    }
530}