smt_scope/parsers/z3/
egraph.rs

1use std::{collections::hash_map::Entry, num::NonZeroUsize};
2
3#[cfg(feature = "mem_dbg")]
4use mem_dbg::{MemDbg, MemSize};
5use petgraph::{
6    graph::{DiGraph, EdgeIndex, NodeIndex},
7    visit::EdgeRef,
8};
9
10use crate::{
11    items::{
12        ENode, ENodeBlame, ENodeIdx, EqGivenIdx, EqGivenUse, EqTransIdx, Equality, EqualityExpl,
13        InstIdx, ProofIdx, TermIdx, TransitiveExpl, TransitiveExplSegment,
14        TransitiveExplSegmentKind,
15    },
16    BoxSlice, Error, FxHashMap, FxHashSet, NonMaxU32, Result, TiVec,
17};
18
19use super::{bugs::TransEqAllowed, stack::Stack, terms::Terms};
20
21#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
22#[derive(Debug, Default)]
23pub struct EGraph {
24    term_to_enode: FxHashMap<TermIdx, TermToEnode>,
25    pub(crate) enodes: TiVec<ENodeIdx, ENode>,
26    /// Which terms have we seen in a proof and might therefore get an enode
27    /// from them?
28    proofs: FxHashMap<TermIdx, ProofIdx>,
29    pub equalities: Equalities,
30}
31
32impl Equalities {
33    pub fn from_to(&self, eq: EqTransIdx) -> (ENodeIdx, ENodeIdx) {
34        let from = self.walk_trans(eq, |eq, fwd| {
35            let from = if fwd { eq.from() } else { eq.to() };
36            Err(from)
37        });
38        let equality = &self.transitive[eq];
39        let from = from.err().unwrap_or_else(|| equality.error_from().unwrap());
40        let to = equality.to;
41        (from, to)
42    }
43
44    /// Walk the given equalities of a transitive equality, returning early with
45    /// the error if the closure returns one.
46    pub fn walk_trans<E>(
47        &self,
48        eq: EqTransIdx,
49        f: impl FnMut(&EqualityExpl, bool) -> core::result::Result<(), E>,
50    ) -> core::result::Result<(), E> {
51        let mut walker = TransEqStopWalker {
52            equalities: self,
53            simple: f,
54        };
55        walker.walk_trans(eq, true)
56    }
57}
58
59impl EGraph {
60    pub fn get_blame(
61        &self,
62        tidx: TermIdx,
63        inst: Option<(InstIdx, FxHashSet<EqTransIdx>)>,
64        terms: &Terms,
65        stack: &Stack,
66    ) -> ENodeBlame {
67        if terms.is_bool_const(tidx) {
68            return ENodeBlame::BoolConst;
69        }
70        if let Some(inst) = inst {
71            return ENodeBlame::Inst(inst);
72        }
73        let proof = self.proofs.get(&tidx).copied();
74        let proof = proof.filter(|&p| stack.is_alive(terms[p].frame));
75        if let Some(proof) = proof {
76            return ENodeBlame::Proof(proof);
77        }
78        ENodeBlame::Unknown
79    }
80
81    pub fn new_enode(
82        &mut self,
83        blame: ENodeBlame,
84        term: TermIdx,
85        z3_generation: NonMaxU32,
86        stack: &Stack,
87    ) -> Result<ENodeIdx> {
88        // TODO: why does this happen sometimes?
89        // if created_by.is_none() && z3_generation.is_some() {
90        //     debug_assert_eq!(
91        //         z3_generation.unwrap(),
92        //         0,
93        //         "enode with no creator has non-zero generation"
94        //     );
95        // }
96        self.enodes.raw.try_reserve(1)?;
97        let enode = self.enodes.push_and_get_key(ENode {
98            frame: stack.active_frame(),
99            blame,
100            owner: term,
101            z3_generation,
102            equalities: Vec::new(),
103            transitive: FxHashMap::default(),
104        });
105        self.insert_tte(term, enode, stack)?;
106        Ok(enode)
107    }
108
109    pub fn get_enode(&mut self, term: TermIdx, stack: &Stack) -> Result<ENodeIdx> {
110        self.get_tte(term, stack).ok_or(Error::UnknownEnode(term))
111    }
112
113    pub fn get_enode_imm(&self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
114        self.get_tte_imm(term, stack)
115    }
116
117    pub(super) fn new_proof(
118        &mut self,
119        term: TermIdx,
120        proof: ProofIdx,
121        terms: &Terms,
122        stack: &Stack,
123    ) -> Result<()> {
124        if !terms[proof].kind.is_asserted() {
125            return Ok(());
126        }
127        terms.app_walk(term, |tidx, app| {
128            self.proofs.try_reserve(1)?;
129            match self.proofs.entry(tidx) {
130                Entry::Occupied(mut o) => {
131                    if stack.is_alive(terms[*o.get()].frame) {
132                        return Ok(&[]);
133                    } else {
134                        o.insert(proof);
135                    }
136                }
137                Entry::Vacant(v) => {
138                    v.insert(proof);
139                }
140            }
141            Ok(&app.child_ids)
142        })
143    }
144
145    pub fn get_given(&self, from: ENodeIdx, to: ENodeIdx) -> Option<EqGivenIdx> {
146        self.enodes[from]
147            .equalities
148            .iter()
149            .find(|eq| eq.to == to)
150            .map(|eq| eq.expl)
151    }
152
153    pub fn new_given_equality(
154        &mut self,
155        from: ENodeIdx,
156        expl: EqualityExpl,
157        stack: &Stack,
158    ) -> Result<()> {
159        debug_assert_eq!(from, expl.from());
160        let to = expl.to();
161        self.equalities.given.raw.try_reserve(1)?;
162        let expl = self.equalities.given.push_and_get_key(expl);
163        let enode = &mut self.enodes[from];
164        let eq = Equality {
165            _frame: stack.active_frame(),
166            to,
167            expl,
168        };
169        enode.equalities.try_reserve(1)?;
170        enode.equalities.push(eq);
171        // TODO: is ok to simply ignore the old equality, or should we also blame it later on?
172        // let (new, others) = enode.equalities.split_last().unwrap();
173        // if let Some(old) = others.last() {
174        //     let inactive = old.frame.map(|f| !stack.stack_frames[f].active).unwrap_or_default();
175        //     if inactive {
176        //         return;
177        //     }
178        //     let is_root = matches!(old.expl, EqualityExpl::Root { .. });
179        //     let root_unchanged = is_root || (self.path_to_root(old.to).last().unwrap() == self.path_to_root(new.to).last().unwrap());
180        //     if root_unchanged {
181        //         return;
182        //     }
183
184        //     let is_unknown = matches!(enode.get_equality(stack).unwrap().expl, EqualityExpl::Unknown { .. });
185        //     let equivalent = old.expl == enode.get_equality(stack).unwrap().expl;
186        //     // let test = old.frame.is_some_and(|f| usize::from(f) == 854);
187        //     if !equivalent && !is_unknown {
188        //         panic!();
189        //     }
190        // }
191        Ok(())
192    }
193
194    pub fn new_trans_equality(
195        &mut self,
196        from: ENodeIdx,
197        to: ENodeIdx,
198        stack: &Stack,
199        mismatch: TransEqAllowed,
200    ) -> Result<core::result::Result<EqTransIdx, ENodeIdx>> {
201        if from == to {
202            Ok(Err(from))
203        } else {
204            self.construct_trans_equality(from, to, stack, mismatch)
205                .map(Ok)
206        }
207    }
208
209    fn to_root<'a>(
210        &'a self,
211        from: ENodeIdx,
212        stack: &'a Stack,
213    ) -> impl Iterator<Item = ENodeIdx> + 'a {
214        core::iter::successors(Some(from), move |&from| {
215            self.enodes[from].get_equality(stack).map(|eq| eq.to)
216        })
217    }
218
219    fn to_root_visited<'a>(
220        &'a self,
221        from: ENodeIdx,
222        stack: &'a Stack,
223        cycle: &'a mut Option<ENodeIdx>,
224        visited: &'a mut FxHashSet<ENodeIdx>,
225    ) -> impl Iterator<Item = ENodeIdx> + 'a {
226        self.to_root(from, stack).take_while(move |&from| {
227            if visited.insert(from) {
228                true
229            } else {
230                // On rare occasions there is a cycle of more than one enode in
231                // the equality graph (EXPLAIN).
232                *cycle = Some(from);
233                false
234            }
235        })
236    }
237
238    pub fn check_eq(&self, from: ENodeIdx, to: ENodeIdx, stack: &Stack) -> bool {
239        let mut visited_from = FxHashSet::default();
240        self.to_root_visited(from, stack, &mut None, &mut visited_from)
241            .for_each(drop);
242        for to in self.to_root_visited(to, stack, &mut None, &mut FxHashSet::default()) {
243            if visited_from.contains(&to) {
244                return true;
245            }
246        }
247        false
248    }
249
250    pub fn path_to_root(
251        &self,
252        from: ENodeIdx,
253        root: Option<ENodeIdx>,
254        stack: &Stack,
255    ) -> Result<(Option<ENodeIdx>, Vec<ENodeIdx>)> {
256        let mut cycle = None;
257        let mut visited = FxHashSet::default();
258        let mut path = Vec::new();
259        for e in self.to_root_visited(from, stack, &mut cycle, &mut visited) {
260            path.try_reserve(1)?;
261            path.push(e);
262            if root.is_some_and(|root| root == e) {
263                return Ok((Some(e), path));
264            }
265        }
266        Ok((cycle, path))
267    }
268
269    fn get_simple_path(
270        &self,
271        from: ENodeIdx,
272        to: ENodeIdx,
273        stack: &Stack,
274        can_mismatch: bool,
275    ) -> Result<Option<SimplePath>> {
276        let (_, f_path) = self.path_to_root(from, None, stack)?;
277        let f_root = f_path.len() - 1;
278        let (_, t_path) = self.path_to_root(to, Some(*f_path.last().unwrap()), stack)?;
279        let t_root = t_path.len() - 1;
280
281        if f_path[f_root] != t_path[t_root] {
282            // Root may not always be the same from v4.12.3 onwards if `to` is an `ite` expression. See:
283            // https://github.com/Z3Prover/z3/commit/faf14012ba18d21c1fcddbdc321ac127f019fa03#diff-0a9ec50ded668e51578edc67ecfe32380336b9cbf12c5d297e2d3759a7a39847R2417-R2419
284            if can_mismatch {
285                // Return no path if the roots are different.
286                return Ok(None);
287            } else {
288                return Err(Error::EnodeRootMismatch(from, to));
289            }
290        }
291        let mut shared = 1;
292        while shared < f_path.len()
293            && shared < t_path.len()
294            && f_path[f_root - shared] == t_path[t_root - shared]
295        {
296            shared += 1;
297        }
298        let path = SimplePath {
299            from_to_root: f_path,
300            to_to_root: t_path,
301            shared,
302        };
303        Ok(Some(path))
304    }
305
306    fn construct_trans_equality(
307        &mut self,
308        from: ENodeIdx,
309        to: ENodeIdx,
310        stack: &Stack,
311        mismatch: TransEqAllowed,
312    ) -> Result<EqTransIdx> {
313        debug_assert_ne!(from, to);
314        let can_mismatch = mismatch.can_mismatch_initial;
315        let Some(simple_path) = self.get_simple_path(from, to, stack, can_mismatch)? else {
316            // There was a root mismatch (and `can_mismatch` was true), so we
317            // can't construct a simple path.
318            let error = TransitiveExpl::error(from, to);
319            self.enodes[from].transitive.try_reserve(1)?;
320            let trans = match self.enodes[from].transitive.entry(to) {
321                Entry::Occupied(mut o) => {
322                    let trans = *o.get();
323                    if self.equalities.transitive[trans].given_len.is_some() {
324                        // These two nodes are no longer equal (this is an old
325                        // transitive equality that is no longer valid).
326                        self.equalities.transitive.raw.try_reserve(1)?;
327                        let trans = self.equalities.transitive.push_and_get_key(error);
328                        o.insert(trans);
329                        trans
330                    } else {
331                        trans
332                    }
333                }
334                Entry::Vacant(v) => {
335                    self.equalities.transitive.raw.try_reserve(1)?;
336                    let trans = self.equalities.transitive.push_and_get_key(error);
337                    *v.insert(trans)
338                }
339            };
340            return Ok(trans);
341        };
342        // Should not fail since `from != to`
343        let edges_len = simple_path.edges_len();
344        let mut graph = simple_path.initialise_graph(self, stack);
345
346        // Add transitive edges to graph, start by trying to find either forward
347        // or backward full solutions.
348        if let Some((forward, solution)) = graph.add_trans_from(0, self) {
349            debug_assert!(forward);
350            return Ok(solution);
351        }
352        let trans = if let Some((forward, solution)) = graph.add_trans_from(edges_len.get(), self) {
353            debug_assert!(!forward);
354            let inner = &self.equalities.transitive[solution];
355            if inner.path.len() == 1 {
356                use TransitiveExplSegmentKind::*;
357                match inner.path[0] {
358                    TransitiveExplSegment {
359                        forward: false,
360                        kind: Transitive(idx),
361                    } => return Ok(idx),
362                    TransitiveExplSegment {
363                        forward: true,
364                        kind: Transitive(_),
365                    } => unreachable!(),
366                    TransitiveExplSegment {
367                        kind: Given(..), ..
368                    } => (),
369                    TransitiveExplSegment {
370                        kind: Error(..), ..
371                    } => unreachable!(),
372                }
373            }
374            let solution = TransitiveExplSegment {
375                forward,
376                kind: TransitiveExplSegmentKind::Transitive(solution),
377            };
378            TransitiveExpl::new([solution].into_iter(), NonZeroUsize::new(1).unwrap(), to)?
379        } else {
380            for idx in 1..edges_len.get() {
381                graph.add_trans_from(idx, self);
382            }
383            // Find the shortest path
384            for idx in (0..edges_len.get()).rev() {
385                let idx = NodeIndex::new(idx);
386                // Use `.rev()` here to prefer transitive edges over leaf edges,
387                // though hopefully the `min_by_key` should be unique.
388                let min = graph
389                    .graph
390                    .edges(idx)
391                    .min_by_key(|edge| graph.graph[edge.target()].0)
392                    .unwrap();
393                let (cost, id) = (graph.graph[min.target()].0, min.id());
394                let idx = &mut graph.graph[idx];
395                idx.0 = cost + 1;
396                idx.1 = Some(id);
397            }
398
399            let start = NodeIndex::new(0);
400            let mut edge = graph.graph[start].1;
401            let path_length = graph.graph[start].0;
402            TransitiveExpl::new(
403                (0..path_length)
404                    .map(|_| {
405                        let kind = &graph.graph[edge.unwrap()];
406                        edge = graph.graph[graph.graph.edge_endpoints(edge.unwrap()).unwrap().1].1;
407                        kind
408                    })
409                    .copied(),
410                edges_len,
411                to,
412            )?
413        };
414        let trans = self.insert_trans_equality(trans, stack, mismatch.can_mismatch_congr)?;
415        debug_assert_eq!(self.equalities.walk_to(from, trans), to);
416        self.enodes[from].transitive.try_reserve(1)?;
417        let old = self.enodes[from].transitive.insert(to, trans);
418        debug_assert_eq!(old, None);
419        Ok(trans)
420    }
421
422    fn insert_trans_equality(
423        &mut self,
424        mut trans: TransitiveExpl,
425        stack: &Stack,
426        can_mismatch_congr: bool,
427    ) -> Result<EqTransIdx> {
428        let mismatch = TransEqAllowed {
429            can_mismatch_initial: can_mismatch_congr,
430            can_mismatch_congr: false,
431        };
432        // Find the current congruence uses
433        for seg in trans.path.iter_mut() {
434            if let TransitiveExplSegmentKind::Given((cg, idx)) = &mut seg.kind {
435                debug_assert_eq!(*idx, None);
436                let EqualityExpl::Congruence { arg_eqs, .. } = &self.equalities.given[*cg] else {
437                    continue;
438                };
439                let mut args = Vec::new();
440                args.try_reserve_exact(arg_eqs.len())?;
441                args.extend(arg_eqs.iter().copied());
442
443                let mut use_ = Vec::new();
444                use_.try_reserve_exact(arg_eqs.len())?;
445                for (from, to) in args {
446                    let Ok(trans) = self.new_trans_equality(from, to, stack, mismatch)? else {
447                        continue;
448                    };
449                    use_.push(trans);
450                }
451                let EqualityExpl::Congruence { uses, .. } = &mut self.equalities.given[*cg] else {
452                    unreachable!()
453                };
454                let real_idx = uses.iter().position(|u| **u == use_).unwrap_or_else(|| {
455                    uses.push(BoxSlice::from(use_));
456                    uses.len() - 1
457                });
458                *idx = Some(NonMaxU32::new(real_idx as u32).unwrap());
459            }
460        }
461
462        self.equalities.transitive.raw.try_reserve(1)?;
463        let trans = self.equalities.transitive.push_and_get_key(trans);
464        Ok(trans)
465    }
466}
467
468impl std::ops::Index<ENodeIdx> for EGraph {
469    type Output = ENode;
470    fn index(&self, idx: ENodeIdx) -> &Self::Output {
471        &self.enodes[idx]
472    }
473}
474
475impl ENode {
476    pub fn get_equality(&self, _stack: &Stack) -> Option<&Equality> {
477        // TODO: why are we allowed to use equalities from popped stack frames?
478        // self.equalities.iter().rev().find(|eq| stack.is_alive(eq._frame))
479        self.equalities.last()
480    }
481}
482
483#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
484#[derive(Debug, Default)]
485pub struct Equalities {
486    pub(crate) given: TiVec<EqGivenIdx, EqualityExpl>,
487    pub(crate) transitive: TiVec<EqTransIdx, TransitiveExpl>,
488}
489
490pub trait EqualityWalker<'a> {
491    type Error;
492
493    fn equalities(&self) -> &'a Equalities;
494
495    /// Override this method to get all the given equalities from a transitive
496    /// equality. If you wish to recurse into congruences then use the
497    /// following structure:
498    ///
499    /// ```ignore
500    /// let (eq, use_) = eq_use;
501    /// match &self.equalities().given[eq] {
502    ///     EqualityExpl::Congruence { uses, .. } => self.walk_congruence(uses, use_.unwrap(), forward),
503    ///     _ => Ok(()),
504    /// }
505    /// ```
506    fn walk_given(
507        &mut self,
508        eq_use: EqGivenUse,
509        forward: bool,
510    ) -> core::result::Result<(), Self::Error>;
511
512    /// Does nothing if `CONGR` is false. Otherwise recursively walks into the
513    /// congruence uses.
514    fn walk_congruence(
515        &mut self,
516        uses: &[BoxSlice<EqTransIdx>],
517        use_: NonMaxU32,
518        forward: bool,
519    ) -> core::result::Result<(), Self::Error> {
520        let use_ = &uses[use_.get() as usize];
521        for &eq in use_.iter() {
522            self.walk_trans(eq, forward)?;
523        }
524        Ok(())
525    }
526
527    /// Return `Err` if the walk should abort, `Ok` to stop recursing, and
528    /// otherwise call `self.super_walk_trans` to recurse.
529    fn walk_trans(
530        &mut self,
531        eq: EqTransIdx,
532        forward: bool,
533    ) -> core::result::Result<(), Self::Error> {
534        self.super_walk_trans(eq, forward)
535    }
536
537    /// This method defines the walking of transitive equalities, override
538    /// `walk_trans` to intercept before walking a transitive equality.
539    fn super_walk_trans(
540        &mut self,
541        eq: EqTransIdx,
542        forward: bool,
543    ) -> core::result::Result<(), Self::Error> {
544        let all = self.equalities().transitive[eq].all(forward);
545        for next in all {
546            use TransitiveExplSegmentKind::*;
547            match next.kind {
548                Given(eq_use) => self.walk_given(eq_use, next.forward)?,
549                Transitive(eq) => self.walk_trans(eq, next.forward)?,
550                Error(..) => (),
551            }
552        }
553        Ok(())
554    }
555}
556
557struct TransEqChecker<'a, I: Iterator<Item = (EqGivenUse, bool)>> {
558    equalities: &'a Equalities,
559    simple: I,
560}
561
562impl<'a, I: Iterator<Item = (EqGivenUse, bool)>> EqualityWalker<'a> for TransEqChecker<'a, I> {
563    type Error = bool;
564    fn equalities(&self) -> &'a Equalities {
565        self.equalities
566    }
567    fn walk_given(
568        &mut self,
569        (eq, _use_): EqGivenUse,
570        eq_fwd: bool,
571    ) -> core::result::Result<(), Self::Error> {
572        // Return `Err(false)` if we're out of simple, this should never
573        // happen.
574        let ((simple, _simple_use), fwd) = self.simple.next().ok_or(false)?;
575        // Return `Ok` if equal, else `Err(true)`.
576        // We do not compare `simple_use == use_` here because `simple_use`
577        // hasn't been set yet (I think?).
578        (simple == eq && fwd == eq_fwd).then_some(()).ok_or(true)
579    }
580}
581
582impl Equalities {
583    pub fn is_equal(
584        &self,
585        eq: EqTransIdx,
586        simple: &mut impl Iterator<Item = (EqGivenUse, bool)>,
587    ) -> Option<bool> {
588        let mut checker = TransEqChecker {
589            equalities: self,
590            simple,
591        };
592        let res = checker.walk_trans(eq, true);
593        // Map `Ok` -> `Ok(true)`, `Err(true)` -> `Ok(false)`, `Err(false)` -> `Err`.
594        res.map_or_else(|e| e.then_some(false), |_| Some(true))
595    }
596}
597
598pub type TransEqSimpleWalker<'a, F> = TransEqStopWalker<'a, Never, F>;
599
600pub struct TransEqStopWalker<'a, E, F: FnMut(&'a EqualityExpl, bool) -> core::result::Result<(), E>>
601{
602    equalities: &'a Equalities,
603    simple: F,
604}
605
606#[derive(Debug)]
607pub enum Never {}
608
609impl<'a, E, F: FnMut(&'a EqualityExpl, bool) -> core::result::Result<(), E>> EqualityWalker<'a>
610    for TransEqStopWalker<'a, E, F>
611{
612    type Error = E;
613    fn equalities(&self) -> &'a Equalities {
614        self.equalities
615    }
616    fn walk_given(
617        &mut self,
618        (eq, _): EqGivenUse,
619        forward: bool,
620    ) -> core::result::Result<(), Self::Error> {
621        (self.simple)(&self.equalities.given[eq], forward)
622    }
623}
624
625impl Equalities {
626    pub fn walk_to(&self, mut from: ENodeIdx, eq: EqTransIdx) -> ENodeIdx {
627        let mut walker = TransEqSimpleWalker {
628            equalities: self,
629            simple: |eq, fwd| {
630                from = eq.walk(from, fwd).unwrap();
631                Ok(())
632            },
633        };
634        walker.walk_trans(eq, true).unwrap();
635        from
636    }
637    pub fn path(&self, eq: EqTransIdx) -> Vec<ENodeIdx> {
638        let equality = &self.transitive[eq];
639        if let Some(from) = equality.error_from() {
640            return vec![from, equality.to];
641        }
642
643        let mut path = Vec::new();
644        let mut walker = TransEqSimpleWalker {
645            equalities: self,
646            simple: |eq, fwd| {
647                let from = if fwd { eq.from() } else { eq.to() };
648                path.push(from);
649                Ok(())
650            },
651        };
652        walker.walk_trans(eq, true).unwrap();
653        path.push(self.transitive[eq].to);
654        path
655    }
656}
657
658#[derive(Debug)]
659pub struct SimplePath {
660    from_to_root: Vec<ENodeIdx>,
661    to_to_root: Vec<ENodeIdx>,
662    shared: usize,
663}
664impl SimplePath {
665    pub fn edges_len(&self) -> NonZeroUsize {
666        NonZeroUsize::new(self.from_to_root.len() + self.to_to_root.len() - 2 * self.shared)
667            .unwrap()
668    }
669    pub fn all_simple_edges<'a>(
670        &'a self,
671        egraph: &'a EGraph,
672        stack: &'a Stack,
673    ) -> impl DoubleEndedIterator<Item = (ENodeIdx, EqGivenIdx, bool)> + 'a {
674        let from_to_join = self.from_to_root[..self.from_to_root.len() - self.shared]
675            .iter()
676            .copied();
677        let from_to_join = from_to_join.map(|e| {
678            let eq = &egraph.enodes[e].get_equality(stack).unwrap();
679            (eq.to, eq.expl, true)
680        });
681        let join_to_to = self.to_to_root[..self.to_to_root.len() - self.shared]
682            .iter()
683            .rev()
684            .copied();
685        let join_to_to =
686            join_to_to.map(|e| (e, egraph.enodes[e].get_equality(stack).unwrap().expl, false));
687        from_to_join.chain(join_to_to)
688    }
689
690    pub fn nodes_len(&self) -> usize {
691        self.edges_len().get() + 1
692    }
693    pub fn all_nodes(&self) -> impl DoubleEndedIterator<Item = ENodeIdx> + '_ {
694        let from_to_join = self.from_to_root[..self.from_to_root.len() + 1 - self.shared].iter();
695        let join_to_to = self.to_to_root[..self.to_to_root.len() - self.shared]
696            .iter()
697            .rev();
698        from_to_join.chain(join_to_to).copied()
699    }
700    pub fn node_at(&self, idx: usize) -> ENodeIdx {
701        let from_len = self.from_to_root.len() - self.shared;
702        if idx <= from_len {
703            self.from_to_root[idx]
704        } else {
705            let to_len = self.to_to_root.len() - self.shared;
706            self.to_to_root[(to_len + from_len) - idx]
707        }
708    }
709    pub fn all_transitive<'a>(
710        &'a self,
711        egraph: &'a EGraph,
712    ) -> impl DoubleEndedIterator<Item = impl Iterator<Item = EqTransIdx> + 'a> + 'a {
713        self.all_nodes()
714            .map(move |idx| egraph.enodes[idx].transitive.values().copied())
715    }
716
717    pub fn initialise_graph<'a>(self, egraph: &'a EGraph, stack: &'a Stack) -> Graph {
718        let edges_len = self.edges_len();
719        let mut g = Graph {
720            graph: DiGraph::with_capacity(self.nodes_len(), edges_len.get()),
721            path_enodes: self.all_nodes().collect(),
722            edges_len,
723            simple_path: self,
724        };
725        let mut last = g.graph.add_node((edges_len.get() as u32, None));
726        for (idx, (_, eq, forward)) in g.simple_path.all_simple_edges(egraph, stack).enumerate() {
727            let cost = (edges_len.get() - idx - 1) as u32;
728            let next = g.graph.add_node((cost, None));
729            let kind = TransitiveExplSegmentKind::Given((eq, None));
730            g.graph
731                .add_edge(last, next, TransitiveExplSegment { forward, kind });
732            last = next;
733        }
734        g
735    }
736}
737
738pub struct Graph {
739    simple_path: SimplePath,
740    path_enodes: FxHashSet<ENodeIdx>,
741    graph: DiGraph<(u32, Option<EdgeIndex>), TransitiveExplSegment>,
742    edges_len: NonZeroUsize,
743}
744impl Graph {
745    pub fn add_trans_from(
746        &mut self,
747        idx: usize,
748        egraph: &mut EGraph,
749    ) -> Option<(bool, EqTransIdx)> {
750        let nfrom = NodeIndex::new(idx);
751        let efrom = self.simple_path.node_at(idx);
752
753        for to in 0..self.simple_path.nodes_len() {
754            let to = self.simple_path.node_at(to);
755            if let Entry::Occupied(o) = egraph.enodes[efrom].transitive.entry(to) {
756                let trans = *o.get();
757                let trans_node = &egraph.equalities.transitive[trans];
758                let Some((nto, forward)) =
759                    self.add_trans_single(idx, trans, trans_node, &egraph.equalities)
760                else {
761                    o.remove();
762                    continue;
763                };
764                let segment = TransitiveExplSegment {
765                    forward,
766                    kind: TransitiveExplSegmentKind::Transitive(trans),
767                };
768                let (from, to) = if forward { (nfrom, nto) } else { (nto, nfrom) };
769                debug_assert!(trans_node.given_len.is_none() || from.index() < to.index());
770                self.graph.add_edge(from, to, segment);
771                if trans_node
772                    .given_len
773                    .is_some_and(|len| len == self.edges_len)
774                {
775                    return Some((forward, trans));
776                }
777            }
778        }
779        None
780    }
781    fn add_trans_single(
782        &self,
783        idx: usize,
784        trans: EqTransIdx,
785        trans_node: &TransitiveExpl,
786        equalities: &Equalities,
787    ) -> Option<(NodeIndex, bool)> {
788        let given_len = trans_node.given_len?.get();
789        let edges_len = self.edges_len.get();
790        debug_assert!(self.path_enodes.contains(&trans_node.to));
791
792        let left = given_len <= idx && trans_node.to == self.simple_path.node_at(idx - given_len);
793        if left {
794            let prior_simple_edges = (0..idx).map(|idx| self.graph[EdgeIndex::new(idx)]);
795            let mut prior_simple_edges = TransitiveExplSegment::rev(prior_simple_edges)
796                .map(|seg| (seg.kind.given().unwrap(), seg.forward));
797            match equalities.is_equal(trans, &mut prior_simple_edges) {
798                None => {
799                    debug_assert!(false);
800                    None
801                }
802                Some(false) => None,
803                Some(true) => {
804                    let to = NodeIndex::new(idx - given_len);
805                    Some((to, false))
806                }
807            }
808        } else if given_len <= edges_len - idx
809            && trans_node.to == self.simple_path.node_at(idx + given_len)
810        {
811            let post_simple_edges = (idx..edges_len).map(|idx| self.graph[EdgeIndex::new(idx)]);
812            let mut post_simple_edges =
813                post_simple_edges.map(|seg| (seg.kind.given().unwrap(), seg.forward));
814            match equalities.is_equal(trans, &mut post_simple_edges) {
815                None => {
816                    debug_assert!(false);
817                    None
818                }
819                Some(false) => None,
820                Some(true) => {
821                    let to = NodeIndex::new(idx + given_len);
822                    Some((to, true))
823                }
824            }
825        } else {
826            None
827        }
828    }
829}
830
831/// The complexity of this arises from the fact that z3 will sometimes create a
832/// new enode in a higher frame when one in a lower frame already exists. This
833/// new enode might then be popped, but z3 will still want to use the old enode.
834#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
835#[derive(Debug)]
836enum TermToEnode {
837    Single(ENodeIdx),
838    Multiple(Vec<ENodeIdx>),
839}
840
841impl Default for TermToEnode {
842    fn default() -> Self {
843        Self::Multiple(Vec::default())
844    }
845}
846
847impl EGraph {
848    pub fn insert_tte(&mut self, term: TermIdx, enode: ENodeIdx, stack: &Stack) -> Result<()> {
849        let remove = |e: &ENodeIdx| !stack.is_alive(self.enodes[*e].frame);
850        self.term_to_enode.try_reserve(1)?;
851        let tte = self.term_to_enode.entry(term).or_default();
852        let mut vec = match tte {
853            TermToEnode::Single(e) => [*e].into_iter().filter(|e| !remove(e)).collect(),
854            TermToEnode::Multiple(es) => {
855                let mut es = core::mem::take(es);
856                let idx = es.iter().position(remove).unwrap_or(es.len());
857                debug_assert!(es[idx..].iter().all(remove));
858                es.drain(idx..);
859                es
860            }
861        };
862        if vec.is_empty() {
863            *tte = TermToEnode::Single(enode);
864        } else {
865            vec.push(enode);
866            *tte = TermToEnode::Multiple(vec);
867        }
868        // TODO: why does this happen sometimes?
869        // debug_assert!(!old.is_some_and(|o| stack.is_alive(self[o].frame)));
870        Ok(())
871    }
872
873    fn get_tte(&mut self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
874        let Entry::Occupied(mut o) = self.term_to_enode.entry(term) else {
875            return None;
876        };
877        let remove = |e: &ENodeIdx| !stack.is_alive(self.enodes[*e].frame);
878        match o.get_mut() {
879            TermToEnode::Single(e) => {
880                if remove(&*e) {
881                    o.remove();
882                    None
883                } else {
884                    Some(*e)
885                }
886            }
887            TermToEnode::Multiple(es) => {
888                let idx = es.iter().position(remove).unwrap_or(es.len());
889                debug_assert!(es[idx..].iter().all(remove));
890                es.drain(idx..);
891                if let Some(last) = es.last().copied() {
892                    if es.len() == 1 {
893                        *o.get_mut() = TermToEnode::Single(last);
894                    }
895                    Some(last)
896                } else {
897                    o.remove();
898                    None
899                }
900            }
901        }
902    }
903
904    fn get_tte_imm(&self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
905        let enode = self.term_to_enode.get(&term)?;
906        let keep = |e: &ENodeIdx| stack.is_alive(self.enodes[*e].frame);
907        match enode {
908            TermToEnode::Single(e) if keep(e) => Some(*e),
909            TermToEnode::Single(_) => None,
910            TermToEnode::Multiple(es) => es.iter().copied().find(keep),
911        }
912    }
913}