slotted_egraphs/extract/
mod.rs1use crate::*;
2
3mod cost;
4pub use cost::*;
5
6mod with_ord;
7pub use with_ord::*;
8
9use std::collections::BinaryHeap;
10
11pub struct Extractor<L: Language, CF: CostFunction<L>> {
16 pub(crate) map: HashMap<Id, WithOrdRev<L, CF::Cost>>,
17}
18
19impl<L: Language, CF: CostFunction<L>> Extractor<L, CF> {
20 pub fn new<N: Analysis<L>>(eg: &EGraph<L, N>, cost_fn: CF) -> Self {
21 if CHECKS {
22 eg.check();
23 }
24
25 let mut map: HashMap<Id, WithOrdRev<L, CF::Cost>> = HashMap::default();
31 let mut queue: BinaryHeap<WithOrdRev<L, CF::Cost>> = BinaryHeap::new();
32
33 for id in eg.ids() {
34 for x in eg.enodes(id) {
35 if x.applied_id_occurrences().is_empty() {
36 let x = eg.class_nf(&x);
37 let c = cost_fn.cost(&x, |_| panic!());
38 queue.push(WithOrdRev(x, c));
39 }
40 }
41 }
42
43 while let Some(WithOrdRev(enode, c)) = queue.pop() {
44 let i = eg.lookup(&enode).unwrap();
45 if map.contains_key(&i.id) {
46 continue;
47 }
48 map.insert(i.id, WithOrdRev(enode, c));
49
50 for x in eg.usages(i.id).clone() {
51 if x.applied_id_occurrences()
52 .iter()
53 .all(|i| map.contains_key(&i.id))
54 {
55 if eg
56 .lookup(&x)
57 .map(|i| map.contains_key(&i.id))
58 .unwrap_or(false)
59 {
60 continue;
61 }
62 let x = eg.class_nf(&x);
63 let c = cost_fn.cost(&x, |i| map[&i].1.clone());
64 queue.push(WithOrdRev(x, c));
65 }
66 }
67 }
68
69 Self { map }
70 }
71
72 pub fn extract<N: Analysis<L>>(&self, i: &AppliedId, eg: &EGraph<L, N>) -> RecExpr<L> {
73 let i = eg.find_applied_id(i);
74
75 let mut children = Vec::new();
76
77 let l = self.map[&i.id].0.apply_slotmap(&i.m);
79 for child in l.applied_id_occurrences() {
80 let n = self.extract(&child, eg);
81 children.push(n);
82 }
83
84 RecExpr { node: l, children }
85 }
86
87 pub fn get_best_cost<N: Analysis<L>>(&self, i: &AppliedId) -> CF::Cost {
88 self.map[&i.id].1.clone()
89 }
90}
91
92pub fn ast_size_extract<L: Language, N: Analysis<L>>(
93 i: &AppliedId,
94 eg: &EGraph<L, N>,
95) -> RecExpr<L> {
96 extract::<L, N, AstSize>(i, eg)
97}
98
99pub fn extract<L: Language, N: Analysis<L>, CF: CostFunction<L> + Default>(
101 i: &AppliedId,
102 eg: &EGraph<L, N>,
103) -> RecExpr<L> {
104 let cost_fn = CF::default();
105 let extractor = Extractor::<L, CF>::new(eg, cost_fn);
106 let out = extractor.extract(&i, eg);
107 if CHECKS {
108 let i = eg.find_id(i.id);
109 let cost_fn = CF::default();
110 assert_eq!(cost_fn.cost_rec(&out), extractor.map[&i].1);
111 }
112 out
113}