rust_rule_engine/backward/
unification.rs

1//! Unification module for backward chaining
2//!
3//! This module provides variable binding and unification capabilities
4//! for pattern matching in backward chaining queries.
5//!
6//! # Example
7//! ```ignore
8//! use rust_rule_engine::backward::unification::{Bindings, Unifier};
9//! use rust_rule_engine::backward::expression::Expression;
10//!
11//! let mut bindings = Bindings::new();
12//!
13//! // Unify variable with value
14//! let var = Expression::Variable("X".to_string());
15//! let val = Expression::Literal(Value::Number(42.0));
16//!
17//! Unifier::unify(&var, &val, &mut bindings)?;
18//! assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
19//! ```
20
21use super::expression::Expression;
22use crate::errors::{Result, RuleEngineError};
23use crate::types::Value;
24use crate::Facts;
25use std::collections::HashMap;
26
27/// Variable bindings during proof
28///
29/// This type manages variable-to-value mappings during backward chaining,
30/// with support for merging, conflict detection, and binding propagation.
31#[derive(Debug, Clone)]
32pub struct Bindings {
33    /// Map from variable name to value
34    bindings: HashMap<String, Value>,
35}
36
37impl Bindings {
38    /// Create a new empty bindings set
39    pub fn new() -> Self {
40        Self {
41            bindings: HashMap::new(),
42        }
43    }
44
45    /// Bind a variable to a value
46    ///
47    /// If the variable is already bound, this checks that the new value
48    /// matches the existing binding. Returns an error if there's a conflict.
49    ///
50    /// # Example
51    /// ```ignore
52    /// let mut bindings = Bindings::new();
53    /// bindings.bind("X".to_string(), Value::Number(42.0))?;
54    ///
55    /// // This will succeed (same value)
56    /// bindings.bind("X".to_string(), Value::Number(42.0))?;
57    ///
58    /// // This will fail (different value)
59    /// bindings.bind("X".to_string(), Value::Number(100.0))?; // Error!
60    /// ```
61    pub fn bind(&mut self, var_name: String, value: Value) -> Result<()> {
62        // Check if already bound
63        if let Some(existing) = self.bindings.get(&var_name) {
64            // Must match existing binding
65            if existing != &value {
66                return Err(RuleEngineError::ExecutionError(format!(
67                    "Variable binding conflict: {} is already bound to {:?}, cannot rebind to {:?}",
68                    var_name, existing, value
69                )));
70            }
71        } else {
72            self.bindings.insert(var_name, value);
73        }
74        Ok(())
75    }
76
77    /// Get binding for a variable
78    pub fn get(&self, var_name: &str) -> Option<&Value> {
79        self.bindings.get(var_name)
80    }
81
82    /// Check if variable is bound
83    pub fn is_bound(&self, var_name: &str) -> bool {
84        self.bindings.contains_key(var_name)
85    }
86
87    /// Merge bindings from another set
88    ///
89    /// This attempts to merge all bindings from `other` into this set.
90    /// If any conflicts are detected, returns an error and leaves this set unchanged.
91    pub fn merge(&mut self, other: &Bindings) -> Result<()> {
92        for (var, val) in &other.bindings {
93            self.bind(var.clone(), val.clone())?;
94        }
95        Ok(())
96    }
97
98    /// Get all bindings as a map
99    pub fn as_map(&self) -> &HashMap<String, Value> {
100        &self.bindings
101    }
102
103    /// Get number of bindings
104    pub fn len(&self) -> usize {
105        self.bindings.len()
106    }
107
108    /// Check if bindings is empty
109    pub fn is_empty(&self) -> bool {
110        self.bindings.is_empty()
111    }
112
113    /// Clear all bindings
114    pub fn clear(&mut self) {
115        self.bindings.clear();
116    }
117
118    /// Create bindings from a HashMap
119    pub fn from_map(map: HashMap<String, Value>) -> Self {
120        Self { bindings: map }
121    }
122
123    /// Convert bindings to HashMap (for backward compatibility)
124    pub fn into_map(self) -> HashMap<String, Value> {
125        self.bindings
126    }
127
128    /// Get bindings as HashMap clone
129    pub fn to_map(&self) -> HashMap<String, Value> {
130        self.bindings.clone()
131    }
132}
133
134impl Default for Bindings {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140/// Unification algorithm for pattern matching
141///
142/// The Unifier provides algorithms for:
143/// - Unifying two expressions with variable bindings
144/// - Matching expressions against facts
145/// - Evaluating expressions with variable substitution
146pub struct Unifier;
147
148impl Unifier {
149    /// Unify two expressions with variable bindings
150    ///
151    /// This is the core unification algorithm. It attempts to make two expressions
152    /// equal by binding variables to values.
153    ///
154    /// # Returns
155    /// - `Ok(true)` if unification succeeded
156    /// - `Ok(false)` if expressions cannot be unified
157    /// - `Err(_)` if there's a binding conflict
158    pub fn unify(left: &Expression, right: &Expression, bindings: &mut Bindings) -> Result<bool> {
159        match (left, right) {
160            // Variable on left
161            (Expression::Variable(var), expr) => {
162                if let Some(bound_value) = bindings.get(var) {
163                    // Variable already bound - check if it matches
164                    Self::unify(&Expression::Literal(bound_value.clone()), expr, bindings)
165                } else {
166                    // Bind variable to expression value
167                    if let Some(value) = Self::expression_to_value(expr, bindings)? {
168                        bindings.bind(var.clone(), value)?;
169                        Ok(true)
170                    } else {
171                        // Cannot extract value from expression yet
172                        Ok(false)
173                    }
174                }
175            }
176
177            // Variable on right (symmetric)
178            (expr, Expression::Variable(var)) => {
179                Self::unify(&Expression::Variable(var.clone()), expr, bindings)
180            }
181
182            // Two literals - must be equal
183            (Expression::Literal(v1), Expression::Literal(v2)) => Ok(v1 == v2),
184
185            // Two fields - must be same field
186            (Expression::Field(f1), Expression::Field(f2)) => Ok(f1 == f2),
187
188            // Comparison - both sides must unify
189            (
190                Expression::Comparison {
191                    left: l1,
192                    operator: op1,
193                    right: r1,
194                },
195                Expression::Comparison {
196                    left: l2,
197                    operator: op2,
198                    right: r2,
199                },
200            ) => {
201                if op1 != op2 {
202                    return Ok(false);
203                }
204
205                let left_match = Self::unify(l1, l2, bindings)?;
206                let right_match = Self::unify(r1, r2, bindings)?;
207
208                Ok(left_match && right_match)
209            }
210
211            // Logical AND - both sides must unify
212            (
213                Expression::And {
214                    left: l1,
215                    right: r1,
216                },
217                Expression::And {
218                    left: l2,
219                    right: r2,
220                },
221            ) => {
222                let left_match = Self::unify(l1, l2, bindings)?;
223                let right_match = Self::unify(r1, r2, bindings)?;
224                Ok(left_match && right_match)
225            }
226
227            // Logical OR - both sides must unify
228            (
229                Expression::Or {
230                    left: l1,
231                    right: r1,
232                },
233                Expression::Or {
234                    left: l2,
235                    right: r2,
236                },
237            ) => {
238                let left_match = Self::unify(l1, l2, bindings)?;
239                let right_match = Self::unify(r1, r2, bindings)?;
240                Ok(left_match && right_match)
241            }
242
243            // Negation - inner expression must unify
244            (Expression::Not(e1), Expression::Not(e2)) => Self::unify(e1, e2, bindings),
245
246            // Different expression types - cannot unify
247            _ => Ok(false),
248        }
249    }
250
251    /// Match expression against facts and extract bindings
252    ///
253    /// This evaluates an expression against the current facts,
254    /// binding any variables to their matched values.
255    pub fn match_expression(
256        expr: &Expression,
257        facts: &Facts,
258        bindings: &mut Bindings,
259    ) -> Result<bool> {
260        match expr {
261            Expression::Variable(var) => {
262                // Unbound variable - cannot match
263                if !bindings.is_bound(var) {
264                    return Ok(false);
265                }
266                Ok(true)
267            }
268
269            Expression::Field(field_name) => {
270                // Field must exist in facts
271                Ok(facts.get(field_name).is_some())
272            }
273
274            Expression::Literal(_) => {
275                // Literals always match
276                Ok(true)
277            }
278
279            Expression::Comparison {
280                left,
281                operator,
282                right,
283            } => {
284                // Evaluate both sides with bindings
285                let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
286                let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
287
288                // Perform comparison
289                let result = match operator {
290                    crate::types::Operator::Equal => left_val == right_val,
291                    crate::types::Operator::NotEqual => left_val != right_val,
292                    crate::types::Operator::GreaterThan => {
293                        Self::compare_values(&left_val, &right_val)? > 0
294                    }
295                    crate::types::Operator::LessThan => {
296                        Self::compare_values(&left_val, &right_val)? < 0
297                    }
298                    crate::types::Operator::GreaterThanOrEqual => {
299                        Self::compare_values(&left_val, &right_val)? >= 0
300                    }
301                    crate::types::Operator::LessThanOrEqual => {
302                        Self::compare_values(&left_val, &right_val)? <= 0
303                    }
304                    _ => {
305                        return Err(RuleEngineError::ExecutionError(format!(
306                            "Unsupported operator: {:?}",
307                            operator
308                        )));
309                    }
310                };
311
312                Ok(result)
313            }
314
315            Expression::And { left, right } => {
316                let left_match = Self::match_expression(left, facts, bindings)?;
317                if !left_match {
318                    return Ok(false);
319                }
320                Self::match_expression(right, facts, bindings)
321            }
322
323            Expression::Or { left, right } => {
324                let left_match = Self::match_expression(left, facts, bindings)?;
325                if left_match {
326                    return Ok(true);
327                }
328                Self::match_expression(right, facts, bindings)
329            }
330
331            Expression::Not(expr) => {
332                let result = Self::match_expression(expr, facts, bindings)?;
333                Ok(!result)
334            }
335        }
336    }
337
338    /// Evaluate expression with variable bindings
339    ///
340    /// This evaluates an expression to a value, substituting any bound variables.
341    pub fn evaluate_with_bindings(
342        expr: &Expression,
343        facts: &Facts,
344        bindings: &Bindings,
345    ) -> Result<Value> {
346        match expr {
347            Expression::Variable(var) => bindings.get(var).cloned().ok_or_else(|| {
348                RuleEngineError::ExecutionError(format!("Unbound variable: {}", var))
349            }),
350
351            Expression::Field(field) => facts.get(field).ok_or_else(|| {
352                RuleEngineError::ExecutionError(format!("Field not found: {}", field))
353            }),
354
355            Expression::Literal(val) => Ok(val.clone()),
356
357            Expression::Comparison {
358                left,
359                operator,
360                right,
361            } => {
362                let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
363                let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
364
365                let result = match operator {
366                    crate::types::Operator::Equal => left_val == right_val,
367                    crate::types::Operator::NotEqual => left_val != right_val,
368                    crate::types::Operator::GreaterThan => {
369                        Self::compare_values(&left_val, &right_val)? > 0
370                    }
371                    crate::types::Operator::LessThan => {
372                        Self::compare_values(&left_val, &right_val)? < 0
373                    }
374                    crate::types::Operator::GreaterThanOrEqual => {
375                        Self::compare_values(&left_val, &right_val)? >= 0
376                    }
377                    crate::types::Operator::LessThanOrEqual => {
378                        Self::compare_values(&left_val, &right_val)? <= 0
379                    }
380                    _ => {
381                        return Err(RuleEngineError::ExecutionError(format!(
382                            "Unsupported operator: {:?}",
383                            operator
384                        )));
385                    }
386                };
387
388                Ok(Value::Boolean(result))
389            }
390
391            Expression::And { left, right } => {
392                let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
393                if !left_val.to_bool() {
394                    return Ok(Value::Boolean(false));
395                }
396                let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
397                Ok(Value::Boolean(right_val.to_bool()))
398            }
399
400            Expression::Or { left, right } => {
401                let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
402                if left_val.to_bool() {
403                    return Ok(Value::Boolean(true));
404                }
405                let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
406                Ok(Value::Boolean(right_val.to_bool()))
407            }
408
409            Expression::Not(expr) => {
410                let value = Self::evaluate_with_bindings(expr, facts, bindings)?;
411                Ok(Value::Boolean(!value.to_bool()))
412            }
413        }
414    }
415
416    /// Extract a value from an expression (if possible)
417    fn expression_to_value(expr: &Expression, bindings: &Bindings) -> Result<Option<Value>> {
418        match expr {
419            Expression::Literal(val) => Ok(Some(val.clone())),
420            Expression::Variable(var) => Ok(bindings.get(var).cloned()),
421            _ => Ok(None), // Cannot extract value from complex expressions
422        }
423    }
424
425    /// Compare two values for ordering
426    fn compare_values(left: &Value, right: &Value) -> Result<i32> {
427        match (left, right) {
428            (Value::Number(a), Value::Number(b)) => {
429                if a < b {
430                    Ok(-1)
431                } else if a > b {
432                    Ok(1)
433                } else {
434                    Ok(0)
435                }
436            }
437            (Value::String(a), Value::String(b)) => Ok(a.cmp(b) as i32),
438            (Value::Boolean(a), Value::Boolean(b)) => Ok(a.cmp(b) as i32),
439            _ => Err(RuleEngineError::ExecutionError(format!(
440                "Cannot compare values: {:?} and {:?}",
441                left, right
442            ))),
443        }
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use crate::types::Operator;
451
452    #[test]
453    fn test_bindings_basic() {
454        let mut bindings = Bindings::new();
455
456        assert!(bindings.is_empty());
457        assert_eq!(bindings.len(), 0);
458
459        bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
460
461        assert!(!bindings.is_empty());
462        assert_eq!(bindings.len(), 1);
463        assert!(bindings.is_bound("X"));
464        assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
465    }
466
467    #[test]
468    fn test_bindings_conflict() {
469        let mut bindings = Bindings::new();
470
471        bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
472
473        // Same value - should succeed
474        assert!(bindings.bind("X".to_string(), Value::Number(42.0)).is_ok());
475
476        // Different value - should fail
477        assert!(bindings
478            .bind("X".to_string(), Value::Number(100.0))
479            .is_err());
480    }
481
482    #[test]
483    fn test_bindings_merge() {
484        let mut bindings1 = Bindings::new();
485        let mut bindings2 = Bindings::new();
486
487        bindings1
488            .bind("X".to_string(), Value::Number(42.0))
489            .unwrap();
490        bindings2
491            .bind("Y".to_string(), Value::String("hello".to_string()))
492            .unwrap();
493
494        bindings1.merge(&bindings2).unwrap();
495
496        assert_eq!(bindings1.len(), 2);
497        assert_eq!(bindings1.get("X"), Some(&Value::Number(42.0)));
498        assert_eq!(
499            bindings1.get("Y"),
500            Some(&Value::String("hello".to_string()))
501        );
502    }
503
504    #[test]
505    fn test_bindings_merge_conflict() {
506        let mut bindings1 = Bindings::new();
507        let mut bindings2 = Bindings::new();
508
509        bindings1
510            .bind("X".to_string(), Value::Number(42.0))
511            .unwrap();
512        bindings2
513            .bind("X".to_string(), Value::Number(100.0))
514            .unwrap();
515
516        // Should fail due to conflict
517        assert!(bindings1.merge(&bindings2).is_err());
518    }
519
520    #[test]
521    fn test_unify_variable_with_literal() {
522        let mut bindings = Bindings::new();
523
524        let var = Expression::Variable("X".to_string());
525        let lit = Expression::Literal(Value::Number(42.0));
526
527        let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
528
529        assert!(result);
530        assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
531    }
532
533    #[test]
534    fn test_unify_bound_variable() {
535        let mut bindings = Bindings::new();
536        bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
537
538        let var = Expression::Variable("X".to_string());
539        let lit = Expression::Literal(Value::Number(42.0));
540
541        // Should succeed - same value
542        let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
543        assert!(result);
544
545        // Should fail - different value
546        let lit2 = Expression::Literal(Value::Number(100.0));
547        let result2 = Unifier::unify(&var, &lit2, &mut bindings);
548        assert!(result2.is_err() || !result2.unwrap());
549    }
550
551    #[test]
552    fn test_unify_two_literals() {
553        let mut bindings = Bindings::new();
554
555        let lit1 = Expression::Literal(Value::Number(42.0));
556        let lit2 = Expression::Literal(Value::Number(42.0));
557        let lit3 = Expression::Literal(Value::Number(100.0));
558
559        assert!(Unifier::unify(&lit1, &lit2, &mut bindings).unwrap());
560        assert!(!Unifier::unify(&lit1, &lit3, &mut bindings).unwrap());
561    }
562
563    #[test]
564    fn test_match_expression_simple() {
565        let facts = Facts::new();
566        facts.set("User.IsVIP", Value::Boolean(true));
567
568        let mut bindings = Bindings::new();
569
570        let expr = Expression::Comparison {
571            left: Box::new(Expression::Field("User.IsVIP".to_string())),
572            operator: Operator::Equal,
573            right: Box::new(Expression::Literal(Value::Boolean(true))),
574        };
575
576        let result = Unifier::match_expression(&expr, &facts, &mut bindings).unwrap();
577        assert!(result);
578    }
579
580    #[test]
581    fn test_evaluate_with_bindings() {
582        let facts = Facts::new();
583        facts.set("Order.Amount", Value::Number(100.0));
584
585        let mut bindings = Bindings::new();
586        bindings.bind("X".to_string(), Value::Number(50.0)).unwrap();
587
588        // Evaluate variable
589        let var_expr = Expression::Variable("X".to_string());
590        let result = Unifier::evaluate_with_bindings(&var_expr, &facts, &bindings).unwrap();
591        assert_eq!(result, Value::Number(50.0));
592
593        // Evaluate field
594        let field_expr = Expression::Field("Order.Amount".to_string());
595        let result = Unifier::evaluate_with_bindings(&field_expr, &facts, &bindings).unwrap();
596        assert_eq!(result, Value::Number(100.0));
597    }
598
599    #[test]
600    fn test_compare_values() {
601        assert_eq!(
602            Unifier::compare_values(&Value::Number(10.0), &Value::Number(20.0)).unwrap(),
603            -1
604        );
605        assert_eq!(
606            Unifier::compare_values(&Value::Number(20.0), &Value::Number(10.0)).unwrap(),
607            1
608        );
609        assert_eq!(
610            Unifier::compare_values(&Value::Number(10.0), &Value::Number(10.0)).unwrap(),
611            0
612        );
613    }
614}