1use crate::*;
2
3mod find;
4pub use find::*;
5
6mod add;
7pub use add::*;
8
9mod union;
10pub use union::*;
11
12mod rebuild;
13pub use rebuild::*;
14
15mod check;
16pub use check::*;
17
18mod analysis;
19pub use analysis::*;
20use vec_collections::AbstractVecSet;
21
22use std::cell::RefCell;
23
24pub struct EGraph<L: Language, N: Analysis<L> = ()> {
34    unionfind: RefCell<Vec<ProvenAppliedId>>,
40
41    pub(crate) classes: HashMap<Id, EClass<L, N>>,
44
45    hashcons: HashMap<L, Id>,
47
48    syn_hashcons: HashMap<L, AppliedId>,
51
52    pending: HashMap<L, PendingType>,
54
55    pub(crate) proof_registry: ProofRegistry,
57
58    pub(crate) subst_method: Option<Box<dyn SubstMethod<L, N>>>,
59
60    pub analysis: N,
61
62    modify_queue: Vec<Id>,
65}
66
67#[derive(Clone, Copy, PartialEq, Eq, Debug)]
68pub(crate) enum PendingType {
69    OnlyAnalysis, Full,         }
72
73#[derive(Clone)]
77pub(crate) struct EClass<L: Language, N: Analysis<L>> {
78    nodes: HashMap<L, ProvenSourceNode>,
81
82    slots: SmallHashSet<Slot>,
85
86    usages: HashSet<L>,
88
89    pub(crate) group: Group<ProvenPerm>,
91
92    syn_enode: L,
94
95    analysis_data: N::Data,
96}
97
98impl<L: Language, N: Analysis<L> + Default> Default for EGraph<L, N> {
99    fn default() -> Self {
100        EGraph::new(N::default())
101    }
102}
103
104impl<L: Language, N: Analysis<L>> EGraph<L, N> {
105    pub fn new(analysis: N) -> Self {
107        Self::with_subst_method::<SynExprSubst>(analysis)
108    }
109
110    pub fn with_subst_method<S: SubstMethod<L, N>>(analysis: N) -> Self {
112        EGraph {
113            unionfind: Default::default(),
114            classes: Default::default(),
115            hashcons: Default::default(),
116            syn_hashcons: Default::default(),
117            pending: Default::default(),
118            proof_registry: ProofRegistry::default(),
119            subst_method: Some(S::new_boxed()),
120            analysis,
121            modify_queue: Vec::new(),
122        }
123    }
124
125    pub fn slots(&self, id: Id) -> SmallHashSet<Slot> {
126        self.classes[&id].slots.clone()
127    }
128
129    pub(crate) fn syn_slots(&self, id: Id) -> SmallHashSet<Slot> {
130        self.classes[&id].syn_enode.slots()
131    }
132
133    pub fn analysis_data(&self, i: Id) -> &N::Data {
134        &self.classes[&self.find_id(i)].analysis_data
135    }
136
137    pub fn analysis_data_mut(&mut self, i: Id) -> &mut N::Data {
138        &mut self
139            .classes
140            .get_mut(&self.find_id(i))
141            .unwrap()
142            .analysis_data
143    }
144
145    pub fn enodes(&self, i: Id) -> HashSet<L> {
146        assert!(self.is_alive(i), "Can't access e-nodes of dead class");
148
149        self.classes[&i]
150            .nodes
151            .iter()
152            .map(|(x, psn)| x.apply_slotmap(&psn.elem))
153            .collect()
154    }
155
156    pub fn enodes_applied(&self, i: &AppliedId) -> Vec<L> {
158        let class = &self.classes[&i.id];
159        let class_slots = &class.slots;
160
161        let mut result = Vec::with_capacity(class.nodes.len());
162
163        for (x, psn) in &class.nodes {
164            let mut x = x.apply_slotmap(&psn.elem);
165
166            let mut map: SmallHashMap<Slot, Slot> = SmallHashMap::default();
167            for slot in x.all_slot_occurrences_mut() {
168                if !class_slots.contains(&slot) {
169                    if let Some(v) = map.get(slot) {
170                        *slot = *v;
171                    } else {
172                        let v = Slot::fresh();
173                        map.insert(slot.clone(), v.clone());
174                        *slot = v;
175                    }
176                }
177            }
178
179            let mut m = SlotMap::new();
180            for slot in x.slots() {
181                if !i.m.contains_key(slot) {
182                    m.insert(slot, Slot::fresh());
183                }
184            }
185
186            for (x, y) in i.m.iter() {
187                m.insert(x, y);
188            }
189
190            x = x.apply_slotmap(&m);
191            result.push(x);
192        }
193
194        result
195    }
196
197    pub fn total_number_of_nodes(&self) -> usize {
199        self.hashcons.len()
200    }
201
202    pub fn eq(&self, a: &AppliedId, b: &AppliedId) -> bool {
204        let a = self.find_applied_id(a);
205        let b = self.find_applied_id(b);
206
207        if CHECKS {
208            self.check_sem_applied_id(&a);
209            self.check_sem_applied_id(&b);
210        }
211
212        if a.id != b.id {
213            return false;
214        }
215        if a.m.values() != b.m.values() {
216            return false;
217        }
218        let id = a.id;
219
220        let perm = a.m.compose(&b.m.inverse());
221        if CHECKS {
222            assert!(perm.is_perm());
223            assert_eq!(&perm.values(), &self.classes[&id].slots);
224        }
225
226        self.classes[&id].group.contains(&perm)
227    }
228
229    pub(crate) fn refresh_internals(&self, l: &L) -> L {
231        let i = self.lookup(l).unwrap();
232        l.refresh_internals(i.slots())
233    }
234
235    pub(crate) fn class_nf(&self, l: &L) -> L {
237        let l = self.refresh_internals(l);
238        let i = self.lookup(&l).unwrap();
239
240        let l = l.apply_slotmap_fresh(&i.m);
242
243        if CHECKS {
244            let identity = self.mk_sem_identity_applied_id(i.id);
245            assert!(self.eq(&i, &identity));
246        }
247
248        l
249    }
250
251    pub fn dump(&self) {
253        println!("");
254        let mut v: Vec<(&Id, &EClass<L, N>)> = self.classes.iter().collect();
255        v.sort_by_key(|(x, _)| *x);
256
257        for (i, c) in v {
258            if c.nodes.len() == 0 {
259                continue;
260            }
261
262            let mut slot_order: Vec<Slot> = c.slots.iter().cloned().collect();
263            slot_order.sort();
264            let slot_str = slot_order
265                .iter()
266                .map(|x| x.to_string())
267                .collect::<Vec<_>>()
268                .join(", ");
269            println!("\n{:?}({}):", i, &slot_str);
270
271            println!(">> {:?}", &c.syn_enode);
272
273            for (sh, psn) in &c.nodes {
274                let n = sh.apply_slotmap(&psn.elem);
275
276                #[cfg(feature = "explanations")]
277                println!(" - {n:?}    [originally {:?}]", psn.src_id);
278
279                #[cfg(not(feature = "explanations"))]
280                println!(" - {n:?}");
281            }
282            for pp in &c.group.generators() {
283                println!(" -- {:?}", pp.elem);
284            }
285        }
286        println!("");
287    }
288
289    pub(crate) fn usages(&self, i: Id) -> Vec<L> {
291        let mut out = Vec::new();
292        for x in &self.classes[&i].usages {
293            let j = self.lookup(x).unwrap().id;
294            let bij = &self.classes[&j].nodes[&x].elem;
295            let x = x.apply_slotmap(bij);
296            out.push(x);
297        }
298        out
299    }
300
301    pub(crate) fn shape(&self, e: &L) -> (L, Bijection) {
302        let (pnode, bij) = self.proven_shape(e);
303        (pnode.elem, bij)
304    }
305
306    pub(crate) fn proven_shape(&self, e: &L) -> (ProvenNode<L>, Bijection) {
307        self.proven_proven_shape(&self.refl_pn(e))
308    }
309
310    pub(crate) fn proven_proven_shape(&self, e: &ProvenNode<L>) -> (ProvenNode<L>, Bijection) {
311        self.proven_proven_pre_shape(&e).weak_shape()
312    }
313
314    pub(crate) fn proven_proven_pre_shape(&self, e: &ProvenNode<L>) -> ProvenNode<L> {
315        let e = self.proven_proven_find_enode(e);
316        self.proven_proven_get_group_compatible_variants(&e)
317            .into_iter()
318            .min_by_key(|pn| pn.weak_shape().0.elem.all_slot_occurrences())
319            .unwrap()
320    }
321
322    pub(crate) fn proven_proven_get_group_compatible_variants(
340        &self,
341        enode: &ProvenNode<L>,
342    ) -> Vec<ProvenNode<L>> {
343        if CHECKS {
345            for x in enode.elem.applied_id_occurrences() {
346                assert!(self.is_alive(x.id));
347            }
348        }
349
350        let mut out = Vec::new();
351
352        if enode
354            .elem
355            .ids()
356            .iter()
357            .all(|i| self.classes[i].group.is_trivial())
358        {
359            out.push(enode.clone());
360            return out;
361        }
362
363        let groups: Vec<Vec<ProvenPerm>> = enode
364            .elem
365            .applied_id_occurrences()
366            .iter()
367            .map(|x| self.classes[&x.id].group.all_perms().into_iter().collect())
368            .collect();
369
370        for l in cartesian(&groups) {
371            let pn = enode.clone();
372            let pn = self.chain_pn_map(&pn, |i, pai| self.chain_pai_pp(&pai, l[i]));
373            out.push(pn);
376        }
377
378        out
379    }
380
381    pub(crate) fn proven_get_group_compatible_variants(&self, enode: &L) -> Vec<ProvenNode<L>> {
384        self.proven_proven_get_group_compatible_variants(&self.refl_pn(enode))
385    }
386
387    pub(crate) fn get_group_compatible_variants(&self, enode: &L) -> Vec<L> {
388        self.proven_get_group_compatible_variants(enode)
389            .into_iter()
390            .map(|pnode| pnode.elem)
391            .collect()
392    }
393
394    pub(crate) fn get_group_compatible_weak_variants(&self, enode: &L) -> Vec<L> {
395        let set = self.get_group_compatible_variants(enode);
396        let mut shapes = SmallHashSet::empty();
397        let mut out = Vec::new();
398
399        for x in set {
400            let (sh, _) = x.weak_shape();
401            if shapes.contains(&sh) {
402                continue;
403            }
404            shapes.insert(sh);
405            out.push(x);
406        }
407
408        out
409    }
410
411    pub(crate) fn synify_app_id(&self, app: AppliedId) -> AppliedId {
412        let mut app = app;
413        for s in self.syn_slots(app.id) {
414            if !app.m.contains_key(s) {
415                app.m.insert(s, Slot::fresh());
416            }
417        }
418        app
419    }
420
421    pub(crate) fn synify_enode(&self, enode: L) -> L {
422        enode.map_applied_ids(|app| self.synify_app_id(app))
423    }
424
425    pub(crate) fn semify_app_id(&self, app: AppliedId) -> AppliedId {
426        let slots = self.slots(app.id);
427
428        let mut app = app;
429        for k in app.m.keys() {
430            if !slots.contains(&k) {
431                app.m.remove(k);
432            }
433        }
434        app
435    }
436
437    #[cfg(feature = "explanations")]
438    pub(crate) fn semify_enode(&self, enode: L) -> L {
439        enode.map_applied_ids(|app| self.semify_app_id(app))
440    }
441
442    pub fn get_syn_expr(&self, i: &AppliedId) -> RecExpr<L> {
446        let enode = self.get_syn_node(i);
447        let cs = enode
448            .applied_id_occurrences()
449            .iter()
450            .map(|x| self.get_syn_expr(x))
451            .collect();
452        RecExpr {
453            node: nullify_app_ids(&enode),
454            children: cs,
455        }
456    }
457
458    pub fn get_syn_node(&self, i: &AppliedId) -> L {
460        let syn = &self.classes[&i.id].syn_enode;
461        syn.apply_slotmap(&i.m)
462    }
463}
464
465impl PendingType {
466    pub(crate) fn merge(self, other: PendingType) -> PendingType {
467        match (self, other) {
468            (PendingType::Full, _) => PendingType::Full,
469            (_, PendingType::Full) => PendingType::Full,
470            (PendingType::OnlyAnalysis, PendingType::OnlyAnalysis) => PendingType::OnlyAnalysis,
471        }
472    }
473}
474
475fn cartesian<'a, T>(input: &'a [Vec<T>]) -> impl Iterator<Item = Vec<&'a T>> + use<'a, T> {
477    let n = input.len();
478    let mut indices = vec![0; n];
479    let mut done = false;
480    let f = move || {
481        if done {
482            return None;
483        }
484        let out: Vec<&T> = (0..n).map(|i| &input[i][indices[i]]).collect();
485        for i in 0..n {
486            indices[i] += 1;
487            if indices[i] >= input[i].len() {
488                indices[i] = 0;
489            } else {
490                return Some(out);
491            }
492        }
493        done = true;
494        Some(out)
495    };
496    std::iter::from_fn(f)
497}
498
499#[test]
500fn cartesian1() {
501    let v = [vec![1, 2], vec![3], vec![4, 5]];
502    let vals = cartesian(&v);
503    assert_eq!(vals.count(), 4);
504}