unifier_set/
lib.rs

1use rpds::HashTrieMap;
2use std::{
3    cell::RefCell,
4    collections::{BTreeMap, BTreeSet},
5    fmt,
6    hash::Hash,
7};
8
9pub use crate::traits::{ClassifyTerm, DirectChildren, TermKind};
10
11mod traits;
12
13#[cfg(test)]
14mod unit_tests;
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17enum Root<Term> {
18    NonVar(Term, usize),
19    Var(usize),
20}
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23enum Node<Var, Term> {
24    Root(Root<Term>),
25    Child(Var),
26}
27
28#[derive(Clone)]
29pub struct UnifierSet<Var, Term>
30where
31    Var: Eq + Hash,
32{
33    /// This persistent map is wrapped in a `RefCell` so that path compression can be
34    /// performed behind the scenes. Path compression only speeds up access, it doesn't
35    /// change externally observable behavior.
36    map: RefCell<HashTrieMap<Var, Node<Var, Term>>>,
37}
38
39impl<Var, Term> Default for UnifierSet<Var, Term>
40where
41    Var: Eq + Hash,
42{
43    fn default() -> Self {
44        Self {
45            map: RefCell::new(HashTrieMap::new()),
46        }
47    }
48}
49
50impl<Var, Term> fmt::Debug for UnifierSet<Var, Term>
51where
52    Var: Clone + Eq + Hash + Into<Term>,
53    Term: Clone + Eq + Hash + ClassifyTerm<Var> + DirectChildren<Var>,
54    Var: Ord + fmt::Debug,
55    Term: Ord + fmt::Debug,
56{
57    /// Outputs the `UnifierSet` as a forest (a set of sets).
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        let forest = self.reified_forest();
60
61        write!(f, "{{")?;
62        for (term, equiv_set) in forest {
63            f.debug_set().entry(&term).entries(equiv_set).finish()?;
64        }
65        write!(f, "}}")?;
66
67        Ok(())
68    }
69}
70
71impl<Var, Term> From<HashTrieMap<Var, Node<Var, Term>>> for UnifierSet<Var, Term>
72where
73    Var: Eq + Hash,
74{
75    fn from(map: HashTrieMap<Var, Node<Var, Term>>) -> Self {
76        Self {
77            map: RefCell::new(map),
78        }
79    }
80}
81
82impl<Var, Term> UnifierSet<Var, Term>
83where
84    Var: Clone + Eq + Hash + Into<Term>,
85    Term: Clone + Eq + Hash + ClassifyTerm<Var> + DirectChildren<Var>,
86{
87    pub fn new() -> Self {
88        HashTrieMap::new().into()
89    }
90
91    pub fn is_empty(&self) -> bool {
92        self.map.borrow().is_empty()
93    }
94
95    #[track_caller]
96    fn insert(&self, var: Var, node: Node<Var, Term>) -> Self {
97        self.map.borrow().insert(var, node).into()
98    }
99
100    #[track_caller]
101    fn hidden_update(&self, var: Var, node: Node<Var, Term>) {
102        self.map.borrow_mut().insert_mut(var, node);
103    }
104
105    pub fn unify(&self, x: &Term, y: &Term) -> Option<Self> {
106        use TermKind::*;
107        match (x.classify_term(), y.classify_term()) {
108            // If they're both compound terms, let's defer to a helper function.
109            (NonVar, NonVar) => self.unify_non_vars(x, y),
110
111            // This is an easy case. Just join the var's root to the term's root!
112            (Var(x), NonVar) => match self.find_root_and_root_child(x) {
113                // We now know `x` and `y` are both non-var terms. Let's reuse the `unify`
114                // function! (We could also call `unify_non_vars`, but whatever.)
115                (Root::NonVar(x, _size), _root_child) => self.unify(&x, y),
116
117                // We're sure `x` is an unsolved variable, and `y` is a non-var.
118                (Root::Var(x_root_size), root_child) => {
119                    // Every var who pointed to `x` now points to the non-var `y`.
120                    let new_root = Root::NonVar(y.clone(), x_root_size);
121                    let new_node = Node::Root(new_root);
122                    Some(self.insert(root_child, new_node))
123                }
124            },
125
126            // Hmm this case looks familiar... Swap 'em and try again!
127            (NonVar, Var(y)) => self.unify(&y.clone().into(), x),
128
129            // At first glance, `x` and `y` both look like variables.
130            (Var(x), Var(y)) => {
131                // But what do their roots have to say on the matter?
132                let (x, x_root_size) = self.find_root_term_and_size(x);
133                let (y, y_root_size) = self.find_root_term_and_size(y);
134
135                // Lets classify their roots.
136                match (x.classify_term(), y.classify_term()) {
137                    // We already know how to unify when either `x` or `y` is a nonvar:
138                    // Use the `unify` function! Don't worry, we HAVE made progress!
139                    (Var(_), NonVar) | (NonVar, Var(_)) | (NonVar, NonVar) => self.unify(&x, &y),
140
141                    // Ok, `x` and `y` are FOR REAL unknowns.
142                    (Var(x), Var(y)) => {
143                        // Weighting heuristic.
144                        let (small, large) = if x_root_size <= y_root_size {
145                            (x, y)
146                        } else {
147                            (y, x)
148                        };
149
150                        let new_size = x_root_size + y_root_size;
151
152                        // Make the little guy (and all its siblings) point to the bigger
153                        // guy. This is good because if the little guy has a TINY family
154                        // and the big guy has a HUGE family, we wanna inconvenience
155                        // the fewest number of people.
156                        //
157                        // (Making `small` point to `large` adds one more step in the
158                        // lookup chain. Every time we do `find_root(small)`, it will take
159                        // a little longer now.)
160                        let new_self = self
161                            .insert(small.clone(), Node::Child(large.clone()))
162                            .insert(large.clone(), Node::Root(Root::Var(new_size)));
163
164                        Some(new_self)
165                    }
166                }
167            }
168        }
169    }
170
171    fn unify_non_vars(&self, x: &Term, y: &Term) -> Option<Self> {
172        debug_assert!(x.is_non_var() && y.is_non_var());
173
174        // No way to unify two terms which are this different.
175        if !x.superficially_unifiable(y) {
176            return None;
177        }
178
179        let mut u = self.clone();
180
181        for (x_child, y_child) in x.direct_children().zip(y.direct_children()) {
182            u = u.unify(x_child, y_child)?;
183        }
184
185        Some(u)
186    }
187
188    fn find(&self, var: &Var) -> Term {
189        let (root_term, _) = self.find_root_term_and_size(var);
190        root_term
191    }
192
193    fn find_root_term_and_size(&self, var: &Var) -> (Term, usize) {
194        match self.find_root_and_root_child(var) {
195            (Root::NonVar(root_term, size), _) => (root_term, size),
196            (Root::Var(size), root_child) => (root_child.into(), size),
197        }
198    }
199
200    #[track_caller]
201    fn get_associated(&self, var: &Var) -> Option<Node<Var, Term>> {
202        self.map.borrow().get(var).cloned()
203    }
204
205    fn find_root_and_root_child(&self, var: &Var) -> (Root<Term>, Var) {
206        // NOTE: This is a load-bearing "extract to function" situation. Inlining the call
207        // to `self.get_associated` causes a `RefCell` `BorrowMutError`.
208        match self.get_associated(var) {
209            None => {
210                // Var has not been registered in the map yet, so put it in.
211                let root = Root::Var(1);
212                let node = Node::Root(root.clone());
213                self.hidden_update(var.clone(), node);
214                (root, var.clone())
215            }
216            Some(Node::Root(root)) => (root.clone(), var.clone()),
217            Some(Node::Child(parent_var)) => {
218                let (root, root_child) = self.find_root_and_root_child(&parent_var);
219
220                // Path compression heuristic:
221                // Point `var` directly to the root's first child var.
222                let var_parent_node = Node::Child(root_child.clone());
223                self.hidden_update(var.clone(), var_parent_node);
224
225                (root, root_child)
226            }
227        }
228    }
229
230    pub fn reify_term(&self, term: &Term) -> Term {
231        match term.classify_term() {
232            // If the term is a variable, return the reification of what it maps to.
233            TermKind::Var(var) => {
234                let root_term = self.find(var);
235
236                // An unbound variable will `find` itself. Don't reify it again.
237                if &root_term == term {
238                    term.clone()
239                } else {
240                    self.reify_term(&root_term)
241                }
242            }
243
244            // Otherwise, reify all direct children.
245            TermKind::NonVar => term.map_direct_children(|child| self.reify_term(child)),
246        }
247    }
248}
249
250impl<Var, Term> UnifierSet<Var, Term>
251where
252    Var: Clone + Eq + Hash + Into<Term>,
253    Term: Clone + Eq + Hash + ClassifyTerm<Var> + DirectChildren<Var>,
254    Var: Ord,
255    Term: Ord,
256{
257    pub fn reified_forest(&self) -> BTreeMap<Term, BTreeSet<Var>> {
258        let mut sets = BTreeMap::new();
259
260        // Clone here so borrow of `RefCell` can be dropped immediately.
261        let map = self.map.borrow().clone();
262        for var in map.keys() {
263            let root_term = self.find(var);
264            let reified_root_term = self.reify_term(&root_term);
265
266            let entry = sets
267                .entry(reified_root_term.clone())
268                .or_insert_with(BTreeSet::new);
269
270            // If `reified_root_term` is `var`, we don't need to include it.
271            if var.clone().into() != reified_root_term {
272                entry.insert(var.clone());
273            }
274        }
275
276        sets
277    }
278}