1use super::expression::Expression;
22use crate::errors::{Result, RuleEngineError};
23use crate::types::Value;
24use crate::Facts;
25use std::collections::HashMap;
26
27#[derive(Debug, Clone)]
32pub struct Bindings {
33 bindings: HashMap<String, Value>,
35}
36
37impl Bindings {
38 pub fn new() -> Self {
40 Self {
41 bindings: HashMap::new(),
42 }
43 }
44
45 pub fn bind(&mut self, var_name: String, value: Value) -> Result<()> {
62 if let Some(existing) = self.bindings.get(&var_name) {
64 if existing != &value {
66 return Err(RuleEngineError::ExecutionError(format!(
67 "Variable binding conflict: {} is already bound to {:?}, cannot rebind to {:?}",
68 var_name, existing, value
69 )));
70 }
71 } else {
72 self.bindings.insert(var_name, value);
73 }
74 Ok(())
75 }
76
77 pub fn get(&self, var_name: &str) -> Option<&Value> {
79 self.bindings.get(var_name)
80 }
81
82 pub fn is_bound(&self, var_name: &str) -> bool {
84 self.bindings.contains_key(var_name)
85 }
86
87 pub fn merge(&mut self, other: &Bindings) -> Result<()> {
92 for (var, val) in &other.bindings {
93 self.bind(var.clone(), val.clone())?;
94 }
95 Ok(())
96 }
97
98 pub fn as_map(&self) -> &HashMap<String, Value> {
100 &self.bindings
101 }
102
103 pub fn len(&self) -> usize {
105 self.bindings.len()
106 }
107
108 pub fn is_empty(&self) -> bool {
110 self.bindings.is_empty()
111 }
112
113 pub fn clear(&mut self) {
115 self.bindings.clear();
116 }
117
118 pub fn from_map(map: HashMap<String, Value>) -> Self {
120 Self { bindings: map }
121 }
122
123 pub fn into_map(self) -> HashMap<String, Value> {
125 self.bindings
126 }
127
128 pub fn to_map(&self) -> HashMap<String, Value> {
130 self.bindings.clone()
131 }
132}
133
134impl Default for Bindings {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140pub struct Unifier;
147
148impl Unifier {
149 pub fn unify(left: &Expression, right: &Expression, bindings: &mut Bindings) -> Result<bool> {
159 match (left, right) {
160 (Expression::Variable(var), expr) => {
162 if let Some(bound_value) = bindings.get(var) {
163 Self::unify(&Expression::Literal(bound_value.clone()), expr, bindings)
165 } else {
166 if let Some(value) = Self::expression_to_value(expr, bindings)? {
168 bindings.bind(var.clone(), value)?;
169 Ok(true)
170 } else {
171 Ok(false)
173 }
174 }
175 }
176
177 (expr, Expression::Variable(var)) => {
179 Self::unify(&Expression::Variable(var.clone()), expr, bindings)
180 }
181
182 (Expression::Literal(v1), Expression::Literal(v2)) => Ok(v1 == v2),
184
185 (Expression::Field(f1), Expression::Field(f2)) => Ok(f1 == f2),
187
188 (
190 Expression::Comparison {
191 left: l1,
192 operator: op1,
193 right: r1,
194 },
195 Expression::Comparison {
196 left: l2,
197 operator: op2,
198 right: r2,
199 },
200 ) => {
201 if op1 != op2 {
202 return Ok(false);
203 }
204
205 let left_match = Self::unify(l1, l2, bindings)?;
206 let right_match = Self::unify(r1, r2, bindings)?;
207
208 Ok(left_match && right_match)
209 }
210
211 (
213 Expression::And {
214 left: l1,
215 right: r1,
216 },
217 Expression::And {
218 left: l2,
219 right: r2,
220 },
221 ) => {
222 let left_match = Self::unify(l1, l2, bindings)?;
223 let right_match = Self::unify(r1, r2, bindings)?;
224 Ok(left_match && right_match)
225 }
226
227 (
229 Expression::Or {
230 left: l1,
231 right: r1,
232 },
233 Expression::Or {
234 left: l2,
235 right: r2,
236 },
237 ) => {
238 let left_match = Self::unify(l1, l2, bindings)?;
239 let right_match = Self::unify(r1, r2, bindings)?;
240 Ok(left_match && right_match)
241 }
242
243 (Expression::Not(e1), Expression::Not(e2)) => Self::unify(e1, e2, bindings),
245
246 _ => Ok(false),
248 }
249 }
250
251 pub fn match_expression(
256 expr: &Expression,
257 facts: &Facts,
258 bindings: &mut Bindings,
259 ) -> Result<bool> {
260 match expr {
261 Expression::Variable(var) => {
262 if !bindings.is_bound(var) {
264 return Ok(false);
265 }
266 Ok(true)
267 }
268
269 Expression::Field(field_name) => {
270 Ok(facts.get(field_name).is_some())
272 }
273
274 Expression::Literal(_) => {
275 Ok(true)
277 }
278
279 Expression::Comparison {
280 left,
281 operator,
282 right,
283 } => {
284 let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
286 let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
287
288 let result = match operator {
290 crate::types::Operator::Equal => left_val == right_val,
291 crate::types::Operator::NotEqual => left_val != right_val,
292 crate::types::Operator::GreaterThan => {
293 Self::compare_values(&left_val, &right_val)? > 0
294 }
295 crate::types::Operator::LessThan => {
296 Self::compare_values(&left_val, &right_val)? < 0
297 }
298 crate::types::Operator::GreaterThanOrEqual => {
299 Self::compare_values(&left_val, &right_val)? >= 0
300 }
301 crate::types::Operator::LessThanOrEqual => {
302 Self::compare_values(&left_val, &right_val)? <= 0
303 }
304 _ => {
305 return Err(RuleEngineError::ExecutionError(format!(
306 "Unsupported operator: {:?}",
307 operator
308 )));
309 }
310 };
311
312 Ok(result)
313 }
314
315 Expression::And { left, right } => {
316 let left_match = Self::match_expression(left, facts, bindings)?;
317 if !left_match {
318 return Ok(false);
319 }
320 Self::match_expression(right, facts, bindings)
321 }
322
323 Expression::Or { left, right } => {
324 let left_match = Self::match_expression(left, facts, bindings)?;
325 if left_match {
326 return Ok(true);
327 }
328 Self::match_expression(right, facts, bindings)
329 }
330
331 Expression::Not(expr) => {
332 let result = Self::match_expression(expr, facts, bindings)?;
333 Ok(!result)
334 }
335 }
336 }
337
338 pub fn evaluate_with_bindings(
342 expr: &Expression,
343 facts: &Facts,
344 bindings: &Bindings,
345 ) -> Result<Value> {
346 match expr {
347 Expression::Variable(var) => bindings.get(var).cloned().ok_or_else(|| {
348 RuleEngineError::ExecutionError(format!("Unbound variable: {}", var))
349 }),
350
351 Expression::Field(field) => facts.get(field).ok_or_else(|| {
352 RuleEngineError::ExecutionError(format!("Field not found: {}", field))
353 }),
354
355 Expression::Literal(val) => Ok(val.clone()),
356
357 Expression::Comparison {
358 left,
359 operator,
360 right,
361 } => {
362 let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
363 let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
364
365 let result = match operator {
366 crate::types::Operator::Equal => left_val == right_val,
367 crate::types::Operator::NotEqual => left_val != right_val,
368 crate::types::Operator::GreaterThan => {
369 Self::compare_values(&left_val, &right_val)? > 0
370 }
371 crate::types::Operator::LessThan => {
372 Self::compare_values(&left_val, &right_val)? < 0
373 }
374 crate::types::Operator::GreaterThanOrEqual => {
375 Self::compare_values(&left_val, &right_val)? >= 0
376 }
377 crate::types::Operator::LessThanOrEqual => {
378 Self::compare_values(&left_val, &right_val)? <= 0
379 }
380 _ => {
381 return Err(RuleEngineError::ExecutionError(format!(
382 "Unsupported operator: {:?}",
383 operator
384 )));
385 }
386 };
387
388 Ok(Value::Boolean(result))
389 }
390
391 Expression::And { left, right } => {
392 let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
393 if !left_val.to_bool() {
394 return Ok(Value::Boolean(false));
395 }
396 let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
397 Ok(Value::Boolean(right_val.to_bool()))
398 }
399
400 Expression::Or { left, right } => {
401 let left_val = Self::evaluate_with_bindings(left, facts, bindings)?;
402 if left_val.to_bool() {
403 return Ok(Value::Boolean(true));
404 }
405 let right_val = Self::evaluate_with_bindings(right, facts, bindings)?;
406 Ok(Value::Boolean(right_val.to_bool()))
407 }
408
409 Expression::Not(expr) => {
410 let value = Self::evaluate_with_bindings(expr, facts, bindings)?;
411 Ok(Value::Boolean(!value.to_bool()))
412 }
413 }
414 }
415
416 fn expression_to_value(expr: &Expression, bindings: &Bindings) -> Result<Option<Value>> {
418 match expr {
419 Expression::Literal(val) => Ok(Some(val.clone())),
420 Expression::Variable(var) => Ok(bindings.get(var).cloned()),
421 _ => Ok(None), }
423 }
424
425 fn compare_values(left: &Value, right: &Value) -> Result<i32> {
427 match (left, right) {
428 (Value::Number(a), Value::Number(b)) => {
429 if a < b {
430 Ok(-1)
431 } else if a > b {
432 Ok(1)
433 } else {
434 Ok(0)
435 }
436 }
437 (Value::String(a), Value::String(b)) => Ok(a.cmp(b) as i32),
438 (Value::Boolean(a), Value::Boolean(b)) => Ok(a.cmp(b) as i32),
439 _ => Err(RuleEngineError::ExecutionError(format!(
440 "Cannot compare values: {:?} and {:?}",
441 left, right
442 ))),
443 }
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use crate::types::Operator;
451
452 #[test]
453 fn test_bindings_basic() {
454 let mut bindings = Bindings::new();
455
456 assert!(bindings.is_empty());
457 assert_eq!(bindings.len(), 0);
458
459 bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
460
461 assert!(!bindings.is_empty());
462 assert_eq!(bindings.len(), 1);
463 assert!(bindings.is_bound("X"));
464 assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
465 }
466
467 #[test]
468 fn test_bindings_conflict() {
469 let mut bindings = Bindings::new();
470
471 bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
472
473 assert!(bindings.bind("X".to_string(), Value::Number(42.0)).is_ok());
475
476 assert!(bindings
478 .bind("X".to_string(), Value::Number(100.0))
479 .is_err());
480 }
481
482 #[test]
483 fn test_bindings_merge() {
484 let mut bindings1 = Bindings::new();
485 let mut bindings2 = Bindings::new();
486
487 bindings1
488 .bind("X".to_string(), Value::Number(42.0))
489 .unwrap();
490 bindings2
491 .bind("Y".to_string(), Value::String("hello".to_string()))
492 .unwrap();
493
494 bindings1.merge(&bindings2).unwrap();
495
496 assert_eq!(bindings1.len(), 2);
497 assert_eq!(bindings1.get("X"), Some(&Value::Number(42.0)));
498 assert_eq!(
499 bindings1.get("Y"),
500 Some(&Value::String("hello".to_string()))
501 );
502 }
503
504 #[test]
505 fn test_bindings_merge_conflict() {
506 let mut bindings1 = Bindings::new();
507 let mut bindings2 = Bindings::new();
508
509 bindings1
510 .bind("X".to_string(), Value::Number(42.0))
511 .unwrap();
512 bindings2
513 .bind("X".to_string(), Value::Number(100.0))
514 .unwrap();
515
516 assert!(bindings1.merge(&bindings2).is_err());
518 }
519
520 #[test]
521 fn test_unify_variable_with_literal() {
522 let mut bindings = Bindings::new();
523
524 let var = Expression::Variable("X".to_string());
525 let lit = Expression::Literal(Value::Number(42.0));
526
527 let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
528
529 assert!(result);
530 assert_eq!(bindings.get("X"), Some(&Value::Number(42.0)));
531 }
532
533 #[test]
534 fn test_unify_bound_variable() {
535 let mut bindings = Bindings::new();
536 bindings.bind("X".to_string(), Value::Number(42.0)).unwrap();
537
538 let var = Expression::Variable("X".to_string());
539 let lit = Expression::Literal(Value::Number(42.0));
540
541 let result = Unifier::unify(&var, &lit, &mut bindings).unwrap();
543 assert!(result);
544
545 let lit2 = Expression::Literal(Value::Number(100.0));
547 let result2 = Unifier::unify(&var, &lit2, &mut bindings);
548 assert!(result2.is_err() || !result2.unwrap());
549 }
550
551 #[test]
552 fn test_unify_two_literals() {
553 let mut bindings = Bindings::new();
554
555 let lit1 = Expression::Literal(Value::Number(42.0));
556 let lit2 = Expression::Literal(Value::Number(42.0));
557 let lit3 = Expression::Literal(Value::Number(100.0));
558
559 assert!(Unifier::unify(&lit1, &lit2, &mut bindings).unwrap());
560 assert!(!Unifier::unify(&lit1, &lit3, &mut bindings).unwrap());
561 }
562
563 #[test]
564 fn test_match_expression_simple() {
565 let facts = Facts::new();
566 facts.set("User.IsVIP", Value::Boolean(true));
567
568 let mut bindings = Bindings::new();
569
570 let expr = Expression::Comparison {
571 left: Box::new(Expression::Field("User.IsVIP".to_string())),
572 operator: Operator::Equal,
573 right: Box::new(Expression::Literal(Value::Boolean(true))),
574 };
575
576 let result = Unifier::match_expression(&expr, &facts, &mut bindings).unwrap();
577 assert!(result);
578 }
579
580 #[test]
581 fn test_evaluate_with_bindings() {
582 let facts = Facts::new();
583 facts.set("Order.Amount", Value::Number(100.0));
584
585 let mut bindings = Bindings::new();
586 bindings.bind("X".to_string(), Value::Number(50.0)).unwrap();
587
588 let var_expr = Expression::Variable("X".to_string());
590 let result = Unifier::evaluate_with_bindings(&var_expr, &facts, &bindings).unwrap();
591 assert_eq!(result, Value::Number(50.0));
592
593 let field_expr = Expression::Field("Order.Amount".to_string());
595 let result = Unifier::evaluate_with_bindings(&field_expr, &facts, &bindings).unwrap();
596 assert_eq!(result, Value::Number(100.0));
597 }
598
599 #[test]
600 fn test_compare_values() {
601 assert_eq!(
602 Unifier::compare_values(&Value::Number(10.0), &Value::Number(20.0)).unwrap(),
603 -1
604 );
605 assert_eq!(
606 Unifier::compare_values(&Value::Number(20.0), &Value::Number(10.0)).unwrap(),
607 1
608 );
609 assert_eq!(
610 Unifier::compare_values(&Value::Number(10.0), &Value::Number(10.0)).unwrap(),
611 0
612 );
613 }
614}