1use crate::error::IrError;
44use crate::term::Term;
45use serde::{Deserialize, Serialize};
46use std::collections::HashMap;
47
48#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
53pub struct Substitution {
54 bindings: HashMap<String, Term>,
56}
57
58impl Substitution {
59 pub fn empty() -> Self {
61 Substitution {
62 bindings: HashMap::new(),
63 }
64 }
65
66 pub fn singleton(var: String, term: Term) -> Self {
68 let mut bindings = HashMap::new();
69 bindings.insert(var, term);
70 Substitution { bindings }
71 }
72
73 pub fn from_map(bindings: HashMap<String, Term>) -> Self {
75 Substitution { bindings }
76 }
77
78 pub fn is_empty(&self) -> bool {
80 self.bindings.is_empty()
81 }
82
83 pub fn len(&self) -> usize {
85 self.bindings.len()
86 }
87
88 pub fn get(&self, var: &str) -> Option<&Term> {
90 self.bindings.get(var)
91 }
92
93 pub fn bind(&mut self, var: String, term: Term) {
95 self.bindings.insert(var, term);
96 }
97
98 pub fn apply(&self, term: &Term) -> Term {
103 match term {
104 Term::Var(name) => {
105 self.bindings
108 .get(name)
109 .cloned()
110 .unwrap_or_else(|| term.clone())
111 }
112 Term::Const(_) => term.clone(),
113 Term::Typed {
114 value,
115 type_annotation,
116 } => Term::Typed {
117 value: Box::new(self.apply(value)),
118 type_annotation: type_annotation.clone(),
119 },
120 }
121 }
122
123 pub fn compose(&self, other: &Substitution) -> Substitution {
127 let mut result = HashMap::new();
128
129 for (var, term) in &other.bindings {
131 result.insert(var.clone(), self.apply(term));
132 }
133
134 for (var, term) in &self.bindings {
136 if !result.contains_key(var) {
137 result.insert(var.clone(), term.clone());
138 }
139 }
140
141 Substitution::from_map(result)
142 }
143
144 pub fn domain(&self) -> Vec<String> {
146 self.bindings.keys().cloned().collect()
147 }
148
149 pub fn range(&self) -> Vec<Term> {
151 self.bindings.values().cloned().collect()
152 }
153
154 pub fn extend(&mut self, var: String, term: Term) -> Result<(), IrError> {
158 if let Some(existing) = self.bindings.get(&var) {
159 if existing != &term {
160 return Err(IrError::UnificationFailure {
161 type1: format!("{:?}", existing),
162 type2: format!("{:?}", term),
163 });
164 }
165 }
166 self.bindings.insert(var, term);
167 Ok(())
168 }
169
170 pub fn try_extend(&mut self, other: &Substitution) -> Result<(), IrError> {
175 for (var, term) in &other.bindings {
176 if let Some(existing) = self.bindings.get(var) {
177 if existing != term {
178 return Err(IrError::UnificationFailure {
179 type1: format!("{:?}", existing),
180 type2: format!("{:?}", term),
181 });
182 }
183 } else {
184 self.bindings.insert(var.clone(), term.clone());
185 }
186 }
187 Ok(())
188 }
189}
190
191fn occurs_check(var: &str, term: &Term) -> bool {
196 match term {
197 Term::Var(name) => name == var,
198 Term::Const(_) => false,
199 Term::Typed { value, .. } => occurs_check(var, value),
200 }
201}
202
203pub fn unify_terms(term1: &Term, term2: &Term) -> Result<Substitution, IrError> {
223 unify_impl(term1, term2, &mut Substitution::empty())
224}
225
226fn unify_impl(
228 term1: &Term,
229 term2: &Term,
230 subst: &mut Substitution,
231) -> Result<Substitution, IrError> {
232 let t1 = subst.apply(term1);
234 let t2 = subst.apply(term2);
235
236 match (&t1, &t2) {
237 (Term::Var(n1), Term::Var(n2)) if n1 == n2 => Ok(subst.clone()),
239
240 (Term::Var(name), _) => {
242 if occurs_check(name, &t2) {
243 return Err(IrError::UnificationFailure {
244 type1: format!("{:?}", t1),
245 type2: format!("{:?}", t2),
246 });
247 }
248 subst.bind(name.clone(), t2.clone());
249 Ok(subst.clone())
250 }
251
252 (_, Term::Var(name)) => {
254 if occurs_check(name, &t1) {
255 return Err(IrError::UnificationFailure {
256 type1: format!("{:?}", t1),
257 type2: format!("{:?}", t2),
258 });
259 }
260 subst.bind(name.clone(), t1.clone());
261 Ok(subst.clone())
262 }
263
264 (Term::Const(v1), Term::Const(v2)) => {
266 if v1 == v2 {
267 Ok(subst.clone())
268 } else {
269 Err(IrError::UnificationFailure {
270 type1: format!("{:?}", t1),
271 type2: format!("{:?}", t2),
272 })
273 }
274 }
275
276 (
278 Term::Typed {
279 value: inner1,
280 type_annotation: ty1,
281 },
282 Term::Typed {
283 value: inner2,
284 type_annotation: ty2,
285 },
286 ) => {
287 if ty1 != ty2 {
289 return Err(IrError::UnificationFailure {
290 type1: format!("{:?}", t1),
291 type2: format!("{:?}", t2),
292 });
293 }
294 unify_impl(inner1, inner2, subst)
295 }
296
297 (Term::Typed { value, .. }, other) | (other, Term::Typed { value, .. }) => {
299 unify_impl(value, other, subst)
300 }
301 }
302}
303
304pub fn unify_term_list(pairs: &[(Term, Term)]) -> Result<Substitution, IrError> {
322 let mut subst = Substitution::empty();
323 for (t1, t2) in pairs {
324 subst = unify_impl(t1, t2, &mut subst)?;
325 }
326 Ok(subst)
327}
328
329pub fn are_unifiable(term1: &Term, term2: &Term) -> bool {
331 unify_terms(term1, term2).is_ok()
332}
333
334pub fn rename_vars(term: &Term, suffix: &str) -> Term {
338 match term {
339 Term::Var(name) => Term::Var(format!("{}_{}", name, suffix)),
340 Term::Const(_) => term.clone(),
341 Term::Typed {
342 value,
343 type_annotation,
344 } => Term::Typed {
345 value: Box::new(rename_vars(value, suffix)),
346 type_annotation: type_annotation.clone(),
347 },
348 }
349}
350
351pub fn anti_unify_terms(term1: &Term, term2: &Term) -> (Term, Substitution, Substitution) {
381 let mut var_counter = 0;
382 let mut subst1 = Substitution::empty();
383 let mut subst2 = Substitution::empty();
384
385 let gen = anti_unify_impl(term1, term2, &mut var_counter, &mut subst1, &mut subst2);
386 (gen, subst1, subst2)
387}
388
389fn anti_unify_impl(
391 term1: &Term,
392 term2: &Term,
393 var_counter: &mut usize,
394 subst1: &mut Substitution,
395 subst2: &mut Substitution,
396) -> Term {
397 match (term1, term2) {
398 (Term::Const(c1), Term::Const(c2)) if c1 == c2 => term1.clone(),
400
401 (Term::Var(v1), Term::Var(v2)) if v1 == v2 => term1.clone(),
403
404 (
406 Term::Typed {
407 value: inner1,
408 type_annotation: ty1,
409 },
410 Term::Typed {
411 value: inner2,
412 type_annotation: ty2,
413 },
414 ) if ty1 == ty2 => {
415 let inner_gen = anti_unify_impl(inner1, inner2, var_counter, subst1, subst2);
416 Term::Typed {
417 value: Box::new(inner_gen),
418 type_annotation: ty1.clone(),
419 }
420 }
421
422 _ => {
424 *var_counter += 1;
425 let fresh_var = Term::Var(format!("_G{}", var_counter));
426
427 subst1.bind(format!("_G{}", var_counter), term1.clone());
429 subst2.bind(format!("_G{}", var_counter), term2.clone());
430
431 fresh_var
432 }
433 }
434}
435
436pub fn lgg_terms(terms: &[Term]) -> (Term, Vec<Substitution>) {
456 if terms.is_empty() {
457 return (Term::Var("_Empty".to_string()), vec![]);
458 }
459
460 if terms.len() == 1 {
461 return (terms[0].clone(), vec![Substitution::empty()]);
462 }
463
464 let (mut gen, subst1, subst2) = anti_unify_terms(&terms[0], &terms[1]);
466 let mut substs = vec![subst1, subst2];
467
468 for term in &terms[2..] {
470 let (new_gen, gen_subst, term_subst) = anti_unify_terms(&gen, term);
471 gen = new_gen;
472
473 for s in &mut substs {
475 *s = gen_subst.compose(s);
476 }
477
478 substs.push(term_subst);
480 }
481
482 (gen, substs)
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_empty_substitution() {
491 let subst = Substitution::empty();
492 assert!(subst.is_empty());
493 assert_eq!(subst.len(), 0);
494
495 let term = Term::var("x");
496 assert_eq!(subst.apply(&term), term);
497 }
498
499 #[test]
500 fn test_singleton_substitution() {
501 let subst = Substitution::singleton("x".to_string(), Term::constant("a"));
502 assert_eq!(subst.len(), 1);
503
504 let x = Term::var("x");
505 let a = Term::constant("a");
506 assert_eq!(subst.apply(&x), a);
507 }
508
509 #[test]
510 fn test_substitution_application() {
511 let mut subst = Substitution::empty();
512 subst.bind("x".to_string(), Term::constant("a"));
513 subst.bind("y".to_string(), Term::constant("b"));
514
515 let x = Term::var("x");
516 let y = Term::var("y");
517 let z = Term::var("z");
518
519 assert_eq!(subst.apply(&x), Term::constant("a"));
520 assert_eq!(subst.apply(&y), Term::constant("b"));
521 assert_eq!(subst.apply(&z), z); }
523
524 #[test]
525 fn test_unify_var_constant() {
526 let x = Term::var("x");
527 let a = Term::constant("a");
528
529 let mgu = unify_terms(&x, &a).unwrap();
530 assert_eq!(mgu.apply(&x), a);
531 }
532
533 #[test]
534 fn test_unify_same_variable() {
535 let x = Term::var("x");
536 let mgu = unify_terms(&x, &x).unwrap();
537 assert!(mgu.is_empty());
538 }
539
540 #[test]
541 fn test_unify_different_constants() {
542 let a = Term::constant("a");
543 let b = Term::constant("b");
544
545 let result = unify_terms(&a, &b);
546 assert!(result.is_err());
547 }
548
549 #[test]
550 fn test_unify_same_constant() {
551 let a = Term::constant("a");
552 let mgu = unify_terms(&a, &a).unwrap();
553 assert!(mgu.is_empty());
554 }
555
556 #[test]
557 fn test_occur_check() {
558 let x = Term::var("x");
559 assert!(occurs_check("x", &x));
560 assert!(!occurs_check("y", &x));
561
562 let a = Term::constant("a");
563 assert!(!occurs_check("x", &a));
564 }
565
566 #[test]
567 fn test_substitution_composition() {
568 let sigma = Substitution::singleton("x".to_string(), Term::constant("a"));
570 let theta = Substitution::singleton("y".to_string(), Term::var("x"));
572
573 let composed = sigma.compose(&theta);
575 assert_eq!(composed.len(), 2);
576 assert_eq!(composed.apply(&Term::var("x")), Term::constant("a"));
577 assert_eq!(composed.apply(&Term::var("y")), Term::constant("a"));
578 }
579
580 #[test]
581 fn test_unify_term_list() {
582 let pairs = vec![
583 (Term::var("x"), Term::constant("a")),
584 (Term::var("y"), Term::constant("b")),
585 (Term::var("z"), Term::var("x")),
586 ];
587
588 let mgu = unify_term_list(&pairs).unwrap();
589 assert_eq!(mgu.len(), 3);
590 assert_eq!(mgu.apply(&Term::var("x")), Term::constant("a"));
591 assert_eq!(mgu.apply(&Term::var("y")), Term::constant("b"));
592 assert_eq!(mgu.apply(&Term::var("z")), Term::constant("a"));
593 }
594
595 #[test]
596 fn test_are_unifiable() {
597 let x = Term::var("x");
598 let a = Term::constant("a");
599 let b = Term::constant("b");
600
601 assert!(are_unifiable(&x, &a));
602 assert!(are_unifiable(&a, &a));
603 assert!(!are_unifiable(&a, &b));
604 }
605
606 #[test]
607 fn test_rename_vars() {
608 let x = Term::var("x");
609 let renamed = rename_vars(&x, "1");
610 assert_eq!(renamed, Term::var("x_1"));
611
612 let a = Term::constant("a");
613 let renamed_const = rename_vars(&a, "1");
614 assert_eq!(renamed_const, a); }
616
617 #[test]
618 fn test_extend_substitution() {
619 let mut subst = Substitution::empty();
620 assert!(subst.extend("x".to_string(), Term::constant("a")).is_ok());
621 assert!(subst.extend("y".to_string(), Term::constant("b")).is_ok());
622
623 assert!(subst.extend("x".to_string(), Term::constant("a")).is_ok());
625
626 assert!(subst.extend("x".to_string(), Term::constant("b")).is_err());
628 }
629
630 #[test]
631 fn test_typed_term_unification() {
632 use crate::term::TypeAnnotation;
633
634 let x = Term::Typed {
635 value: Box::new(Term::var("x")),
636 type_annotation: TypeAnnotation::new("Int"),
637 };
638 let a = Term::Typed {
639 value: Box::new(Term::constant("5")),
640 type_annotation: TypeAnnotation::new("Int"),
641 };
642
643 let mgu = unify_terms(&x, &a).unwrap();
644 assert_eq!(mgu.len(), 1);
645 }
646
647 #[test]
650 fn test_anti_unify_same_constant() {
651 let a1 = Term::constant("a");
653 let a2 = Term::constant("a");
654
655 let (gen, subst1, subst2) = anti_unify_terms(&a1, &a2);
656
657 assert_eq!(gen, a1);
659 assert!(subst1.is_empty());
660 assert!(subst2.is_empty());
661 }
662
663 #[test]
664 fn test_anti_unify_different_constants() {
665 let a = Term::constant("a");
667 let b = Term::constant("b");
668
669 let (gen, subst1, subst2) = anti_unify_terms(&a, &b);
670
671 match gen {
673 Term::Var(name) => assert!(name.starts_with("_G")),
674 _ => panic!("Expected fresh variable"),
675 }
676
677 assert_eq!(subst1.len(), 1);
679 assert_eq!(subst2.len(), 1);
680 }
681
682 #[test]
683 fn test_anti_unify_variable_constant() {
684 let x = Term::var("x");
686 let a = Term::constant("a");
687
688 let (gen, _subst1, _subst2) = anti_unify_terms(&x, &a);
689
690 if let Term::Var(name) = gen {
692 assert!(name == "x" || name.starts_with("_G"));
694 }
695
696 }
699
700 #[test]
701 fn test_anti_unify_same_variable() {
702 let x1 = Term::var("x");
704 let x2 = Term::var("x");
705
706 let (gen, subst1, subst2) = anti_unify_terms(&x1, &x2);
707
708 assert_eq!(gen, x1);
710 assert!(subst1.is_empty());
711 assert!(subst2.is_empty());
712 }
713
714 #[test]
715 fn test_anti_unify_typed_terms() {
716 use crate::term::TypeAnnotation;
717
718 let t1 = Term::Typed {
720 value: Box::new(Term::constant("5")),
721 type_annotation: TypeAnnotation::new("Int"),
722 };
723 let t2 = Term::Typed {
724 value: Box::new(Term::constant("10")),
725 type_annotation: TypeAnnotation::new("Int"),
726 };
727
728 let (gen, _subst1, _subst2) = anti_unify_terms(&t1, &t2);
729
730 match gen {
732 Term::Typed {
733 value,
734 type_annotation,
735 } => {
736 assert_eq!(type_annotation.type_name, "Int");
737 match *value {
738 Term::Var(name) => assert!(name.starts_with("_G")),
739 _ => panic!("Expected fresh variable inside typed term"),
740 }
741 }
742 _ => panic!("Expected typed term"),
743 }
744 }
745
746 #[test]
747 fn test_lgg_single_term() {
748 let terms = vec![Term::constant("a")];
750 let (gen, substs) = lgg_terms(&terms);
751
752 assert_eq!(gen, Term::constant("a"));
753 assert_eq!(substs.len(), 1);
754 assert!(substs[0].is_empty());
755 }
756
757 #[test]
758 fn test_lgg_two_same_terms() {
759 let terms = vec![Term::constant("a"), Term::constant("a")];
761 let (gen, substs) = lgg_terms(&terms);
762
763 assert_eq!(gen, Term::constant("a"));
764 assert_eq!(substs.len(), 2);
765 }
766
767 #[test]
768 fn test_lgg_two_different_terms() {
769 let terms = vec![Term::constant("a"), Term::constant("b")];
771 let (gen, substs) = lgg_terms(&terms);
772
773 match gen {
775 Term::Var(name) => assert!(name.starts_with("_G")),
776 _ => panic!("Expected fresh variable"),
777 }
778
779 assert_eq!(substs.len(), 2);
780 }
781
782 #[test]
783 fn test_lgg_three_terms() {
784 let terms = vec![
786 Term::constant("a"),
787 Term::constant("b"),
788 Term::constant("c"),
789 ];
790 let (gen, substs) = lgg_terms(&terms);
791
792 match gen {
794 Term::Var(name) => assert!(name.starts_with("_G")),
795 _ => panic!("Expected fresh variable"),
796 }
797
798 assert_eq!(substs.len(), 3);
799 }
800
801 #[test]
802 fn test_lgg_empty() {
803 let terms: Vec<Term> = vec![];
805 let (gen, substs) = lgg_terms(&terms);
806
807 match gen {
808 Term::Var(name) => assert_eq!(name, "_Empty"),
809 _ => panic!("Expected _Empty variable"),
810 }
811
812 assert_eq!(substs.len(), 0);
813 }
814
815 #[test]
816 fn test_anti_unify_preserves_structure() {
817 let a1 = Term::constant("a");
819 let a2 = Term::constant("a");
820
821 let (gen, _, _) = anti_unify_terms(&a1, &a2);
822
823 assert_eq!(gen, Term::constant("a"));
825 }
826}