Skip to main content

tidepool_optimize/
inline.rs

1use crate::occ::{get_occ, occ_analysis, Occ};
2use tidepool_eval::{Changed, Pass};
3use tidepool_repr::{CoreExpr, CoreFrame, MapLayer};
4use std::collections::HashMap;
5
6/// Inlining pass: eliminates single-use `LetNonRec` bindings by substituting the RHS directly at the use site.
7pub struct Inline;
8
9impl Pass for Inline {
10    fn run(&self, expr: &mut CoreExpr) -> Changed {
11        if expr.nodes.is_empty() {
12            return false;
13        }
14        let occ_map = occ_analysis(expr);
15        match try_inline(expr, &occ_map) {
16            Some(new_expr) => {
17                *expr = new_expr;
18                true
19            }
20            None => false,
21        }
22    }
23
24    fn name(&self) -> &str {
25        "Inline"
26    }
27}
28
29fn try_inline(expr: &CoreExpr, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
30    try_inline_at(expr, expr.nodes.len() - 1, occ_map)
31}
32
33fn try_inline_at(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
34    match &expr.nodes[idx] {
35        CoreFrame::LetNonRec { binder, rhs, body } => {
36            if get_occ(occ_map, *binder) == Occ::Once {
37                // Inline: substitute binder -> rhs in body
38                let body_tree = expr.extract_subtree(*body);
39                let rhs_tree = expr.extract_subtree(*rhs);
40                let inlined = tidepool_repr::subst::subst(&body_tree, *binder, &rhs_tree);
41                Some(replace_subtree(expr, idx, &inlined))
42            } else {
43                // Try children
44                try_inline_at(expr, *rhs, occ_map).or_else(|| try_inline_at(expr, *body, occ_map))
45            }
46        }
47        // Never inline LetRec, even if Once (it might be recursive via own RHS)
48        _ => try_children(expr, idx, occ_map),
49    }
50}
51
52fn try_children(expr: &CoreExpr, idx: usize, occ_map: &crate::occ::OccMap) -> Option<CoreExpr> {
53    let children = get_children(&expr.nodes[idx]);
54    for child in children {
55        if let Some(result) = try_inline_at(expr, child, occ_map) {
56            return Some(result);
57        }
58    }
59    None
60}
61
62fn get_children(frame: &CoreFrame<usize>) -> Vec<usize> {
63    match frame {
64        CoreFrame::Var(_) | CoreFrame::Lit(_) => vec![],
65        CoreFrame::App { fun, arg } => vec![*fun, *arg],
66        CoreFrame::Lam { body, .. } => vec![*body],
67        CoreFrame::LetNonRec { rhs, body, .. } => vec![*rhs, *body],
68        CoreFrame::LetRec { bindings, body, .. } => {
69            let mut c: Vec<usize> = bindings.iter().map(|(_, r)| *r).collect();
70            c.push(*body);
71            c
72        }
73        CoreFrame::Case {
74            scrutinee, alts, ..
75        } => {
76            let mut c = vec![*scrutinee];
77            for alt in alts {
78                c.push(alt.body);
79            }
80            c
81        }
82        CoreFrame::Con { fields, .. } => fields.clone(),
83        CoreFrame::Join { rhs, body, .. } => vec![*rhs, *body],
84        CoreFrame::Jump { args, .. } => args.clone(),
85        CoreFrame::PrimOp { args, .. } => args.clone(),
86    }
87}
88
89fn replace_subtree(expr: &CoreExpr, target_idx: usize, replacement: &CoreExpr) -> CoreExpr {
90    let mut new_nodes = Vec::new();
91    let mut old_to_new = HashMap::new();
92    rebuild(
93        expr,
94        expr.nodes.len() - 1,
95        target_idx,
96        replacement,
97        &mut new_nodes,
98        &mut old_to_new,
99    );
100    CoreExpr { nodes: new_nodes }
101}
102
103fn rebuild(
104    expr: &CoreExpr,
105    idx: usize,
106    target: usize,
107    replacement: &CoreExpr,
108    new_nodes: &mut Vec<CoreFrame<usize>>,
109    old_to_new: &mut HashMap<usize, usize>,
110) -> usize {
111    if let Some(&ni) = old_to_new.get(&idx) {
112        return ni;
113    }
114    if idx == target {
115        let offset = new_nodes.len();
116        for node in &replacement.nodes {
117            new_nodes.push(node.clone().map_layer(|i| i + offset));
118        }
119        let root = new_nodes.len() - 1;
120        old_to_new.insert(idx, root);
121        return root;
122    }
123    let mapped = expr.nodes[idx]
124        .clone()
125        .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
126    let new_idx = new_nodes.len();
127    new_nodes.push(mapped);
128    old_to_new.insert(idx, new_idx);
129    new_idx
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use tidepool_eval::{eval, Env, VecHeap};
136    use tidepool_repr::{Literal, PrimOpKind, VarId};
137
138    fn tree(nodes: Vec<CoreFrame<usize>>) -> CoreExpr {
139        CoreExpr { nodes }
140    }
141
142    // 1. let x = 42 in x -> 42. Binder Once, inlined.
143    #[test]
144    fn test_inline_single_use() {
145        let x = VarId(1);
146        let mut expr = tree(vec![
147            CoreFrame::Lit(Literal::LitInt(42)), // 0
148            CoreFrame::Var(x),                   // 1
149            CoreFrame::LetNonRec {
150                binder: x,
151                rhs: 0,
152                body: 1,
153            }, // 2
154        ]);
155        let pass = Inline;
156        let changed = pass.run(&mut expr);
157        assert!(changed);
158        assert_eq!(expr.nodes.len(), 1);
159        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
160    }
161
162    // 2. let x = 42 in x + x -> unchanged. Binder Many, not inlined.
163    #[test]
164    fn test_inline_multi_use_preserved() {
165        let x = VarId(1);
166        let mut expr = tree(vec![
167            CoreFrame::Lit(Literal::LitInt(42)), // 0
168            CoreFrame::Var(x),                   // 1
169            CoreFrame::Var(x),                   // 2
170            CoreFrame::PrimOp {
171                op: PrimOpKind::IntAdd,
172                args: vec![1, 2],
173            }, // 3
174            CoreFrame::LetNonRec {
175                binder: x,
176                rhs: 0,
177                body: 3,
178            }, // 4
179        ]);
180        let pass = Inline;
181        let changed = pass.run(&mut expr);
182        assert!(!changed);
183    }
184
185    // 3. let x = 42 in 0 -> unchanged by inline (DCE will handle dead bindings).
186    #[test]
187    fn test_inline_dead_preserved() {
188        let x = VarId(1);
189        let mut expr = tree(vec![
190            CoreFrame::Lit(Literal::LitInt(42)), // 0
191            CoreFrame::Lit(Literal::LitInt(0)),  // 1
192            CoreFrame::LetNonRec {
193                binder: x,
194                rhs: 0,
195                body: 1,
196            }, // 2
197        ]);
198        let pass = Inline;
199        let changed = pass.run(&mut expr);
200        assert!(!changed);
201    }
202
203    // 4. let x = 1 in let y = x in y -> after two passes: 1.
204    #[test]
205    fn test_inline_nested() {
206        let x = VarId(1);
207        let y = VarId(2);
208        let mut expr = tree(vec![
209            CoreFrame::Lit(Literal::LitInt(1)), // 0
210            CoreFrame::Var(x),                  // 1
211            CoreFrame::Var(y),                  // 2
212            CoreFrame::LetNonRec {
213                binder: y,
214                rhs: 1,
215                body: 2,
216            }, // 3
217            CoreFrame::LetNonRec {
218                binder: x,
219                rhs: 0,
220                body: 3,
221            }, // 4
222        ]);
223        let pass = Inline;
224
225        // Pass 1: inline x = 1 (outer let), producing: let y = 1 in y
226        assert!(pass.run(&mut expr));
227        // Pass 2: inline y = 1 (inner let), producing: 1
228        assert!(pass.run(&mut expr));
229        // Result should be 1
230        assert_eq!(expr.nodes.len(), 1);
231        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(1)));
232    }
233
234    // 5. letrec f = f in f -> unchanged. LetRec binder Once but must NOT inline.
235    #[test]
236    fn test_inline_letrec_not_inlined() {
237        let f = VarId(1);
238        let mut expr = tree(vec![
239            CoreFrame::Var(f), // 0
240            CoreFrame::Var(f), // 1
241            CoreFrame::LetRec {
242                bindings: vec![(f, 0)],
243                body: 1,
244            }, // 2
245        ]);
246        let pass = Inline;
247        let changed = pass.run(&mut expr);
248        assert!(!changed);
249    }
250
251    // 6. let x = y in \y. x -> \y'. y (fresh y').
252    #[test]
253    fn test_inline_capture_avoiding() {
254        let x = VarId(1);
255        let y = VarId(2);
256        let mut expr = tree(vec![
257            CoreFrame::Var(y),                     // 0: rhs
258            CoreFrame::Var(x),                     // 1
259            CoreFrame::Lam { binder: y, body: 1 }, // 2: body
260            CoreFrame::LetNonRec {
261                binder: x,
262                rhs: 0,
263                body: 2,
264            }, // 3
265        ]);
266        let pass = Inline;
267        let changed = pass.run(&mut expr);
268        assert!(changed);
269
270        // Result should be \y'. y
271        let root = expr.nodes.len() - 1;
272        if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
273            assert_ne!(*binder, y);
274            if let CoreFrame::Var(v) = &expr.nodes[*body] {
275                assert_eq!(*v, y);
276            } else {
277                panic!("Body should be Var(y)");
278            }
279        } else {
280            panic!("Result should be Lam");
281        }
282    }
283
284    // 7. test_inline_preserves_eval: Build let x = 21 in x + x (Many, no inline) and let x = 21 in x (Once, inline). Eval before/after, verify match.
285    #[test]
286    fn test_inline_preserves_eval() {
287        let x = VarId(1);
288
289        // Case A: Once (should inline)
290        let expr_once = tree(vec![
291            CoreFrame::Lit(Literal::LitInt(21)),
292            CoreFrame::Var(x),
293            CoreFrame::LetNonRec {
294                binder: x,
295                rhs: 0,
296                body: 1,
297            },
298        ]);
299        let mut expr_once_reduced = expr_once.clone();
300        Inline.run(&mut expr_once_reduced);
301
302        let mut heap = VecHeap::new();
303        let env = Env::new();
304        let v1 = eval(&expr_once, &env, &mut heap).unwrap();
305        let v2 = eval(&expr_once_reduced, &env, &mut heap).unwrap();
306        match (v1, v2) {
307            (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) => assert_eq!(l1, l2),
308            _ => panic!("Expected literals"),
309        }
310
311        // Case B: Many (should NOT inline)
312        let mut expr_many = tree(vec![
313            CoreFrame::Lit(Literal::LitInt(21)),
314            CoreFrame::Var(x),
315            CoreFrame::Var(x),
316            CoreFrame::PrimOp {
317                op: PrimOpKind::IntAdd,
318                args: vec![1, 2],
319            },
320            CoreFrame::LetNonRec {
321                binder: x,
322                rhs: 0,
323                body: 3,
324            },
325        ]);
326        let expr_many_orig = expr_many.clone();
327        Inline.run(&mut expr_many);
328        assert_eq!(expr_many, expr_many_orig);
329    }
330}