Skip to main content

slotted_egraphs/egraph/
find.rs

1use crate::*;
2
3impl<L: Language, N: Analysis<L>> EGraph<L, N> {
4    fn unionfind_get_impl(&self, i: Id, map: &mut [ProvenAppliedId]) -> ProvenAppliedId {
5        let entry = &mut map[i.0];
6
7        if entry.elem.id == i {
8            return entry.clone();
9        }
10
11        let entry = entry.clone();
12
13        // entry.0.m :: slots(entry.0.id) -> slots(i)
14        // entry_to_leader.0.m :: slots(leader) -> slots(entry.0.id)
15        let entry_to_leader = self.unionfind_get_impl(entry.elem.id, map);
16        let new = self.chain_pai(&entry, &entry_to_leader);
17
18        map[i.0] = new.clone();
19        new
20    }
21
22    pub(crate) fn unionfind_set(&self, i: Id, pai: ProvenAppliedId) {
23        #[cfg(feature = "explanations")]
24        if CHECKS {
25            pai.proof.check(self);
26            assert_eq!(i, pai.proof.l.id);
27            assert_eq!(pai.elem.id, pai.proof.r.id);
28        }
29
30        let mut lock = self.unionfind.borrow_mut();
31        if lock.len() == i.0 {
32            lock.push(pai);
33        } else {
34            lock[i.0] = pai;
35        }
36    }
37
38    pub(crate) fn proven_unionfind_get(&self, i: Id) -> ProvenAppliedId {
39        let mut map = self.unionfind.borrow_mut();
40        self.unionfind_get_impl(i, &mut *map)
41    }
42
43    pub(crate) fn unionfind_get(&self, i: Id) -> AppliedId {
44        self.proven_unionfind_get(i).elem
45    }
46
47    /// Returns whether an id is still alive, or whether it was merged into another class.
48    pub fn is_alive(&self, i: Id) -> bool {
49        let map = self.unionfind.borrow();
50        map[i.0].elem.id == i
51    }
52
53    pub(crate) fn unionfind_iter(&self) -> impl Iterator<Item = (Id, AppliedId)> {
54        let mut map = self.unionfind.borrow_mut();
55        let mut out = Vec::new();
56
57        for x in (0..map.len()).map(Id) {
58            let y = self.unionfind_get_impl(x, &mut *map).elem;
59            out.push((x, y));
60        }
61
62        out.into_iter()
63    }
64
65    pub(crate) fn unionfind_len(&self) -> usize {
66        self.unionfind.borrow().len()
67    }
68
69    pub(crate) fn find_enode(&self, enode: &L) -> L {
70        self.proven_find_enode(enode).elem
71    }
72
73    pub(crate) fn proven_find_enode(&self, enode: &L) -> ProvenNode<L> {
74        let pn = self.refl_pn(enode);
75        self.proven_proven_find_enode(&pn)
76    }
77
78    pub(crate) fn proven_proven_find_enode(&self, enode: &ProvenNode<L>) -> ProvenNode<L> {
79        self.chain_pn_map(enode, |_, pai| self.proven_proven_find_applied_id(&pai))
80    }
81
82    // normalize i.id
83    //
84    // Example 1:
85    // 'find(c1(s10, s11)) = c2(s11, s10)', where 'c1(s0, s1) -> c2(s1, s0)' in unionfind.
86    //
87    // Example 2:
88    // 'find(c1(s3, s7, s8)) = c2(s8, s7)', where 'c1(s0, s1, s2) -> c2(s2, s1)' in unionfind,
89    pub fn find_applied_id(&self, i: &AppliedId) -> AppliedId {
90        #[cfg(feature = "explanations")]
91        let i = &self.synify_app_id(i.clone());
92
93        self.proven_find_applied_id(i).elem
94    }
95
96    pub(crate) fn proven_find_applied_id(&self, i: &AppliedId) -> ProvenAppliedId {
97        let pai = self.refl_pai(i);
98        self.proven_proven_find_applied_id(&pai)
99    }
100
101    pub(crate) fn proven_proven_find_applied_id(&self, pai: &ProvenAppliedId) -> ProvenAppliedId {
102        if CHECKS {
103            self.check_pai(&pai);
104        }
105
106        let mut pai2 = self.proven_unionfind_get(pai.elem.id);
107
108        pai2.elem.m = pai2.elem.m.compose_partial(&pai.elem.m);
109
110        #[cfg(feature = "explanations")]
111        {
112            pai2.proof = prove_transitivity(pai.proof.clone(), pai2.proof, &self.proof_registry);
113        }
114
115        if CHECKS {
116            self.check_pai(&pai);
117        }
118
119        pai2
120    }
121
122    pub(crate) fn find_id(&self, i: Id) -> Id {
123        self.unionfind_get(i).id
124    }
125
126    pub fn ids(&self) -> Vec<Id> {
127        let map = self.unionfind.borrow();
128        (0..map.len())
129            .map(Id)
130            .filter(|x| map[x.0].elem.id == *x)
131            .collect()
132    }
133}