smt_scope/analysis/graph/analysis/
cost.rs

1use petgraph::Direction;
2
3use crate::{
4    analysis::{
5        raw::{Node, NodeKind},
6        RawNodeIndex,
7    },
8    Z3Parser,
9};
10
11use super::run::{Initialiser, TransferInitialiser};
12
13pub trait CostInitialiser<const FORWARD: bool = false> {
14    /// The starting value for a node.
15    fn base(&mut self, node: &Node, parser: &Z3Parser) -> f64;
16    fn assign(&mut self, node: &mut Node, value: f64) {
17        node.cost = value;
18    }
19    /// Called between initialisations of different subgraphs.
20    fn reset(&mut self) {}
21    type Observed;
22    fn observe(&mut self, node: &Node, parser: &Z3Parser) -> Self::Observed;
23    fn transfer(
24        &mut self,
25        from: &Node,
26        from_idx: RawNodeIndex,
27        to_idx: usize,
28        to_all: &[Self::Observed],
29    ) -> f64;
30}
31impl<C: CostInitialiser<FORWARD>, const FORWARD: bool> Initialiser<FORWARD, 0> for C {
32    type Value = f64;
33    fn direction() -> Direction {
34        if FORWARD {
35            Direction::Outgoing
36        } else {
37            Direction::Incoming
38        }
39    }
40    fn base(&mut self, node: &Node, parser: &Z3Parser) -> Self::Value {
41        CostInitialiser::base(self, node, parser)
42    }
43    fn assign(&mut self, node: &mut Node, value: Self::Value) {
44        CostInitialiser::assign(self, node, value)
45    }
46    fn reset(&mut self) {
47        CostInitialiser::reset(self)
48    }
49}
50impl<C: CostInitialiser<FORWARD>, const FORWARD: bool> TransferInitialiser<FORWARD, 0> for C {
51    type Observed = C::Observed;
52    fn observe(&mut self, node: &Node, parser: &Z3Parser) -> Self::Observed {
53        CostInitialiser::observe(self, node, parser)
54    }
55    fn transfer(
56        &mut self,
57        from: &Node,
58        from_idx: RawNodeIndex,
59        to_idx: usize,
60        to_all: &[Self::Observed],
61    ) -> Self::Value {
62        CostInitialiser::transfer(self, from, from_idx, to_idx, to_all)
63    }
64    fn add(&mut self, node: &mut Node, value: Self::Value) {
65        node.cost += value;
66    }
67}
68
69pub struct DefaultCost;
70impl CostInitialiser for DefaultCost {
71    fn base(&mut self, node: &Node, _parser: &Z3Parser) -> f64 {
72        match node.kind() {
73            NodeKind::Instantiation(_) | NodeKind::Cdcl(_) if !node.disabled() => 1.0,
74            _ => 0.0,
75        }
76    }
77    type Observed = usize;
78    fn observe(&mut self, node: &Node, parser: &Z3Parser) -> Self::Observed {
79        match node.kind() {
80            NodeKind::ENode(_) => 1,
81            NodeKind::GivenEquality(_, _) => 1,
82            NodeKind::TransEquality(eq) => {
83                parser[*eq].given_len.map(|l| l.get()).unwrap_or_default()
84            }
85            NodeKind::Instantiation(_) => 1,
86            NodeKind::Proof(_) => 0,
87            NodeKind::Cdcl(_) => 1,
88        }
89    }
90    fn transfer(
91        &mut self,
92        child: &Node,
93        _from_idx: RawNodeIndex,
94        parent_idx: usize,
95        parents: &[Self::Observed],
96    ) -> f64 {
97        let total = parents.iter().sum::<usize>();
98        if total == 0 || child.kind().proof().is_some() {
99            return 0.0;
100        }
101        child.cost * parents[parent_idx] as f64 / total as f64
102    }
103}
104
105pub struct ProofCost;
106impl CostInitialiser<true> for ProofCost {
107    fn base(&mut self, node: &Node, _parser: &Z3Parser) -> f64 {
108        match node.kind() {
109            NodeKind::Proof(_) if !node.disabled() => 1.0,
110            _ => 0.0,
111        }
112    }
113    fn assign(&mut self, node: &mut Node, value: f64) {
114        if node.kind().proof().is_some() {
115            node.cost = value;
116        }
117    }
118
119    type Observed = usize;
120    fn observe(&mut self, node: &Node, _parser: &Z3Parser) -> Self::Observed {
121        match node.kind() {
122            NodeKind::Proof(_) => 1,
123            _ => 0,
124        }
125    }
126    fn transfer(
127        &mut self,
128        node: &Node,
129        _from_idx: RawNodeIndex,
130        child_idx: usize,
131        children: &[Self::Observed],
132    ) -> f64 {
133        if node.kind().proof().is_none() {
134            return 0.0;
135        }
136
137        let total = children.iter().sum::<usize>();
138        if total == 0 {
139            return 0.0;
140        }
141        node.cost * children[child_idx] as f64 / total as f64
142    }
143}