1use std::{collections::hash_map::Entry, num::NonZeroUsize};
2
3#[cfg(feature = "mem_dbg")]
4use mem_dbg::{MemDbg, MemSize};
5use petgraph::{
6 graph::{DiGraph, EdgeIndex, NodeIndex},
7 visit::EdgeRef,
8};
9
10use crate::{
11 items::{
12 ENode, ENodeBlame, ENodeIdx, EqGivenIdx, EqGivenUse, EqTransIdx, Equality, EqualityExpl,
13 InstIdx, ProofIdx, TermIdx, TransitiveExpl, TransitiveExplSegment,
14 TransitiveExplSegmentKind,
15 },
16 BoxSlice, Error, FxHashMap, FxHashSet, NonMaxU32, Result, TiVec,
17};
18
19use super::{bugs::TransEqAllowed, stack::Stack, terms::Terms};
20
21#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
22#[derive(Debug, Default)]
23pub struct EGraph {
24 term_to_enode: FxHashMap<TermIdx, TermToEnode>,
25 pub(crate) enodes: TiVec<ENodeIdx, ENode>,
26 proofs: FxHashMap<TermIdx, ProofIdx>,
29 pub equalities: Equalities,
30}
31
32impl Equalities {
33 pub fn from_to(&self, eq: EqTransIdx) -> (ENodeIdx, ENodeIdx) {
34 let from = self.walk_trans(eq, |eq, fwd| {
35 let from = if fwd { eq.from() } else { eq.to() };
36 Err(from)
37 });
38 let equality = &self.transitive[eq];
39 let from = from.err().unwrap_or_else(|| equality.error_from().unwrap());
40 let to = equality.to;
41 (from, to)
42 }
43
44 pub fn walk_trans<E>(
47 &self,
48 eq: EqTransIdx,
49 f: impl FnMut(&EqualityExpl, bool) -> core::result::Result<(), E>,
50 ) -> core::result::Result<(), E> {
51 let mut walker = TransEqStopWalker {
52 equalities: self,
53 simple: f,
54 };
55 walker.walk_trans(eq, true)
56 }
57}
58
59impl EGraph {
60 pub fn get_blame(
61 &self,
62 tidx: TermIdx,
63 inst: Option<(InstIdx, FxHashSet<EqTransIdx>)>,
64 terms: &Terms,
65 stack: &Stack,
66 ) -> ENodeBlame {
67 if terms.is_bool_const(tidx) {
68 return ENodeBlame::BoolConst;
69 }
70 if let Some(inst) = inst {
71 return ENodeBlame::Inst(inst);
72 }
73 let proof = self.proofs.get(&tidx).copied();
74 let proof = proof.filter(|&p| stack.is_alive(terms[p].frame));
75 if let Some(proof) = proof {
76 return ENodeBlame::Proof(proof);
77 }
78 ENodeBlame::Unknown
79 }
80
81 pub fn new_enode(
82 &mut self,
83 blame: ENodeBlame,
84 term: TermIdx,
85 z3_generation: NonMaxU32,
86 stack: &Stack,
87 ) -> Result<ENodeIdx> {
88 self.enodes.raw.try_reserve(1)?;
97 let enode = self.enodes.push_and_get_key(ENode {
98 frame: stack.active_frame(),
99 blame,
100 owner: term,
101 z3_generation,
102 equalities: Vec::new(),
103 transitive: FxHashMap::default(),
104 });
105 self.insert_tte(term, enode, stack)?;
106 Ok(enode)
107 }
108
109 pub fn get_enode(&mut self, term: TermIdx, stack: &Stack) -> Result<ENodeIdx> {
110 self.get_tte(term, stack).ok_or(Error::UnknownEnode(term))
111 }
112
113 pub fn get_enode_imm(&self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
114 self.get_tte_imm(term, stack)
115 }
116
117 pub(super) fn new_proof(
118 &mut self,
119 term: TermIdx,
120 proof: ProofIdx,
121 terms: &Terms,
122 stack: &Stack,
123 ) -> Result<()> {
124 if !terms[proof].kind.is_asserted() {
125 return Ok(());
126 }
127 terms.app_walk(term, |tidx, app| {
128 self.proofs.try_reserve(1)?;
129 match self.proofs.entry(tidx) {
130 Entry::Occupied(mut o) => {
131 if stack.is_alive(terms[*o.get()].frame) {
132 return Ok(&[]);
133 } else {
134 o.insert(proof);
135 }
136 }
137 Entry::Vacant(v) => {
138 v.insert(proof);
139 }
140 }
141 Ok(&app.child_ids)
142 })
143 }
144
145 pub fn get_given(&self, from: ENodeIdx, to: ENodeIdx) -> Option<EqGivenIdx> {
146 self.enodes[from]
147 .equalities
148 .iter()
149 .find(|eq| eq.to == to)
150 .map(|eq| eq.expl)
151 }
152
153 pub fn new_given_equality(
154 &mut self,
155 from: ENodeIdx,
156 expl: EqualityExpl,
157 stack: &Stack,
158 ) -> Result<()> {
159 debug_assert_eq!(from, expl.from());
160 let to = expl.to();
161 self.equalities.given.raw.try_reserve(1)?;
162 let expl = self.equalities.given.push_and_get_key(expl);
163 let enode = &mut self.enodes[from];
164 let eq = Equality {
165 _frame: stack.active_frame(),
166 to,
167 expl,
168 };
169 enode.equalities.try_reserve(1)?;
170 enode.equalities.push(eq);
171 Ok(())
192 }
193
194 pub fn new_trans_equality(
195 &mut self,
196 from: ENodeIdx,
197 to: ENodeIdx,
198 stack: &Stack,
199 mismatch: TransEqAllowed,
200 ) -> Result<core::result::Result<EqTransIdx, ENodeIdx>> {
201 if from == to {
202 Ok(Err(from))
203 } else {
204 self.construct_trans_equality(from, to, stack, mismatch)
205 .map(Ok)
206 }
207 }
208
209 fn to_root<'a>(
210 &'a self,
211 from: ENodeIdx,
212 stack: &'a Stack,
213 ) -> impl Iterator<Item = ENodeIdx> + 'a {
214 core::iter::successors(Some(from), move |&from| {
215 self.enodes[from].get_equality(stack).map(|eq| eq.to)
216 })
217 }
218
219 fn to_root_visited<'a>(
220 &'a self,
221 from: ENodeIdx,
222 stack: &'a Stack,
223 cycle: &'a mut Option<ENodeIdx>,
224 visited: &'a mut FxHashSet<ENodeIdx>,
225 ) -> impl Iterator<Item = ENodeIdx> + 'a {
226 self.to_root(from, stack).take_while(move |&from| {
227 if visited.insert(from) {
228 true
229 } else {
230 *cycle = Some(from);
233 false
234 }
235 })
236 }
237
238 pub fn check_eq(&self, from: ENodeIdx, to: ENodeIdx, stack: &Stack) -> bool {
239 let mut visited_from = FxHashSet::default();
240 self.to_root_visited(from, stack, &mut None, &mut visited_from)
241 .for_each(drop);
242 for to in self.to_root_visited(to, stack, &mut None, &mut FxHashSet::default()) {
243 if visited_from.contains(&to) {
244 return true;
245 }
246 }
247 false
248 }
249
250 pub fn path_to_root(
251 &self,
252 from: ENodeIdx,
253 root: Option<ENodeIdx>,
254 stack: &Stack,
255 ) -> Result<(Option<ENodeIdx>, Vec<ENodeIdx>)> {
256 let mut cycle = None;
257 let mut visited = FxHashSet::default();
258 let mut path = Vec::new();
259 for e in self.to_root_visited(from, stack, &mut cycle, &mut visited) {
260 path.try_reserve(1)?;
261 path.push(e);
262 if root.is_some_and(|root| root == e) {
263 return Ok((Some(e), path));
264 }
265 }
266 Ok((cycle, path))
267 }
268
269 fn get_simple_path(
270 &self,
271 from: ENodeIdx,
272 to: ENodeIdx,
273 stack: &Stack,
274 can_mismatch: bool,
275 ) -> Result<Option<SimplePath>> {
276 let (_, f_path) = self.path_to_root(from, None, stack)?;
277 let f_root = f_path.len() - 1;
278 let (_, t_path) = self.path_to_root(to, Some(*f_path.last().unwrap()), stack)?;
279 let t_root = t_path.len() - 1;
280
281 if f_path[f_root] != t_path[t_root] {
282 if can_mismatch {
285 return Ok(None);
287 } else {
288 return Err(Error::EnodeRootMismatch(from, to));
289 }
290 }
291 let mut shared = 1;
292 while shared < f_path.len()
293 && shared < t_path.len()
294 && f_path[f_root - shared] == t_path[t_root - shared]
295 {
296 shared += 1;
297 }
298 let path = SimplePath {
299 from_to_root: f_path,
300 to_to_root: t_path,
301 shared,
302 };
303 Ok(Some(path))
304 }
305
306 fn construct_trans_equality(
307 &mut self,
308 from: ENodeIdx,
309 to: ENodeIdx,
310 stack: &Stack,
311 mismatch: TransEqAllowed,
312 ) -> Result<EqTransIdx> {
313 debug_assert_ne!(from, to);
314 let can_mismatch = mismatch.can_mismatch_initial;
315 let Some(simple_path) = self.get_simple_path(from, to, stack, can_mismatch)? else {
316 let error = TransitiveExpl::error(from, to);
319 self.enodes[from].transitive.try_reserve(1)?;
320 let trans = match self.enodes[from].transitive.entry(to) {
321 Entry::Occupied(mut o) => {
322 let trans = *o.get();
323 if self.equalities.transitive[trans].given_len.is_some() {
324 self.equalities.transitive.raw.try_reserve(1)?;
327 let trans = self.equalities.transitive.push_and_get_key(error);
328 o.insert(trans);
329 trans
330 } else {
331 trans
332 }
333 }
334 Entry::Vacant(v) => {
335 self.equalities.transitive.raw.try_reserve(1)?;
336 let trans = self.equalities.transitive.push_and_get_key(error);
337 *v.insert(trans)
338 }
339 };
340 return Ok(trans);
341 };
342 let edges_len = simple_path.edges_len();
344 let mut graph = simple_path.initialise_graph(self, stack);
345
346 if let Some((forward, solution)) = graph.add_trans_from(0, self) {
349 debug_assert!(forward);
350 return Ok(solution);
351 }
352 let trans = if let Some((forward, solution)) = graph.add_trans_from(edges_len.get(), self) {
353 debug_assert!(!forward);
354 let inner = &self.equalities.transitive[solution];
355 if inner.path.len() == 1 {
356 use TransitiveExplSegmentKind::*;
357 match inner.path[0] {
358 TransitiveExplSegment {
359 forward: false,
360 kind: Transitive(idx),
361 } => return Ok(idx),
362 TransitiveExplSegment {
363 forward: true,
364 kind: Transitive(_),
365 } => unreachable!(),
366 TransitiveExplSegment {
367 kind: Given(..), ..
368 } => (),
369 TransitiveExplSegment {
370 kind: Error(..), ..
371 } => unreachable!(),
372 }
373 }
374 let solution = TransitiveExplSegment {
375 forward,
376 kind: TransitiveExplSegmentKind::Transitive(solution),
377 };
378 TransitiveExpl::new([solution].into_iter(), NonZeroUsize::new(1).unwrap(), to)?
379 } else {
380 for idx in 1..edges_len.get() {
381 graph.add_trans_from(idx, self);
382 }
383 for idx in (0..edges_len.get()).rev() {
385 let idx = NodeIndex::new(idx);
386 let min = graph
389 .graph
390 .edges(idx)
391 .min_by_key(|edge| graph.graph[edge.target()].0)
392 .unwrap();
393 let (cost, id) = (graph.graph[min.target()].0, min.id());
394 let idx = &mut graph.graph[idx];
395 idx.0 = cost + 1;
396 idx.1 = Some(id);
397 }
398
399 let start = NodeIndex::new(0);
400 let mut edge = graph.graph[start].1;
401 let path_length = graph.graph[start].0;
402 TransitiveExpl::new(
403 (0..path_length)
404 .map(|_| {
405 let kind = &graph.graph[edge.unwrap()];
406 edge = graph.graph[graph.graph.edge_endpoints(edge.unwrap()).unwrap().1].1;
407 kind
408 })
409 .copied(),
410 edges_len,
411 to,
412 )?
413 };
414 let trans = self.insert_trans_equality(trans, stack, mismatch.can_mismatch_congr)?;
415 debug_assert_eq!(self.equalities.walk_to(from, trans), to);
416 self.enodes[from].transitive.try_reserve(1)?;
417 let old = self.enodes[from].transitive.insert(to, trans);
418 debug_assert_eq!(old, None);
419 Ok(trans)
420 }
421
422 fn insert_trans_equality(
423 &mut self,
424 mut trans: TransitiveExpl,
425 stack: &Stack,
426 can_mismatch_congr: bool,
427 ) -> Result<EqTransIdx> {
428 let mismatch = TransEqAllowed {
429 can_mismatch_initial: can_mismatch_congr,
430 can_mismatch_congr: false,
431 };
432 for seg in trans.path.iter_mut() {
434 if let TransitiveExplSegmentKind::Given((cg, idx)) = &mut seg.kind {
435 debug_assert_eq!(*idx, None);
436 let EqualityExpl::Congruence { arg_eqs, .. } = &self.equalities.given[*cg] else {
437 continue;
438 };
439 let mut args = Vec::new();
440 args.try_reserve_exact(arg_eqs.len())?;
441 args.extend(arg_eqs.iter().copied());
442
443 let mut use_ = Vec::new();
444 use_.try_reserve_exact(arg_eqs.len())?;
445 for (from, to) in args {
446 let Ok(trans) = self.new_trans_equality(from, to, stack, mismatch)? else {
447 continue;
448 };
449 use_.push(trans);
450 }
451 let EqualityExpl::Congruence { uses, .. } = &mut self.equalities.given[*cg] else {
452 unreachable!()
453 };
454 let real_idx = uses.iter().position(|u| **u == use_).unwrap_or_else(|| {
455 uses.push(BoxSlice::from(use_));
456 uses.len() - 1
457 });
458 *idx = Some(NonMaxU32::new(real_idx as u32).unwrap());
459 }
460 }
461
462 self.equalities.transitive.raw.try_reserve(1)?;
463 let trans = self.equalities.transitive.push_and_get_key(trans);
464 Ok(trans)
465 }
466}
467
468impl std::ops::Index<ENodeIdx> for EGraph {
469 type Output = ENode;
470 fn index(&self, idx: ENodeIdx) -> &Self::Output {
471 &self.enodes[idx]
472 }
473}
474
475impl ENode {
476 pub fn get_equality(&self, _stack: &Stack) -> Option<&Equality> {
477 self.equalities.last()
480 }
481}
482
483#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
484#[derive(Debug, Default)]
485pub struct Equalities {
486 pub(crate) given: TiVec<EqGivenIdx, EqualityExpl>,
487 pub(crate) transitive: TiVec<EqTransIdx, TransitiveExpl>,
488}
489
490pub trait EqualityWalker<'a> {
491 type Error;
492
493 fn equalities(&self) -> &'a Equalities;
494
495 fn walk_given(
507 &mut self,
508 eq_use: EqGivenUse,
509 forward: bool,
510 ) -> core::result::Result<(), Self::Error>;
511
512 fn walk_congruence(
515 &mut self,
516 uses: &[BoxSlice<EqTransIdx>],
517 use_: NonMaxU32,
518 forward: bool,
519 ) -> core::result::Result<(), Self::Error> {
520 let use_ = &uses[use_.get() as usize];
521 for &eq in use_.iter() {
522 self.walk_trans(eq, forward)?;
523 }
524 Ok(())
525 }
526
527 fn walk_trans(
530 &mut self,
531 eq: EqTransIdx,
532 forward: bool,
533 ) -> core::result::Result<(), Self::Error> {
534 self.super_walk_trans(eq, forward)
535 }
536
537 fn super_walk_trans(
540 &mut self,
541 eq: EqTransIdx,
542 forward: bool,
543 ) -> core::result::Result<(), Self::Error> {
544 let all = self.equalities().transitive[eq].all(forward);
545 for next in all {
546 use TransitiveExplSegmentKind::*;
547 match next.kind {
548 Given(eq_use) => self.walk_given(eq_use, next.forward)?,
549 Transitive(eq) => self.walk_trans(eq, next.forward)?,
550 Error(..) => (),
551 }
552 }
553 Ok(())
554 }
555}
556
557struct TransEqChecker<'a, I: Iterator<Item = (EqGivenUse, bool)>> {
558 equalities: &'a Equalities,
559 simple: I,
560}
561
562impl<'a, I: Iterator<Item = (EqGivenUse, bool)>> EqualityWalker<'a> for TransEqChecker<'a, I> {
563 type Error = bool;
564 fn equalities(&self) -> &'a Equalities {
565 self.equalities
566 }
567 fn walk_given(
568 &mut self,
569 (eq, _use_): EqGivenUse,
570 eq_fwd: bool,
571 ) -> core::result::Result<(), Self::Error> {
572 let ((simple, _simple_use), fwd) = self.simple.next().ok_or(false)?;
575 (simple == eq && fwd == eq_fwd).then_some(()).ok_or(true)
579 }
580}
581
582impl Equalities {
583 pub fn is_equal(
584 &self,
585 eq: EqTransIdx,
586 simple: &mut impl Iterator<Item = (EqGivenUse, bool)>,
587 ) -> Option<bool> {
588 let mut checker = TransEqChecker {
589 equalities: self,
590 simple,
591 };
592 let res = checker.walk_trans(eq, true);
593 res.map_or_else(|e| e.then_some(false), |_| Some(true))
595 }
596}
597
598pub type TransEqSimpleWalker<'a, F> = TransEqStopWalker<'a, super::Never, F>;
599
600pub struct TransEqStopWalker<'a, E, F: FnMut(&'a EqualityExpl, bool) -> core::result::Result<(), E>>
601{
602 equalities: &'a Equalities,
603 simple: F,
604}
605
606impl<'a, E, F: FnMut(&'a EqualityExpl, bool) -> core::result::Result<(), E>> EqualityWalker<'a>
607 for TransEqStopWalker<'a, E, F>
608{
609 type Error = E;
610 fn equalities(&self) -> &'a Equalities {
611 self.equalities
612 }
613 fn walk_given(
614 &mut self,
615 (eq, _): EqGivenUse,
616 forward: bool,
617 ) -> core::result::Result<(), Self::Error> {
618 (self.simple)(&self.equalities.given[eq], forward)
619 }
620}
621
622impl Equalities {
623 pub fn walk_to(&self, mut from: ENodeIdx, eq: EqTransIdx) -> ENodeIdx {
624 let mut walker = TransEqSimpleWalker {
625 equalities: self,
626 simple: |eq, fwd| {
627 from = eq.walk(from, fwd).unwrap();
628 Ok(())
629 },
630 };
631 walker.walk_trans(eq, true).unwrap();
632 from
633 }
634 pub fn path(&self, eq: EqTransIdx) -> Vec<ENodeIdx> {
635 let equality = &self.transitive[eq];
636 if let Some(from) = equality.error_from() {
637 return vec![from, equality.to];
638 }
639
640 let mut path = Vec::new();
641 let mut walker = TransEqSimpleWalker {
642 equalities: self,
643 simple: |eq, fwd| {
644 let from = if fwd { eq.from() } else { eq.to() };
645 path.push(from);
646 Ok(())
647 },
648 };
649 walker.walk_trans(eq, true).unwrap();
650 path.push(self.transitive[eq].to);
651 path
652 }
653}
654
655#[derive(Debug)]
656pub struct SimplePath {
657 from_to_root: Vec<ENodeIdx>,
658 to_to_root: Vec<ENodeIdx>,
659 shared: usize,
660}
661impl SimplePath {
662 pub fn edges_len(&self) -> NonZeroUsize {
663 NonZeroUsize::new(self.from_to_root.len() + self.to_to_root.len() - 2 * self.shared)
664 .unwrap()
665 }
666 pub fn all_simple_edges<'a>(
667 &'a self,
668 egraph: &'a EGraph,
669 stack: &'a Stack,
670 ) -> impl DoubleEndedIterator<Item = (ENodeIdx, EqGivenIdx, bool)> + 'a {
671 let from_to_join = self.from_to_root[..self.from_to_root.len() - self.shared]
672 .iter()
673 .copied();
674 let from_to_join = from_to_join.map(|e| {
675 let eq = &egraph.enodes[e].get_equality(stack).unwrap();
676 (eq.to, eq.expl, true)
677 });
678 let join_to_to = self.to_to_root[..self.to_to_root.len() - self.shared]
679 .iter()
680 .rev()
681 .copied();
682 let join_to_to =
683 join_to_to.map(|e| (e, egraph.enodes[e].get_equality(stack).unwrap().expl, false));
684 from_to_join.chain(join_to_to)
685 }
686
687 pub fn nodes_len(&self) -> usize {
688 self.edges_len().get() + 1
689 }
690 pub fn all_nodes(&self) -> impl DoubleEndedIterator<Item = ENodeIdx> + '_ {
691 let from_to_join = self.from_to_root[..self.from_to_root.len() + 1 - self.shared].iter();
692 let join_to_to = self.to_to_root[..self.to_to_root.len() - self.shared]
693 .iter()
694 .rev();
695 from_to_join.chain(join_to_to).copied()
696 }
697 pub fn node_at(&self, idx: usize) -> ENodeIdx {
698 let from_len = self.from_to_root.len() - self.shared;
699 if idx <= from_len {
700 self.from_to_root[idx]
701 } else {
702 let to_len = self.to_to_root.len() - self.shared;
703 self.to_to_root[(to_len + from_len) - idx]
704 }
705 }
706 pub fn all_transitive<'a>(
707 &'a self,
708 egraph: &'a EGraph,
709 ) -> impl DoubleEndedIterator<Item = impl Iterator<Item = EqTransIdx> + 'a> + 'a {
710 self.all_nodes()
711 .map(move |idx| egraph.enodes[idx].transitive.values().copied())
712 }
713
714 pub fn initialise_graph<'a>(self, egraph: &'a EGraph, stack: &'a Stack) -> Graph {
715 let edges_len = self.edges_len();
716 let mut g = Graph {
717 graph: DiGraph::with_capacity(self.nodes_len(), edges_len.get()),
718 path_enodes: self.all_nodes().collect(),
719 edges_len,
720 simple_path: self,
721 };
722 let mut last = g.graph.add_node((edges_len.get() as u32, None));
723 for (idx, (_, eq, forward)) in g.simple_path.all_simple_edges(egraph, stack).enumerate() {
724 let cost = (edges_len.get() - idx - 1) as u32;
725 let next = g.graph.add_node((cost, None));
726 let kind = TransitiveExplSegmentKind::Given((eq, None));
727 g.graph
728 .add_edge(last, next, TransitiveExplSegment { forward, kind });
729 last = next;
730 }
731 g
732 }
733}
734
735pub struct Graph {
736 simple_path: SimplePath,
737 path_enodes: FxHashSet<ENodeIdx>,
738 graph: DiGraph<(u32, Option<EdgeIndex>), TransitiveExplSegment>,
739 edges_len: NonZeroUsize,
740}
741impl Graph {
742 pub fn add_trans_from(
743 &mut self,
744 idx: usize,
745 egraph: &mut EGraph,
746 ) -> Option<(bool, EqTransIdx)> {
747 let nfrom = NodeIndex::new(idx);
748 let efrom = self.simple_path.node_at(idx);
749
750 for to in 0..self.simple_path.nodes_len() {
751 let to = self.simple_path.node_at(to);
752 if let Entry::Occupied(o) = egraph.enodes[efrom].transitive.entry(to) {
753 let trans = *o.get();
754 let trans_node = &egraph.equalities.transitive[trans];
755 let Some((nto, forward)) =
756 self.add_trans_single(idx, trans, trans_node, &egraph.equalities)
757 else {
758 o.remove();
759 continue;
760 };
761 let segment = TransitiveExplSegment {
762 forward,
763 kind: TransitiveExplSegmentKind::Transitive(trans),
764 };
765 let (from, to) = if forward { (nfrom, nto) } else { (nto, nfrom) };
766 debug_assert!(trans_node.given_len.is_none() || from.index() < to.index());
767 self.graph.add_edge(from, to, segment);
768 if trans_node
769 .given_len
770 .is_some_and(|len| len == self.edges_len)
771 {
772 return Some((forward, trans));
773 }
774 }
775 }
776 None
777 }
778 fn add_trans_single(
779 &self,
780 idx: usize,
781 trans: EqTransIdx,
782 trans_node: &TransitiveExpl,
783 equalities: &Equalities,
784 ) -> Option<(NodeIndex, bool)> {
785 let given_len = trans_node.given_len?.get();
786 let edges_len = self.edges_len.get();
787 debug_assert!(self.path_enodes.contains(&trans_node.to));
788
789 let left = given_len <= idx && trans_node.to == self.simple_path.node_at(idx - given_len);
790 if left {
791 let prior_simple_edges = (0..idx).map(|idx| self.graph[EdgeIndex::new(idx)]);
792 let mut prior_simple_edges = TransitiveExplSegment::rev(prior_simple_edges)
793 .map(|seg| (seg.kind.given().unwrap(), seg.forward));
794 match equalities.is_equal(trans, &mut prior_simple_edges) {
795 None => {
796 debug_assert!(false);
797 None
798 }
799 Some(false) => None,
800 Some(true) => {
801 let to = NodeIndex::new(idx - given_len);
802 Some((to, false))
803 }
804 }
805 } else if given_len <= edges_len - idx
806 && trans_node.to == self.simple_path.node_at(idx + given_len)
807 {
808 let post_simple_edges = (idx..edges_len).map(|idx| self.graph[EdgeIndex::new(idx)]);
809 let mut post_simple_edges =
810 post_simple_edges.map(|seg| (seg.kind.given().unwrap(), seg.forward));
811 match equalities.is_equal(trans, &mut post_simple_edges) {
812 None => {
813 debug_assert!(false);
814 None
815 }
816 Some(false) => None,
817 Some(true) => {
818 let to = NodeIndex::new(idx + given_len);
819 Some((to, true))
820 }
821 }
822 } else {
823 None
824 }
825 }
826}
827
828#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
832#[derive(Debug)]
833enum TermToEnode {
834 Single(ENodeIdx),
835 Multiple(Vec<ENodeIdx>),
836}
837
838impl Default for TermToEnode {
839 fn default() -> Self {
840 Self::Multiple(Vec::default())
841 }
842}
843
844impl EGraph {
845 pub fn insert_tte(&mut self, term: TermIdx, enode: ENodeIdx, stack: &Stack) -> Result<()> {
846 let remove = |e: &ENodeIdx| !stack.is_alive(self.enodes[*e].frame);
847 self.term_to_enode.try_reserve(1)?;
848 let tte = self.term_to_enode.entry(term).or_default();
849 let mut vec = match tte {
850 TermToEnode::Single(e) => [*e].into_iter().filter(|e| !remove(e)).collect(),
851 TermToEnode::Multiple(es) => {
852 let mut es = core::mem::take(es);
853 let idx = es.iter().position(remove).unwrap_or(es.len());
854 debug_assert!(es[idx..].iter().all(remove));
855 es.drain(idx..);
856 es
857 }
858 };
859 if vec.is_empty() {
860 *tte = TermToEnode::Single(enode);
861 } else {
862 vec.push(enode);
863 *tte = TermToEnode::Multiple(vec);
864 }
865 Ok(())
868 }
869
870 fn get_tte(&mut self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
871 let Entry::Occupied(mut o) = self.term_to_enode.entry(term) else {
872 return None;
873 };
874 let remove = |e: &ENodeIdx| !stack.is_alive(self.enodes[*e].frame);
875 match o.get_mut() {
876 TermToEnode::Single(e) => {
877 if remove(&*e) {
878 o.remove();
879 None
880 } else {
881 Some(*e)
882 }
883 }
884 TermToEnode::Multiple(es) => {
885 let idx = es.iter().position(remove).unwrap_or(es.len());
886 debug_assert!(es[idx..].iter().all(remove));
887 es.drain(idx..);
888 if let Some(last) = es.last().copied() {
889 if es.len() == 1 {
890 *o.get_mut() = TermToEnode::Single(last);
891 }
892 Some(last)
893 } else {
894 o.remove();
895 None
896 }
897 }
898 }
899 }
900
901 fn get_tte_imm(&self, term: TermIdx, stack: &Stack) -> Option<ENodeIdx> {
902 let enode = self.term_to_enode.get(&term)?;
903 let keep = |e: &ENodeIdx| stack.is_alive(self.enodes[*e].frame);
904 match enode {
905 TermToEnode::Single(e) if keep(e) => Some(*e),
906 TermToEnode::Single(_) => None,
907 TermToEnode::Multiple(es) => es.iter().copied().find(keep),
908 }
909 }
910}