1use rpds::HashTrieMap;
2use std::{
3 cell::RefCell,
4 collections::{BTreeMap, BTreeSet},
5 fmt,
6 hash::Hash,
7};
8
9pub use crate::traits::{ClassifyTerm, DirectChildren, TermKind};
10
11mod traits;
12
13#[cfg(test)]
14mod unit_tests;
15
16#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17enum Root<Term> {
18 NonVar(Term, usize),
19 Var(usize),
20}
21
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23enum Node<Var, Term> {
24 Root(Root<Term>),
25 Child(Var),
26}
27
28#[derive(Clone)]
29pub struct UnifierSet<Var, Term>
30where
31 Var: Eq + Hash,
32{
33 map: RefCell<HashTrieMap<Var, Node<Var, Term>>>,
37}
38
39impl<Var, Term> Default for UnifierSet<Var, Term>
40where
41 Var: Eq + Hash,
42{
43 fn default() -> Self {
44 Self {
45 map: RefCell::new(HashTrieMap::new()),
46 }
47 }
48}
49
50impl<Var, Term> fmt::Debug for UnifierSet<Var, Term>
51where
52 Var: Clone + Eq + Hash + Into<Term>,
53 Term: Clone + Eq + Hash + ClassifyTerm<Var> + DirectChildren<Var>,
54 Var: Ord + fmt::Debug,
55 Term: Ord + fmt::Debug,
56{
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 let forest = self.reified_forest();
60
61 write!(f, "{{")?;
62 for (term, equiv_set) in forest {
63 f.debug_set().entry(&term).entries(equiv_set).finish()?;
64 }
65 write!(f, "}}")?;
66
67 Ok(())
68 }
69}
70
71impl<Var, Term> From<HashTrieMap<Var, Node<Var, Term>>> for UnifierSet<Var, Term>
72where
73 Var: Eq + Hash,
74{
75 fn from(map: HashTrieMap<Var, Node<Var, Term>>) -> Self {
76 Self {
77 map: RefCell::new(map),
78 }
79 }
80}
81
82impl<Var, Term> UnifierSet<Var, Term>
83where
84 Var: Clone + Eq + Hash + Into<Term>,
85 Term: Clone + Eq + Hash + ClassifyTerm<Var> + DirectChildren<Var>,
86{
87 pub fn new() -> Self {
88 HashTrieMap::new().into()
89 }
90
91 pub fn is_empty(&self) -> bool {
92 self.map.borrow().is_empty()
93 }
94
95 #[track_caller]
96 fn insert(&self, var: Var, node: Node<Var, Term>) -> Self {
97 self.map.borrow().insert(var, node).into()
98 }
99
100 #[track_caller]
101 fn hidden_update(&self, var: Var, node: Node<Var, Term>) {
102 self.map.borrow_mut().insert_mut(var, node);
103 }
104
105 pub fn unify(&self, x: &Term, y: &Term) -> Option<Self> {
106 use TermKind::*;
107 match (x.classify_term(), y.classify_term()) {
108 (NonVar, NonVar) => self.unify_non_vars(x, y),
110
111 (Var(x), NonVar) => match self.find_root_and_root_child(x) {
113 (Root::NonVar(x, _size), _root_child) => self.unify(&x, y),
116
117 (Root::Var(x_root_size), root_child) => {
119 let new_root = Root::NonVar(y.clone(), x_root_size);
121 let new_node = Node::Root(new_root);
122 Some(self.insert(root_child, new_node))
123 }
124 },
125
126 (NonVar, Var(y)) => self.unify(&y.clone().into(), x),
128
129 (Var(x), Var(y)) => {
131 let (x, x_root_size) = self.find_root_term_and_size(x);
133 let (y, y_root_size) = self.find_root_term_and_size(y);
134
135 match (x.classify_term(), y.classify_term()) {
137 (Var(_), NonVar) | (NonVar, Var(_)) | (NonVar, NonVar) => self.unify(&x, &y),
140
141 (Var(x), Var(y)) => {
143 let (small, large) = if x_root_size <= y_root_size {
145 (x, y)
146 } else {
147 (y, x)
148 };
149
150 let new_size = x_root_size + y_root_size;
151
152 let new_self = self
161 .insert(small.clone(), Node::Child(large.clone()))
162 .insert(large.clone(), Node::Root(Root::Var(new_size)));
163
164 Some(new_self)
165 }
166 }
167 }
168 }
169 }
170
171 fn unify_non_vars(&self, x: &Term, y: &Term) -> Option<Self> {
172 debug_assert!(x.is_non_var() && y.is_non_var());
173
174 if !x.superficially_unifiable(y) {
176 return None;
177 }
178
179 let mut u = self.clone();
180
181 for (x_child, y_child) in x.direct_children().zip(y.direct_children()) {
182 u = u.unify(x_child, y_child)?;
183 }
184
185 Some(u)
186 }
187
188 fn find(&self, var: &Var) -> Term {
189 let (root_term, _) = self.find_root_term_and_size(var);
190 root_term
191 }
192
193 fn find_root_term_and_size(&self, var: &Var) -> (Term, usize) {
194 match self.find_root_and_root_child(var) {
195 (Root::NonVar(root_term, size), _) => (root_term, size),
196 (Root::Var(size), root_child) => (root_child.into(), size),
197 }
198 }
199
200 #[track_caller]
201 fn get_associated(&self, var: &Var) -> Option<Node<Var, Term>> {
202 self.map.borrow().get(var).cloned()
203 }
204
205 fn find_root_and_root_child(&self, var: &Var) -> (Root<Term>, Var) {
206 match self.get_associated(var) {
209 None => {
210 let root = Root::Var(1);
212 let node = Node::Root(root.clone());
213 self.hidden_update(var.clone(), node);
214 (root, var.clone())
215 }
216 Some(Node::Root(root)) => (root.clone(), var.clone()),
217 Some(Node::Child(parent_var)) => {
218 let (root, root_child) = self.find_root_and_root_child(&parent_var);
219
220 let var_parent_node = Node::Child(root_child.clone());
223 self.hidden_update(var.clone(), var_parent_node);
224
225 (root, root_child)
226 }
227 }
228 }
229
230 pub fn reify_term(&self, term: &Term) -> Term {
231 match term.classify_term() {
232 TermKind::Var(var) => {
234 let root_term = self.find(var);
235
236 if &root_term == term {
238 term.clone()
239 } else {
240 self.reify_term(&root_term)
241 }
242 }
243
244 TermKind::NonVar => term.map_direct_children(|child| self.reify_term(child)),
246 }
247 }
248}
249
250impl<Var, Term> UnifierSet<Var, Term>
251where
252 Var: Clone + Eq + Hash + Into<Term>,
253 Term: Clone + Eq + Hash + ClassifyTerm<Var> + DirectChildren<Var>,
254 Var: Ord,
255 Term: Ord,
256{
257 pub fn reified_forest(&self) -> BTreeMap<Term, BTreeSet<Var>> {
258 let mut sets = BTreeMap::new();
259
260 let map = self.map.borrow().clone();
262 for var in map.keys() {
263 let root_term = self.find(var);
264 let reified_root_term = self.reify_term(&root_term);
265
266 let entry = sets
267 .entry(reified_root_term.clone())
268 .or_insert_with(BTreeSet::new);
269
270 if var.clone().into() != reified_root_term {
272 entry.insert(var.clone());
273 }
274 }
275
276 sets
277 }
278}