rust_rule_engine/backward/
unification.rs

1//! Unification module for backward chaining
2//!
3//! This module provides variable binding and unification capabilities
4//! for pattern matching in backward chaining queries.
5//!
6//! # Example
7//! ```ignore
8//! use rust_rule_engine::backward::unification::{Bindings, Unifier};
9//! use rust_rule_engine::backward::expression::Expression;
10//!
11//! let mut bindings = Bindings::new();
12//!
13//! // Unify variable with value
14//! let var = Expression::Variable("X".to_string());
15//! let val = Expression::Literal(Value::Number(42.0));
16//!
17//! Unifier::unify(&var, &val, &mut bindings)?;
18//! assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
19//! ```
20
21use crate::types::Value;
22use crate::errors::{Result, RuleEngineError};
23use crate::Facts;
24use super::expression::Expression;
25use std::collections::HashMap;
26
27/// Variable bindings during proof
28///
29/// This type manages variable-to-value mappings during backward chaining,
30/// with support for merging, conflict detection, and binding propagation.
31#[derive(Debug, Clone)]
32pub struct Bindings {
33    /// Map from variable name to value
34    bindings: HashMap<String, Value>,
35}
36
37impl Bindings {
38    /// Create a new empty bindings set
39    pub fn new() -> Self {
40        Self {
41            bindings: HashMap::new(),
42        }
43    }
44
45    /// Bind a variable to a value
46    ///
47    /// If the variable is already bound, this checks that the new value
48    /// matches the existing binding. Returns an error if there's a conflict.
49    ///
50    /// # Example
51    /// ```ignore
52    /// let mut bindings = Bindings::new();
53    /// bindings.bind("X".to_string(), Value::Number(42.0))?;
54    ///
55    /// // This will succeed (same value)
56    /// bindings.bind("X".to_string(), Value::Number(42.0))?;
57    ///
58    /// // This will fail (different value)
59    /// bindings.bind("X".to_string(), Value::Number(100.0))?; // Error!
60    /// ```
61    pub fn bind(&mut self, var_name: String, value: Value) -> Result<()> {
62        // Check if already bound
63        if let Some(existing) = self.bindings.get(&var_name) {
64            // Must match existing binding
65            if existing != &value {
66                return Err(RuleEngineError::ExecutionError(
67                    format!(
68                        "Variable binding conflict: {} is already bound to {:?}, cannot rebind to {:?}",
69                        var_name, existing, value
70                    )
71                ));
72            }
73        } else {
74            self.bindings.insert(var_name, value);
75        }
76        Ok(())
77    }
78
79    /// Get binding for a variable
80    pub fn get(&self, var_name: &str) -> Option<&Value> {
81        self.bindings.get(var_name)
82    }
83
84    /// Check if variable is bound
85    pub fn is_bound(&self, var_name: &str) -> bool {
86        self.bindings.contains_key(var_name)
87    }
88
89    /// Merge bindings from another set
90    ///
91    /// This attempts to merge all bindings from `other` into this set.
92    /// If any conflicts are detected, returns an error and leaves this set unchanged.
93    pub fn merge(&mut self, other: &Bindings) -> Result<()> {
94        for (var, val) in &other.bindings {
95            self.bind(var.clone(), val.clone())?;
96        }
97        Ok(())
98    }
99
100    /// Get all bindings as a map
101    pub fn as_map(&self) -> &HashMap<String, Value> {
102        &self.bindings
103    }
104
105    /// Get number of bindings
106    pub fn len(&self) -> usize {
107        self.bindings.len()
108    }
109
110    /// Check if bindings is empty
111    pub fn is_empty(&self) -> bool {
112        self.bindings.is_empty()
113    }
114
115    /// Clear all bindings
116    pub fn clear(&mut self) {
117        self.bindings.clear();
118    }
119
120    /// Create bindings from a HashMap
121    pub fn from_map(map: HashMap<String, Value>) -> Self {
122        Self { bindings: map }
123    }
124
125    /// Convert bindings to HashMap (for backward compatibility)
126    pub fn into_map(self) -> HashMap<String, Value> {
127        self.bindings
128    }
129
130    /// Get bindings as HashMap clone
131    pub fn to_map(&self) -> HashMap<String, Value> {
132        self.bindings.clone()
133    }
134}
135
136impl Default for Bindings {
137    fn default() -> Self {
138        Self::new()
139    }
140}
141
142/// Unification algorithm for pattern matching
143///
144/// The Unifier provides algorithms for:
145/// - Unifying two expressions with variable bindings
146/// - Matching expressions against facts
147/// - Evaluating expressions with variable substitution
148pub struct Unifier;
149
150impl Unifier {
151    /// Unify two expressions with variable bindings
152    ///
153    /// This is the core unification algorithm. It attempts to make two expressions
154    /// equal by binding variables to values.
155    ///
156    /// # Returns
157    /// - `Ok(true)` if unification succeeded
158    /// - `Ok(false)` if expressions cannot be unified
159    /// - `Err(_)` if there's a binding conflict
160    pub fn unify(
161        left: &Expression,
162        right: &Expression,
163        bindings: &mut Bindings,
164    ) -> Result<bool> {
165        match (left, right) {
166            // Variable on left
167            (Expression::Variable(var), expr) => {
168                if let Some(bound_value) = bindings.get(var) {
169                    // Variable already bound - check if it matches
170                    Self::unify(
171                        &Expression::Literal(bound_value.clone()),
172                        expr,
173                        bindings
174                    )
175                } else {
176                    // Bind variable to expression value
177                    if let Some(value) = Self::expression_to_value(expr, bindings)? {
178                        bindings.bind(var.clone(), value)?;
179                        Ok(true)
180                    } else {
181                        // Cannot extract value from expression yet
182                        Ok(false)
183                    }
184                }
185            }
186
187            // Variable on right (symmetric)
188            (expr, Expression::Variable(var)) => {
189                Self::unify(&Expression::Variable(var.clone()), expr, bindings)
190            }
191
192            // Two literals - must be equal
193            (Expression::Literal(v1), Expression::Literal(v2)) => {
194                Ok(v1 == v2)
195            }
196
197            // Two fields - must be same field
198            (Expression::Field(f1), Expression::Field(f2)) => {
199                Ok(f1 == f2)
200            }
201
202            // Comparison - both sides must unify
203            (
204                Expression::Comparison { left: l1, operator: op1, right: r1 },
205                Expression::Comparison { left: l2, operator: op2, right: r2 }
206            ) => {
207                if op1 != op2 {
208                    return Ok(false);
209                }
210
211                let left_match = Self::unify(l1, l2, bindings)?;
212                let right_match = Self::unify(r1, r2, bindings)?;
213
214                Ok(left_match && right_match)
215            }
216
217            // Logical AND - both sides must unify
218            (
219                Expression::And { left: l1, right: r1 },
220                Expression::And { left: l2, right: r2 }
221            ) => {
222                let left_match = Self::unify(l1, l2, bindings)?;
223                let right_match = Self::unify(r1, r2, bindings)?;
224                Ok(left_match && right_match)
225            }
226
227            // Logical OR - both sides must unify
228            (
229                Expression::Or { left: l1, right: r1 },
230                Expression::Or { left: l2, right: r2 }
231            ) => {
232                let left_match = Self::unify(l1, l2, bindings)?;
233                let right_match = Self::unify(r1, r2, bindings)?;
234                Ok(left_match && right_match)
235            }
236
237            // Negation - inner expression must unify
238            (Expression::Not(e1), Expression::Not(e2)) => {
239                Self::unify(e1, e2, bindings)
240            }
241
242            // Different expression types - cannot unify
243            _ => Ok(false),
244        }
245    }
246
247    /// Match expression against facts and extract bindings
248    ///
249    /// This evaluates an expression against the current facts,
250    /// binding any variables to their matched values.
251    pub fn match_expression(
252        expr: &Expression,
253        facts: &Facts,
254        bindings: &mut Bindings,
255    ) -> Result<bool> {
256        match expr {
257            Expression::Variable(var) => {
258                // Unbound variable - cannot match
259                if !bindings.is_bound(var) {
260                    return Ok(false);
261                }
262                Ok(true)
263            }
264
265            Expression::Field(field_name) => {
266                // Field must exist in facts
267                Ok(facts.get(field_name).is_some())
268            }
269
270            Expression::Literal(_) => {
271                // Literals always match
272                Ok(true)
273            }
274
275            Expression::Comparison { left, operator, right } => {
276                // Evaluate both sides with bindings
277                let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
278                let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
279
280                // Perform comparison
281                let result = match operator {
282                    crate::types::Operator::Equal => left_val == right_val,
283                    crate::types::Operator::NotEqual => left_val != right_val,
284                    crate::types::Operator::GreaterThan => {
285                        Self::compare_values(&left_val, &right_val)? > 0
286                    }
287                    crate::types::Operator::LessThan => {
288                        Self::compare_values(&left_val, &right_val)? < 0
289                    }
290                    crate::types::Operator::GreaterThanOrEqual => {
291                        Self::compare_values(&left_val, &right_val)? >= 0
292                    }
293                    crate::types::Operator::LessThanOrEqual => {
294                        Self::compare_values(&left_val, &right_val)? <= 0
295                    }
296                    _ => {
297                        return Err(RuleEngineError::ExecutionError(
298                            format!("Unsupported operator: {:?}", operator)
299                        ));
300                    }
301                };
302
303                Ok(result)
304            }
305
306            Expression::And { left, right } => {
307                let left_match = Self::match_expression(left, facts, bindings)?;
308                if !left_match {
309                    return Ok(false);
310                }
311                Self::match_expression(right, facts, bindings)
312            }
313
314            Expression::Or { left, right } => {
315                let left_match = Self::match_expression(left, facts, bindings)?;
316                if left_match {
317                    return Ok(true);
318                }
319                Self::match_expression(right, facts, bindings)
320            }
321
322            Expression::Not(expr) => {
323                let result = Self::match_expression(expr, facts, bindings)?;
324                Ok(!result)
325            }
326        }
327    }
328
329    /// Evaluate expression with variable bindings
330    ///
331    /// This evaluates an expression to a value, substituting any bound variables.
332    pub fn evaluate_with_bindings(
333        expr: &Expression,
334        facts: &Facts,
335        bindings: &Bindings,
336    ) -> Result<Value> {
337        match expr {
338            Expression::Variable(var) => {
339                bindings.get(var)
340                    .cloned()
341                    .ok_or_else(|| RuleEngineError::ExecutionError(
342                        format!("Unbound variable: {}", var)
343                    ))
344            }
345
346            Expression::Field(field) => {
347                facts.get(field)
348                    .ok_or_else(|| RuleEngineError::ExecutionError(
349                        format!("Field not found: {}", field)
350                    ))
351            }
352
353            Expression::Literal(val) => Ok(val.clone()),
354
355            Expression::Comparison { left, operator, right } => {
356                let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
357                let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
358
359                let result = match operator {
360                    crate::types::Operator::Equal => left_val == right_val,
361                    crate::types::Operator::NotEqual => left_val != right_val,
362                    crate::types::Operator::GreaterThan => {
363                        Self::compare_values(&left_val, &right_val)? > 0
364                    }
365                    crate::types::Operator::LessThan => {
366                        Self::compare_values(&left_val, &right_val)? < 0
367                    }
368                    crate::types::Operator::GreaterThanOrEqual => {
369                        Self::compare_values(&left_val, &right_val)? >= 0
370                    }
371                    crate::types::Operator::LessThanOrEqual => {
372                        Self::compare_values(&left_val, &right_val)? <= 0
373                    }
374                    _ => {
375                        return Err(RuleEngineError::ExecutionError(
376                            format!("Unsupported operator: {:?}", operator)
377                        ));
378                    }
379                };
380
381                Ok(Value::Boolean(result))
382            }
383
384            Expression::And { left, right } => {
385                let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
386                if !left_val.to_bool() {
387                    return Ok(Value::Boolean(false));
388                }
389                let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
390                Ok(Value::Boolean(right_val.to_bool()))
391            }
392
393            Expression::Or { left, right } => {
394                let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
395                if left_val.to_bool() {
396                    return Ok(Value::Boolean(true));
397                }
398                let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
399                Ok(Value::Boolean(right_val.to_bool()))
400            }
401
402            Expression::Not(expr) => {
403                let value = Self::evaluate_with_bindings(expr, facts, bindings)?;
404                Ok(Value::Boolean(!value.to_bool()))
405            }
406        }
407    }
408
409    /// Extract a value from an expression (if possible)
410    fn expression_to_value(expr: &Expression, bindings: &Bindings) -> Result<Option<Value>> {
411        match expr {
412            Expression::Literal(val) => Ok(Some(val.clone())),
413            Expression::Variable(var) => Ok(bindings.get(var).cloned()),
414            _ => Ok(None), // Cannot extract value from complex expressions
415        }
416    }
417
418    /// Compare two values for ordering
419    fn compare_values(left: &Value, right: &Value) -> Result<i32> {
420        match (left, right) {
421            (Value::Number(a), Value::Number(b)) => {
422                if a < b {
423                    Ok(-1)
424                } else if a > b {
425                    Ok(1)
426                } else {
427                    Ok(0)
428                }
429            }
430            (Value::String(a), Value::String(b)) => {
431                Ok(a.cmp(b) as i32)
432            }
433            (Value::Boolean(a), Value::Boolean(b)) => {
434                Ok(a.cmp(b) as i32)
435            }
436            _ => Err(RuleEngineError::ExecutionError(
437                format!("Cannot compare values: {:?} and {:?}", left, right)
438            ))
439        }
440    }
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use crate::types::Operator;
447
448    #[test]
449    fn test_bindings_basic() {
450        let mut bindings = Bindings::new();
451
452        assert!(bindings.is_empty());
453        assert_eq!(bindings.len(), 0);
454
455        bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
456
457        assert!(!bindings.is_empty());
458        assert_eq!(bindings.len(), 1);
459        assert!(bindings.is_bound("X"));
460        assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
461    }
462
463    #[test]
464    fn test_bindings_conflict() {
465        let mut bindings = Bindings::new();
466
467        bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
468
469        // Same value - should succeed
470        assert!(bindings.bind("X".to_string(), Value::Number(42.0)).is_ok());
471
472        // Different value - should fail
473        assert!(bindings.bind("X".to_string(), Value::Number(100.0)).is_err());
474    }
475
476    #[test]
477    fn test_bindings_merge() {
478        let mut bindings1 = Bindings::new();
479        let mut bindings2 = Bindings::new();
480
481        bindings1.bind("X".to_string(), Value::Number(42.0)).unwrap();
482        bindings2.bind("Y".to_string(), Value::String("hello".to_string())).unwrap();
483
484        bindings1.merge(&bindings2).unwrap();
485
486        assert_eq!(bindings1.len(), 2);
487        assert_eq!(bindings1.get("X"), Some(&Value::Number(42.0)));
488        assert_eq!(bindings1.get("Y"), Some(&Value::String("hello".to_string())));
489    }
490
491    #[test]
492    fn test_bindings_merge_conflict() {
493        let mut bindings1 = Bindings::new();
494        let mut bindings2 = Bindings::new();
495
496        bindings1.bind("X".to_string(), Value::Number(42.0)).unwrap();
497        bindings2.bind("X".to_string(), Value::Number(100.0)).unwrap();
498
499        // Should fail due to conflict
500        assert!(bindings1.merge(&bindings2).is_err());
501    }
502
503    #[test]
504    fn test_unify_variable_with_literal() {
505        let mut bindings = Bindings::new();
506
507        let var = Expression::Variable("X".to_string());
508        let lit = Expression::Literal(Value::Number(42.0));
509
510        let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
511
512        assert!(result);
513        assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
514    }
515
516    #[test]
517    fn test_unify_bound_variable() {
518        let mut bindings = Bindings::new();
519        bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
520
521        let var = Expression::Variable("X".to_string());
522        let lit = Expression::Literal(Value::Number(42.0));
523
524        // Should succeed - same value
525        let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
526        assert!(result);
527
528        // Should fail - different value
529        let lit2 = Expression::Literal(Value::Number(100.0));
530        let result2 = Unifier::unify(&var, &lit2, &mut bindings);
531        assert!(result2.is_err() || !result2.unwrap());
532    }
533
534    #[test]
535    fn test_unify_two_literals() {
536        let mut bindings = Bindings::new();
537
538        let lit1 = Expression::Literal(Value::Number(42.0));
539        let lit2 = Expression::Literal(Value::Number(42.0));
540        let lit3 = Expression::Literal(Value::Number(100.0));
541
542        assert!(Unifier::unify(&lit1, &lit2, &mut bindings).unwrap());
543        assert!(!Unifier::unify(&lit1, &lit3, &mut bindings).unwrap());
544    }
545
546    #[test]
547    fn test_match_expression_simple() {
548        let mut facts = Facts::new();
549        facts.set("User.IsVIP", Value::Boolean(true));
550
551        let mut bindings = Bindings::new();
552
553        let expr = Expression::Comparison {
554            left: Box::new(Expression::Field("User.IsVIP".to_string())),
555            operator: Operator::Equal,
556            right: Box::new(Expression::Literal(Value::Boolean(true))),
557        };
558
559        let result = Unifier::match_expression(&expr, &facts, &mut bindings).unwrap();
560        assert!(result);
561    }
562
563    #[test]
564    fn test_evaluate_with_bindings() {
565        let mut facts = Facts::new();
566        facts.set("Order.Amount", Value::Number(100.0));
567
568        let mut bindings = Bindings::new();
569        bindings.bind("X".to_string(), Value::Number(50.0)).unwrap();
570
571        // Evaluate variable
572        let var_expr = Expression::Variable("X".to_string());
573        let result = Unifier::evaluate_with_bindings(&var_expr, &facts, &bindings).unwrap();
574        assert_eq!(result, Value::Number(50.0));
575
576        // Evaluate field
577        let field_expr = Expression::Field("Order.Amount".to_string());
578        let result = Unifier::evaluate_with_bindings(&field_expr, &facts, &bindings).unwrap();
579        assert_eq!(result, Value::Number(100.0));
580    }
581
582    #[test]
583    fn test_compare_values() {
584        assert_eq!(
585            Unifier::compare_values(&Value::Number(10.0), &Value::Number(20.0)).unwrap(),
586            -1
587        );
588        assert_eq!(
589            Unifier::compare_values(&Value::Number(20.0), &Value::Number(10.0)).unwrap(),
590            1
591        );
592        assert_eq!(
593            Unifier::compare_values(&Value::Number(10.0), &Value::Number(10.0)).unwrap(),
594            0
595        );
596    }
597}