Skip to main content

slotted_egraphs/rewrite/
subst_method.rs

1use crate::*;
2
3/// Specifies a certain implementation of how substitution `b[x := t]` is implemented internally.
4pub trait SubstMethod<L: Language, N: Analysis<L>> {
5    fn new_boxed() -> Box<dyn SubstMethod<L, N>>
6    where
7        Self: Sized;
8    fn subst(
9        &mut self,
10        b: AppliedId,
11        x: AppliedId,
12        t: AppliedId,
13        eg: &mut EGraph<L, N>,
14    ) -> AppliedId;
15}
16
17/// A [SubstMethod] that uses the [EGraph::get_syn_expr] of an e-class to do substitution on it.
18pub struct SynExprSubst;
19
20impl<L: Language, N: Analysis<L>> SubstMethod<L, N> for SynExprSubst {
21    fn new_boxed() -> Box<dyn SubstMethod<L, N>> {
22        Box::new(SynExprSubst)
23    }
24
25    fn subst(
26        &mut self,
27        b: AppliedId,
28        x: AppliedId,
29        t: AppliedId,
30        eg: &mut EGraph<L, N>,
31    ) -> AppliedId {
32        let term = eg.get_syn_expr(&eg.synify_app_id(b));
33        do_term_subst(eg, &term, &x, &t)
34    }
35}
36
37/// A [SubstMethod] that extracts the smallest term (measured by [AstSize]) of an e-class to do substitution on it.
38pub struct ExtractionSubst;
39
40impl<L: Language, N: Analysis<L>> SubstMethod<L, N> for ExtractionSubst {
41    fn new_boxed() -> Box<dyn SubstMethod<L, N>> {
42        Box::new(ExtractionSubst)
43    }
44
45    fn subst(
46        &mut self,
47        b: AppliedId,
48        x: AppliedId,
49        t: AppliedId,
50        eg: &mut EGraph<L, N>,
51    ) -> AppliedId {
52        let term = ast_size_extract::<L, N>(&b, eg);
53        do_term_subst(eg, &term, &x, &t)
54    }
55}
56
57// returns re[x := t]
58fn do_term_subst<L: Language, N: Analysis<L>>(
59    eg: &mut EGraph<L, N>,
60    re: &RecExpr<L>,
61    x: &AppliedId,
62    t: &AppliedId,
63) -> AppliedId {
64    let mut n = re.node.clone();
65    let mut refs: Vec<&mut AppliedId> = n.applied_id_occurrences_mut();
66    if CHECKS {
67        assert_eq!(re.children.len(), refs.len());
68    }
69    for i in 0..refs.len() {
70        *(refs[i]) = do_term_subst(eg, &re.children[i], x, t);
71    }
72    let app_id = eg.add_syn(n);
73
74    if app_id == *x {
75        return t.clone();
76    } else {
77        app_id
78    }
79}