Skip to main content

tidepool_optimize/
beta.rs

1use tidepool_eval::{Changed, Pass};
2use tidepool_repr::{replace_subtree, CoreExpr, CoreFrame};
3
4/// Beta reduction pass: find `App { fun, arg }` where `fun` is a `Lam { binder, body }`.
5/// Replaces it with `subst(body, binder, arg)`.
6pub struct BetaReduce;
7
8impl Pass for BetaReduce {
9    fn run(&self, expr: &mut CoreExpr) -> Changed {
10        if expr.nodes.is_empty() {
11            return false;
12        }
13        match try_beta_reduce(expr) {
14            Some(new_expr) => {
15                *expr = new_expr;
16                true
17            }
18            None => false,
19        }
20    }
21
22    fn name(&self) -> &str {
23        "BetaReduce"
24    }
25}
26
27fn try_beta_reduce(expr: &CoreExpr) -> Option<CoreExpr> {
28    // Start from root (last node)
29    try_beta_at(expr, expr.nodes.len() - 1)
30}
31
32fn try_beta_at(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
33    match &expr.nodes[idx] {
34        CoreFrame::App { fun, arg } => {
35            // Check if fun is a Lam
36            if let CoreFrame::Lam { binder, body } = &expr.nodes[*fun] {
37                // Found a manifest beta redex!
38                let body_tree = expr.extract_subtree(*body);
39                let arg_tree = expr.extract_subtree(*arg);
40                let substituted = tidepool_repr::subst::subst(&body_tree, *binder, &arg_tree);
41                Some(replace_subtree(expr, idx, &substituted))
42            } else {
43                // Try to find redex in children
44                try_beta_at(expr, *fun).or_else(|| try_beta_at(expr, *arg))
45            }
46        }
47        // For other nodes, try each child
48        other => {
49            let mut result = None;
50            // We need to visit children. Since map_layer is for remapping indices,
51            // we can use it to "visit" indices if we are careful.
52            // But it's easier to just match on the frame and visit children.
53            match other {
54                CoreFrame::Var(_) | CoreFrame::Lit(_) => {}
55                CoreFrame::App { .. } => {
56                    // App nodes are handled in the outer match — this arm should never fire.
57                    return None;
58                }
59                CoreFrame::Lam { body, .. } => {
60                    result = try_beta_at(expr, *body);
61                }
62                CoreFrame::LetNonRec { rhs, body, .. } => {
63                    result = try_beta_at(expr, *rhs).or_else(|| try_beta_at(expr, *body));
64                }
65                CoreFrame::LetRec { bindings, body } => {
66                    for (_, rhs) in bindings {
67                        result = try_beta_at(expr, *rhs);
68                        if result.is_some() {
69                            break;
70                        }
71                    }
72                    if result.is_none() {
73                        result = try_beta_at(expr, *body);
74                    }
75                }
76                CoreFrame::Case {
77                    scrutinee, alts, ..
78                } => {
79                    result = try_beta_at(expr, *scrutinee);
80                    if result.is_none() {
81                        for alt in alts {
82                            result = try_beta_at(expr, alt.body);
83                            if result.is_some() {
84                                break;
85                            }
86                        }
87                    }
88                }
89                CoreFrame::Con { fields, .. } => {
90                    for field in fields {
91                        result = try_beta_at(expr, *field);
92                        if result.is_some() {
93                            break;
94                        }
95                    }
96                }
97                CoreFrame::Join { rhs, body, .. } => {
98                    result = try_beta_at(expr, *rhs).or_else(|| try_beta_at(expr, *body));
99                }
100                CoreFrame::Jump { args, .. } => {
101                    for arg in args {
102                        result = try_beta_at(expr, *arg);
103                        if result.is_some() {
104                            break;
105                        }
106                    }
107                }
108                CoreFrame::PrimOp { args, .. } => {
109                    for arg in args {
110                        result = try_beta_at(expr, *arg);
111                        if result.is_some() {
112                            break;
113                        }
114                    }
115                }
116            }
117            result
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use tidepool_eval::{eval, Env, VecHeap};
126    use tidepool_repr::{Literal, VarId};
127
128    #[test]
129    fn test_beta_identity() {
130        // (λx.x) 42 → 42
131        let x = VarId(1);
132        let nodes = vec![
133            CoreFrame::Var(x),                     // 0: x
134            CoreFrame::Lam { binder: x, body: 0 }, // 1: λx.x
135            CoreFrame::Lit(Literal::LitInt(42)),   // 2: 42
136            CoreFrame::App { fun: 1, arg: 2 },     // 3: (λx.x) 42
137        ];
138        let mut expr = CoreExpr { nodes };
139        let pass = BetaReduce;
140        let changed = pass.run(&mut expr);
141
142        assert!(changed);
143        assert_eq!(expr.nodes.len(), 1);
144        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
145    }
146
147    #[test]
148    fn test_beta_const() {
149        // (λx.λy.x) 1 → λy.1
150        let x = VarId(1);
151        let y = VarId(2);
152        let nodes = vec![
153            CoreFrame::Var(x),                     // 0: x
154            CoreFrame::Lam { binder: y, body: 0 }, // 1: λy.x
155            CoreFrame::Lam { binder: x, body: 1 }, // 2: λx.λy.x
156            CoreFrame::Lit(Literal::LitInt(1)),    // 3: 1
157            CoreFrame::App { fun: 2, arg: 3 },     // 4: (λx.λy.x) 1
158        ];
159        let mut expr = CoreExpr { nodes };
160        let pass = BetaReduce;
161        let changed = pass.run(&mut expr);
162
163        assert!(changed);
164        // Result should be λy.1
165        let root = expr.nodes.len() - 1;
166        if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
167            assert_eq!(*binder, y);
168            if let CoreFrame::Lit(Literal::LitInt(1)) = &expr.nodes[*body] {
169                // OK
170            } else {
171                panic!("Body should be 1, got {:?}", expr.nodes[*body]);
172            }
173        } else {
174            panic!("Result should be Lam, got {:?}", expr.nodes[root]);
175        }
176    }
177
178    #[test]
179    fn test_beta_no_redex() {
180        // (λx.x)
181        let x = VarId(1);
182        let nodes = vec![
183            CoreFrame::Var(x),                     // 0: x
184            CoreFrame::Lam { binder: x, body: 0 }, // 1: λx.x
185        ];
186        let mut expr = CoreExpr { nodes };
187        let pass = BetaReduce;
188        let changed = pass.run(&mut expr);
189        assert!(!changed);
190    }
191
192    #[test]
193    fn test_beta_capture_avoiding() {
194        // (λx.λy.x) y → λy'.y (y' fresh)
195        let x = VarId(1);
196        let y = VarId(2);
197        let nodes = vec![
198            CoreFrame::Var(x),                     // 0: x
199            CoreFrame::Lam { binder: y, body: 0 }, // 1: λy.x
200            CoreFrame::Lam { binder: x, body: 1 }, // 2: λx.λy.x
201            CoreFrame::Var(y),                     // 3: y
202            CoreFrame::App { fun: 2, arg: 3 },     // 4: (λx.λy.x) y
203        ];
204        let mut expr = CoreExpr { nodes };
205        let pass = BetaReduce;
206        let changed = pass.run(&mut expr);
207
208        assert!(changed);
209        let root = expr.nodes.len() - 1;
210        if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
211            assert_ne!(*binder, y); // Should be renamed
212            if let CoreFrame::Var(v) = &expr.nodes[*body] {
213                assert_eq!(*v, y); // Should refer to the free y
214            } else {
215                panic!("Body should be Var(y)");
216            }
217        } else {
218            panic!("Result should be Lam");
219        }
220    }
221
222    #[test]
223    fn test_beta_preserves_eval() {
224        // (λx. x + x) 21
225        let x = VarId(1);
226        let nodes = vec![
227            CoreFrame::Var(x), // 0: x
228            CoreFrame::PrimOp {
229                op: tidepool_repr::PrimOpKind::IntAdd,
230                args: vec![0, 0],
231            }, // 1: x + x
232            CoreFrame::Lam { binder: x, body: 1 }, // 2: λx. x + x
233            CoreFrame::Lit(Literal::LitInt(21)), // 3: 21
234            CoreFrame::App { fun: 2, arg: 3 }, // 4: (λx. x + x) 21
235        ];
236        let expr_orig = CoreExpr { nodes };
237        let mut expr_reduced = expr_orig.clone();
238        let pass = BetaReduce;
239        pass.run(&mut expr_reduced);
240
241        let mut heap = VecHeap::new();
242        let env = Env::new();
243
244        let val_orig = eval(&expr_orig, &env, &mut heap).expect("Original eval failed");
245        let val_reduced = eval(&expr_reduced, &env, &mut heap).expect("Reduced eval failed");
246
247        if let (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) =
248            (&val_orig, &val_reduced)
249        {
250            assert_eq!(l1, l2);
251        } else {
252            panic!(
253                "Expected literal results, got {:?} and {:?}",
254                val_orig, val_reduced
255            );
256        }
257
258        if let tidepool_eval::Value::Lit(Literal::LitInt(n)) = val_orig {
259            assert_eq!(n, 42);
260        } else {
261            panic!("Expected 42");
262        }
263    }
264
265    #[test]
266    fn test_beta_nested() {
267        // (λx.x) ((λy.y) 42)
268        let x = VarId(1);
269        let y = VarId(2);
270        let nodes = vec![
271            CoreFrame::Var(y),                     // 0: y
272            CoreFrame::Lam { binder: y, body: 0 }, // 1: λy.y
273            CoreFrame::Lit(Literal::LitInt(42)),   // 2: 42
274            CoreFrame::App { fun: 1, arg: 2 },     // 3: (λy.y) 42
275            CoreFrame::Var(x),                     // 4: x
276            CoreFrame::Lam { binder: x, body: 4 }, // 5: λx.x
277            CoreFrame::App { fun: 5, arg: 3 },     // 6: (λx.x) ((λy.y) 42)
278        ];
279        let mut expr = CoreExpr { nodes };
280        let pass = BetaReduce;
281
282        // Run until fixpoint
283        while pass.run(&mut expr) {}
284
285        assert_eq!(expr.nodes.len(), 1);
286        assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
287    }
288}