Skip to main content

tidepool_optimize/
beta.rs

1use tidepool_eval::{Changed, Pass};
2use tidepool_repr::{CoreExpr, CoreFrame, MapLayer};
3use std::collections::HashMap;
4
5/// Beta reduction pass: find `App { fun, arg }` where `fun` is a `Lam { binder, body }`.
6/// Replaces it with `subst(body, binder, arg)`.
7pub struct BetaReduce;
8
9impl Pass for BetaReduce {
10    fn run(&self, expr: &mut CoreExpr) -> Changed {
11        if expr.nodes.is_empty() {
12            return false;
13        }
14        match try_beta_reduce(expr) {
15            Some(new_expr) => {
16                *expr = new_expr;
17                true
18            }
19            None => false,
20        }
21    }
22
23    fn name(&self) -> &str {
24        "BetaReduce"
25    }
26}
27
28fn try_beta_reduce(expr: &CoreExpr) -> Option<CoreExpr> {
29    // Start from root (last node)
30    try_beta_at(expr, expr.nodes.len() - 1)
31}
32
33fn try_beta_at(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
34    match &expr.nodes[idx] {
35        CoreFrame::App { fun, arg } => {
36            // Check if fun is a Lam
37            if let CoreFrame::Lam { binder, body } = &expr.nodes[*fun] {
38                // Found a manifest beta redex!
39                let body_tree = expr.extract_subtree(*body);
40                let arg_tree = expr.extract_subtree(*arg);
41                let substituted = tidepool_repr::subst::subst(&body_tree, *binder, &arg_tree);
42                Some(replace_subtree(expr, idx, &substituted))
43            } else {
44                // Try to find redex in children
45                try_beta_at(expr, *fun).or_else(|| try_beta_at(expr, *arg))
46            }
47        }
48        // For other nodes, try each child
49        other => {
50            let mut result = None;
51            // We need to visit children. Since map_layer is for remapping indices,
52            // we can use it to "visit" indices if we are careful.
53            // But it's easier to just match on the frame and visit children.
54            match other {
55                CoreFrame::Var(_) | CoreFrame::Lit(_) => {}
56                CoreFrame::App { .. } => {
57                    unreachable!("App nodes are handled in the outer match")
58                }
59                CoreFrame::Lam { body, .. } => {
60                    result = try_beta_at(expr, *body);
61                }
62                CoreFrame::LetNonRec { rhs, body, .. } => {
63                    result = try_beta_at(expr, *rhs).or_else(|| try_beta_at(expr, *body));
64                }
65                CoreFrame::LetRec { bindings, body } => {
66                    for (_, rhs) in bindings {
67                        result = try_beta_at(expr, *rhs);
68                        if result.is_some() {
69                            break;
70                        }
71                    }
72                    if result.is_none() {
73                        result = try_beta_at(expr, *body);
74                    }
75                }
76                CoreFrame::Case {
77                    scrutinee, alts, ..
78                } => {
79                    result = try_beta_at(expr, *scrutinee);
80                    if result.is_none() {
81                        for alt in alts {
82                            result = try_beta_at(expr, alt.body);
83                            if result.is_some() {
84                                break;
85                            }
86                        }
87                    }
88                }
89                CoreFrame::Con { fields, .. } => {
90                    for field in fields {
91                        result = try_beta_at(expr, *field);
92                        if result.is_some() {
93                            break;
94                        }
95                    }
96                }
97                CoreFrame::Join { rhs, body, .. } => {
98                    result = try_beta_at(expr, *rhs).or_else(|| try_beta_at(expr, *body));
99                }
100                CoreFrame::Jump { args, .. } => {
101                    for arg in args {
102                        result = try_beta_at(expr, *arg);
103                        if result.is_some() {
104                            break;
105                        }
106                    }
107                }
108                CoreFrame::PrimOp { args, .. } => {
109                    for arg in args {
110                        result = try_beta_at(expr, *arg);
111                        if result.is_some() {
112                            break;
113                        }
114                    }
115                }
116            }
117            result
118        }
119    }
120}
121
122fn replace_subtree(expr: &CoreExpr, target_idx: usize, replacement: &CoreExpr) -> CoreExpr {
123    let mut new_nodes = Vec::new();
124    let mut old_to_new = HashMap::new();
125
126    fn rebuild(
127        expr: &CoreExpr,
128        idx: usize,
129        target: usize,
130        replacement: &CoreExpr,
131        new_nodes: &mut Vec<CoreFrame<usize>>,
132        old_to_new: &mut HashMap<usize, usize>,
133    ) -> usize {
134        if let Some(&ni) = old_to_new.get(&idx) {
135            return ni;
136        }
137
138        if idx == target {
139            // Splice replacement
140            let offset = new_nodes.len();
141            for node in &replacement.nodes {
142                let mapped = node.clone().map_layer(|i| i + offset);
143                new_nodes.push(mapped);
144            }
145            let root = new_nodes.len() - 1;
146            old_to_new.insert(idx, root);
147            return root;
148        }
149
150        let mapped = expr.nodes[idx]
151            .clone()
152            .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
153        let new_idx = new_nodes.len();
154        new_nodes.push(mapped);
155        old_to_new.insert(idx, new_idx);
156        new_idx
157    }
158
159    rebuild(
160        expr,
161        expr.nodes.len() - 1,
162        target_idx,
163        replacement,
164        &mut new_nodes,
165        &mut old_to_new,
166    );
167    CoreExpr { nodes: new_nodes }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use tidepool_eval::{eval, Env, VecHeap};
174    use tidepool_repr::{Literal, VarId};
175
176    #[test]
177    fn test_beta_identity() {
178        // (λx.x) 42 → 42
179        let x = VarId(1);
180        let nodes = vec![
181            CoreFrame::Var(x),                     // 0: x
182            CoreFrame::Lam { binder: x, body: 0 }, // 1: λx.x
183            CoreFrame::Lit(Literal::LitInt(42)),   // 2: 42
184            CoreFrame::App { fun: 1, arg: 2 },     // 3: (λx.x) 42
185        ];
186        let mut expr = CoreExpr { nodes };
187        let pass = BetaReduce;
188        let changed = pass.run(&mut expr);
189
190        assert!(changed);
191        assert_eq!(expr.nodes.len(), 1);
192        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
193    }
194
195    #[test]
196    fn test_beta_const() {
197        // (λx.λy.x) 1 → λy.1
198        let x = VarId(1);
199        let y = VarId(2);
200        let nodes = vec![
201            CoreFrame::Var(x),                     // 0: x
202            CoreFrame::Lam { binder: y, body: 0 }, // 1: λy.x
203            CoreFrame::Lam { binder: x, body: 1 }, // 2: λx.λy.x
204            CoreFrame::Lit(Literal::LitInt(1)),    // 3: 1
205            CoreFrame::App { fun: 2, arg: 3 },     // 4: (λx.λy.x) 1
206        ];
207        let mut expr = CoreExpr { nodes };
208        let pass = BetaReduce;
209        let changed = pass.run(&mut expr);
210
211        assert!(changed);
212        // Result should be λy.1
213        let root = expr.nodes.len() - 1;
214        if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
215            assert_eq!(*binder, y);
216            if let CoreFrame::Lit(Literal::LitInt(1)) = &expr.nodes[*body] {
217                // OK
218            } else {
219                panic!("Body should be 1, got {:?}", expr.nodes[*body]);
220            }
221        } else {
222            panic!("Result should be Lam, got {:?}", expr.nodes[root]);
223        }
224    }
225
226    #[test]
227    fn test_beta_no_redex() {
228        // (λx.x)
229        let x = VarId(1);
230        let nodes = vec![
231            CoreFrame::Var(x),                     // 0: x
232            CoreFrame::Lam { binder: x, body: 0 }, // 1: λx.x
233        ];
234        let mut expr = CoreExpr { nodes };
235        let pass = BetaReduce;
236        let changed = pass.run(&mut expr);
237        assert!(!changed);
238    }
239
240    #[test]
241    fn test_beta_capture_avoiding() {
242        // (λx.λy.x) y → λy'.y (y' fresh)
243        let x = VarId(1);
244        let y = VarId(2);
245        let nodes = vec![
246            CoreFrame::Var(x),                     // 0: x
247            CoreFrame::Lam { binder: y, body: 0 }, // 1: λy.x
248            CoreFrame::Lam { binder: x, body: 1 }, // 2: λx.λy.x
249            CoreFrame::Var(y),                     // 3: y
250            CoreFrame::App { fun: 2, arg: 3 },     // 4: (λx.λy.x) y
251        ];
252        let mut expr = CoreExpr { nodes };
253        let pass = BetaReduce;
254        let changed = pass.run(&mut expr);
255
256        assert!(changed);
257        let root = expr.nodes.len() - 1;
258        if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
259            assert_ne!(*binder, y); // Should be renamed
260            if let CoreFrame::Var(v) = &expr.nodes[*body] {
261                assert_eq!(*v, y); // Should refer to the free y
262            } else {
263                panic!("Body should be Var(y)");
264            }
265        } else {
266            panic!("Result should be Lam");
267        }
268    }
269
270    #[test]
271    fn test_beta_preserves_eval() {
272        // (λx. x + x) 21
273        let x = VarId(1);
274        let nodes = vec![
275            CoreFrame::Var(x), // 0: x
276            CoreFrame::PrimOp {
277                op: tidepool_repr::PrimOpKind::IntAdd,
278                args: vec![0, 0],
279            }, // 1: x + x
280            CoreFrame::Lam { binder: x, body: 1 }, // 2: λx. x + x
281            CoreFrame::Lit(Literal::LitInt(21)), // 3: 21
282            CoreFrame::App { fun: 2, arg: 3 }, // 4: (λx. x + x) 21
283        ];
284        let expr_orig = CoreExpr { nodes };
285        let mut expr_reduced = expr_orig.clone();
286        let pass = BetaReduce;
287        pass.run(&mut expr_reduced);
288
289        let mut heap = VecHeap::new();
290        let env = Env::new();
291
292        let val_orig = eval(&expr_orig, &env, &mut heap).expect("Original eval failed");
293        let val_reduced = eval(&expr_reduced, &env, &mut heap).expect("Reduced eval failed");
294
295        if let (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) = (&val_orig, &val_reduced) {
296            assert_eq!(l1, l2);
297        } else {
298            panic!(
299                "Expected literal results, got {:?} and {:?}",
300                val_orig, val_reduced
301            );
302        }
303
304        if let tidepool_eval::Value::Lit(Literal::LitInt(n)) = val_orig {
305            assert_eq!(n, 42);
306        } else {
307            panic!("Expected 42");
308        }
309    }
310
311    #[test]
312    fn test_beta_nested() {
313        // (λx.x) ((λy.y) 42)
314        let x = VarId(1);
315        let y = VarId(2);
316        let nodes = vec![
317            CoreFrame::Var(y),                     // 0: y
318            CoreFrame::Lam { binder: y, body: 0 }, // 1: λy.y
319            CoreFrame::Lit(Literal::LitInt(42)),   // 2: 42
320            CoreFrame::App { fun: 1, arg: 2 },     // 3: (λy.y) 42
321            CoreFrame::Var(x),                     // 4: x
322            CoreFrame::Lam { binder: x, body: 4 }, // 5: λx.x
323            CoreFrame::App { fun: 5, arg: 3 },     // 6: (λx.x) ((λy.y) 42)
324        ];
325        let mut expr = CoreExpr { nodes };
326        let pass = BetaReduce;
327
328        // Run until fixpoint
329        while pass.run(&mut expr) {}
330
331        assert_eq!(expr.nodes.len(), 1);
332        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
333    }
334}