smt_scope/analysis/graph/
raw.rs

1use std::{
2    fmt,
3    ops::{Index, IndexMut},
4};
5
6use bitmask_enum::bitmask;
7#[cfg(feature = "mem_dbg")]
8use mem_dbg::{MemDbg, MemSize};
9use petgraph::{
10    graph::NodeIndex,
11    visit::{Dfs, NodeFiltered, Reversed, Walker},
12    Direction::{self},
13};
14
15use crate::{
16    graph_idx,
17    items::{
18        CdclIdx, ENodeBlame, ENodeIdx, EqGivenIdx, EqTransIdx, EqualityExpl, GraphIdx, InstIdx,
19        ProofIdx, StackIdx, TransitiveExplSegmentKind,
20    },
21    DiGraph, FxHashMap, FxHashSet, NonMaxU32, Result, TiVec, Z3Parser,
22};
23
24use super::analysis::reconnect::{ReachKind, ReachNonDisabled};
25
26graph_idx!(raw_idx, RawNodeIndex, RawEdgeIndex, RawIx);
27
28#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
29#[derive(Debug)]
30pub struct RawInstGraph {
31    pub graph: DiGraph<Node, EdgeKind, RawIx>,
32    inst_idx: RawNodeIndex,
33    enode_idx: RawNodeIndex,
34    eq_trans_idx: RawNodeIndex,
35    eq_given_idx: FxHashMap<(EqGivenIdx, Option<NonMaxU32>), RawNodeIndex>,
36    proofs_idx: RawNodeIndex,
37    cdcl_idx: RawNodeIndex,
38
39    pub(crate) stats: GraphStats,
40}
41
42struct GraphTryReserve;
43impl GraphTryReserve {
44    fn try_reserve_exact(nodes: usize, edges: usize) -> Result<()> {
45        type Nodes = Vec<petgraph::graph::Node<Node>>;
46        type Edges = Vec<petgraph::graph::Edge<EdgeKind>>;
47        let mut n = Nodes::new();
48        n.try_reserve_exact(nodes)?;
49        let mut e = Edges::new();
50        e.try_reserve_exact(edges)?;
51        Ok(())
52    }
53}
54
55impl RawInstGraph {
56    pub fn new(parser: &Z3Parser) -> Result<Self> {
57        let total_nodes = parser.insts.insts.len()
58            + parser.egraph.enodes.len()
59            + parser.egraph.equalities.given.len()
60            + parser.egraph.equalities.transitive.len()
61            + parser.proofs().len()
62            + parser.cdcls().len();
63        let edges_estimate = parser.insts.insts.len()
64            + parser.egraph.equalities.transitive.len()
65            + parser.proofs().len();
66        GraphTryReserve::try_reserve_exact(total_nodes, edges_estimate)?;
67
68        let mut graph = DiGraph::with_capacity(total_nodes, edges_estimate);
69        let inst_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
70        for inst in parser.insts.insts.keys() {
71            graph.add_node(Node::new(NodeKind::Instantiation(inst)));
72        }
73        let enode_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
74        for enode in parser.egraph.enodes.keys() {
75            graph.add_node(Node::new(NodeKind::ENode(enode)));
76        }
77        let eq_trans_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
78        for eq_trans in parser.egraph.equalities.transitive.keys() {
79            graph.add_node(Node::new(NodeKind::TransEquality(eq_trans)));
80        }
81        let mut eq_given_idx = FxHashMap::default();
82        eq_given_idx.try_reserve(parser.egraph.equalities.given.len())?;
83        for (eq_given, eq) in parser.egraph.equalities.given.iter_enumerated() {
84            match eq {
85                EqualityExpl::Congruence { uses, .. } => {
86                    for i in 0..uses.len() {
87                        let use_ = Some(NonMaxU32::new(i as u32).unwrap());
88                        let node =
89                            graph.add_node(Node::new(NodeKind::GivenEquality(eq_given, use_)));
90                        eq_given_idx.insert((eq_given, use_), RawNodeIndex(node));
91                    }
92                }
93                _ => {
94                    let node = graph.add_node(Node::new(NodeKind::GivenEquality(eq_given, None)));
95                    eq_given_idx.insert((eq_given, None), RawNodeIndex(node));
96                }
97            }
98        }
99        let proofs_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
100        for ps_idx in parser.proofs().keys() {
101            graph.add_node(Node::new(NodeKind::Proof(ps_idx)));
102        }
103        let cdcl_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
104        for cdcl in parser.cdcls().keys() {
105            graph.add_node(Node::new(NodeKind::Cdcl(cdcl)));
106        }
107
108        let stats = GraphStats {
109            hidden: graph.node_count() as u32,
110            disabled: 0,
111            generation: 0,
112        };
113        let mut self_ = RawInstGraph {
114            graph,
115            inst_idx,
116            enode_idx,
117            eq_given_idx,
118            eq_trans_idx,
119            proofs_idx,
120            cdcl_idx,
121            stats,
122        };
123
124        // Add instantiation blamed edges
125        for (idx, inst) in parser.insts.insts.iter_enumerated() {
126            for (i, blame) in parser.insts.matches[inst.match_]
127                .pattern_matches()
128                .enumerate()
129            {
130                let pattern_term = i as u16;
131                self_.add_edge(blame.enode, idx, EdgeKind::Blame { pattern_term });
132                for (i, eq) in blame.equalities.iter().enumerate() {
133                    self_.add_edge(
134                        *eq,
135                        idx,
136                        EdgeKind::BlameEq {
137                            pattern_term,
138                            eq_order: i as u16,
139                        },
140                    );
141                }
142            }
143        }
144
145        // Add enode blamed edges
146        for (idx, enode) in parser.egraph.enodes.iter_enumerated() {
147            match &enode.blame {
148                ENodeBlame::Inst((iidx, eqs)) => {
149                    self_.add_edge(*iidx, idx, EdgeKind::Yield);
150                    for &eq in eqs.iter() {
151                        self_.add_edge(eq, idx, EdgeKind::YieldEq);
152                    }
153                }
154                ENodeBlame::Proof(pidx) => self_.add_edge(*pidx, idx, EdgeKind::Asserted),
155                ENodeBlame::BoolConst | ENodeBlame::Unknown => (),
156            }
157        }
158
159        // Add given equality created edges
160        for (idx, eq) in parser.egraph.equalities.given.iter_enumerated() {
161            match eq {
162                EqualityExpl::Root { .. } => (),
163                EqualityExpl::Literal { eq, .. } => {
164                    if let Some(iidx) = parser[*eq].iblame {
165                        self_.add_edge(iidx, (idx, None), EdgeKind::EqualityFact)
166                    }
167                }
168                EqualityExpl::Congruence { uses, .. } => {
169                    for (use_, arg_eqs) in uses.iter().enumerate() {
170                        let use_ = Some(NonMaxU32::new(use_ as u32).unwrap());
171                        for arg_eq in arg_eqs.iter() {
172                            self_.add_edge(*arg_eq, (idx, use_), EdgeKind::EqualityCongruence);
173                        }
174                    }
175                }
176                EqualityExpl::Theory { .. } => (),
177                EqualityExpl::Axiom { .. } => (),
178                EqualityExpl::Unknown { .. } => (),
179            }
180        }
181
182        // Add transitive equality created edges
183        for (idx, eq) in parser.egraph.equalities.transitive.iter_enumerated() {
184            let all = eq.all(true);
185            for parent in all {
186                use TransitiveExplSegmentKind::*;
187                match parent.kind {
188                    Given((eq, use_)) => self_.add_edge(
189                        (eq, use_),
190                        idx,
191                        EdgeKind::TEqualitySimple {
192                            forward: parent.forward,
193                        },
194                    ),
195                    Transitive(eq) => self_.add_edge(
196                        eq,
197                        idx,
198                        EdgeKind::TEqualityTransitive {
199                            forward: parent.forward,
200                        },
201                    ),
202                    Error(..) => (),
203                }
204            }
205        }
206
207        // Add proof step edges
208        for (idx, ps) in parser.proofs().iter_enumerated() {
209            for pre in ps.prerequisites.iter() {
210                self_.add_edge(*pre, idx, EdgeKind::ProofStep)
211            }
212        }
213
214        for (iidx, inst) in parser.insts.insts.iter_enumerated() {
215            let Some(proof) = inst.kind.proof() else {
216                continue;
217            };
218            self_.add_edge(iidx, proof, EdgeKind::YieldProof);
219        }
220
221        // Add cdcl edges
222        for cidx in parser.cdcls().keys() {
223            let backlink = parser.lits.cdcl.backlink(cidx);
224            match (backlink.previous, backlink.backtrack) {
225                (Some(previous), Some(backtrack)) => {
226                    self_.add_edge(backtrack, cidx, EdgeKind::Cdcl(CdclEdge::RetryFrom));
227                    self_.add_edge(previous, cidx, EdgeKind::Cdcl(CdclEdge::Backtrack));
228                }
229                (Some(previous), None) => {
230                    self_.add_edge(previous, cidx, EdgeKind::Cdcl(CdclEdge::Decide))
231                }
232                (None, Some(sidetrack)) => {
233                    self_.add_edge(sidetrack, cidx, EdgeKind::Cdcl(CdclEdge::Sidetrack))
234                }
235                (None, None) => (),
236            }
237        }
238
239        debug_assert!(
240            !petgraph::algo::is_cyclic_directed(&*self_.graph),
241            "Graph is cyclic, this should not happen by construction!"
242        );
243        Ok(self_)
244    }
245    fn add_edge(
246        &mut self,
247        source: impl IndexesInstGraph,
248        target: impl IndexesInstGraph,
249        kind: EdgeKind,
250    ) {
251        let a = source.index(self).0;
252        let b = target.index(self).0;
253        self.graph.add_edge(a, b, kind);
254    }
255
256    pub fn index(&self, kind: NodeKind) -> RawNodeIndex {
257        match kind {
258            NodeKind::ENode(enode) => enode.index(self),
259            NodeKind::GivenEquality(eq, use_) => (eq, use_).index(self),
260            NodeKind::TransEquality(eq) => eq.index(self),
261            NodeKind::Instantiation(inst) => inst.index(self),
262            NodeKind::Proof(ps) => ps.index(self),
263            NodeKind::Cdcl(cdcl) => cdcl.index(self),
264        }
265    }
266
267    pub fn rev(&self) -> Reversed<&petgraph::graph::DiGraph<Node, EdgeKind, RawIx>> {
268        Reversed(&*self.graph)
269    }
270
271    pub fn visible_nodes(&self) -> usize {
272        self.graph.node_count() - self.stats.hidden as usize - self.stats.disabled as usize
273    }
274    pub fn node_indices(&self) -> impl Iterator<Item = RawNodeIndex> {
275        self.graph.node_indices().map(RawNodeIndex)
276    }
277
278    /// Similar to `self.graph.neighbors` but will walk through disabled nodes.
279    ///
280    /// Note: Iterating the neighbors is **not** a O(1) operation.
281    pub fn neighbors<'a>(
282        &'a self,
283        node: RawNodeIndex,
284        reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
285    ) -> Neighbors<'a> {
286        self.neighbors_directed(node, Direction::Outgoing, reach)
287    }
288
289    /// Similar to `self.graph.neighbors_directed` but will walk through
290    /// disabled nodes.
291    ///
292    /// Note: Iterating the neighbors is **not** a O(1) operation.
293    pub fn neighbors_directed<'a>(
294        &'a self,
295        node: RawNodeIndex,
296        dir: Direction,
297        reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
298    ) -> Neighbors<'a> {
299        let direct = self.graph.neighbors_directed(node.0, dir).detach();
300        let walk = WalkNeighbors {
301            dir,
302            visited: FxHashSet::default(),
303            stack: Vec::new(),
304            direct,
305        };
306        Neighbors {
307            raw: self,
308            reach,
309            walk,
310        }
311    }
312
313    pub fn inst_to_raw_idx(&self) -> impl Fn(InstIdx) -> RawNodeIndex {
314        let inst_idx = self.inst_idx;
315        move |inst| RawNodeIndex(NodeIndex::new(inst_idx.0.index() + usize::from(inst)))
316    }
317
318    pub fn hypotheses(&self, parser: &Z3Parser, proof: ProofIdx) -> Vec<ProofIdx> {
319        let proof = proof.index(self);
320        let node = &self[proof];
321        if !node.proof.under_hypothesis() {
322            return Default::default();
323        }
324        let mut hypotheses = Vec::new();
325        let graph = NodeFiltered::from_fn(&*self.graph, |n| self.graph[n].proof.under_hypothesis());
326        let dfs = Dfs::new(Reversed(&graph), proof.0);
327        for n in dfs.iter(Reversed(&graph)).map(RawNodeIndex) {
328            let Some(n) = self[n].kind().proof() else {
329                debug_assert!(false, "Expected proof node");
330                continue;
331            };
332            if parser[n].kind.is_hypothesis() {
333                hypotheses.push(n);
334            }
335        }
336        hypotheses.sort_unstable();
337        hypotheses
338    }
339}
340
341#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
342#[derive(Debug)]
343pub struct GraphStats {
344    pub hidden: u32,
345    pub disabled: u32,
346    /// How many times has the visibility of nodes been changed?
347    /// Used to keep track of if the hidden graph needs to be recalculated.
348    pub generation: u32,
349}
350
351impl GraphStats {
352    pub fn set_state(&mut self, node: &mut Node, state: NodeState) -> bool {
353        if node.state == state {
354            return false;
355        }
356        self.generation = self.generation.wrapping_add(1);
357        match node.state {
358            NodeState::Disabled => self.disabled -= 1,
359            NodeState::Hidden => self.hidden -= 1,
360            _ => (),
361        }
362        match state {
363            NodeState::Disabled => self.disabled += 1,
364            NodeState::Hidden => self.hidden += 1,
365            _ => (),
366        }
367        node.state = state;
368        true
369    }
370}
371
372#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
373#[derive(Debug, Clone)]
374pub struct Node {
375    state: NodeState,
376    kind: NodeKind,
377    pub cost: f64,
378    pub fwd_depth: Depth,
379    pub bwd_depth: Depth,
380    pub subgraph: Option<(GraphIdx, u32)>,
381    pub parents: NextNodes,
382    pub children: NextNodes,
383    pub proof: ProofReach,
384}
385
386#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
387#[derive(Debug, Clone, Copy, PartialEq, Eq)]
388pub enum NodeState {
389    Disabled,
390    Hidden,
391    Visible,
392}
393
394#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
395#[derive(Debug, Clone, Copy, Default)]
396pub struct Depth {
397    /// What is the shortest path to a root/leaf
398    pub min: u16,
399    /// What is the longest path to a root/leaf
400    pub max: u16,
401}
402
403#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
404#[derive(Debug, Clone, Default)]
405pub struct NextNodes {
406    // Issue 4: storing inst children in all nodes huge memory overhead
407    #[cfg(any())]
408    /// What are the immediate next instantiation nodes
409    pub insts: FxHashSet<InstIdx>,
410    /// How many parents/children does this node have (not-necessarily
411    /// instantiation nodes), walking through disabled nodes.
412    pub count: u32,
413}
414
415#[bitmask(u8)]
416#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
417#[cfg_attr(feature = "mem_dbg", copy_type)]
418#[derive(Default)]
419pub enum ProofReach {
420    ProvesFalse,
421    UnderHypothesis,
422    ReachesProof,
423    ReachesNonTrivialProof,
424    ReachesFalse,
425
426    /// Is this a CDCL dead branch (i.e. all children lead to a contradiction)
427    CdclDeadBranch,
428}
429
430impl ProofReach {
431    pub fn if_(self, cond: bool) -> Self {
432        cond.then_some(self).unwrap_or_default()
433    }
434
435    pub fn proves_false(self) -> bool {
436        self.contains(ProofReach::ProvesFalse)
437    }
438    pub fn under_hypothesis(self) -> bool {
439        self.contains(ProofReach::UnderHypothesis)
440    }
441    pub fn reaches_proof(self) -> bool {
442        self.contains(ProofReach::ReachesProof)
443    }
444    pub fn reaches_non_trivial_proof(self) -> bool {
445        self.contains(ProofReach::ReachesNonTrivialProof)
446    }
447    pub fn reaches_false(self) -> bool {
448        self.contains(ProofReach::ReachesFalse)
449    }
450
451    pub fn cdcl_dead_branch(self) -> bool {
452        self.contains(ProofReach::CdclDeadBranch)
453    }
454}
455
456impl Node {
457    fn new(kind: NodeKind) -> Self {
458        Self {
459            state: NodeState::Hidden,
460            cost: 0.0,
461            fwd_depth: Depth::default(),
462            bwd_depth: Depth::default(),
463            subgraph: None,
464            kind,
465            parents: NextNodes::default(),
466            children: NextNodes::default(),
467            proof: ProofReach::default(),
468        }
469    }
470    pub fn kind(&self) -> &NodeKind {
471        &self.kind
472    }
473
474    pub fn disabled(&self) -> bool {
475        matches!(self.state, NodeState::Disabled)
476    }
477    pub fn hidden(&self) -> bool {
478        matches!(self.state, NodeState::Hidden)
479    }
480    pub fn visible(&self) -> bool {
481        matches!(self.state, NodeState::Visible)
482    }
483    pub fn hidden_inst(&self) -> bool {
484        matches!(
485            (self.state, self.kind),
486            (NodeState::Hidden, NodeKind::Instantiation(_))
487        )
488    }
489
490    pub fn frame(&self, parser: &Z3Parser) -> Option<StackIdx> {
491        match *self.kind() {
492            NodeKind::Instantiation(iidx) => Some(parser[iidx].frame),
493            NodeKind::ENode(eidx) => Some(parser[eidx].frame),
494            NodeKind::GivenEquality(..) | NodeKind::TransEquality(_) => None,
495            NodeKind::Proof(psidx) => Some(parser[psidx].frame),
496            NodeKind::Cdcl(cdcl) => Some(parser[cdcl].frame),
497        }
498    }
499}
500
501#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
502#[derive(Debug, Clone, Copy)]
503pub enum NodeKind {
504    /// Corresponds to `InstIdx`.
505    ///
506    /// **Parents:** (small) arbitrary count, will always be `ENode` or
507    /// `TransEquality`.\
508    /// **Children:** (small) arbitrary count, will always be `ENode`.
509    Instantiation(InstIdx),
510    /// Corresponds to `ENodeIdx`.
511    ///
512    /// **Parents:** will always have 0 or 1 parents, if 1 then this will be an `Instantiation`.\
513    /// **Children:** (large) arbitrary count, will always be either
514    /// `Instantiation` or `GivenEquality` of type `EqualityExpl::Literal`.
515    ENode(ENodeIdx),
516    /// Corresponds to `EqGivenIdx`.
517    ///
518    /// **Parents:** will always have 0 or 1 parents, if 1 then this will be an
519    /// `ENode` or a `TransEquality` depending on if it's a `Literal` or
520    /// `Congruence` resp.\
521    /// **Children:** (small) arbitrary count, will always be `TransEquality` of
522    /// type.
523    GivenEquality(EqGivenIdx, Option<NonMaxU32>),
524    /// Corresponds to `EqTransIdx`.
525    ///
526    /// **Parents:** (small) arbitrary count, will always be `GivenEquality` or
527    /// `TransEquality`. The number of immediately reachable `GivenEquality` can
528    /// be found in `TransitiveExpl::given_len`.\
529    /// **Children:** (large) arbitrary count, can be `GivenEquality`,
530    /// `TransEquality` or `Instantiation`.
531    TransEquality(EqTransIdx),
532    /// Corresponds to `ProofIdx`.
533    ///
534    /// **Parents:** (large) arbitrary count, will always be `Proof` or
535    /// `Instantiation`.
536    /// **Children:** (small?) arbitrary count, will always be `Proof`.
537    Proof(ProofIdx),
538    /// Corresponds to `CdclIdx`. Only connected to other `Cdcl` nodes.
539    ///
540    /// **Parents:** will always have between 0 and 2 parents, if 2 then only
541    /// one is a real edge and the other is a backtracking edge.
542    /// **Children:** (generally small) arbitrary count, depends on how many
543    /// times we backtracked here.
544    Cdcl(CdclIdx),
545}
546
547impl fmt::Display for NodeKind {
548    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
549        match self {
550            NodeKind::ENode(enode) => write!(f, "{enode}"),
551            NodeKind::GivenEquality(eq, use_) => write!(
552                f,
553                "{eq}{}",
554                use_.filter(|u| *u != NonMaxU32::ZERO)
555                    .map(|u| format!("[{u}]"))
556                    .unwrap_or_default()
557            ),
558            NodeKind::TransEquality(eq) => write!(f, "{eq}"),
559            NodeKind::Instantiation(inst) => write!(f, "{inst}"),
560            NodeKind::Proof(ps) => write!(f, "{ps}"),
561            NodeKind::Cdcl(cdcl) => write!(f, "{cdcl}"),
562        }
563    }
564}
565
566impl NodeKind {
567    pub fn inst(&self) -> Option<InstIdx> {
568        match self {
569            Self::Instantiation(inst) => Some(*inst),
570            _ => None,
571        }
572    }
573    pub fn enode(&self) -> Option<ENodeIdx> {
574        match self {
575            Self::ENode(enode) => Some(*enode),
576            _ => None,
577        }
578    }
579    pub fn eq_given(&self) -> Option<(EqGivenIdx, Option<NonMaxU32>)> {
580        match self {
581            Self::GivenEquality(eq, use_) => Some((*eq, *use_)),
582            _ => None,
583        }
584    }
585    pub fn eq_trans(&self) -> Option<EqTransIdx> {
586        match self {
587            Self::TransEquality(eq) => Some(*eq),
588            _ => None,
589        }
590    }
591    pub fn proof(&self) -> Option<ProofIdx> {
592        match self {
593            Self::Proof(ps) => Some(*ps),
594            _ => None,
595        }
596    }
597    pub fn cdcl(&self) -> Option<CdclIdx> {
598        match self {
599            Self::Cdcl(cdcl) => Some(*cdcl),
600            _ => None,
601        }
602    }
603
604    /// Same as `reconnect_parents` but for children. Do we reconnect hidden
605    /// children of this visible node or just this node itself?
606    pub fn reconnect_child(&self, _child: &Self) -> bool {
607        // TODO: what behavior do we want here?
608        #[cfg(any())]
609        !matches!(
610            (self, child),
611            (
612                Self::ENode(..) | Self::TransEquality(..),
613                Self::Instantiation(..)
614            ) | (Self::Proof(..), Self::Proof(..))
615        );
616        false
617    }
618}
619
620#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
621#[cfg_attr(feature = "mem_dbg", copy_type)]
622#[derive(Debug, Clone, Copy)]
623pub enum EdgeKind {
624    /// Instantiation -> ENode
625    Yield,
626    /// GivenEquality -> ENode
627    YieldEq,
628    /// Proof (asserted) -> ENode
629    Asserted,
630    /// ENode -> Instantiation
631    Blame { pattern_term: u16 },
632    /// TransEquality -> Instantiation
633    BlameEq { pattern_term: u16, eq_order: u16 },
634    /// ENode -> GivenEquality (`EqualityExpl::Literal`)
635    EqualityFact,
636    /// TransEquality -> GivenEquality (`EqualityExpl::Congruence`)
637    EqualityCongruence,
638    /// GivenEquality -> TransEquality (`TransitiveExplSegmentKind::Leaf`)
639    TEqualitySimple { forward: bool },
640    /// TransEquality -> TransEquality (`TransitiveExplSegmentKind::Transitive`)
641    TEqualityTransitive { forward: bool },
642    /// Proof -> Proof
643    ProofStep,
644    /// Instantiation -> Proof
645    YieldProof,
646    /// Cdcl -> Cdcl
647    Cdcl(CdclEdge),
648}
649
650#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
651#[cfg_attr(feature = "mem_dbg", copy_type)]
652#[derive(Debug, Clone, Copy)]
653pub enum CdclEdge {
654    /// Edge deeper into the CDCL tree
655    Decide,
656    /// Edge back to a higher level in the tree
657    Backtrack,
658    /// Edge to a side branch which may later be popped by the user.
659    Sidetrack,
660    /// Edge linking a backtracked node to the correct place in the tree.
661    RetryFrom,
662}
663
664pub trait IndexesInstGraph {
665    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex;
666}
667impl IndexesInstGraph for InstIdx {
668    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
669        graph.inst_to_raw_idx()(*self)
670    }
671}
672impl IndexesInstGraph for ENodeIdx {
673    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
674        RawNodeIndex(NodeIndex::new(
675            graph.enode_idx.0.index() + usize::from(*self),
676        ))
677    }
678}
679impl IndexesInstGraph for EqTransIdx {
680    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
681        RawNodeIndex(NodeIndex::new(
682            graph.eq_trans_idx.0.index() + usize::from(*self),
683        ))
684    }
685}
686impl IndexesInstGraph for (EqGivenIdx, Option<NonMaxU32>) {
687    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
688        graph.eq_given_idx[self]
689    }
690}
691impl IndexesInstGraph for ProofIdx {
692    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
693        RawNodeIndex(NodeIndex::new(
694            graph.proofs_idx.0.index() + usize::from(*self),
695        ))
696    }
697}
698impl IndexesInstGraph for CdclIdx {
699    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
700        RawNodeIndex(NodeIndex::new(
701            graph.cdcl_idx.0.index() + usize::from(*self),
702        ))
703    }
704}
705impl IndexesInstGraph for RawNodeIndex {
706    fn index(&self, _graph: &RawInstGraph) -> RawNodeIndex {
707        *self
708    }
709}
710
711impl<T: IndexesInstGraph> Index<T> for RawInstGraph {
712    type Output = Node;
713    fn index(&self, index: T) -> &Self::Output {
714        let index = index.index(self);
715        &self.graph[index.0]
716    }
717}
718impl<T: IndexesInstGraph> IndexMut<T> for RawInstGraph {
719    fn index_mut(&mut self, index: T) -> &mut Self::Output {
720        let index = index.index(self);
721        &mut self.graph[index.0]
722    }
723}
724
725impl Index<RawEdgeIndex> for RawInstGraph {
726    type Output = EdgeKind;
727    fn index(&self, index: RawEdgeIndex) -> &Self::Output {
728        &self.graph[index.0]
729    }
730}
731
732#[derive(Clone)]
733pub struct Neighbors<'a> {
734    pub raw: &'a RawInstGraph,
735    pub reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
736    pub walk: WalkNeighbors,
737}
738
739impl<'a> Neighbors<'a> {
740    pub fn detach(self) -> WalkNeighbors {
741        self.walk
742    }
743
744    pub fn count_hidden(self) -> usize {
745        let raw = self.raw;
746        self.filter(|&ix| raw[ix].hidden()).count()
747    }
748}
749
750impl Iterator for Neighbors<'_> {
751    type Item = RawNodeIndex;
752    fn next(&mut self) -> Option<Self::Item> {
753        self.walk.next(self.raw, self.reach)
754    }
755}
756
757#[derive(Clone)]
758pub struct WalkNeighbors {
759    pub dir: Direction,
760    pub visited: FxHashSet<RawNodeIndex>,
761    pub stack: Vec<RawNodeIndex>,
762    pub direct: petgraph::graph::WalkNeighbors<RawIx>,
763}
764
765impl WalkNeighbors {
766    fn next_direct(
767        &mut self,
768        raw: &RawInstGraph,
769        reach: &TiVec<RawNodeIndex, ReachNonDisabled>,
770    ) -> Option<RawNodeIndex> {
771        while let Some((_, n)) = self.direct.next(&raw.graph) {
772            let n = RawNodeIndex(n);
773            let skip = reach.get(n).is_some_and(|v| !v.value());
774            if !skip {
775                return Some(n);
776            }
777        }
778        None
779    }
780
781    pub fn next(
782        &mut self,
783        raw: &RawInstGraph,
784        reach: &TiVec<RawNodeIndex, ReachNonDisabled>,
785    ) -> Option<RawNodeIndex> {
786        // TODO: decide if we want to prevent direct neighbors from being
787        // visited multiple times if there are multiple direct edges.
788        loop {
789            // let mut idx = None;
790            // while let Some((_, direct)) = self.direct.next(&raw.graph) {
791            //     let direct = RawNodeIndex(direct);
792            //     if self.visited.insert(direct) {
793            //         idx = Some(direct);
794            //         break;
795            //     }
796            // }
797            // let idx = idx.or_else(|| self.stack.pop())?;
798            let idx = self.next_direct(raw, reach).or_else(|| self.stack.pop())?;
799            let node = &raw[idx];
800            if !node.disabled() {
801                return Some(idx);
802            }
803            for n in raw.graph.neighbors_directed(idx.0, self.dir) {
804                let n = RawNodeIndex(n);
805                if self.visited.insert(n) {
806                    self.stack.push(n);
807                }
808            }
809        }
810    }
811}