smt_scope/analysis/graph/
raw.rs

1use std::{
2    fmt,
3    ops::{Index, IndexMut},
4};
5
6use bitmask_enum::bitmask;
7#[cfg(feature = "mem_dbg")]
8use mem_dbg::{MemDbg, MemSize};
9use petgraph::{
10    graph::NodeIndex,
11    visit::{Dfs, NodeFiltered, Reversed, Walker},
12    Direction::{self},
13};
14
15use crate::{
16    graph_idx,
17    items::{
18        CdclIdx, ENodeBlame, ENodeIdx, EqGivenIdx, EqTransIdx, EqualityExpl, GraphIdx, InstIdx,
19        ProofIdx, StackIdx, TransitiveExplSegmentKind,
20    },
21    DiGraph, FxHashMap, FxHashSet, NonMaxU32, Result, TiVec, Z3Parser,
22};
23
24use super::analysis::reconnect::{ReachKind, ReachNonDisabled};
25
26graph_idx!(raw_idx, RawNodeIndex, RawEdgeIndex, RawIx);
27
28#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
29#[derive(Debug)]
30pub struct RawInstGraph {
31    pub graph: DiGraph<Node, EdgeKind, RawIx>,
32    inst_idx: RawNodeIndex,
33    enode_idx: RawNodeIndex,
34    eq_trans_idx: RawNodeIndex,
35    eq_given_idx: FxHashMap<(EqGivenIdx, Option<NonMaxU32>), RawNodeIndex>,
36    proofs_idx: RawNodeIndex,
37    cdcl_idx: RawNodeIndex,
38
39    pub(crate) stats: GraphStats,
40}
41
42struct GraphTryReserve;
43impl GraphTryReserve {
44    fn try_reserve_exact(nodes: usize, edges: usize) -> Result<()> {
45        type Nodes = Vec<petgraph::graph::Node<Node>>;
46        type Edges = Vec<petgraph::graph::Edge<EdgeKind>>;
47        let mut n = Nodes::new();
48        n.try_reserve_exact(nodes)?;
49        let mut e = Edges::new();
50        e.try_reserve_exact(edges)?;
51        Ok(())
52    }
53}
54
55impl RawInstGraph {
56    pub fn new(parser: &Z3Parser) -> Result<Self> {
57        let total_nodes = parser.insts.insts.len()
58            + parser.egraph.enodes.len()
59            + parser.egraph.equalities.given.len()
60            + parser.egraph.equalities.transitive.len()
61            + parser.proofs().len()
62            + parser.cdcls().len();
63        let edges_estimate = parser.insts.insts.len()
64            + parser.egraph.equalities.transitive.len()
65            + parser.proofs().len();
66        GraphTryReserve::try_reserve_exact(total_nodes, edges_estimate)?;
67
68        let mut graph = DiGraph::with_capacity(total_nodes, edges_estimate);
69        let inst_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
70        for inst in parser.insts.insts.keys() {
71            graph.add_node(Node::new(NodeKind::Instantiation(inst)));
72        }
73        let enode_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
74        for enode in parser.egraph.enodes.keys() {
75            graph.add_node(Node::new(NodeKind::ENode(enode)));
76        }
77        let eq_trans_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
78        for eq_trans in parser.egraph.equalities.transitive.keys() {
79            graph.add_node(Node::new(NodeKind::TransEquality(eq_trans)));
80        }
81        let mut eq_given_idx = FxHashMap::default();
82        eq_given_idx.try_reserve(parser.egraph.equalities.given.len())?;
83        for (eq_given, eq) in parser.egraph.equalities.given.iter_enumerated() {
84            match eq {
85                EqualityExpl::Congruence { uses, .. } => {
86                    for i in 0..uses.len() {
87                        let use_ = Some(NonMaxU32::new(i as u32).unwrap());
88                        let node =
89                            graph.add_node(Node::new(NodeKind::GivenEquality(eq_given, use_)));
90                        eq_given_idx.insert((eq_given, use_), RawNodeIndex(node));
91                    }
92                }
93                _ => {
94                    let node = graph.add_node(Node::new(NodeKind::GivenEquality(eq_given, None)));
95                    eq_given_idx.insert((eq_given, None), RawNodeIndex(node));
96                }
97            }
98        }
99        let proofs_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
100        for ps_idx in parser.proofs().keys() {
101            graph.add_node(Node::new(NodeKind::Proof(ps_idx)));
102        }
103        let cdcl_idx = RawNodeIndex(NodeIndex::new(graph.node_count()));
104        for cdcl in parser.cdcls().keys() {
105            graph.add_node(Node::new(NodeKind::Cdcl(cdcl)));
106        }
107
108        let stats = GraphStats {
109            hidden: graph.node_count() as u32,
110            disabled: 0,
111            generation: 0,
112        };
113        let mut self_ = RawInstGraph {
114            graph,
115            inst_idx,
116            enode_idx,
117            eq_given_idx,
118            eq_trans_idx,
119            proofs_idx,
120            cdcl_idx,
121            stats,
122        };
123
124        // Add instantiation blamed edges
125        for (idx, inst) in parser.insts.insts.iter_enumerated() {
126            let match_ = &parser.insts.matches[inst.match_];
127            // if let Some(blame) = match_.kind.quant_idx().and_then(|qidx| parser[qidx].blame) {
128            //     self_.add_edge(blame, idx, EdgeKind::AssertedQuant);
129            // }
130
131            for (i, blame) in match_.pattern_matches().enumerate() {
132                let pattern_term = i as u16;
133                self_.add_edge(blame.enode, idx, EdgeKind::Blame { pattern_term });
134                for (i, eq) in blame.equalities.iter().enumerate() {
135                    self_.add_edge(
136                        *eq,
137                        idx,
138                        EdgeKind::BlameEq {
139                            pattern_term,
140                            eq_order: i as u16,
141                        },
142                    );
143                }
144            }
145        }
146
147        // Add enode blamed edges
148        for (idx, enode) in parser.egraph.enodes.iter_enumerated() {
149            match &enode.blame {
150                ENodeBlame::Inst((iidx, eqs)) => {
151                    self_.add_edge(*iidx, idx, EdgeKind::Yield);
152                    for &eq in eqs.iter() {
153                        self_.add_edge(eq, idx, EdgeKind::YieldEq);
154                    }
155                }
156                ENodeBlame::Proof(pidx) => self_.add_edge(*pidx, idx, EdgeKind::Asserted),
157                ENodeBlame::BoolConst | ENodeBlame::Unknown => (),
158            }
159        }
160
161        // Add given equality created edges
162        for (idx, eq) in parser.egraph.equalities.given.iter_enumerated() {
163            match eq {
164                EqualityExpl::Root { .. } => (),
165                EqualityExpl::Literal { eq, .. } => {
166                    let literal = &parser[*eq];
167                    if let Some(enode) = literal.enode {
168                        self_.add_edge(enode, (idx, None), EdgeKind::EqualityFact)
169                    }
170                    if let Some(iidx) = literal.iblame {
171                        let enode_blame =
172                            literal.enode.and_then(|enode| parser[enode].blame.inst());
173                        if enode_blame.is_none_or(|enode_blame| enode_blame != iidx) {
174                            self_.add_edge(iidx, (idx, None), EdgeKind::EqualityProof)
175                        }
176                    }
177                }
178                EqualityExpl::Congruence { uses, .. } => {
179                    for (use_, arg_eqs) in uses.iter().enumerate() {
180                        let use_ = Some(NonMaxU32::new(use_ as u32).unwrap());
181                        for arg_eq in arg_eqs.iter() {
182                            self_.add_edge(*arg_eq, (idx, use_), EdgeKind::EqualityCongruence);
183                        }
184                    }
185                }
186                EqualityExpl::Theory { .. } => (),
187                EqualityExpl::Axiom { .. } => (),
188                EqualityExpl::Unknown { .. } => (),
189            }
190        }
191
192        // Add transitive equality created edges
193        for (idx, eq) in parser.egraph.equalities.transitive.iter_enumerated() {
194            let all = eq.all(true);
195            for parent in all {
196                use TransitiveExplSegmentKind::*;
197                match parent.kind {
198                    Given((eq, use_)) => self_.add_edge(
199                        (eq, use_),
200                        idx,
201                        EdgeKind::TEqualitySimple {
202                            forward: parent.forward,
203                        },
204                    ),
205                    Transitive(eq) => self_.add_edge(
206                        eq,
207                        idx,
208                        EdgeKind::TEqualityTransitive {
209                            forward: parent.forward,
210                        },
211                    ),
212                    Error(..) => (),
213                }
214            }
215        }
216
217        // Add proof step edges
218        for (idx, ps) in parser.proofs().iter_enumerated() {
219            for pre in ps.prerequisites.iter() {
220                self_.add_edge(*pre, idx, EdgeKind::ProofStep)
221            }
222        }
223
224        for (iidx, inst) in parser.insts.insts.iter_enumerated() {
225            let Some(proof) = inst.kind.proof() else {
226                continue;
227            };
228            self_.add_edge(iidx, proof, EdgeKind::YieldProof);
229        }
230
231        // Add cdcl edges
232        for cidx in parser.cdcls().keys() {
233            let backlink = parser.lits.cdcl.backlink(cidx);
234            match (backlink.previous, backlink.backtrack) {
235                (Some(previous), Some(backtrack)) => {
236                    self_.add_edge(backtrack, cidx, EdgeKind::Cdcl(CdclEdge::RetryFrom));
237                    self_.add_edge(previous, cidx, EdgeKind::Cdcl(CdclEdge::Backtrack));
238                }
239                (Some(previous), None) => {
240                    self_.add_edge(previous, cidx, EdgeKind::Cdcl(CdclEdge::Decide))
241                }
242                (None, Some(sidetrack)) => {
243                    self_.add_edge(sidetrack, cidx, EdgeKind::Cdcl(CdclEdge::Sidetrack))
244                }
245                (None, None) => (),
246            }
247        }
248
249        debug_assert!(
250            !petgraph::algo::is_cyclic_directed(&*self_.graph),
251            "Graph is cyclic, this should not happen by construction!"
252        );
253        Ok(self_)
254    }
255    fn add_edge(
256        &mut self,
257        source: impl IndexesInstGraph,
258        target: impl IndexesInstGraph,
259        kind: EdgeKind,
260    ) {
261        let a = source.index(self).0;
262        let b = target.index(self).0;
263        self.graph.add_edge(a, b, kind);
264    }
265
266    pub fn index(&self, kind: NodeKind) -> RawNodeIndex {
267        match kind {
268            NodeKind::ENode(enode) => enode.index(self),
269            NodeKind::GivenEquality(eq, use_) => (eq, use_).index(self),
270            NodeKind::TransEquality(eq) => eq.index(self),
271            NodeKind::Instantiation(inst) => inst.index(self),
272            NodeKind::Proof(ps) => ps.index(self),
273            NodeKind::Cdcl(cdcl) => cdcl.index(self),
274        }
275    }
276
277    pub fn rev(&self) -> Reversed<&petgraph::graph::DiGraph<Node, EdgeKind, RawIx>> {
278        Reversed(&*self.graph)
279    }
280
281    pub fn visible_nodes(&self) -> usize {
282        self.graph.node_count() - self.stats.hidden as usize - self.stats.disabled as usize
283    }
284    pub fn node_indices(&self) -> impl Iterator<Item = RawNodeIndex> {
285        self.graph.node_indices().map(RawNodeIndex)
286    }
287
288    /// Similar to `self.graph.neighbors` but will walk through disabled nodes.
289    ///
290    /// Note: Iterating the neighbors is **not** a O(1) operation.
291    pub fn neighbors<'a>(
292        &'a self,
293        node: RawNodeIndex,
294        reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
295    ) -> Neighbors<'a> {
296        self.neighbors_directed(node, Direction::Outgoing, reach)
297    }
298
299    /// Similar to `self.graph.neighbors_directed` but will walk through
300    /// disabled nodes.
301    ///
302    /// Note: Iterating the neighbors is **not** a O(1) operation.
303    pub fn neighbors_directed<'a>(
304        &'a self,
305        node: RawNodeIndex,
306        dir: Direction,
307        reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
308    ) -> Neighbors<'a> {
309        let direct = self.graph.neighbors_directed(node.0, dir).detach();
310        let walk = WalkNeighbors {
311            dir,
312            visited: FxHashSet::default(),
313            stack: Vec::new(),
314            direct,
315        };
316        Neighbors {
317            raw: self,
318            reach,
319            walk,
320        }
321    }
322
323    pub fn inst_to_raw_idx(&self) -> impl Fn(InstIdx) -> RawNodeIndex {
324        let inst_idx = self.inst_idx;
325        move |inst| RawNodeIndex(NodeIndex::new(inst_idx.0.index() + usize::from(inst)))
326    }
327
328    pub fn hypotheses(&self, parser: &Z3Parser, proof: ProofIdx) -> Vec<ProofIdx> {
329        let proof = proof.index(self);
330        let node = &self[proof];
331        if !node.proof.under_hypothesis() {
332            return Default::default();
333        }
334        let mut hypotheses = Vec::new();
335        let graph = NodeFiltered::from_fn(&*self.graph, |n| self.graph[n].proof.under_hypothesis());
336        let dfs = Dfs::new(Reversed(&graph), proof.0);
337        for n in dfs.iter(Reversed(&graph)).map(RawNodeIndex) {
338            let Some(n) = self[n].kind().proof() else {
339                debug_assert!(false, "Expected proof node");
340                continue;
341            };
342            if parser[n].kind.is_hypothesis() {
343                hypotheses.push(n);
344            }
345        }
346        hypotheses.sort_unstable();
347        hypotheses
348    }
349}
350
351#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
352#[derive(Debug)]
353pub struct GraphStats {
354    pub hidden: u32,
355    pub disabled: u32,
356    /// How many times has the visibility of nodes been changed?
357    /// Used to keep track of if the hidden graph needs to be recalculated.
358    pub generation: u32,
359}
360
361impl GraphStats {
362    pub fn set_state(&mut self, node: &mut Node, state: NodeState) -> bool {
363        if node.state == state {
364            return false;
365        }
366        self.generation = self.generation.wrapping_add(1);
367        match node.state {
368            NodeState::Disabled => self.disabled -= 1,
369            NodeState::Hidden => self.hidden -= 1,
370            _ => (),
371        }
372        match state {
373            NodeState::Disabled => self.disabled += 1,
374            NodeState::Hidden => self.hidden += 1,
375            _ => (),
376        }
377        node.state = state;
378        true
379    }
380}
381
382#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
383#[derive(Debug, Clone)]
384pub struct Node {
385    state: NodeState,
386    kind: NodeKind,
387    pub cost: f64,
388    pub fwd_depth: Depth,
389    pub bwd_depth: Depth,
390    pub subgraph: Option<(GraphIdx, u32)>,
391    pub parents: NextNodes,
392    pub children: NextNodes,
393    pub proof: ProofReach,
394}
395
396#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
397#[derive(Debug, Clone, Copy, PartialEq, Eq)]
398pub enum NodeState {
399    Disabled,
400    Hidden,
401    Visible,
402}
403
404#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
405#[derive(Debug, Clone, Copy, Default)]
406pub struct Depth {
407    /// What is the shortest path to a root/leaf
408    pub min: u16,
409    /// What is the longest path to a root/leaf
410    pub max: u16,
411}
412
413#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
414#[derive(Debug, Clone, Default)]
415pub struct NextNodes {
416    // Issue 4: storing inst children in all nodes huge memory overhead
417    #[cfg(any())]
418    /// What are the immediate next instantiation nodes
419    pub insts: FxHashSet<InstIdx>,
420    /// How many parents/children does this node have (not-necessarily
421    /// instantiation nodes), walking through disabled nodes.
422    pub count: u32,
423}
424
425#[bitmask(u8)]
426#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
427#[cfg_attr(feature = "mem_dbg", copy_type)]
428#[derive(Default)]
429pub enum ProofReach {
430    ProvesFalse,
431    UnderHypothesis,
432    ReachesProof,
433    ReachesNonTrivialProof,
434    ReachesFalse,
435
436    /// Is this a CDCL dead branch (i.e. all children lead to a contradiction)
437    CdclDeadBranch,
438}
439
440impl ProofReach {
441    pub fn if_(self, cond: bool) -> Self {
442        if cond {
443            self
444        } else {
445            Default::default()
446        }
447    }
448
449    pub fn proves_false(self) -> bool {
450        self.contains(ProofReach::ProvesFalse)
451    }
452    pub fn under_hypothesis(self) -> bool {
453        self.contains(ProofReach::UnderHypothesis)
454    }
455    pub fn reaches_proof(self) -> bool {
456        self.contains(ProofReach::ReachesProof)
457    }
458    pub fn reaches_non_trivial_proof(self) -> bool {
459        self.contains(ProofReach::ReachesNonTrivialProof)
460    }
461    pub fn reaches_false(self) -> bool {
462        self.contains(ProofReach::ReachesFalse)
463    }
464
465    pub fn cdcl_dead_branch(self) -> bool {
466        self.contains(ProofReach::CdclDeadBranch)
467    }
468}
469
470impl Node {
471    fn new(kind: NodeKind) -> Self {
472        Self {
473            state: NodeState::Hidden,
474            cost: 0.0,
475            fwd_depth: Depth::default(),
476            bwd_depth: Depth::default(),
477            subgraph: None,
478            kind,
479            parents: NextNodes::default(),
480            children: NextNodes::default(),
481            proof: ProofReach::default(),
482        }
483    }
484    pub fn kind(&self) -> &NodeKind {
485        &self.kind
486    }
487
488    pub fn disabled(&self) -> bool {
489        matches!(self.state, NodeState::Disabled)
490    }
491    pub fn hidden(&self) -> bool {
492        matches!(self.state, NodeState::Hidden)
493    }
494    pub fn visible(&self) -> bool {
495        matches!(self.state, NodeState::Visible)
496    }
497    pub fn hidden_inst(&self) -> bool {
498        matches!(
499            (self.state, self.kind),
500            (NodeState::Hidden, NodeKind::Instantiation(_))
501        )
502    }
503
504    pub fn frame(&self, parser: &Z3Parser) -> Option<StackIdx> {
505        match *self.kind() {
506            NodeKind::Instantiation(iidx) => Some(parser[iidx].frame),
507            NodeKind::ENode(eidx) => Some(parser[eidx].frame),
508            NodeKind::GivenEquality(..) | NodeKind::TransEquality(_) => None,
509            NodeKind::Proof(psidx) => Some(parser[psidx].frame),
510            NodeKind::Cdcl(cdcl) => Some(parser[cdcl].frame),
511        }
512    }
513}
514
515#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
516#[derive(Debug, Clone, Copy)]
517pub enum NodeKind {
518    /// Corresponds to `InstIdx`.
519    ///
520    /// **Parents:** (small) arbitrary count, will always be `ENode` or
521    /// `TransEquality`.\
522    /// **Children:** (small) arbitrary count, will always be `ENode`.
523    Instantiation(InstIdx),
524    /// Corresponds to `ENodeIdx`.
525    ///
526    /// **Parents:** will always have 0 or 1 parents, if 1 then this will be an `Instantiation`.\
527    /// **Children:** (large) arbitrary count, will always be either
528    /// `Instantiation` or `GivenEquality` of type `EqualityExpl::Literal`.
529    ENode(ENodeIdx),
530    /// Corresponds to `EqGivenIdx`.
531    ///
532    /// **Parents:** will always have 0 or 1 parents, if 1 then this will be an
533    /// `ENode` or a `TransEquality` depending on if it's a `Literal` or
534    /// `Congruence` resp.\
535    /// **Children:** (small) arbitrary count, will always be `TransEquality` of
536    /// type.
537    GivenEquality(EqGivenIdx, Option<NonMaxU32>),
538    /// Corresponds to `EqTransIdx`.
539    ///
540    /// **Parents:** (small) arbitrary count, will always be `GivenEquality` or
541    /// `TransEquality`. The number of immediately reachable `GivenEquality` can
542    /// be found in `TransitiveExpl::given_len`.\
543    /// **Children:** (large) arbitrary count, can be `GivenEquality`,
544    /// `TransEquality` or `Instantiation`.
545    TransEquality(EqTransIdx),
546    /// Corresponds to `ProofIdx`.
547    ///
548    /// **Parents:** (large) arbitrary count, will always be `Proof` or
549    /// `Instantiation`.
550    /// **Children:** (small?) arbitrary count, will always be `Proof`.
551    Proof(ProofIdx),
552    /// Corresponds to `CdclIdx`. Only connected to other `Cdcl` nodes.
553    ///
554    /// **Parents:** will always have between 0 and 2 parents, if 2 then only
555    /// one is a real edge and the other is a backtracking edge.
556    /// **Children:** (generally small) arbitrary count, depends on how many
557    /// times we backtracked here.
558    Cdcl(CdclIdx),
559}
560
561impl fmt::Display for NodeKind {
562    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
563        match self {
564            NodeKind::ENode(enode) => write!(f, "{enode}"),
565            NodeKind::GivenEquality(eq, use_) => write!(
566                f,
567                "{eq}{}",
568                use_.filter(|u| *u != NonMaxU32::ZERO)
569                    .map(|u| format!("[{u}]"))
570                    .unwrap_or_default()
571            ),
572            NodeKind::TransEquality(eq) => write!(f, "{eq}"),
573            NodeKind::Instantiation(inst) => write!(f, "{inst}"),
574            NodeKind::Proof(ps) => write!(f, "{ps}"),
575            NodeKind::Cdcl(cdcl) => write!(f, "{cdcl}"),
576        }
577    }
578}
579
580impl NodeKind {
581    pub fn inst(&self) -> Option<InstIdx> {
582        match self {
583            Self::Instantiation(inst) => Some(*inst),
584            _ => None,
585        }
586    }
587    pub fn enode(&self) -> Option<ENodeIdx> {
588        match self {
589            Self::ENode(enode) => Some(*enode),
590            _ => None,
591        }
592    }
593    pub fn eq_given(&self) -> Option<(EqGivenIdx, Option<NonMaxU32>)> {
594        match self {
595            Self::GivenEquality(eq, use_) => Some((*eq, *use_)),
596            _ => None,
597        }
598    }
599    pub fn eq_trans(&self) -> Option<EqTransIdx> {
600        match self {
601            Self::TransEquality(eq) => Some(*eq),
602            _ => None,
603        }
604    }
605    pub fn proof(&self) -> Option<ProofIdx> {
606        match self {
607            Self::Proof(ps) => Some(*ps),
608            _ => None,
609        }
610    }
611    pub fn cdcl(&self) -> Option<CdclIdx> {
612        match self {
613            Self::Cdcl(cdcl) => Some(*cdcl),
614            _ => None,
615        }
616    }
617
618    /// Same as `reconnect_parents` but for children. Do we reconnect hidden
619    /// children of this visible node or just this node itself?
620    pub fn reconnect_child(&self, _child: &Self) -> bool {
621        // TODO: what behavior do we want here?
622        #[cfg(any())]
623        !matches!(
624            (self, child),
625            (
626                Self::ENode(..) | Self::TransEquality(..),
627                Self::Instantiation(..)
628            ) | (Self::Proof(..), Self::Proof(..))
629        );
630        false
631    }
632}
633
634#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
635#[cfg_attr(feature = "mem_dbg", copy_type)]
636#[derive(Debug, Clone, Copy)]
637pub enum EdgeKind {
638    /// Instantiation -> ENode
639    Yield,
640    /// GivenEquality -> ENode
641    YieldEq,
642    /// Proof (asserted) -> ENode
643    Asserted,
644    /// ENode -> Instantiation
645    Blame { pattern_term: u16 },
646    /// TransEquality -> Instantiation
647    BlameEq { pattern_term: u16, eq_order: u16 },
648    /// ENode -> GivenEquality (`EqualityExpl::Literal`)
649    EqualityFact,
650    /// Instantiation -> GivenEquality (`EqualityExpl::Literal`)
651    EqualityProof,
652    /// TransEquality -> GivenEquality (`EqualityExpl::Congruence`)
653    EqualityCongruence,
654    /// GivenEquality -> TransEquality (`TransitiveExplSegmentKind::Leaf`)
655    TEqualitySimple { forward: bool },
656    /// TransEquality -> TransEquality (`TransitiveExplSegmentKind::Transitive`)
657    TEqualityTransitive { forward: bool },
658    /// Proof -> Proof
659    ProofStep,
660    /// Instantiation -> Proof
661    YieldProof,
662    /// Cdcl -> Cdcl
663    Cdcl(CdclEdge),
664}
665
666#[cfg_attr(feature = "mem_dbg", derive(MemSize, MemDbg))]
667#[cfg_attr(feature = "mem_dbg", copy_type)]
668#[derive(Debug, Clone, Copy)]
669pub enum CdclEdge {
670    /// Edge deeper into the CDCL tree
671    Decide,
672    /// Edge back to a higher level in the tree
673    Backtrack,
674    /// Edge to a side branch which may later be popped by the user.
675    Sidetrack,
676    /// Edge linking a backtracked node to the correct place in the tree.
677    RetryFrom,
678}
679
680pub trait IndexesInstGraph {
681    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex;
682}
683impl IndexesInstGraph for InstIdx {
684    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
685        graph.inst_to_raw_idx()(*self)
686    }
687}
688impl IndexesInstGraph for ENodeIdx {
689    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
690        RawNodeIndex(NodeIndex::new(
691            graph.enode_idx.0.index() + usize::from(*self),
692        ))
693    }
694}
695impl IndexesInstGraph for EqTransIdx {
696    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
697        RawNodeIndex(NodeIndex::new(
698            graph.eq_trans_idx.0.index() + usize::from(*self),
699        ))
700    }
701}
702impl IndexesInstGraph for (EqGivenIdx, Option<NonMaxU32>) {
703    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
704        graph.eq_given_idx[self]
705    }
706}
707impl IndexesInstGraph for ProofIdx {
708    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
709        RawNodeIndex(NodeIndex::new(
710            graph.proofs_idx.0.index() + usize::from(*self),
711        ))
712    }
713}
714impl IndexesInstGraph for CdclIdx {
715    fn index(&self, graph: &RawInstGraph) -> RawNodeIndex {
716        RawNodeIndex(NodeIndex::new(
717            graph.cdcl_idx.0.index() + usize::from(*self),
718        ))
719    }
720}
721impl IndexesInstGraph for RawNodeIndex {
722    fn index(&self, _graph: &RawInstGraph) -> RawNodeIndex {
723        *self
724    }
725}
726
727impl<T: IndexesInstGraph> Index<T> for RawInstGraph {
728    type Output = Node;
729    fn index(&self, index: T) -> &Self::Output {
730        let index = index.index(self);
731        &self.graph[index.0]
732    }
733}
734impl<T: IndexesInstGraph> IndexMut<T> for RawInstGraph {
735    fn index_mut(&mut self, index: T) -> &mut Self::Output {
736        let index = index.index(self);
737        &mut self.graph[index.0]
738    }
739}
740
741impl Index<RawEdgeIndex> for RawInstGraph {
742    type Output = EdgeKind;
743    fn index(&self, index: RawEdgeIndex) -> &Self::Output {
744        &self.graph[index.0]
745    }
746}
747
748#[derive(Clone)]
749pub struct Neighbors<'a> {
750    pub raw: &'a RawInstGraph,
751    pub reach: &'a TiVec<RawNodeIndex, ReachNonDisabled>,
752    pub walk: WalkNeighbors,
753}
754
755impl<'a> Neighbors<'a> {
756    pub fn detach(self) -> WalkNeighbors {
757        self.walk
758    }
759
760    pub fn count_hidden(self) -> usize {
761        let raw = self.raw;
762        self.filter(|&ix| raw[ix].hidden()).count()
763    }
764}
765
766impl Iterator for Neighbors<'_> {
767    type Item = RawNodeIndex;
768    fn next(&mut self) -> Option<Self::Item> {
769        self.walk.next(self.raw, self.reach)
770    }
771}
772
773#[derive(Clone)]
774pub struct WalkNeighbors {
775    pub dir: Direction,
776    pub visited: FxHashSet<RawNodeIndex>,
777    pub stack: Vec<RawNodeIndex>,
778    pub direct: petgraph::graph::WalkNeighbors<RawIx>,
779}
780
781impl WalkNeighbors {
782    fn next_direct(
783        &mut self,
784        raw: &RawInstGraph,
785        reach: &TiVec<RawNodeIndex, ReachNonDisabled>,
786    ) -> Option<RawNodeIndex> {
787        while let Some((_, n)) = self.direct.next(&raw.graph) {
788            let n = RawNodeIndex(n);
789            let skip = reach.get(n).is_some_and(|v| !v.value());
790            if !skip {
791                return Some(n);
792            }
793        }
794        None
795    }
796
797    pub fn next(
798        &mut self,
799        raw: &RawInstGraph,
800        reach: &TiVec<RawNodeIndex, ReachNonDisabled>,
801    ) -> Option<RawNodeIndex> {
802        // TODO: decide if we want to prevent direct neighbors from being
803        // visited multiple times if there are multiple direct edges.
804        loop {
805            // let mut idx = None;
806            // while let Some((_, direct)) = self.direct.next(&raw.graph) {
807            //     let direct = RawNodeIndex(direct);
808            //     if self.visited.insert(direct) {
809            //         idx = Some(direct);
810            //         break;
811            //     }
812            // }
813            // let idx = idx.or_else(|| self.stack.pop())?;
814            let idx = self.next_direct(raw, reach).or_else(|| self.stack.pop())?;
815            let node = &raw[idx];
816            if !node.disabled() {
817                return Some(idx);
818            }
819            for n in raw.graph.neighbors_directed(idx.0, self.dir) {
820                let n = RawNodeIndex(n);
821                if self.visited.insert(n) {
822                    self.stack.push(n);
823                }
824            }
825        }
826    }
827}