1use std::{collections::hash_map::Entry, num::NonZeroUsize};
2
3#[cfg(feature = "mem_dbg")]
4use mem_dbg::{MemDbg, MemSize};
5use petgraph::{
6 graph::{DiGraph, EdgeIndex, NodeIndex},
7 visit::EdgeRef,
8};
9
10use crate::{
11 items::{
12 ENode, ENodeBlame, ENodeIdx, EqGivenIdx, EqGivenUse, EqTransIdx, Equality, EqualityExpl,
13 InstIdx, ProofIdx, TermIdx, TransitiveExpl, TransitiveExplSegment,
14 TransitiveExplSegmentKind,
15 },
16 BoxSlice, Error, FxHashMap, FxHashSet, NonMaxU32, Result, TiVec,
17};
18
19use super::{bugs::TransEqAllowed, stack::Stack, terms::Terms};
20
21#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
22#[derive(Debug, Default)]
23pub struct EGraph {
24 term_to_enode: FxHashMap<TermIdx, TermToEnode>,
25 pub(crate) enodes: TiVec<ENodeIdx, ENode>,
26 proofs: FxHashMap<TermIdx, ProofIdx>,
29 pub equalities: Equalities,
30}
31
32impl Equalities {
33 pub fn from_to(&self, eq: EqTransIdx) -> (ENodeIdx, ENodeIdx) {
34 let from = self.walk_trans(eq, |eq, fwd| {
35 let from = if fwd { eq.from() } else { eq.to() };
36 Err(from)
37 });
38 let equality = &self.transitive[eq];
39 let from = from.err().unwrap_or_else(|| equality.error_from().unwrap());
40 let to = equality.to;
41 (from, to)
42 }
43
44 pub fn walk_trans<E>(
47 &self,
48 eq: EqTransIdx,
49 f: impl FnMut(&EqualityExpl, bool) -> core::result::Result<(), E>,
50 ) -> core::result::Result<(), E> {
51 let mut walker = TransEqStopWalker {
52 equalities: self,
53 simple: f,
54 };
55 walker.walk_trans(eq, true)
56 }
57}
58
59impl EGraph {
60 pub fn get_blame(
61 &self,
62 tidx: TermIdx,
63 inst: Option<(InstIdx, FxHashSet<EqTransIdx>)>,
64 terms: &Terms,
65 stack: &Stack,
66 ) -> ENodeBlame {
67 if terms.is_bool_const(tidx) {
68 return ENodeBlame::BoolConst;
69 }
70 if let Some(inst) = inst {
71 return ENodeBlame::Inst(inst);
72 }
73 let proof = self.proofs.get(&tidx).copied();
74 let proof = proof.filter(|&p| stack.is_alive(terms[p].frame));
75 if let Some(proof) = proof {
76 return ENodeBlame::Proof(proof);
77 }
78 ENodeBlame::Unknown
79 }
80
81 pub fn new_enode(
82 &mut self,
83 blame: ENodeBlame,
84 term: TermIdx,
85 z3_generation: NonMaxU32,
86 stack: &Stack,
87 ) -> Result<ENodeIdx> {
88 self.enodes.raw.try_reserve(1)?;
97 let enode = self.enodes.push_and_get_key(ENode {
98 frame: stack.active_frame(),
99 blame,
100 owner: term,
101 z3_generation,
102 equalities: Vec::new(),
103 transitive: FxHashMap::default(),
104 });
105 self.insert_tte(term, enode, stack)?;
106 Ok(enode)
107 }
108
109 pub fn get_enode(&mut self, term: TermIdx, stack: &Stack) -> Result<ENodeIdx> {
110 self.get_tte(term, stack).ok_or(Error::UnknownEnode(term))
111 }
112
113 pub fn get_enode_imm(&self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
114 self.get_tte_imm(term, stack)
115 }
116
117 pub(super) fn new_proof(
118 &mut self,
119 term: TermIdx,
120 proof: ProofIdx,
121 terms: &Terms,
122 stack: &Stack,
123 ) -> Result<()> {
124 if !terms[proof].kind.is_asserted() {
125 return Ok(());
126 }
127 terms.app_walk(term, |tidx, app| {
128 self.proofs.try_reserve(1)?;
129 match self.proofs.entry(tidx) {
130 Entry::Occupied(mut o) => {
131 if stack.is_alive(terms[*o.get()].frame) {
132 return Ok(&[]);
133 } else {
134 o.insert(proof);
135 }
136 }
137 Entry::Vacant(v) => {
138 v.insert(proof);
139 }
140 }
141 Ok(&app.child_ids)
142 })
143 }
144
145 pub fn get_given(&self, from: ENodeIdx, to: ENodeIdx) -> Option<EqGivenIdx> {
146 self.enodes[from]
147 .equalities
148 .iter()
149 .find(|eq| eq.to == to)
150 .map(|eq| eq.expl)
151 }
152
153 pub fn new_given_equality(
154 &mut self,
155 from: ENodeIdx,
156 expl: EqualityExpl,
157 stack: &Stack,
158 ) -> Result<()> {
159 debug_assert_eq!(from, expl.from());
160 let to = expl.to();
161 self.equalities.given.raw.try_reserve(1)?;
162 let expl = self.equalities.given.push_and_get_key(expl);
163 let enode = &mut self.enodes[from];
164 let eq = Equality {
165 _frame: stack.active_frame(),
166 to,
167 expl,
168 };
169 enode.equalities.try_reserve(1)?;
170 enode.equalities.push(eq);
171 Ok(())
192 }
193
194 pub fn new_trans_equality(
195 &mut self,
196 from: ENodeIdx,
197 to: ENodeIdx,
198 stack: &Stack,
199 mismatch: TransEqAllowed,
200 ) -> Result<core::result::Result<EqTransIdx, ENodeIdx>> {
201 if from == to {
202 Ok(Err(from))
203 } else {
204 self.construct_trans_equality(from, to, stack, mismatch)
205 .map(Ok)
206 }
207 }
208
209 fn to_root<'a>(
210 &'a self,
211 from: ENodeIdx,
212 stack: &'a Stack,
213 ) -> impl Iterator<Item = ENodeIdx> + 'a {
214 core::iter::successors(Some(from), move |&from| {
215 self.enodes[from].get_equality(stack).map(|eq| eq.to)
216 })
217 }
218
219 fn to_root_visited<'a>(
220 &'a self,
221 from: ENodeIdx,
222 stack: &'a Stack,
223 cycle: &'a mut Option<ENodeIdx>,
224 visited: &'a mut FxHashSet<ENodeIdx>,
225 ) -> impl Iterator<Item = ENodeIdx> + 'a {
226 self.to_root(from, stack).take_while(move |&from| {
227 if visited.insert(from) {
228 true
229 } else {
230 *cycle = Some(from);
233 false
234 }
235 })
236 }
237
238 pub fn check_eq(&self, from: ENodeIdx, to: ENodeIdx, stack: &Stack) -> bool {
239 let mut visited_from = FxHashSet::default();
240 self.to_root_visited(from, stack, &mut None, &mut visited_from)
241 .for_each(drop);
242 for to in self.to_root_visited(to, stack, &mut None, &mut FxHashSet::default()) {
243 if visited_from.contains(&to) {
244 return true;
245 }
246 }
247 false
248 }
249
250 pub fn path_to_root(
251 &self,
252 from: ENodeIdx,
253 root: Option<ENodeIdx>,
254 stack: &Stack,
255 ) -> Result<(Option<ENodeIdx>, Vec<ENodeIdx>)> {
256 let mut cycle = None;
257 let mut visited = FxHashSet::default();
258 let mut path = Vec::new();
259 for e in self.to_root_visited(from, stack, &mut cycle, &mut visited) {
260 path.try_reserve(1)?;
261 path.push(e);
262 if root.is_some_and(|root| root == e) {
263 return Ok((Some(e), path));
264 }
265 }
266 Ok((cycle, path))
267 }
268
269 fn get_simple_path(
270 &self,
271 from: ENodeIdx,
272 to: ENodeIdx,
273 stack: &Stack,
274 can_mismatch: bool,
275 ) -> Result<Option<SimplePath>> {
276 let (_, f_path) = self.path_to_root(from, None, stack)?;
277 let f_root = f_path.len() - 1;
278 let (_, t_path) = self.path_to_root(to, Some(*f_path.last().unwrap()), stack)?;
279 let t_root = t_path.len() - 1;
280
281 if f_path[f_root] != t_path[t_root] {
282 if can_mismatch {
285 return Ok(None);
287 } else {
288 return Err(Error::EnodeRootMismatch(from, to));
289 }
290 }
291 let mut shared = 1;
292 while shared < f_path.len()
293 && shared < t_path.len()
294 && f_path[f_root - shared] == t_path[t_root - shared]
295 {
296 shared += 1;
297 }
298 let path = SimplePath {
299 from_to_root: f_path,
300 to_to_root: t_path,
301 shared,
302 };
303 Ok(Some(path))
304 }
305
306 fn construct_trans_equality(
307 &mut self,
308 from: ENodeIdx,
309 to: ENodeIdx,
310 stack: &Stack,
311 mismatch: TransEqAllowed,
312 ) -> Result<EqTransIdx> {
313 debug_assert_ne!(from, to);
314 let can_mismatch = mismatch.can_mismatch_initial;
315 let Some(simple_path) = self.get_simple_path(from, to, stack, can_mismatch)? else {
316 let error = TransitiveExpl::error(from, to);
319 self.enodes[from].transitive.try_reserve(1)?;
320 let trans = match self.enodes[from].transitive.entry(to) {
321 Entry::Occupied(mut o) => {
322 let trans = *o.get();
323 if self.equalities.transitive[trans].given_len.is_some() {
324 self.equalities.transitive.raw.try_reserve(1)?;
327 let trans = self.equalities.transitive.push_and_get_key(error);
328 o.insert(trans);
329 trans
330 } else {
331 trans
332 }
333 }
334 Entry::Vacant(v) => {
335 self.equalities.transitive.raw.try_reserve(1)?;
336 let trans = self.equalities.transitive.push_and_get_key(error);
337 *v.insert(trans)
338 }
339 };
340 return Ok(trans);
341 };
342 let edges_len = simple_path.edges_len();
344 let mut graph = simple_path.initialise_graph(self, stack);
345
346 if let Some((forward, solution)) = graph.add_trans_from(0, self) {
349 debug_assert!(forward);
350 return Ok(solution);
351 }
352 let trans = if let Some((forward, solution)) = graph.add_trans_from(edges_len.get(), self) {
353 debug_assert!(!forward);
354 let inner = &self.equalities.transitive[solution];
355 if inner.path.len() == 1 {
356 use TransitiveExplSegmentKind::*;
357 match inner.path[0] {
358 TransitiveExplSegment {
359 forward: false,
360 kind: Transitive(idx),
361 } => return Ok(idx),
362 TransitiveExplSegment {
363 forward: true,
364 kind: Transitive(_),
365 } => unreachable!(),
366 TransitiveExplSegment {
367 kind: Given(..), ..
368 } => (),
369 TransitiveExplSegment {
370 kind: Error(..), ..
371 } => unreachable!(),
372 }
373 }
374 let solution = TransitiveExplSegment {
375 forward,
376 kind: TransitiveExplSegmentKind::Transitive(solution),
377 };
378 TransitiveExpl::new([solution].into_iter(), NonZeroUsize::new(1).unwrap(), to)?
379 } else {
380 for idx in 1..edges_len.get() {
381 graph.add_trans_from(idx, self);
382 }
383 for idx in (0..edges_len.get()).rev() {
385 let idx = NodeIndex::new(idx);
386 let min = graph
389 .graph
390 .edges(idx)
391 .min_by_key(|edge| graph.graph[edge.target()].0)
392 .unwrap();
393 let (cost, id) = (graph.graph[min.target()].0, min.id());
394 let idx = &mut graph.graph[idx];
395 idx.0 = cost + 1;
396 idx.1 = Some(id);
397 }
398
399 let start = NodeIndex::new(0);
400 let mut edge = graph.graph[start].1;
401 let path_length = graph.graph[start].0;
402 TransitiveExpl::new(
403 (0..path_length)
404 .map(|_| {
405 let kind = &graph.graph[edge.unwrap()];
406 edge = graph.graph[graph.graph.edge_endpoints(edge.unwrap()).unwrap().1].1;
407 kind
408 })
409 .copied(),
410 edges_len,
411 to,
412 )?
413 };
414 let trans = self.insert_trans_equality(trans, stack, mismatch.can_mismatch_congr)?;
415 debug_assert_eq!(self.equalities.walk_to(from, trans), to);
416 self.enodes[from].transitive.try_reserve(1)?;
417 let old = self.enodes[from].transitive.insert(to, trans);
418 debug_assert_eq!(old, None);
419 Ok(trans)
420 }
421
422 fn insert_trans_equality(
423 &mut self,
424 mut trans: TransitiveExpl,
425 stack: &Stack,
426 can_mismatch_congr: bool,
427 ) -> Result<EqTransIdx> {
428 let mismatch = TransEqAllowed {
429 can_mismatch_initial: can_mismatch_congr,
430 can_mismatch_congr: false,
431 };
432 for seg in trans.path.iter_mut() {
434 if let TransitiveExplSegmentKind::Given((cg, idx)) = &mut seg.kind {
435 debug_assert_eq!(*idx, None);
436 let EqualityExpl::Congruence { arg_eqs, .. } = &self.equalities.given[*cg] else {
437 continue;
438 };
439 let mut args = Vec::new();
440 args.try_reserve_exact(arg_eqs.len())?;
441 args.extend(arg_eqs.iter().copied());
442
443 let mut use_ = Vec::new();
444 use_.try_reserve_exact(arg_eqs.len())?;
445 for (from, to) in args {
446 let Ok(trans) = self.new_trans_equality(from, to, stack, mismatch)? else {
447 continue;
448 };
449 use_.push(trans);
450 }
451 let EqualityExpl::Congruence { uses, .. } = &mut self.equalities.given[*cg] else {
452 unreachable!()
453 };
454 let real_idx = uses.iter().position(|u| **u == use_).unwrap_or_else(|| {
455 uses.push(BoxSlice::from(use_));
456 uses.len() - 1
457 });
458 *idx = Some(NonMaxU32::new(real_idx as u32).unwrap());
459 }
460 }
461
462 self.equalities.transitive.raw.try_reserve(1)?;
463 let trans = self.equalities.transitive.push_and_get_key(trans);
464 Ok(trans)
465 }
466}
467
468impl std::ops::Index<ENodeIdx> for EGraph {
469 type Output = ENode;
470 fn index(&self, idx: ENodeIdx) -> &Self::Output {
471 &self.enodes[idx]
472 }
473}
474
475impl ENode {
476 pub fn get_equality(&self, _stack: &Stack) -> Option<&Equality> {
477 self.equalities.last()
480 }
481}
482
483#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
484#[derive(Debug, Default)]
485pub struct Equalities {
486 pub(crate) given: TiVec<EqGivenIdx, EqualityExpl>,
487 pub(crate) transitive: TiVec<EqTransIdx, TransitiveExpl>,
488}
489
490pub trait EqualityWalker<'a> {
491 type Error;
492
493 fn equalities(&self) -> &'a Equalities;
494
495 fn walk_given(
507 &mut self,
508 eq_use: EqGivenUse,
509 forward: bool,
510 ) -> core::result::Result<(), Self::Error>;
511
512 fn walk_congruence(
515 &mut self,
516 uses: &[BoxSlice<EqTransIdx>],
517 use_: NonMaxU32,
518 forward: bool,
519 ) -> core::result::Result<(), Self::Error> {
520 let use_ = &uses[use_.get() as usize];
521 for &eq in use_.iter() {
522 self.walk_trans(eq, forward)?;
523 }
524 Ok(())
525 }
526
527 fn walk_trans(
530 &mut self,
531 eq: EqTransIdx,
532 forward: bool,
533 ) -> core::result::Result<(), Self::Error> {
534 self.super_walk_trans(eq, forward)
535 }
536
537 fn super_walk_trans(
540 &mut self,
541 eq: EqTransIdx,
542 forward: bool,
543 ) -> core::result::Result<(), Self::Error> {
544 let all = self.equalities().transitive[eq].all(forward);
545 for next in all {
546 use TransitiveExplSegmentKind::*;
547 match next.kind {
548 Given(eq_use) => self.walk_given(eq_use, next.forward)?,
549 Transitive(eq) => self.walk_trans(eq, next.forward)?,
550 Error(..) => (),
551 }
552 }
553 Ok(())
554 }
555}
556
557struct TransEqChecker<'a, I: Iterator<Item = (EqGivenUse, bool)>> {
558 equalities: &'a Equalities,
559 simple: I,
560}
561
562impl<'a, I: Iterator<Item = (EqGivenUse, bool)>> EqualityWalker<'a> for TransEqChecker<'a, I> {
563 type Error = bool;
564 fn equalities(&self) -> &'a Equalities {
565 self.equalities
566 }
567 fn walk_given(
568 &mut self,
569 (eq, _use_): EqGivenUse,
570 eq_fwd: bool,
571 ) -> core::result::Result<(), Self::Error> {
572 let ((simple, _simple_use), fwd) = self.simple.next().ok_or(false)?;
575 (simple == eq && fwd == eq_fwd).then_some(()).ok_or(true)
579 }
580}
581
582impl Equalities {
583 pub fn is_equal(
584 &self,
585 eq: EqTransIdx,
586 simple: &mut impl Iterator<Item = (EqGivenUse, bool)>,
587 ) -> Option<bool> {
588 let mut checker = TransEqChecker {
589 equalities: self,
590 simple,
591 };
592 let res = checker.walk_trans(eq, true);
593 res.map_or_else(|e| e.then_some(false), |_| Some(true))
595 }
596}
597
598pub type TransEqSimpleWalker<'a, F> = TransEqStopWalker<'a, Never, F>;
599
600pub struct TransEqStopWalker<'a, E, F: FnMut(&'a EqualityExpl, bool) -> core::result::Result<(), E>>
601{
602 equalities: &'a Equalities,
603 simple: F,
604}
605
606#[derive(Debug)]
607pub enum Never {}
608
609impl<'a, E, F: FnMut(&'a EqualityExpl, bool) -> core::result::Result<(), E>> EqualityWalker<'a>
610 for TransEqStopWalker<'a, E, F>
611{
612 type Error = E;
613 fn equalities(&self) -> &'a Equalities {
614 self.equalities
615 }
616 fn walk_given(
617 &mut self,
618 (eq, _): EqGivenUse,
619 forward: bool,
620 ) -> core::result::Result<(), Self::Error> {
621 (self.simple)(&self.equalities.given[eq], forward)
622 }
623}
624
625impl Equalities {
626 pub fn walk_to(&self, mut from: ENodeIdx, eq: EqTransIdx) -> ENodeIdx {
627 let mut walker = TransEqSimpleWalker {
628 equalities: self,
629 simple: |eq, fwd| {
630 from = eq.walk(from, fwd).unwrap();
631 Ok(())
632 },
633 };
634 walker.walk_trans(eq, true).unwrap();
635 from
636 }
637 pub fn path(&self, eq: EqTransIdx) -> Vec<ENodeIdx> {
638 let equality = &self.transitive[eq];
639 if let Some(from) = equality.error_from() {
640 return vec![from, equality.to];
641 }
642
643 let mut path = Vec::new();
644 let mut walker = TransEqSimpleWalker {
645 equalities: self,
646 simple: |eq, fwd| {
647 let from = if fwd { eq.from() } else { eq.to() };
648 path.push(from);
649 Ok(())
650 },
651 };
652 walker.walk_trans(eq, true).unwrap();
653 path.push(self.transitive[eq].to);
654 path
655 }
656}
657
658#[derive(Debug)]
659pub struct SimplePath {
660 from_to_root: Vec<ENodeIdx>,
661 to_to_root: Vec<ENodeIdx>,
662 shared: usize,
663}
664impl SimplePath {
665 pub fn edges_len(&self) -> NonZeroUsize {
666 NonZeroUsize::new(self.from_to_root.len() + self.to_to_root.len() - 2 * self.shared)
667 .unwrap()
668 }
669 pub fn all_simple_edges<'a>(
670 &'a self,
671 egraph: &'a EGraph,
672 stack: &'a Stack,
673 ) -> impl DoubleEndedIterator<Item = (ENodeIdx, EqGivenIdx, bool)> + 'a {
674 let from_to_join = self.from_to_root[..self.from_to_root.len() - self.shared]
675 .iter()
676 .copied();
677 let from_to_join = from_to_join.map(|e| {
678 let eq = &egraph.enodes[e].get_equality(stack).unwrap();
679 (eq.to, eq.expl, true)
680 });
681 let join_to_to = self.to_to_root[..self.to_to_root.len() - self.shared]
682 .iter()
683 .rev()
684 .copied();
685 let join_to_to =
686 join_to_to.map(|e| (e, egraph.enodes[e].get_equality(stack).unwrap().expl, false));
687 from_to_join.chain(join_to_to)
688 }
689
690 pub fn nodes_len(&self) -> usize {
691 self.edges_len().get() + 1
692 }
693 pub fn all_nodes(&self) -> impl DoubleEndedIterator<Item = ENodeIdx> + '_ {
694 let from_to_join = self.from_to_root[..self.from_to_root.len() + 1 - self.shared].iter();
695 let join_to_to = self.to_to_root[..self.to_to_root.len() - self.shared]
696 .iter()
697 .rev();
698 from_to_join.chain(join_to_to).copied()
699 }
700 pub fn node_at(&self, idx: usize) -> ENodeIdx {
701 let from_len = self.from_to_root.len() - self.shared;
702 if idx <= from_len {
703 self.from_to_root[idx]
704 } else {
705 let to_len = self.to_to_root.len() - self.shared;
706 self.to_to_root[(to_len + from_len) - idx]
707 }
708 }
709 pub fn all_transitive<'a>(
710 &'a self,
711 egraph: &'a EGraph,
712 ) -> impl DoubleEndedIterator<Item = impl Iterator<Item = EqTransIdx> + 'a> + 'a {
713 self.all_nodes()
714 .map(move |idx| egraph.enodes[idx].transitive.values().copied())
715 }
716
717 pub fn initialise_graph<'a>(self, egraph: &'a EGraph, stack: &'a Stack) -> Graph {
718 let edges_len = self.edges_len();
719 let mut g = Graph {
720 graph: DiGraph::with_capacity(self.nodes_len(), edges_len.get()),
721 path_enodes: self.all_nodes().collect(),
722 edges_len,
723 simple_path: self,
724 };
725 let mut last = g.graph.add_node((edges_len.get() as u32, None));
726 for (idx, (_, eq, forward)) in g.simple_path.all_simple_edges(egraph, stack).enumerate() {
727 let cost = (edges_len.get() - idx - 1) as u32;
728 let next = g.graph.add_node((cost, None));
729 let kind = TransitiveExplSegmentKind::Given((eq, None));
730 g.graph
731 .add_edge(last, next, TransitiveExplSegment { forward, kind });
732 last = next;
733 }
734 g
735 }
736}
737
738pub struct Graph {
739 simple_path: SimplePath,
740 path_enodes: FxHashSet<ENodeIdx>,
741 graph: DiGraph<(u32, Option<EdgeIndex>), TransitiveExplSegment>,
742 edges_len: NonZeroUsize,
743}
744impl Graph {
745 pub fn add_trans_from(
746 &mut self,
747 idx: usize,
748 egraph: &mut EGraph,
749 ) -> Option<(bool, EqTransIdx)> {
750 let nfrom = NodeIndex::new(idx);
751 let efrom = self.simple_path.node_at(idx);
752
753 for to in 0..self.simple_path.nodes_len() {
754 let to = self.simple_path.node_at(to);
755 if let Entry::Occupied(o) = egraph.enodes[efrom].transitive.entry(to) {
756 let trans = *o.get();
757 let trans_node = &egraph.equalities.transitive[trans];
758 let Some((nto, forward)) =
759 self.add_trans_single(idx, trans, trans_node, &egraph.equalities)
760 else {
761 o.remove();
762 continue;
763 };
764 let segment = TransitiveExplSegment {
765 forward,
766 kind: TransitiveExplSegmentKind::Transitive(trans),
767 };
768 let (from, to) = if forward { (nfrom, nto) } else { (nto, nfrom) };
769 debug_assert!(trans_node.given_len.is_none() || from.index() < to.index());
770 self.graph.add_edge(from, to, segment);
771 if trans_node
772 .given_len
773 .is_some_and(|len| len == self.edges_len)
774 {
775 return Some((forward, trans));
776 }
777 }
778 }
779 None
780 }
781 fn add_trans_single(
782 &self,
783 idx: usize,
784 trans: EqTransIdx,
785 trans_node: &TransitiveExpl,
786 equalities: &Equalities,
787 ) -> Option<(NodeIndex, bool)> {
788 let given_len = trans_node.given_len?.get();
789 let edges_len = self.edges_len.get();
790 debug_assert!(self.path_enodes.contains(&trans_node.to));
791
792 let left = given_len <= idx && trans_node.to == self.simple_path.node_at(idx - given_len);
793 if left {
794 let prior_simple_edges = (0..idx).map(|idx| self.graph[EdgeIndex::new(idx)]);
795 let mut prior_simple_edges = TransitiveExplSegment::rev(prior_simple_edges)
796 .map(|seg| (seg.kind.given().unwrap(), seg.forward));
797 match equalities.is_equal(trans, &mut prior_simple_edges) {
798 None => {
799 debug_assert!(false);
800 None
801 }
802 Some(false) => None,
803 Some(true) => {
804 let to = NodeIndex::new(idx - given_len);
805 Some((to, false))
806 }
807 }
808 } else if given_len <= edges_len - idx
809 && trans_node.to == self.simple_path.node_at(idx + given_len)
810 {
811 let post_simple_edges = (idx..edges_len).map(|idx| self.graph[EdgeIndex::new(idx)]);
812 let mut post_simple_edges =
813 post_simple_edges.map(|seg| (seg.kind.given().unwrap(), seg.forward));
814 match equalities.is_equal(trans, &mut post_simple_edges) {
815 None => {
816 debug_assert!(false);
817 None
818 }
819 Some(false) => None,
820 Some(true) => {
821 let to = NodeIndex::new(idx + given_len);
822 Some((to, true))
823 }
824 }
825 } else {
826 None
827 }
828 }
829}
830
831#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
835#[derive(Debug)]
836enum TermToEnode {
837 Single(ENodeIdx),
838 Multiple(Vec<ENodeIdx>),
839}
840
841impl Default for TermToEnode {
842 fn default() -> Self {
843 Self::Multiple(Vec::default())
844 }
845}
846
847impl EGraph {
848 pub fn insert_tte(&mut self, term: TermIdx, enode: ENodeIdx, stack: &Stack) -> Result<()> {
849 let remove = |e: &ENodeIdx| !stack.is_alive(self.enodes[*e].frame);
850 self.term_to_enode.try_reserve(1)?;
851 let tte = self.term_to_enode.entry(term).or_default();
852 let mut vec = match tte {
853 TermToEnode::Single(e) => [*e].into_iter().filter(|e| !remove(e)).collect(),
854 TermToEnode::Multiple(es) => {
855 let mut es = core::mem::take(es);
856 let idx = es.iter().position(remove).unwrap_or(es.len());
857 debug_assert!(es[idx..].iter().all(remove));
858 es.drain(idx..);
859 es
860 }
861 };
862 if vec.is_empty() {
863 *tte = TermToEnode::Single(enode);
864 } else {
865 vec.push(enode);
866 *tte = TermToEnode::Multiple(vec);
867 }
868 Ok(())
871 }
872
873 fn get_tte(&mut self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
874 let Entry::Occupied(mut o) = self.term_to_enode.entry(term) else {
875 return None;
876 };
877 let remove = |e: &ENodeIdx| !stack.is_alive(self.enodes[*e].frame);
878 match o.get_mut() {
879 TermToEnode::Single(e) => {
880 if remove(&*e) {
881 o.remove();
882 None
883 } else {
884 Some(*e)
885 }
886 }
887 TermToEnode::Multiple(es) => {
888 let idx = es.iter().position(remove).unwrap_or(es.len());
889 debug_assert!(es[idx..].iter().all(remove));
890 es.drain(idx..);
891 if let Some(last) = es.last().copied() {
892 if es.len() == 1 {
893 *o.get_mut() = TermToEnode::Single(last);
894 }
895 Some(last)
896 } else {
897 o.remove();
898 None
899 }
900 }
901 }
902 }
903
904 fn get_tte_imm(&self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
905 let enode = self.term_to_enode.get(&term)?;
906 let keep = |e: &ENodeIdx| stack.is_alive(self.enodes[*e].frame);
907 match enode {
908 TermToEnode::Single(e) if keep(e) => Some(*e),
909 TermToEnode::Single(_) => None,
910 TermToEnode::Multiple(es) => es.iter().copied().find(keep),
911 }
912 }
913}