1use crate::types::Value;
22use crate::errors::{Result, RuleEngineError};
23use crate::Facts;
24use super::expression::Expression;
25use std::collections::HashMap;
26
27#[derive(Debug, Clone)]
32pub struct Bindings {
33 bindings: HashMap<String, Value>,
35}
36
37impl Bindings {
38 pub fn new() -> Self {
40 Self {
41 bindings: HashMap::new(),
42 }
43 }
44
45 pub fn bind(&mut self, var_name: String, value: Value) -> Result<()> {
62 if let Some(existing) = self.bindings.get(&var_name) {
64 if existing != &value {
66 return Err(RuleEngineError::ExecutionError(
67 format!(
68 "Variable binding conflict: {} is already bound to {:?}, cannot rebind to {:?}",
69 var_name, existing, value
70 )
71 ));
72 }
73 } else {
74 self.bindings.insert(var_name, value);
75 }
76 Ok(())
77 }
78
79 pub fn get(&self, var_name: &str) -> Option<&Value> {
81 self.bindings.get(var_name)
82 }
83
84 pub fn is_bound(&self, var_name: &str) -> bool {
86 self.bindings.contains_key(var_name)
87 }
88
89 pub fn merge(&mut self, other: &Bindings) -> Result<()> {
94 for (var, val) in &other.bindings {
95 self.bind(var.clone(), val.clone())?;
96 }
97 Ok(())
98 }
99
100 pub fn as_map(&self) -> &HashMap<String, Value> {
102 &self.bindings
103 }
104
105 pub fn len(&self) -> usize {
107 self.bindings.len()
108 }
109
110 pub fn is_empty(&self) -> bool {
112 self.bindings.is_empty()
113 }
114
115 pub fn clear(&mut self) {
117 self.bindings.clear();
118 }
119
120 pub fn from_map(map: HashMap<String, Value>) -> Self {
122 Self { bindings: map }
123 }
124
125 pub fn into_map(self) -> HashMap<String, Value> {
127 self.bindings
128 }
129
130 pub fn to_map(&self) -> HashMap<String, Value> {
132 self.bindings.clone()
133 }
134}
135
136impl Default for Bindings {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142pub struct Unifier;
149
150impl Unifier {
151 pub fn unify(
161 left: &Expression,
162 right: &Expression,
163 bindings: &mut Bindings,
164 ) -> Result<bool> {
165 match (left, right) {
166 (Expression::Variable(var), expr) => {
168 if let Some(bound_value) = bindings.get(var) {
169 Self::unify(
171 &Expression::Literal(bound_value.clone()),
172 expr,
173 bindings
174 )
175 } else {
176 if let Some(value) = Self::expression_to_value(expr, bindings)? {
178 bindings.bind(var.clone(), value)?;
179 Ok(true)
180 } else {
181 Ok(false)
183 }
184 }
185 }
186
187 (expr, Expression::Variable(var)) => {
189 Self::unify(&Expression::Variable(var.clone()), expr, bindings)
190 }
191
192 (Expression::Literal(v1), Expression::Literal(v2)) => {
194 Ok(v1 == v2)
195 }
196
197 (Expression::Field(f1), Expression::Field(f2)) => {
199 Ok(f1 == f2)
200 }
201
202 (
204 Expression::Comparison { left: l1, operator: op1, right: r1 },
205 Expression::Comparison { left: l2, operator: op2, right: r2 }
206 ) => {
207 if op1 != op2 {
208 return Ok(false);
209 }
210
211 let left_match = Self::unify(l1, l2, bindings)?;
212 let right_match = Self::unify(r1, r2, bindings)?;
213
214 Ok(left_match && right_match)
215 }
216
217 (
219 Expression::And { left: l1, right: r1 },
220 Expression::And { left: l2, right: r2 }
221 ) => {
222 let left_match = Self::unify(l1, l2, bindings)?;
223 let right_match = Self::unify(r1, r2, bindings)?;
224 Ok(left_match && right_match)
225 }
226
227 (
229 Expression::Or { left: l1, right: r1 },
230 Expression::Or { left: l2, right: r2 }
231 ) => {
232 let left_match = Self::unify(l1, l2, bindings)?;
233 let right_match = Self::unify(r1, r2, bindings)?;
234 Ok(left_match && right_match)
235 }
236
237 (Expression::Not(e1), Expression::Not(e2)) => {
239 Self::unify(e1, e2, bindings)
240 }
241
242 _ => Ok(false),
244 }
245 }
246
247 pub fn match_expression(
252 expr: &Expression,
253 facts: &Facts,
254 bindings: &mut Bindings,
255 ) -> Result<bool> {
256 match expr {
257 Expression::Variable(var) => {
258 if !bindings.is_bound(var) {
260 return Ok(false);
261 }
262 Ok(true)
263 }
264
265 Expression::Field(field_name) => {
266 Ok(facts.get(field_name).is_some())
268 }
269
270 Expression::Literal(_) => {
271 Ok(true)
273 }
274
275 Expression::Comparison { left, operator, right } => {
276 let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
278 let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
279
280 let result = match operator {
282 crate::types::Operator::Equal => left_val == right_val,
283 crate::types::Operator::NotEqual => left_val != right_val,
284 crate::types::Operator::GreaterThan => {
285 Self::compare_values(&left_val, &right_val)? > 0
286 }
287 crate::types::Operator::LessThan => {
288 Self::compare_values(&left_val, &right_val)? < 0
289 }
290 crate::types::Operator::GreaterThanOrEqual => {
291 Self::compare_values(&left_val, &right_val)? >= 0
292 }
293 crate::types::Operator::LessThanOrEqual => {
294 Self::compare_values(&left_val, &right_val)? <= 0
295 }
296 _ => {
297 return Err(RuleEngineError::ExecutionError(
298 format!("Unsupported operator: {:?}", operator)
299 ));
300 }
301 };
302
303 Ok(result)
304 }
305
306 Expression::And { left, right } => {
307 let left_match = Self::match_expression(left, facts, bindings)?;
308 if !left_match {
309 return Ok(false);
310 }
311 Self::match_expression(right, facts, bindings)
312 }
313
314 Expression::Or { left, right } => {
315 let left_match = Self::match_expression(left, facts, bindings)?;
316 if left_match {
317 return Ok(true);
318 }
319 Self::match_expression(right, facts, bindings)
320 }
321
322 Expression::Not(expr) => {
323 let result = Self::match_expression(expr, facts, bindings)?;
324 Ok(!result)
325 }
326 }
327 }
328
329 pub fn evaluate_with_bindings(
333 expr: &Expression,
334 facts: &Facts,
335 bindings: &Bindings,
336 ) -> Result<Value> {
337 match expr {
338 Expression::Variable(var) => {
339 bindings.get(var)
340 .cloned()
341 .ok_or_else(|| RuleEngineError::ExecutionError(
342 format!("Unbound variable: {}", var)
343 ))
344 }
345
346 Expression::Field(field) => {
347 facts.get(field)
348 .ok_or_else(|| RuleEngineError::ExecutionError(
349 format!("Field not found: {}", field)
350 ))
351 }
352
353 Expression::Literal(val) => Ok(val.clone()),
354
355 Expression::Comparison { left, operator, right } => {
356 let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
357 let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
358
359 let result = match operator {
360 crate::types::Operator::Equal => left_val == right_val,
361 crate::types::Operator::NotEqual => left_val != right_val,
362 crate::types::Operator::GreaterThan => {
363 Self::compare_values(&left_val, &right_val)? > 0
364 }
365 crate::types::Operator::LessThan => {
366 Self::compare_values(&left_val, &right_val)? < 0
367 }
368 crate::types::Operator::GreaterThanOrEqual => {
369 Self::compare_values(&left_val, &right_val)? >= 0
370 }
371 crate::types::Operator::LessThanOrEqual => {
372 Self::compare_values(&left_val, &right_val)? <= 0
373 }
374 _ => {
375 return Err(RuleEngineError::ExecutionError(
376 format!("Unsupported operator: {:?}", operator)
377 ));
378 }
379 };
380
381 Ok(Value::Boolean(result))
382 }
383
384 Expression::And { left, right } => {
385 let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
386 if !left_val.to_bool() {
387 return Ok(Value::Boolean(false));
388 }
389 let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
390 Ok(Value::Boolean(right_val.to_bool()))
391 }
392
393 Expression::Or { left, right } => {
394 let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
395 if left_val.to_bool() {
396 return Ok(Value::Boolean(true));
397 }
398 let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
399 Ok(Value::Boolean(right_val.to_bool()))
400 }
401
402 Expression::Not(expr) => {
403 let value = Self::evaluate_with_bindings(expr, facts, bindings)?;
404 Ok(Value::Boolean(!value.to_bool()))
405 }
406 }
407 }
408
409 fn expression_to_value(expr: &Expression, bindings: &Bindings) -> Result<Option<Value>> {
411 match expr {
412 Expression::Literal(val) => Ok(Some(val.clone())),
413 Expression::Variable(var) => Ok(bindings.get(var).cloned()),
414 _ => Ok(None), }
416 }
417
418 fn compare_values(left: &Value, right: &Value) -> Result<i32> {
420 match (left, right) {
421 (Value::Number(a), Value::Number(b)) => {
422 if a < b {
423 Ok(-1)
424 } else if a > b {
425 Ok(1)
426 } else {
427 Ok(0)
428 }
429 }
430 (Value::String(a), Value::String(b)) => {
431 Ok(a.cmp(b) as i32)
432 }
433 (Value::Boolean(a), Value::Boolean(b)) => {
434 Ok(a.cmp(b) as i32)
435 }
436 _ => Err(RuleEngineError::ExecutionError(
437 format!("Cannot compare values: {:?} and {:?}", left, right)
438 ))
439 }
440 }
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use crate::types::Operator;
447
448 #[test]
449 fn test_bindings_basic() {
450 let mut bindings = Bindings::new();
451
452 assert!(bindings.is_empty());
453 assert_eq!(bindings.len(), 0);
454
455 bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
456
457 assert!(!bindings.is_empty());
458 assert_eq!(bindings.len(), 1);
459 assert!(bindings.is_bound("X"));
460 assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
461 }
462
463 #[test]
464 fn test_bindings_conflict() {
465 let mut bindings = Bindings::new();
466
467 bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
468
469 assert!(bindings.bind("X".to_string(), Value::Number(42.0)).is_ok());
471
472 assert!(bindings.bind("X".to_string(), Value::Number(100.0)).is_err());
474 }
475
476 #[test]
477 fn test_bindings_merge() {
478 let mut bindings1 = Bindings::new();
479 let mut bindings2 = Bindings::new();
480
481 bindings1.bind("X".to_string(), Value::Number(42.0)).unwrap();
482 bindings2.bind("Y".to_string(), Value::String("hello".to_string())).unwrap();
483
484 bindings1.merge(&bindings2).unwrap();
485
486 assert_eq!(bindings1.len(), 2);
487 assert_eq!(bindings1.get("X"), Some(&Value::Number(42.0)));
488 assert_eq!(bindings1.get("Y"), Some(&Value::String("hello".to_string())));
489 }
490
491 #[test]
492 fn test_bindings_merge_conflict() {
493 let mut bindings1 = Bindings::new();
494 let mut bindings2 = Bindings::new();
495
496 bindings1.bind("X".to_string(), Value::Number(42.0)).unwrap();
497 bindings2.bind("X".to_string(), Value::Number(100.0)).unwrap();
498
499 assert!(bindings1.merge(&bindings2).is_err());
501 }
502
503 #[test]
504 fn test_unify_variable_with_literal() {
505 let mut bindings = Bindings::new();
506
507 let var = Expression::Variable("X".to_string());
508 let lit = Expression::Literal(Value::Number(42.0));
509
510 let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
511
512 assert!(result);
513 assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
514 }
515
516 #[test]
517 fn test_unify_bound_variable() {
518 let mut bindings = Bindings::new();
519 bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
520
521 let var = Expression::Variable("X".to_string());
522 let lit = Expression::Literal(Value::Number(42.0));
523
524 let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
526 assert!(result);
527
528 let lit2 = Expression::Literal(Value::Number(100.0));
530 let result2 = Unifier::unify(&var, &lit2, &mut bindings);
531 assert!(result2.is_err() || !result2.unwrap());
532 }
533
534 #[test]
535 fn test_unify_two_literals() {
536 let mut bindings = Bindings::new();
537
538 let lit1 = Expression::Literal(Value::Number(42.0));
539 let lit2 = Expression::Literal(Value::Number(42.0));
540 let lit3 = Expression::Literal(Value::Number(100.0));
541
542 assert!(Unifier::unify(&lit1, &lit2, &mut bindings).unwrap());
543 assert!(!Unifier::unify(&lit1, &lit3, &mut bindings).unwrap());
544 }
545
546 #[test]
547 fn test_match_expression_simple() {
548 let mut facts = Facts::new();
549 facts.set("User.IsVIP", Value::Boolean(true));
550
551 let mut bindings = Bindings::new();
552
553 let expr = Expression::Comparison {
554 left: Box::new(Expression::Field("User.IsVIP".to_string())),
555 operator: Operator::Equal,
556 right: Box::new(Expression::Literal(Value::Boolean(true))),
557 };
558
559 let result = Unifier::match_expression(&expr, &facts, &mut bindings).unwrap();
560 assert!(result);
561 }
562
563 #[test]
564 fn test_evaluate_with_bindings() {
565 let mut facts = Facts::new();
566 facts.set("Order.Amount", Value::Number(100.0));
567
568 let mut bindings = Bindings::new();
569 bindings.bind("X".to_string(), Value::Number(50.0)).unwrap();
570
571 let var_expr = Expression::Variable("X".to_string());
573 let result = Unifier::evaluate_with_bindings(&var_expr, &facts, &bindings).unwrap();
574 assert_eq!(result, Value::Number(50.0));
575
576 let field_expr = Expression::Field("Order.Amount".to_string());
578 let result = Unifier::evaluate_with_bindings(&field_expr, &facts, &bindings).unwrap();
579 assert_eq!(result, Value::Number(100.0));
580 }
581
582 #[test]
583 fn test_compare_values() {
584 assert_eq!(
585 Unifier::compare_values(&Value::Number(10.0), &Value::Number(20.0)).unwrap(),
586 -1
587 );
588 assert_eq!(
589 Unifier::compare_values(&Value::Number(20.0), &Value::Number(10.0)).unwrap(),
590 1
591 );
592 assert_eq!(
593 Unifier::compare_values(&Value::Number(10.0), &Value::Number(10.0)).unwrap(),
594 0
595 );
596 }
597}