smt_scope/analysis/graph/
raw.rs

1use std::{
2    fmt,
3    ops::{Index, IndexMut},
4};
5
6use bitmask_enum::bitmask;
7#[cfg(feature = "mem_dbg")]
8use mem_dbg::{MemDbg, MemSize};
9use petgraph::{
10    graph::NodeIndex,
11    visit::{Dfs, NodeFiltered, Reversed, Walker},
12    Direction::{self},
13};
14
15use crate::{
16    graph_idx,
17    items::{
18        CdclIdx, ENodeBlame, ENodeIdx, EqGivenIdx, EqTransIdx, EqualityExpl, GraphIdx, InstIdx,
19        ProofIdx, StackIdx, TransitiveExplSegmentKind,
20    },
21    DiGraph, FxHashMap, FxHashSet, NonMaxU32, Result, TiVec, Z3Parser,
22};
23
24use super::analysis::reconnect::{ReachKind, ReachNonDisabled};
25
26graph_idx!(raw_idx, RawNodeIndex, RawEdgeIndex, RawIx);
27
28#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
29#[derive(Debug)]
30pub struct RawInstGraph {
31    pub graph: DiGraph<Node, EdgeKind, RawIx>,
32    inst_idx: RawNodeIndex,
33    enode_idx: RawNodeIndex,
34    eq_trans_idx: RawNodeIndex,
35    eq_given_idx: FxHashMap<(EqGivenIdx, Option<NonMaxU32>), RawNodeIndex>,
36    proofs_idx: RawNodeIndex,
37    cdcl_idx: RawNodeIndex,
38
39    pub(crate) stats: GraphStats,
40}
41
42struct GraphTryReserve;
43impl GraphTryReserve {
44    fn try_reserve_exact(nodes: usize, edges: usize) -> Result<()> {
45        type Nodes = Vec<petgraph::graph::Node<Node>>;
46        type Edges = Vec<petgraph::graph::Edge<EdgeKind>>;
47        let mut n = Nodes::new();
48        n.try_reserve_exact(nodes)?;
49        let mut e = Edges::new();
50        e.try_reserve_exact(edges)?;
51        Ok(())
52    }
53}
54
55impl RawInstGraph {
56    pub fn new(parser: &Z3Parser) -> Result<Self> {
57        let total_nodes = parser.insts.insts.len()
58            + parser.egraph.enodes.len()
59            + parser.egraph.equalities.given.len()
60            + parser.egraph.equalities.transitive.len()
61            + parser.proofs().len()
62            + parser.cdcls().len();
63        let edges_estimate = parser.insts.insts.len()
64            + parser.egraph.equalities.transitive.len()
65            + parser.proofs().len();
66        GraphTryReserve::try_reserve_exact(total_nodes, edges_estimate)?;
67
68        let mut graph = DiGraph::with_capacity(total_nodes, edges_estimate);
69        let inst_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
70        for inst in parser.insts.insts.keys() {
71            graph.add_node(Node::new(NodeKind::Instantiation(inst)));
72        }
73        let enode_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
74        for enode in parser.egraph.enodes.keys() {
75            graph.add_node(Node::new(NodeKind::ENode(enode)));
76        }
77        let eq_trans_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
78        for eq_trans in parser.egraph.equalities.transitive.keys() {
79            graph.add_node(Node::new(NodeKind::TransEquality(eq_trans)));
80        }
81        let mut eq_given_idx = FxHashMap::default();
82        eq_given_idx.try_reserve(parser.egraph.equalities.given.len())?;
83        for (eq_given, eq) in parser.egraph.equalities.given.iter_enumerated() {
84            match eq {
85                EqualityExpl::Congruence { uses, .. } => {
86                    for i in 0..uses.len() {
87                        let use_ = Some(NonMaxU32::new(i as u32).unwrap());
88                        let node =
89                            graph.add_node(Node::new(NodeKind::GivenEquality(eq_given, use_)));
90                        eq_given_idx.insert((eq_given, use_), RawNodeIndex(node));
91                    }
92                }
93                _ => {
94                    let node = graph.add_node(Node::new(NodeKind::GivenEquality(eq_given, None)));
95                    eq_given_idx.insert((eq_given, None), RawNodeIndex(node));
96                }
97            }
98        }
99        let proofs_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
100        for ps_idx in parser.proofs().keys() {
101            graph.add_node(Node::new(NodeKind::Proof(ps_idx)));
102        }
103        let cdcl_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
104        for cdcl in parser.cdcls().keys() {
105            graph.add_node(Node::new(NodeKind::Cdcl(cdcl)));
106        }
107
108        let stats = GraphStats {
109            hidden: graph.node_count() as u32,
110            disabled: 0,
111            generation: 0,
112        };
113        let mut self_ = RawInstGraph {
114            graph,
115            inst_idx,
116            enode_idx,
117            eq_given_idx,
118            eq_trans_idx,
119            proofs_idx,
120            cdcl_idx,
121            stats,
122        };
123
124        // Add instantiation blamed edges
125        for (idx, inst) in parser.insts.insts.iter_enumerated() {
126            for (i, blame) in parser.insts.matches[inst.match_]
127                .pattern_matches()
128                .enumerate()
129            {
130                let pattern_term = i as u16;
131                self_.add_edge(blame.enode, idx, EdgeKind::Blame { pattern_term });
132                for (i, eq) in blame.equalities.iter().enumerate() {
133                    self_.add_edge(
134                        *eq,
135                        idx,
136                        EdgeKind::BlameEq {
137                            pattern_term,
138                            eq_order: i as u16,
139                        },
140                    );
141                }
142            }
143        }
144
145        // Add enode blamed edges
146        for (idx, enode) in parser.egraph.enodes.iter_enumerated() {
147            match &enode.blame {
148                ENodeBlame::Inst((iidx, eqs)) => {
149                    self_.add_edge(*iidx, idx, EdgeKind::Yield);
150                    for &eq in eqs.iter() {
151                        self_.add_edge(eq, idx, EdgeKind::YieldEq);
152                    }
153                }
154                ENodeBlame::Proof(pidx) => self_.add_edge(*pidx, idx, EdgeKind::Asserted),
155                ENodeBlame::BoolConst | ENodeBlame::Unknown => (),
156            }
157        }
158
159        // Add given equality created edges
160        for (idx, eq) in parser.egraph.equalities.given.iter_enumerated() {
161            match eq {
162                EqualityExpl::Root { .. } => (),
163                EqualityExpl::Literal { eq, .. } => {
164                    let literal = &parser[*eq];
165                    if let Some(enode) = literal.enode {
166                        self_.add_edge(enode, (idx, None), EdgeKind::EqualityFact)
167                    }
168                    if let Some(iidx) = literal.iblame {
169                        let enode_blame =
170                            literal.enode.and_then(|enode| parser[enode].blame.inst());
171                        if enode_blame.is_none_or(|enode_blame| enode_blame != iidx) {
172                            self_.add_edge(iidx, (idx, None), EdgeKind::EqualityProof)
173                        }
174                    }
175                }
176                EqualityExpl::Congruence { uses, .. } => {
177                    for (use_, arg_eqs) in uses.iter().enumerate() {
178                        let use_ = Some(NonMaxU32::new(use_ as u32).unwrap());
179                        for arg_eq in arg_eqs.iter() {
180                            self_.add_edge(*arg_eq, (idx, use_), EdgeKind::EqualityCongruence);
181                        }
182                    }
183                }
184                EqualityExpl::Theory { .. } => (),
185                EqualityExpl::Axiom { .. } => (),
186                EqualityExpl::Unknown { .. } => (),
187            }
188        }
189
190        // Add transitive equality created edges
191        for (idx, eq) in parser.egraph.equalities.transitive.iter_enumerated() {
192            let all = eq.all(true);
193            for parent in all {
194                use TransitiveExplSegmentKind::*;
195                match parent.kind {
196                    Given((eq, use_)) => self_.add_edge(
197                        (eq, use_),
198                        idx,
199                        EdgeKind::TEqualitySimple {
200                            forward: parent.forward,
201                        },
202                    ),
203                    Transitive(eq) => self_.add_edge(
204                        eq,
205                        idx,
206                        EdgeKind::TEqualityTransitive {
207                            forward: parent.forward,
208                        },
209                    ),
210                    Error(..) => (),
211                }
212            }
213        }
214
215        // Add proof step edges
216        for (idx, ps) in parser.proofs().iter_enumerated() {
217            for pre in ps.prerequisites.iter() {
218                self_.add_edge(*pre, idx, EdgeKind::ProofStep)
219            }
220        }
221
222        for (iidx, inst) in parser.insts.insts.iter_enumerated() {
223            let Some(proof) = inst.kind.proof() else {
224                continue;
225            };
226            self_.add_edge(iidx, proof, EdgeKind::YieldProof);
227        }
228
229        // Add cdcl edges
230        for cidx in parser.cdcls().keys() {
231            let backlink = parser.lits.cdcl.backlink(cidx);
232            match (backlink.previous, backlink.backtrack) {
233                (Some(previous), Some(backtrack)) => {
234                    self_.add_edge(backtrack, cidx, EdgeKind::Cdcl(CdclEdge::RetryFrom));
235                    self_.add_edge(previous, cidx, EdgeKind::Cdcl(CdclEdge::Backtrack));
236                }
237                (Some(previous), None) => {
238                    self_.add_edge(previous, cidx, EdgeKind::Cdcl(CdclEdge::Decide))
239                }
240                (None, Some(sidetrack)) => {
241                    self_.add_edge(sidetrack, cidx, EdgeKind::Cdcl(CdclEdge::Sidetrack))
242                }
243                (None, None) => (),
244            }
245        }
246
247        debug_assert!(
248            !petgraph::algo::is_cyclic_directed(&*self_.graph),
249            "Graph is cyclic, this should not happen by construction!"
250        );
251        Ok(self_)
252    }
253    fn add_edge(
254        &mut self,
255        source: impl IndexesInstGraph,
256        target: impl IndexesInstGraph,
257        kind: EdgeKind,
258    ) {
259        let a = source.index(self).0;
260        let b = target.index(self).0;
261        self.graph.add_edge(a, b, kind);
262    }
263
264    pub fn index(&self, kind: NodeKind) -> RawNodeIndex {
265        match kind {
266            NodeKind::ENode(enode) => enode.index(self),
267            NodeKind::GivenEquality(eq, use_) => (eq, use_).index(self),
268            NodeKind::TransEquality(eq) => eq.index(self),
269            NodeKind::Instantiation(inst) => inst.index(self),
270            NodeKind::Proof(ps) => ps.index(self),
271            NodeKind::Cdcl(cdcl) => cdcl.index(self),
272        }
273    }
274
275    pub fn rev(&self) -> Reversed<&petgraph::graph::DiGraph<Node, EdgeKind, RawIx>> {
276        Reversed(&*self.graph)
277    }
278
279    pub fn visible_nodes(&self) -> usize {
280        self.graph.node_count() - self.stats.hidden as usize - self.stats.disabled as usize
281    }
282    pub fn node_indices(&self) -> impl Iterator<Item = RawNodeIndex> {
283        self.graph.node_indices().map(RawNodeIndex)
284    }
285
286    /// Similar to `self.graph.neighbors` but will walk through disabled nodes.
287    ///
288    /// Note: Iterating the neighbors is **not** a O(1) operation.
289    pub fn neighbors<'a>(
290        &'a self,
291        node: RawNodeIndex,
292        reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
293    ) -> Neighbors<'a> {
294        self.neighbors_directed(node, Direction::Outgoing, reach)
295    }
296
297    /// Similar to `self.graph.neighbors_directed` but will walk through
298    /// disabled nodes.
299    ///
300    /// Note: Iterating the neighbors is **not** a O(1) operation.
301    pub fn neighbors_directed<'a>(
302        &'a self,
303        node: RawNodeIndex,
304        dir: Direction,
305        reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
306    ) -> Neighbors<'a> {
307        let direct = self.graph.neighbors_directed(node.0, dir).detach();
308        let walk = WalkNeighbors {
309            dir,
310            visited: FxHashSet::default(),
311            stack: Vec::new(),
312            direct,
313        };
314        Neighbors {
315            raw: self,
316            reach,
317            walk,
318        }
319    }
320
321    pub fn inst_to_raw_idx(&self) -> impl Fn(InstIdx) -> RawNodeIndex {
322        let inst_idx = self.inst_idx;
323        move |inst| RawNodeIndex(NodeIndex::new(inst_idx.0.index() + usize::from(inst)))
324    }
325
326    pub fn hypotheses(&self, parser: &Z3Parser, proof: ProofIdx) -> Vec<ProofIdx> {
327        let proof = proof.index(self);
328        let node = &self[proof];
329        if !node.proof.under_hypothesis() {
330            return Default::default();
331        }
332        let mut hypotheses = Vec::new();
333        let graph = NodeFiltered::from_fn(&*self.graph, |n| self.graph[n].proof.under_hypothesis());
334        let dfs = Dfs::new(Reversed(&graph), proof.0);
335        for n in dfs.iter(Reversed(&graph)).map(RawNodeIndex) {
336            let Some(n) = self[n].kind().proof() else {
337                debug_assert!(false, "Expected proof node");
338                continue;
339            };
340            if parser[n].kind.is_hypothesis() {
341                hypotheses.push(n);
342            }
343        }
344        hypotheses.sort_unstable();
345        hypotheses
346    }
347}
348
349#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
350#[derive(Debug)]
351pub struct GraphStats {
352    pub hidden: u32,
353    pub disabled: u32,
354    /// How many times has the visibility of nodes been changed?
355    /// Used to keep track of if the hidden graph needs to be recalculated.
356    pub generation: u32,
357}
358
359impl GraphStats {
360    pub fn set_state(&mut self, node: &mut Node, state: NodeState) -> bool {
361        if node.state == state {
362            return false;
363        }
364        self.generation = self.generation.wrapping_add(1);
365        match node.state {
366            NodeState::Disabled => self.disabled -= 1,
367            NodeState::Hidden => self.hidden -= 1,
368            _ => (),
369        }
370        match state {
371            NodeState::Disabled => self.disabled += 1,
372            NodeState::Hidden => self.hidden += 1,
373            _ => (),
374        }
375        node.state = state;
376        true
377    }
378}
379
380#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
381#[derive(Debug, Clone)]
382pub struct Node {
383    state: NodeState,
384    kind: NodeKind,
385    pub cost: f64,
386    pub fwd_depth: Depth,
387    pub bwd_depth: Depth,
388    pub subgraph: Option<(GraphIdx, u32)>,
389    pub parents: NextNodes,
390    pub children: NextNodes,
391    pub proof: ProofReach,
392}
393
394#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
395#[derive(Debug, Clone, Copy, PartialEq, Eq)]
396pub enum NodeState {
397    Disabled,
398    Hidden,
399    Visible,
400}
401
402#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
403#[derive(Debug, Clone, Copy, Default)]
404pub struct Depth {
405    /// What is the shortest path to a root/leaf
406    pub min: u16,
407    /// What is the longest path to a root/leaf
408    pub max: u16,
409}
410
411#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
412#[derive(Debug, Clone, Default)]
413pub struct NextNodes {
414    // Issue 4: storing inst children in all nodes huge memory overhead
415    #[cfg(any())]
416    /// What are the immediate next instantiation nodes
417    pub insts: FxHashSet<InstIdx>,
418    /// How many parents/children does this node have (not-necessarily
419    /// instantiation nodes), walking through disabled nodes.
420    pub count: u32,
421}
422
423#[bitmask(u8)]
424#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
425#[cfg_attr(feature = "mem_dbg", copy_type)]
426#[derive(Default)]
427pub enum ProofReach {
428    ProvesFalse,
429    UnderHypothesis,
430    ReachesProof,
431    ReachesNonTrivialProof,
432    ReachesFalse,
433
434    /// Is this a CDCL dead branch (i.e. all children lead to a contradiction)
435    CdclDeadBranch,
436}
437
438impl ProofReach {
439    pub fn if_(self, cond: bool) -> Self {
440        if cond {
441            self
442        } else {
443            Default::default()
444        }
445    }
446
447    pub fn proves_false(self) -> bool {
448        self.contains(ProofReach::ProvesFalse)
449    }
450    pub fn under_hypothesis(self) -> bool {
451        self.contains(ProofReach::UnderHypothesis)
452    }
453    pub fn reaches_proof(self) -> bool {
454        self.contains(ProofReach::ReachesProof)
455    }
456    pub fn reaches_non_trivial_proof(self) -> bool {
457        self.contains(ProofReach::ReachesNonTrivialProof)
458    }
459    pub fn reaches_false(self) -> bool {
460        self.contains(ProofReach::ReachesFalse)
461    }
462
463    pub fn cdcl_dead_branch(self) -> bool {
464        self.contains(ProofReach::CdclDeadBranch)
465    }
466}
467
468impl Node {
469    fn new(kind: NodeKind) -> Self {
470        Self {
471            state: NodeState::Hidden,
472            cost: 0.0,
473            fwd_depth: Depth::default(),
474            bwd_depth: Depth::default(),
475            subgraph: None,
476            kind,
477            parents: NextNodes::default(),
478            children: NextNodes::default(),
479            proof: ProofReach::default(),
480        }
481    }
482    pub fn kind(&self) -> &NodeKind {
483        &self.kind
484    }
485
486    pub fn disabled(&self) -> bool {
487        matches!(self.state, NodeState::Disabled)
488    }
489    pub fn hidden(&self) -> bool {
490        matches!(self.state, NodeState::Hidden)
491    }
492    pub fn visible(&self) -> bool {
493        matches!(self.state, NodeState::Visible)
494    }
495    pub fn hidden_inst(&self) -> bool {
496        matches!(
497            (self.state, self.kind),
498            (NodeState::Hidden, NodeKind::Instantiation(_))
499        )
500    }
501
502    pub fn frame(&self, parser: &Z3Parser) -> Option<StackIdx> {
503        match *self.kind() {
504            NodeKind::Instantiation(iidx) => Some(parser[iidx].frame),
505            NodeKind::ENode(eidx) => Some(parser[eidx].frame),
506            NodeKind::GivenEquality(..) | NodeKind::TransEquality(_) => None,
507            NodeKind::Proof(psidx) => Some(parser[psidx].frame),
508            NodeKind::Cdcl(cdcl) => Some(parser[cdcl].frame),
509        }
510    }
511}
512
513#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
514#[derive(Debug, Clone, Copy)]
515pub enum NodeKind {
516    /// Corresponds to `InstIdx`.
517    ///
518    /// **Parents:** (small) arbitrary count, will always be `ENode` or
519    /// `TransEquality`.\
520    /// **Children:** (small) arbitrary count, will always be `ENode`.
521    Instantiation(InstIdx),
522    /// Corresponds to `ENodeIdx`.
523    ///
524    /// **Parents:** will always have 0 or 1 parents, if 1 then this will be an `Instantiation`.\
525    /// **Children:** (large) arbitrary count, will always be either
526    /// `Instantiation` or `GivenEquality` of type `EqualityExpl::Literal`.
527    ENode(ENodeIdx),
528    /// Corresponds to `EqGivenIdx`.
529    ///
530    /// **Parents:** will always have 0 or 1 parents, if 1 then this will be an
531    /// `ENode` or a `TransEquality` depending on if it's a `Literal` or
532    /// `Congruence` resp.\
533    /// **Children:** (small) arbitrary count, will always be `TransEquality` of
534    /// type.
535    GivenEquality(EqGivenIdx, Option<NonMaxU32>),
536    /// Corresponds to `EqTransIdx`.
537    ///
538    /// **Parents:** (small) arbitrary count, will always be `GivenEquality` or
539    /// `TransEquality`. The number of immediately reachable `GivenEquality` can
540    /// be found in `TransitiveExpl::given_len`.\
541    /// **Children:** (large) arbitrary count, can be `GivenEquality`,
542    /// `TransEquality` or `Instantiation`.
543    TransEquality(EqTransIdx),
544    /// Corresponds to `ProofIdx`.
545    ///
546    /// **Parents:** (large) arbitrary count, will always be `Proof` or
547    /// `Instantiation`.
548    /// **Children:** (small?) arbitrary count, will always be `Proof`.
549    Proof(ProofIdx),
550    /// Corresponds to `CdclIdx`. Only connected to other `Cdcl` nodes.
551    ///
552    /// **Parents:** will always have between 0 and 2 parents, if 2 then only
553    /// one is a real edge and the other is a backtracking edge.
554    /// **Children:** (generally small) arbitrary count, depends on how many
555    /// times we backtracked here.
556    Cdcl(CdclIdx),
557}
558
559impl fmt::Display for NodeKind {
560    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
561        match self {
562            NodeKind::ENode(enode) => write!(f, "{enode}"),
563            NodeKind::GivenEquality(eq, use_) => write!(
564                f,
565                "{eq}{}",
566                use_.filter(|u| *u != NonMaxU32::ZERO)
567                    .map(|u| format!("[{u}]"))
568                    .unwrap_or_default()
569            ),
570            NodeKind::TransEquality(eq) => write!(f, "{eq}"),
571            NodeKind::Instantiation(inst) => write!(f, "{inst}"),
572            NodeKind::Proof(ps) => write!(f, "{ps}"),
573            NodeKind::Cdcl(cdcl) => write!(f, "{cdcl}"),
574        }
575    }
576}
577
578impl NodeKind {
579    pub fn inst(&self) -> Option<InstIdx> {
580        match self {
581            Self::Instantiation(inst) => Some(*inst),
582            _ => None,
583        }
584    }
585    pub fn enode(&self) -> Option<ENodeIdx> {
586        match self {
587            Self::ENode(enode) => Some(*enode),
588            _ => None,
589        }
590    }
591    pub fn eq_given(&self) -> Option<(EqGivenIdx, Option<NonMaxU32>)> {
592        match self {
593            Self::GivenEquality(eq, use_) => Some((*eq, *use_)),
594            _ => None,
595        }
596    }
597    pub fn eq_trans(&self) -> Option<EqTransIdx> {
598        match self {
599            Self::TransEquality(eq) => Some(*eq),
600            _ => None,
601        }
602    }
603    pub fn proof(&self) -> Option<ProofIdx> {
604        match self {
605            Self::Proof(ps) => Some(*ps),
606            _ => None,
607        }
608    }
609    pub fn cdcl(&self) -> Option<CdclIdx> {
610        match self {
611            Self::Cdcl(cdcl) => Some(*cdcl),
612            _ => None,
613        }
614    }
615
616    /// Same as `reconnect_parents` but for children. Do we reconnect hidden
617    /// children of this visible node or just this node itself?
618    pub fn reconnect_child(&self, _child: &Self) -> bool {
619        // TODO: what behavior do we want here?
620        #[cfg(any())]
621        !matches!(
622            (self, child),
623            (
624                Self::ENode(..) | Self::TransEquality(..),
625                Self::Instantiation(..)
626            ) | (Self::Proof(..), Self::Proof(..))
627        );
628        false
629    }
630}
631
632#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
633#[cfg_attr(feature = "mem_dbg", copy_type)]
634#[derive(Debug, Clone, Copy)]
635pub enum EdgeKind {
636    /// Instantiation -> ENode
637    Yield,
638    /// GivenEquality -> ENode
639    YieldEq,
640    /// Proof (asserted) -> ENode
641    Asserted,
642    /// ENode -> Instantiation
643    Blame { pattern_term: u16 },
644    /// TransEquality -> Instantiation
645    BlameEq { pattern_term: u16, eq_order: u16 },
646    /// ENode -> GivenEquality (`EqualityExpl::Literal`)
647    EqualityFact,
648    /// Instantiation -> GivenEquality (`EqualityExpl::Literal`)
649    EqualityProof,
650    /// TransEquality -> GivenEquality (`EqualityExpl::Congruence`)
651    EqualityCongruence,
652    /// GivenEquality -> TransEquality (`TransitiveExplSegmentKind::Leaf`)
653    TEqualitySimple { forward: bool },
654    /// TransEquality -> TransEquality (`TransitiveExplSegmentKind::Transitive`)
655    TEqualityTransitive { forward: bool },
656    /// Proof -> Proof
657    ProofStep,
658    /// Instantiation -> Proof
659    YieldProof,
660    /// Cdcl -> Cdcl
661    Cdcl(CdclEdge),
662}
663
664#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
665#[cfg_attr(feature = "mem_dbg", copy_type)]
666#[derive(Debug, Clone, Copy)]
667pub enum CdclEdge {
668    /// Edge deeper into the CDCL tree
669    Decide,
670    /// Edge back to a higher level in the tree
671    Backtrack,
672    /// Edge to a side branch which may later be popped by the user.
673    Sidetrack,
674    /// Edge linking a backtracked node to the correct place in the tree.
675    RetryFrom,
676}
677
678pub trait IndexesInstGraph {
679    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex;
680}
681impl IndexesInstGraph for InstIdx {
682    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
683        graph.inst_to_raw_idx()(*self)
684    }
685}
686impl IndexesInstGraph for ENodeIdx {
687    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
688        RawNodeIndex(NodeIndex::new(
689            graph.enode_idx.0.index() + usize::from(*self),
690        ))
691    }
692}
693impl IndexesInstGraph for EqTransIdx {
694    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
695        RawNodeIndex(NodeIndex::new(
696            graph.eq_trans_idx.0.index() + usize::from(*self),
697        ))
698    }
699}
700impl IndexesInstGraph for (EqGivenIdx, Option<NonMaxU32>) {
701    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
702        graph.eq_given_idx[self]
703    }
704}
705impl IndexesInstGraph for ProofIdx {
706    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
707        RawNodeIndex(NodeIndex::new(
708            graph.proofs_idx.0.index() + usize::from(*self),
709        ))
710    }
711}
712impl IndexesInstGraph for CdclIdx {
713    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
714        RawNodeIndex(NodeIndex::new(
715            graph.cdcl_idx.0.index() + usize::from(*self),
716        ))
717    }
718}
719impl IndexesInstGraph for RawNodeIndex {
720    fn index(&self, _graph: &RawInstGraph) -> RawNodeIndex {
721        *self
722    }
723}
724
725impl<T: IndexesInstGraph> Index<T> for RawInstGraph {
726    type Output = Node;
727    fn index(&self, index: T) -> &Self::Output {
728        let index = index.index(self);
729        &self.graph[index.0]
730    }
731}
732impl<T: IndexesInstGraph> IndexMut<T> for RawInstGraph {
733    fn index_mut(&mut self, index: T) -> &mut Self::Output {
734        let index = index.index(self);
735        &mut self.graph[index.0]
736    }
737}
738
739impl Index<RawEdgeIndex> for RawInstGraph {
740    type Output = EdgeKind;
741    fn index(&self, index: RawEdgeIndex) -> &Self::Output {
742        &self.graph[index.0]
743    }
744}
745
746#[derive(Clone)]
747pub struct Neighbors<'a> {
748    pub raw: &'a RawInstGraph,
749    pub reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
750    pub walk: WalkNeighbors,
751}
752
753impl<'a> Neighbors<'a> {
754    pub fn detach(self) -> WalkNeighbors {
755        self.walk
756    }
757
758    pub fn count_hidden(self) -> usize {
759        let raw = self.raw;
760        self.filter(|&ix| raw[ix].hidden()).count()
761    }
762}
763
764impl Iterator for Neighbors<'_> {
765    type Item = RawNodeIndex;
766    fn next(&mut self) -> Option<Self::Item> {
767        self.walk.next(self.raw, self.reach)
768    }
769}
770
771#[derive(Clone)]
772pub struct WalkNeighbors {
773    pub dir: Direction,
774    pub visited: FxHashSet<RawNodeIndex>,
775    pub stack: Vec<RawNodeIndex>,
776    pub direct: petgraph::graph::WalkNeighbors<RawIx>,
777}
778
779impl WalkNeighbors {
780    fn next_direct(
781        &mut self,
782        raw: &RawInstGraph,
783        reach: &TiVec<RawNodeIndex, ReachNonDisabled>,
784    ) -> Option<RawNodeIndex> {
785        while let Some((_, n)) = self.direct.next(&raw.graph) {
786            let n = RawNodeIndex(n);
787            let skip = reach.get(n).is_some_and(|v| !v.value());
788            if !skip {
789                return Some(n);
790            }
791        }
792        None
793    }
794
795    pub fn next(
796        &mut self,
797        raw: &RawInstGraph,
798        reach: &TiVec<RawNodeIndex, ReachNonDisabled>,
799    ) -> Option<RawNodeIndex> {
800        // TODO: decide if we want to prevent direct neighbors from being
801        // visited multiple times if there are multiple direct edges.
802        loop {
803            // let mut idx = None;
804            // while let Some((_, direct)) = self.direct.next(&raw.graph) {
805            //     let direct = RawNodeIndex(direct);
806            //     if self.visited.insert(direct) {
807            //         idx = Some(direct);
808            //         break;
809            //     }
810            // }
811            // let idx = idx.or_else(|| self.stack.pop())?;
812            let idx = self.next_direct(raw, reach).or_else(|| self.stack.pop())?;
813            let node = &raw[idx];
814            if !node.disabled() {
815                return Some(idx);
816            }
817            for n in raw.graph.neighbors_directed(idx.0, self.dir) {
818                let n = RawNodeIndex(n);
819                if self.visited.insert(n) {
820                    self.stack.push(n);
821                }
822            }
823        }
824    }
825}