rust_unify/
robinsons.rs

1use std::collections::HashMap;
2use std;
3
4use Term;
5
6fn occurs_check<'a, T>(
7    x: &'a Term<T>,
8    t: &'a Term<T>,
9    subs: &mut HashMap<&'a Term<T>, &'a Term<T>>,
10) -> bool
11where
12    T: std::cmp::Eq + std::hash::Hash,
13{
14    let mut stack: Vec<&Term<T>> = vec![t];
15
16    fn get_vars<'a, T>(t: &'a Term<T>) -> Vec<&'a Term<T>> {
17        match t {
18            &Term::Variable(_) => vec![t],
19            &Term::Composite(_, ref terms) => {
20                let mut v = vec![];
21                for term in terms {
22                    v.append(&mut get_vars(&term));
23                }
24                v
25            }
26            _ => vec![],
27        }
28    }
29
30    while !stack.is_empty() {
31        let t = stack.pop();
32        for y in get_vars(t.unwrap()) {
33            if x == y {
34                return false;
35            } else if subs.contains_key(y) {
36                stack.push(subs[y]);
37            }
38        }
39    }
40
41    true
42}
43
44pub fn unify<'a, T>(s: &'a Term<T>, t: &'a Term<T>) -> Option<HashMap<&'a Term<T>, &'a Term<T>>>
45where
46    T: std::cmp::Eq + std::hash::Hash,
47{
48    let mut stack: Vec<(&Term<T>, &Term<T>)> = vec![(s, t)];
49    let mut subs: HashMap<&Term<T>, &Term<T>> = HashMap::new();
50
51    while !stack.is_empty() {
52        let (mut s, mut t) = stack.pop().unwrap();
53
54        while subs.contains_key(&s) {
55            s = subs.get(&s).unwrap()
56        }
57
58        while subs.contains_key(&t) {
59            t = subs.get(&t).unwrap()
60        }
61
62        if s != t {
63            match (s, t) {
64                (&Term::Variable(_), &Term::Variable(_)) => {
65                    subs.insert(s, t);
66                }
67                (&Term::Variable(_), _) => if occurs_check(s, t, &mut subs) {
68                    subs.insert(s, t);
69                } else {
70                    return None;
71                },
72                (_, &Term::Variable(_)) => if occurs_check(t, s, &mut subs) {
73                    subs.insert(t, s);
74                } else {
75                    return None;
76                },
77                (
78                    &Term::Composite(ref s_name, ref s_terms),
79                    &Term::Composite(ref t_name, ref t_terms),
80                ) => if s_name == t_name && s_terms.len() == t_terms.len() {
81                    for (s, t) in s_terms.iter().zip(t_terms) {
82                        stack.push((s, t));
83                    }
84                } else {
85                    return None;
86                },
87                (_, _) => return None,
88            }
89        }
90    }
91
92    Some(subs)
93}