Skip to main content

tidepool_optimize/
dce.rs

1use crate::occ::{get_occ, occ_analysis, Occ};
2use tidepool_eval::{Changed, Pass};
3use tidepool_repr::{CoreExpr, CoreFrame, MapLayer};
4use std::collections::HashMap;
5
6/// Dead Code Elimination pass.
7/// Removes `LetNonRec` bindings where the binder is unused.
8/// Removes `LetRec` groups where all binders are unused.
9pub struct Dce;
10
11impl Pass for Dce {
12    fn run(&self, expr: &mut CoreExpr) -> Changed {
13        if expr.nodes.is_empty() {
14            return false;
15        }
16        let occ_map = occ_analysis(expr);
17        match try_dce(expr, &occ_map) {
18            Some(new_expr) => {
19                *expr = new_expr;
20                true
21            }
22            None => false,
23        }
24    }
25
26    fn name(&self) -> &str {
27        "Dce"
28    }
29}
30
31fn try_dce(expr: &CoreExpr, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
32    try_dce_at(expr, expr.nodes.len() - 1, occ_map)
33}
34
35fn try_dce_at(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
36    match &expr.nodes[idx] {
37        CoreFrame::LetNonRec { binder, body, .. } => {
38            if get_occ(occ_map, *binder) == Occ::Dead {
39                // Drop the binding, keep just body
40                let body_tree = expr.extract_subtree(*body);
41                Some(replace_subtree(expr, idx, &body_tree))
42            } else {
43                try_children(expr, idx, occ_map)
44            }
45        }
46        CoreFrame::LetRec { bindings, body } => {
47            let all_dead = bindings
48                .iter()
49                .all(|(binder, _)| get_occ(occ_map, *binder) == Occ::Dead);
50            if all_dead {
51                // Drop the entire group, keep just body
52                let body_tree = expr.extract_subtree(*body);
53                Some(replace_subtree(expr, idx, &body_tree))
54            } else {
55                try_children(expr, idx, occ_map)
56            }
57        }
58        _ => try_children(expr, idx, occ_map),
59    }
60}
61
62fn try_children(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
63    let children = get_children(&expr.nodes[idx]);
64    for child in children {
65        if let Some(result) = try_dce_at(expr, child, occ_map) {
66            return Some(result);
67        }
68    }
69    None
70}
71
72fn get_children(frame: &CoreFrame<usize>) -> Vec<usize> {
73    match frame {
74        CoreFrame::Var(_) | CoreFrame::Lit(_) => vec![],
75        CoreFrame::App { fun, arg } => vec![*fun, *arg],
76        CoreFrame::Lam { body, .. } => vec![*body],
77        CoreFrame::LetNonRec { rhs, body, .. } => vec![*rhs, *body],
78        CoreFrame::LetRec { bindings, body } => {
79            let mut c: Vec<usize> = bindings.iter().map(|(_, r)| *r).collect();
80            c.push(*body);
81            c
82        }
83        CoreFrame::Case {
84            scrutinee, alts, ..
85        } => {
86            let mut c = vec![*scrutinee];
87            for alt in alts {
88                c.push(alt.body);
89            }
90            c
91        }
92        CoreFrame::Con { fields, .. } => fields.clone(),
93        CoreFrame::Join { rhs, body, .. } => vec![*rhs, *body],
94        CoreFrame::Jump { args, .. } => args.clone(),
95        CoreFrame::PrimOp { args, .. } => args.clone(),
96    }
97}
98
99fn replace_subtree(expr: &CoreExpr, target_idx: usize, replacement: &CoreExpr) -> CoreExpr {
100    let mut new_nodes = Vec::new();
101    let mut old_to_new = HashMap::new();
102    rebuild(
103        expr,
104        expr.nodes.len() - 1,
105        target_idx,
106        replacement,
107        &mut new_nodes,
108        &mut old_to_new,
109    );
110    CoreExpr { nodes: new_nodes }
111}
112
113fn rebuild(
114    expr: &CoreExpr,
115    idx: usize,
116    target: usize,
117    replacement: &CoreExpr,
118    new_nodes: &mut Vec<CoreFrame<usize>>,
119    old_to_new: &mut HashMap<usize, usize>,
120) -> usize {
121    if let Some(&ni) = old_to_new.get(&idx) {
122        return ni;
123    }
124    if idx == target {
125        let offset = new_nodes.len();
126        for node in &replacement.nodes {
127            new_nodes.push(node.clone().map_layer(|i| i + offset));
128        }
129        let root = new_nodes.len() - 1;
130        old_to_new.insert(idx, root);
131        return root;
132    }
133    let mapped = expr.nodes[idx]
134        .clone()
135        .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
136    let new_idx = new_nodes.len();
137    new_nodes.push(mapped);
138    old_to_new.insert(idx, new_idx);
139    new_idx
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use tidepool_eval::{eval, Env, VecHeap};
146    use tidepool_repr::{Literal, VarId};
147
148    // Helper to build a tree
149    fn tree(nodes: Vec<CoreFrame<usize>>) -> CoreExpr {
150        CoreExpr { nodes }
151    }
152
153    // 1. test_dce_dead_let: let x = 42 in 0 -> 0. Binder Dead, dropped.
154    #[test]
155    fn test_dce_dead_let() {
156        let x = VarId(1);
157        let expr = tree(vec![
158            CoreFrame::Lit(Literal::LitInt(42)), // 0: rhs
159            CoreFrame::Lit(Literal::LitInt(0)),  // 1: body
160            CoreFrame::LetNonRec {
161                binder: x,
162                rhs: 0,
163                body: 1,
164            }, // 2: root
165        ]);
166        let mut dce_expr = expr;
167        let changed = Dce.run(&mut dce_expr);
168        assert!(changed);
169        assert_eq!(dce_expr.nodes.len(), 1);
170        assert_eq!(dce_expr.nodes[0], CoreFrame::Lit(Literal::LitInt(0)));
171    }
172
173    // 2. test_dce_live_let_preserved: let x = 42 in x -> unchanged. Binder Once, kept.
174    #[test]
175    fn test_dce_live_let_preserved() {
176        let x = VarId(1);
177        let expr = tree(vec![
178            CoreFrame::Lit(Literal::LitInt(42)), // 0: rhs
179            CoreFrame::Var(x),                   // 1: body
180            CoreFrame::LetNonRec {
181                binder: x,
182                rhs: 0,
183                body: 1,
184            }, // 2: root
185        ]);
186        let mut dce_expr = expr.clone();
187        let changed = Dce.run(&mut dce_expr);
188        assert!(!changed);
189        assert_eq!(dce_expr, expr);
190    }
191
192    // 3. test_dce_letrec_all_dead: letrec { f = 1; g = 2 } in 0 -> 0. All Dead, drop entire group.
193    #[test]
194    fn test_dce_letrec_all_dead() {
195        let f = VarId(1);
196        let g = VarId(2);
197        let expr = tree(vec![
198            CoreFrame::Lit(Literal::LitInt(1)), // 0: f's rhs
199            CoreFrame::Lit(Literal::LitInt(2)), // 1: g's rhs
200            CoreFrame::Lit(Literal::LitInt(0)), // 2: body
201            CoreFrame::LetRec {
202                bindings: vec![(f, 0), (g, 1)],
203                body: 2,
204            }, // 3: root
205        ]);
206        let mut dce_expr = expr;
207        let changed = Dce.run(&mut dce_expr);
208        assert!(changed);
209        assert_eq!(dce_expr.nodes.len(), 1);
210        assert_eq!(dce_expr.nodes[0], CoreFrame::Lit(Literal::LitInt(0)));
211    }
212
213    // 4. test_dce_letrec_one_live: letrec { f = g; g = 1 } in f -> unchanged.
214    // f is Once (live), keep entire group even though g might be referenced only by f.
215    #[test]
216    fn test_dce_letrec_one_live() {
217        let f = VarId(1);
218        let g = VarId(2);
219        let expr = tree(vec![
220            CoreFrame::Var(g),                  // 0: f's rhs
221            CoreFrame::Lit(Literal::LitInt(1)), // 1: g's rhs
222            CoreFrame::Var(f),                  // 2: body
223            CoreFrame::LetRec {
224                bindings: vec![(f, 0), (g, 1)],
225                body: 2,
226            }, // 3: root
227        ]);
228        let mut dce_expr = expr.clone();
229        let changed = Dce.run(&mut dce_expr);
230        assert!(!changed);
231        assert_eq!(dce_expr, expr);
232    }
233
234    // 5. test_dce_nested: let x = 42 in let y = 0 in x -> after DCE drops y's let, result is let x = 42 in x.
235    #[test]
236    fn test_dce_nested() {
237        let x = VarId(1);
238        let y = VarId(2);
239        let expr = tree(vec![
240            CoreFrame::Lit(Literal::LitInt(42)), // 0: x's rhs
241            CoreFrame::Lit(Literal::LitInt(0)),  // 1: y's rhs
242            CoreFrame::Var(x),                   // 2: y's body
243            CoreFrame::LetNonRec {
244                binder: y,
245                rhs: 1,
246                body: 2,
247            }, // 3: x's body
248            CoreFrame::LetNonRec {
249                binder: x,
250                rhs: 0,
251                body: 3,
252            }, // 4: root
253        ]);
254        let mut dce_expr = expr;
255        let changed = Dce.run(&mut dce_expr);
256        assert!(changed);
257        // Should have dropped y
258        // let x = 42 in x
259        assert_eq!(dce_expr.nodes.len(), 3);
260        // The root should be a LetNonRec for x
261        let root_idx = dce_expr.nodes.len() - 1;
262        if let CoreFrame::LetNonRec { binder, .. } = &dce_expr.nodes[root_idx] {
263            assert_eq!(*binder, x);
264        } else {
265            panic!(
266                "Root should be LetNonRec for x, got {:?}",
267                dce_expr.nodes[root_idx]
268            );
269        }
270    }
271
272    // 6. test_dce_preserves_eval: let x = 42 in let y = 99 in x -> eval before/after, verify match.
273    #[test]
274    fn test_dce_preserves_eval() {
275        let x = VarId(1);
276        let y = VarId(2);
277        let expr = tree(vec![
278            CoreFrame::Lit(Literal::LitInt(42)), // 0: x's rhs
279            CoreFrame::Lit(Literal::LitInt(99)), // 1: y's rhs
280            CoreFrame::Var(x),                   // 2: y's body
281            CoreFrame::LetNonRec {
282                binder: y,
283                rhs: 1,
284                body: 2,
285            }, // 3: x's body
286            CoreFrame::LetNonRec {
287                binder: x,
288                rhs: 0,
289                body: 3,
290            }, // 4: root
291        ]);
292        let mut dce_expr = expr.clone();
293
294        let mut heap = VecHeap::new();
295        let env = Env::new();
296
297        let val_orig = eval(&expr, &env, &mut heap).expect("Original eval failed");
298
299        let changed = Dce.run(&mut dce_expr);
300        assert!(changed);
301
302        let val_dce = eval(&dce_expr, &env, &mut heap).expect("DCE eval failed");
303
304        match (val_orig, val_dce) {
305            (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) => assert_eq!(l1, l2),
306            _ => panic!("Expected literals"),
307        }
308    }
309}