smt_scope/parsers/z3/
egraph.rs

1use std::{collections::hash_map::Entry, num::NonZeroUsize};
2
3#[cfg(feature = "mem_dbg")]
4use mem_dbg::{MemDbg, MemSize};
5use petgraph::{
6    graph::{DiGraph, EdgeIndex, NodeIndex},
7    visit::EdgeRef,
8};
9
10use crate::{
11    items::{
12        ENode, ENodeBlame, ENodeIdx, EqGivenIdx, EqGivenUse, EqTransIdx, Equality, EqualityExpl,
13        InstIdx, ProofIdx, TermIdx, TransitiveExpl, TransitiveExplSegment,
14        TransitiveExplSegmentKind,
15    },
16    BoxSlice, Error, FxHashMap, FxHashSet, NonMaxU32, Result, TiVec,
17};
18
19use super::{bugs::TransEqAllowed, stack::Stack, terms::Terms};
20
21#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
22#[derive(Debug, Default)]
23pub struct EGraph {
24    term_to_enode: FxHashMap<TermIdx, TermToEnode>,
25    pub(crate) enodes: TiVec<ENodeIdx, ENode>,
26    /// Which terms have we seen in a proof and might therefore get an enode
27    /// from them?
28    proofs: FxHashMap<TermIdx, ProofIdx>,
29    pub equalities: Equalities,
30}
31
32impl Equalities {
33    pub fn from_to(&self, eq: EqTransIdx) -> (ENodeIdx, ENodeIdx) {
34        let from = self.walk_trans(eq, |eq, fwd| {
35            let from = if fwd { eq.from() } else { eq.to() };
36            Err(from)
37        });
38        let equality = &self.transitive[eq];
39        let from = from.err().unwrap_or_else(|| equality.error_from().unwrap());
40        let to = equality.to;
41        (from, to)
42    }
43
44    /// Walk the given equalities of a transitive equality, returning early with
45    /// the error if the closure returns one.
46    pub fn walk_trans<E>(
47        &self,
48        eq: EqTransIdx,
49        f: impl FnMut(&EqualityExpl, bool) -> core::result::Result<(), E>,
50    ) -> core::result::Result<(), E> {
51        let mut walker = TransEqStopWalker {
52            equalities: self,
53            simple: f,
54        };
55        walker.walk_trans(eq, true)
56    }
57}
58
59impl EGraph {
60    pub fn get_blame(
61        &self,
62        tidx: TermIdx,
63        inst: Option<(InstIdx, FxHashSet<EqTransIdx>)>,
64        terms: &Terms,
65        stack: &Stack,
66    ) -> ENodeBlame {
67        if terms.is_bool_const(tidx) {
68            return ENodeBlame::BoolConst;
69        }
70        if let Some(inst) = inst {
71            return ENodeBlame::Inst(inst);
72        }
73        let proof = self.proofs.get(&tidx).copied();
74        let proof = proof.filter(|&p| stack.is_alive(terms[p].frame));
75        if let Some(proof) = proof {
76            return ENodeBlame::Proof(proof);
77        }
78        ENodeBlame::Unknown
79    }
80
81    pub fn new_enode(
82        &mut self,
83        blame: ENodeBlame,
84        term: TermIdx,
85        z3_generation: NonMaxU32,
86        stack: &Stack,
87    ) -> Result<ENodeIdx> {
88        // TODO: why does this happen sometimes?
89        // if created_by.is_none() && z3_generation.is_some() {
90        //     debug_assert_eq!(
91        //         z3_generation.unwrap(),
92        //         0,
93        //         "enode with no creator has non-zero generation"
94        //     );
95        // }
96        self.enodes.raw.try_reserve(1)?;
97        let enode = self.enodes.push_and_get_key(ENode {
98            frame: stack.active_frame(),
99            blame,
100            owner: term,
101            z3_generation,
102            equalities: Vec::new(),
103            transitive: FxHashMap::default(),
104        });
105        self.insert_tte(term, enode, stack)?;
106        Ok(enode)
107    }
108
109    pub fn get_enode(&mut self, term: TermIdx, stack: &Stack) -> Result<ENodeIdx> {
110        self.get_tte(term, stack).ok_or(Error::UnknownEnode(term))
111    }
112
113    pub fn get_enode_imm(&self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
114        self.get_tte_imm(term, stack)
115    }
116
117    pub(super) fn new_proof(
118        &mut self,
119        term: TermIdx,
120        proof: ProofIdx,
121        terms: &Terms,
122        stack: &Stack,
123    ) -> Result<()> {
124        if !terms[proof].kind.is_asserted() {
125            return Ok(());
126        }
127        terms.app_walk(term, |tidx, app| {
128            self.proofs.try_reserve(1)?;
129            match self.proofs.entry(tidx) {
130                Entry::Occupied(mut o) => {
131                    if stack.is_alive(terms[*o.get()].frame) {
132                        return Ok(&[]);
133                    } else {
134                        o.insert(proof);
135                    }
136                }
137                Entry::Vacant(v) => {
138                    v.insert(proof);
139                }
140            }
141            Ok(&app.child_ids)
142        })
143    }
144
145    pub fn get_given(&self, from: ENodeIdx, to: ENodeIdx) -> Option<EqGivenIdx> {
146        self.enodes[from]
147            .equalities
148            .iter()
149            .find(|eq| eq.to == to)
150            .map(|eq| eq.expl)
151    }
152
153    pub fn new_given_equality(
154        &mut self,
155        from: ENodeIdx,
156        expl: EqualityExpl,
157        stack: &Stack,
158    ) -> Result<()> {
159        debug_assert_eq!(from, expl.from());
160        let to = expl.to();
161        self.equalities.given.raw.try_reserve(1)?;
162        let expl = self.equalities.given.push_and_get_key(expl);
163        let enode = &mut self.enodes[from];
164        let eq = Equality {
165            _frame: stack.active_frame(),
166            to,
167            expl,
168        };
169        enode.equalities.try_reserve(1)?;
170        enode.equalities.push(eq);
171        // TODO: is ok to simply ignore the old equality, or should we also blame it later on?
172        // let (new, others) = enode.equalities.split_last().unwrap();
173        // if let Some(old) = others.last() {
174        //     let inactive = old.frame.map(|f| !stack.stack_frames[f].active).unwrap_or_default();
175        //     if inactive {
176        //         return;
177        //     }
178        //     let is_root = matches!(old.expl, EqualityExpl::Root { .. });
179        //     let root_unchanged = is_root || (self.path_to_root(old.to).last().unwrap() == self.path_to_root(new.to).last().unwrap());
180        //     if root_unchanged {
181        //         return;
182        //     }
183
184        //     let is_unknown = matches!(enode.get_equality(stack).unwrap().expl, EqualityExpl::Unknown { .. });
185        //     let equivalent = old.expl == enode.get_equality(stack).unwrap().expl;
186        //     // let test = old.frame.is_some_and(|f| usize::from(f) == 854);
187        //     if !equivalent && !is_unknown {
188        //         panic!();
189        //     }
190        // }
191        Ok(())
192    }
193
194    pub fn new_trans_equality(
195        &mut self,
196        from: ENodeIdx,
197        to: ENodeIdx,
198        stack: &Stack,
199        mismatch: TransEqAllowed,
200    ) -> Result<core::result::Result<EqTransIdx, ENodeIdx>> {
201        if from == to {
202            Ok(Err(from))
203        } else {
204            self.construct_trans_equality(from, to, stack, mismatch)
205                .map(Ok)
206        }
207    }
208
209    fn to_root<'a>(
210        &'a self,
211        from: ENodeIdx,
212        stack: &'a Stack,
213    ) -> impl Iterator<Item = ENodeIdx> + 'a {
214        core::iter::successors(Some(from), move |&from| {
215            self.enodes[from].get_equality(stack).map(|eq| eq.to)
216        })
217    }
218
219    fn to_root_visited<'a>(
220        &'a self,
221        from: ENodeIdx,
222        stack: &'a Stack,
223        cycle: &'a mut Option<ENodeIdx>,
224        visited: &'a mut FxHashSet<ENodeIdx>,
225    ) -> impl Iterator<Item = ENodeIdx> + 'a {
226        self.to_root(from, stack).take_while(move |&from| {
227            if visited.insert(from) {
228                true
229            } else {
230                // On rare occasions there is a cycle of more than one enode in
231                // the equality graph (EXPLAIN).
232                *cycle = Some(from);
233                false
234            }
235        })
236    }
237
238    pub fn check_eq(&self, from: ENodeIdx, to: ENodeIdx, stack: &Stack) -> bool {
239        let mut visited_from = FxHashSet::default();
240        self.to_root_visited(from, stack, &mut None, &mut visited_from)
241            .for_each(drop);
242        for to in self.to_root_visited(to, stack, &mut None, &mut FxHashSet::default()) {
243            if visited_from.contains(&to) {
244                return true;
245            }
246        }
247        false
248    }
249
250    pub fn path_to_root(
251        &self,
252        from: ENodeIdx,
253        root: Option<ENodeIdx>,
254        stack: &Stack,
255    ) -> Result<(Option<ENodeIdx>, Vec<ENodeIdx>)> {
256        let mut cycle = None;
257        let mut visited = FxHashSet::default();
258        let mut path = Vec::new();
259        for e in self.to_root_visited(from, stack, &mut cycle, &mut visited) {
260            path.try_reserve(1)?;
261            path.push(e);
262            if root.is_some_and(|root| root == e) {
263                return Ok((Some(e), path));
264            }
265        }
266        Ok((cycle, path))
267    }
268
269    fn get_simple_path(
270        &self,
271        from: ENodeIdx,
272        to: ENodeIdx,
273        stack: &Stack,
274        can_mismatch: bool,
275    ) -> Result<Option<SimplePath>> {
276        let (_, f_path) = self.path_to_root(from, None, stack)?;
277        let f_root = f_path.len() - 1;
278        let (_, t_path) = self.path_to_root(to, Some(*f_path.last().unwrap()), stack)?;
279        let t_root = t_path.len() - 1;
280
281        if f_path[f_root] != t_path[t_root] {
282            // Root may not always be the same from v4.12.3 onwards if `to` is an `ite` expression. See:
283            // https://github.com/Z3Prover/z3/commit/faf14012ba18d21c1fcddbdc321ac127f019fa03#diff-0a9ec50ded668e51578edc67ecfe32380336b9cbf12c5d297e2d3759a7a39847R2417-R2419
284            if can_mismatch {
285                // Return no path if the roots are different.
286                return Ok(None);
287            } else {
288                return Err(Error::EnodeRootMismatch(from, to));
289            }
290        }
291        let mut shared = 1;
292        while shared < f_path.len()
293            && shared < t_path.len()
294            && f_path[f_root - shared] == t_path[t_root - shared]
295        {
296            shared += 1;
297        }
298        let path = SimplePath {
299            from_to_root: f_path,
300            to_to_root: t_path,
301            shared,
302        };
303        Ok(Some(path))
304    }
305
306    fn construct_trans_equality(
307        &mut self,
308        from: ENodeIdx,
309        to: ENodeIdx,
310        stack: &Stack,
311        mismatch: TransEqAllowed,
312    ) -> Result<EqTransIdx> {
313        debug_assert_ne!(from, to);
314        let can_mismatch = mismatch.can_mismatch_initial;
315        let Some(simple_path) = self.get_simple_path(from, to, stack, can_mismatch)? else {
316            // There was a root mismatch (and `can_mismatch` was true), so we
317            // can't construct a simple path.
318            let error = TransitiveExpl::error(from, to);
319            self.enodes[from].transitive.try_reserve(1)?;
320            let trans = match self.enodes[from].transitive.entry(to) {
321                Entry::Occupied(mut o) => {
322                    let trans = *o.get();
323                    if self.equalities.transitive[trans].given_len.is_some() {
324                        // These two nodes are no longer equal (this is an old
325                        // transitive equality that is no longer valid).
326                        self.equalities.transitive.raw.try_reserve(1)?;
327                        let trans = self.equalities.transitive.push_and_get_key(error);
328                        o.insert(trans);
329                        trans
330                    } else {
331                        trans
332                    }
333                }
334                Entry::Vacant(v) => {
335                    self.equalities.transitive.raw.try_reserve(1)?;
336                    let trans = self.equalities.transitive.push_and_get_key(error);
337                    *v.insert(trans)
338                }
339            };
340            return Ok(trans);
341        };
342        // Should not fail since `from != to`
343        let edges_len = simple_path.edges_len();
344        let mut graph = simple_path.initialise_graph(self, stack);
345
346        // Add transitive edges to graph, start by trying to find either forward
347        // or backward full solutions.
348        if let Some((forward, solution)) = graph.add_trans_from(0, self) {
349            debug_assert!(forward);
350            return Ok(solution);
351        }
352        let trans = if let Some((forward, solution)) = graph.add_trans_from(edges_len.get(), self) {
353            debug_assert!(!forward);
354            let inner = &self.equalities.transitive[solution];
355            if inner.path.len() == 1 {
356                use TransitiveExplSegmentKind::*;
357                match inner.path[0] {
358                    TransitiveExplSegment {
359                        forward: false,
360                        kind: Transitive(idx),
361                    } => return Ok(idx),
362                    TransitiveExplSegment {
363                        forward: true,
364                        kind: Transitive(_),
365                    } => unreachable!(),
366                    TransitiveExplSegment {
367                        kind: Given(..), ..
368                    } => (),
369                    TransitiveExplSegment {
370                        kind: Error(..), ..
371                    } => unreachable!(),
372                }
373            }
374            let solution = TransitiveExplSegment {
375                forward,
376                kind: TransitiveExplSegmentKind::Transitive(solution),
377            };
378            TransitiveExpl::new([solution].into_iter(), NonZeroUsize::new(1).unwrap(), to)?
379        } else {
380            for idx in 1..edges_len.get() {
381                graph.add_trans_from(idx, self);
382            }
383            // Find the shortest path
384            for idx in (0..edges_len.get()).rev() {
385                let idx = NodeIndex::new(idx);
386                // Use `.rev()` here to prefer transitive edges over leaf edges,
387                // though hopefully the `min_by_key` should be unique.
388                let min = graph
389                    .graph
390                    .edges(idx)
391                    .min_by_key(|edge| graph.graph[edge.target()].0)
392                    .unwrap();
393                let (cost, id) = (graph.graph[min.target()].0, min.id());
394                let idx = &mut graph.graph[idx];
395                idx.0 = cost + 1;
396                idx.1 = Some(id);
397            }
398
399            let start = NodeIndex::new(0);
400            let mut edge = graph.graph[start].1;
401            let path_length = graph.graph[start].0;
402            TransitiveExpl::new(
403                (0..path_length)
404                    .map(|_| {
405                        let kind = &graph.graph[edge.unwrap()];
406                        edge = graph.graph[graph.graph.edge_endpoints(edge.unwrap()).unwrap().1].1;
407                        kind
408                    })
409                    .copied(),
410                edges_len,
411                to,
412            )?
413        };
414        let trans = self.insert_trans_equality(trans, stack, mismatch.can_mismatch_congr)?;
415        debug_assert_eq!(self.equalities.walk_to(from, trans), to);
416        self.enodes[from].transitive.try_reserve(1)?;
417        let old = self.enodes[from].transitive.insert(to, trans);
418        debug_assert_eq!(old, None);
419        Ok(trans)
420    }
421
422    fn insert_trans_equality(
423        &mut self,
424        mut trans: TransitiveExpl,
425        stack: &Stack,
426        can_mismatch_congr: bool,
427    ) -> Result<EqTransIdx> {
428        let mismatch = TransEqAllowed {
429            can_mismatch_initial: can_mismatch_congr,
430            can_mismatch_congr: false,
431        };
432        // Find the current congruence uses
433        for seg in trans.path.iter_mut() {
434            if let TransitiveExplSegmentKind::Given((cg, idx)) = &mut seg.kind {
435                debug_assert_eq!(*idx, None);
436                let EqualityExpl::Congruence { arg_eqs, .. } = &self.equalities.given[*cg] else {
437                    continue;
438                };
439                let mut args = Vec::new();
440                args.try_reserve_exact(arg_eqs.len())?;
441                args.extend(arg_eqs.iter().copied());
442
443                let mut use_ = Vec::new();
444                use_.try_reserve_exact(arg_eqs.len())?;
445                for (from, to) in args {
446                    let Ok(trans) = self.new_trans_equality(from, to, stack, mismatch)? else {
447                        continue;
448                    };
449                    use_.push(trans);
450                }
451                let EqualityExpl::Congruence { uses, .. } = &mut self.equalities.given[*cg] else {
452                    unreachable!()
453                };
454                let real_idx = uses.iter().position(|u| **u == use_).unwrap_or_else(|| {
455                    uses.push(BoxSlice::from(use_));
456                    uses.len() - 1
457                });
458                *idx = Some(NonMaxU32::new(real_idx as u32).unwrap());
459            }
460        }
461
462        self.equalities.transitive.raw.try_reserve(1)?;
463        let trans = self.equalities.transitive.push_and_get_key(trans);
464        Ok(trans)
465    }
466}
467
468impl std::ops::Index<ENodeIdx> for EGraph {
469    type Output = ENode;
470    fn index(&self, idx: ENodeIdx) -> &Self::Output {
471        &self.enodes[idx]
472    }
473}
474
475impl ENode {
476    pub fn get_equality(&self, _stack: &Stack) -> Option<&Equality> {
477        // TODO: why are we allowed to use equalities from popped stack frames?
478        // self.equalities.iter().rev().find(|eq| stack.is_alive(eq._frame))
479        self.equalities.last()
480    }
481}
482
483#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
484#[derive(Debug, Default)]
485pub struct Equalities {
486    pub(crate) given: TiVec<EqGivenIdx, EqualityExpl>,
487    pub(crate) transitive: TiVec<EqTransIdx, TransitiveExpl>,
488}
489
490pub trait EqualityWalker<'a> {
491    type Error;
492
493    fn equalities(&self) -> &'a Equalities;
494
495    /// Override this method to get all the given equalities from a transitive
496    /// equality. If you wish to recurse into congruences then use the
497    /// following structure:
498    ///
499    /// ```ignore
500    /// let (eq, use_) = eq_use;
501    /// match &self.equalities().given[eq] {
502    ///     EqualityExpl::Congruence { uses, .. } => self.walk_congruence(uses, use_.unwrap(), forward),
503    ///     _ => Ok(()),
504    /// }
505    /// ```
506    fn walk_given(
507        &mut self,
508        eq_use: EqGivenUse,
509        forward: bool,
510    ) -> core::result::Result<(), Self::Error>;
511
512    /// Does nothing if `CONGR` is false. Otherwise recursively walks into the
513    /// congruence uses.
514    fn walk_congruence(
515        &mut self,
516        uses: &[BoxSlice<EqTransIdx>],
517        use_: NonMaxU32,
518        forward: bool,
519    ) -> core::result::Result<(), Self::Error> {
520        let use_ = &uses[use_.get() as usize];
521        for &eq in use_.iter() {
522            self.walk_trans(eq, forward)?;
523        }
524        Ok(())
525    }
526
527    /// Return `Err` if the walk should abort, `Ok` to stop recursing, and
528    /// otherwise call `self.super_walk_trans` to recurse.
529    fn walk_trans(
530        &mut self,
531        eq: EqTransIdx,
532        forward: bool,
533    ) -> core::result::Result<(), Self::Error> {
534        self.super_walk_trans(eq, forward)
535    }
536
537    /// This method defines the walking of transitive equalities, override
538    /// `walk_trans` to intercept before walking a transitive equality.
539    fn super_walk_trans(
540        &mut self,
541        eq: EqTransIdx,
542        forward: bool,
543    ) -> core::result::Result<(), Self::Error> {
544        let all = self.equalities().transitive[eq].all(forward);
545        for next in all {
546            use TransitiveExplSegmentKind::*;
547            match next.kind {
548                Given(eq_use) => self.walk_given(eq_use, next.forward)?,
549                Transitive(eq) => self.walk_trans(eq, next.forward)?,
550                Error(..) => (),
551            }
552        }
553        Ok(())
554    }
555}
556
557struct TransEqChecker<'a, I: Iterator<Item = (EqGivenUse, bool)>> {
558    equalities: &'a Equalities,
559    simple: I,
560}
561
562impl<'a, I: Iterator<Item = (EqGivenUse, bool)>> EqualityWalker<'a> for TransEqChecker<'a, I> {
563    type Error = bool;
564    fn equalities(&self) -> &'a Equalities {
565        self.equalities
566    }
567    fn walk_given(
568        &mut self,
569        (eq, _use_): EqGivenUse,
570        eq_fwd: bool,
571    ) -> core::result::Result<(), Self::Error> {
572        // Return `Err(false)` if we're out of simple, this should never
573        // happen.
574        let ((simple, _simple_use), fwd) = self.simple.next().ok_or(false)?;
575        // Return `Ok` if equal, else `Err(true)`.
576        // We do not compare `simple_use == use_` here because `simple_use`
577        // hasn't been set yet (I think?).
578        (simple == eq && fwd == eq_fwd).then_some(()).ok_or(true)
579    }
580}
581
582impl Equalities {
583    pub fn is_equal(
584        &self,
585        eq: EqTransIdx,
586        simple: &mut impl Iterator<Item = (EqGivenUse, bool)>,
587    ) -> Option<bool> {
588        let mut checker = TransEqChecker {
589            equalities: self,
590            simple,
591        };
592        let res = checker.walk_trans(eq, true);
593        // Map `Ok` -> `Ok(true)`, `Err(true)` -> `Ok(false)`, `Err(false)` -> `Err`.
594        res.map_or_else(|e| e.then_some(false), |_| Some(true))
595    }
596}
597
598pub type TransEqSimpleWalker<'a, F> = TransEqStopWalker<'a, super::Never, F>;
599
600pub struct TransEqStopWalker<'a, E, F: FnMut(&'a EqualityExpl, bool) -> core::result::Result<(), E>>
601{
602    equalities: &'a Equalities,
603    simple: F,
604}
605
606impl<'a, E, F: FnMut(&'a EqualityExpl, bool) -> core::result::Result<(), E>> EqualityWalker<'a>
607    for TransEqStopWalker<'a, E, F>
608{
609    type Error = E;
610    fn equalities(&self) -> &'a Equalities {
611        self.equalities
612    }
613    fn walk_given(
614        &mut self,
615        (eq, _): EqGivenUse,
616        forward: bool,
617    ) -> core::result::Result<(), Self::Error> {
618        (self.simple)(&self.equalities.given[eq], forward)
619    }
620}
621
622impl Equalities {
623    pub fn walk_to(&self, mut from: ENodeIdx, eq: EqTransIdx) -> ENodeIdx {
624        let mut walker = TransEqSimpleWalker {
625            equalities: self,
626            simple: |eq, fwd| {
627                from = eq.walk(from, fwd).unwrap();
628                Ok(())
629            },
630        };
631        walker.walk_trans(eq, true).unwrap();
632        from
633    }
634    pub fn path(&self, eq: EqTransIdx) -> Vec<ENodeIdx> {
635        let equality = &self.transitive[eq];
636        if let Some(from) = equality.error_from() {
637            return vec![from, equality.to];
638        }
639
640        let mut path = Vec::new();
641        let mut walker = TransEqSimpleWalker {
642            equalities: self,
643            simple: |eq, fwd| {
644                let from = if fwd { eq.from() } else { eq.to() };
645                path.push(from);
646                Ok(())
647            },
648        };
649        walker.walk_trans(eq, true).unwrap();
650        path.push(self.transitive[eq].to);
651        path
652    }
653}
654
655#[derive(Debug)]
656pub struct SimplePath {
657    from_to_root: Vec<ENodeIdx>,
658    to_to_root: Vec<ENodeIdx>,
659    shared: usize,
660}
661impl SimplePath {
662    pub fn edges_len(&self) -> NonZeroUsize {
663        NonZeroUsize::new(self.from_to_root.len() + self.to_to_root.len() - 2 * self.shared)
664            .unwrap()
665    }
666    pub fn all_simple_edges<'a>(
667        &'a self,
668        egraph: &'a EGraph,
669        stack: &'a Stack,
670    ) -> impl DoubleEndedIterator<Item = (ENodeIdx, EqGivenIdx, bool)> + 'a {
671        let from_to_join = self.from_to_root[..self.from_to_root.len() - self.shared]
672            .iter()
673            .copied();
674        let from_to_join = from_to_join.map(|e| {
675            let eq = &egraph.enodes[e].get_equality(stack).unwrap();
676            (eq.to, eq.expl, true)
677        });
678        let join_to_to = self.to_to_root[..self.to_to_root.len() - self.shared]
679            .iter()
680            .rev()
681            .copied();
682        let join_to_to =
683            join_to_to.map(|e| (e, egraph.enodes[e].get_equality(stack).unwrap().expl, false));
684        from_to_join.chain(join_to_to)
685    }
686
687    pub fn nodes_len(&self) -> usize {
688        self.edges_len().get() + 1
689    }
690    pub fn all_nodes(&self) -> impl DoubleEndedIterator<Item = ENodeIdx> + '_ {
691        let from_to_join = self.from_to_root[..self.from_to_root.len() + 1 - self.shared].iter();
692        let join_to_to = self.to_to_root[..self.to_to_root.len() - self.shared]
693            .iter()
694            .rev();
695        from_to_join.chain(join_to_to).copied()
696    }
697    pub fn node_at(&self, idx: usize) -> ENodeIdx {
698        let from_len = self.from_to_root.len() - self.shared;
699        if idx <= from_len {
700            self.from_to_root[idx]
701        } else {
702            let to_len = self.to_to_root.len() - self.shared;
703            self.to_to_root[(to_len + from_len) - idx]
704        }
705    }
706    pub fn all_transitive<'a>(
707        &'a self,
708        egraph: &'a EGraph,
709    ) -> impl DoubleEndedIterator<Item = impl Iterator<Item = EqTransIdx> + 'a> + 'a {
710        self.all_nodes()
711            .map(move |idx| egraph.enodes[idx].transitive.values().copied())
712    }
713
714    pub fn initialise_graph<'a>(self, egraph: &'a EGraph, stack: &'a Stack) -> Graph {
715        let edges_len = self.edges_len();
716        let mut g = Graph {
717            graph: DiGraph::with_capacity(self.nodes_len(), edges_len.get()),
718            path_enodes: self.all_nodes().collect(),
719            edges_len,
720            simple_path: self,
721        };
722        let mut last = g.graph.add_node((edges_len.get() as u32, None));
723        for (idx, (_, eq, forward)) in g.simple_path.all_simple_edges(egraph, stack).enumerate() {
724            let cost = (edges_len.get() - idx - 1) as u32;
725            let next = g.graph.add_node((cost, None));
726            let kind = TransitiveExplSegmentKind::Given((eq, None));
727            g.graph
728                .add_edge(last, next, TransitiveExplSegment { forward, kind });
729            last = next;
730        }
731        g
732    }
733}
734
735pub struct Graph {
736    simple_path: SimplePath,
737    path_enodes: FxHashSet<ENodeIdx>,
738    graph: DiGraph<(u32, Option<EdgeIndex>), TransitiveExplSegment>,
739    edges_len: NonZeroUsize,
740}
741impl Graph {
742    pub fn add_trans_from(
743        &mut self,
744        idx: usize,
745        egraph: &mut EGraph,
746    ) -> Option<(bool, EqTransIdx)> {
747        let nfrom = NodeIndex::new(idx);
748        let efrom = self.simple_path.node_at(idx);
749
750        for to in 0..self.simple_path.nodes_len() {
751            let to = self.simple_path.node_at(to);
752            if let Entry::Occupied(o) = egraph.enodes[efrom].transitive.entry(to) {
753                let trans = *o.get();
754                let trans_node = &egraph.equalities.transitive[trans];
755                let Some((nto, forward)) =
756                    self.add_trans_single(idx, trans, trans_node, &egraph.equalities)
757                else {
758                    o.remove();
759                    continue;
760                };
761                let segment = TransitiveExplSegment {
762                    forward,
763                    kind: TransitiveExplSegmentKind::Transitive(trans),
764                };
765                let (from, to) = if forward { (nfrom, nto) } else { (nto, nfrom) };
766                debug_assert!(trans_node.given_len.is_none() || from.index() < to.index());
767                self.graph.add_edge(from, to, segment);
768                if trans_node
769                    .given_len
770                    .is_some_and(|len| len == self.edges_len)
771                {
772                    return Some((forward, trans));
773                }
774            }
775        }
776        None
777    }
778    fn add_trans_single(
779        &self,
780        idx: usize,
781        trans: EqTransIdx,
782        trans_node: &TransitiveExpl,
783        equalities: &Equalities,
784    ) -> Option<(NodeIndex, bool)> {
785        let given_len = trans_node.given_len?.get();
786        let edges_len = self.edges_len.get();
787        debug_assert!(self.path_enodes.contains(&trans_node.to));
788
789        let left = given_len <= idx && trans_node.to == self.simple_path.node_at(idx - given_len);
790        if left {
791            let prior_simple_edges = (0..idx).map(|idx| self.graph[EdgeIndex::new(idx)]);
792            let mut prior_simple_edges = TransitiveExplSegment::rev(prior_simple_edges)
793                .map(|seg| (seg.kind.given().unwrap(), seg.forward));
794            match equalities.is_equal(trans, &mut prior_simple_edges) {
795                None => {
796                    debug_assert!(false);
797                    None
798                }
799                Some(false) => None,
800                Some(true) => {
801                    let to = NodeIndex::new(idx - given_len);
802                    Some((to, false))
803                }
804            }
805        } else if given_len <= edges_len - idx
806            && trans_node.to == self.simple_path.node_at(idx + given_len)
807        {
808            let post_simple_edges = (idx..edges_len).map(|idx| self.graph[EdgeIndex::new(idx)]);
809            let mut post_simple_edges =
810                post_simple_edges.map(|seg| (seg.kind.given().unwrap(), seg.forward));
811            match equalities.is_equal(trans, &mut post_simple_edges) {
812                None => {
813                    debug_assert!(false);
814                    None
815                }
816                Some(false) => None,
817                Some(true) => {
818                    let to = NodeIndex::new(idx + given_len);
819                    Some((to, true))
820                }
821            }
822        } else {
823            None
824        }
825    }
826}
827
828/// The complexity of this arises from the fact that z3 will sometimes create a
829/// new enode in a higher frame when one in a lower frame already exists. This
830/// new enode might then be popped, but z3 will still want to use the old enode.
831#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
832#[derive(Debug)]
833enum TermToEnode {
834    Single(ENodeIdx),
835    Multiple(Vec<ENodeIdx>),
836}
837
838impl Default for TermToEnode {
839    fn default() -> Self {
840        Self::Multiple(Vec::default())
841    }
842}
843
844impl EGraph {
845    pub fn insert_tte(&mut self, term: TermIdx, enode: ENodeIdx, stack: &Stack) -> Result<()> {
846        let remove = |e: &ENodeIdx| !stack.is_alive(self.enodes[*e].frame);
847        self.term_to_enode.try_reserve(1)?;
848        let tte = self.term_to_enode.entry(term).or_default();
849        let mut vec = match tte {
850            TermToEnode::Single(e) => [*e].into_iter().filter(|e| !remove(e)).collect(),
851            TermToEnode::Multiple(es) => {
852                let mut es = core::mem::take(es);
853                let idx = es.iter().position(remove).unwrap_or(es.len());
854                debug_assert!(es[idx..].iter().all(remove));
855                es.drain(idx..);
856                es
857            }
858        };
859        if vec.is_empty() {
860            *tte = TermToEnode::Single(enode);
861        } else {
862            vec.push(enode);
863            *tte = TermToEnode::Multiple(vec);
864        }
865        // TODO: why does this happen sometimes?
866        // debug_assert!(!old.is_some_and(|o| stack.is_alive(self[o].frame)));
867        Ok(())
868    }
869
870    fn get_tte(&mut self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
871        let Entry::Occupied(mut o) = self.term_to_enode.entry(term) else {
872            return None;
873        };
874        let remove = |e: &ENodeIdx| !stack.is_alive(self.enodes[*e].frame);
875        match o.get_mut() {
876            TermToEnode::Single(e) => {
877                if remove(&*e) {
878                    o.remove();
879                    None
880                } else {
881                    Some(*e)
882                }
883            }
884            TermToEnode::Multiple(es) => {
885                let idx = es.iter().position(remove).unwrap_or(es.len());
886                debug_assert!(es[idx..].iter().all(remove));
887                es.drain(idx..);
888                if let Some(last) = es.last().copied() {
889                    if es.len() == 1 {
890                        *o.get_mut() = TermToEnode::Single(last);
891                    }
892                    Some(last)
893                } else {
894                    o.remove();
895                    None
896                }
897            }
898        }
899    }
900
901    fn get_tte_imm(&self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
902        let enode = self.term_to_enode.get(&term)?;
903        let keep = |e: &ENodeIdx| stack.is_alive(self.enodes[*e].frame);
904        match enode {
905            TermToEnode::Single(e) if keep(e) => Some(*e),
906            TermToEnode::Single(_) => None,
907            TermToEnode::Multiple(es) => es.iter().copied().find(keep),
908        }
909    }
910}