1use tidepool_eval::{Changed, Pass};
2use tidepool_repr::{CoreExpr, CoreFrame, MapLayer};
3use std::collections::HashMap;
4
5pub struct BetaReduce;
8
9impl Pass for BetaReduce {
10 fn run(&self, expr: &mut CoreExpr) -> Changed {
11 if expr.nodes.is_empty() {
12 return false;
13 }
14 match try_beta_reduce(expr) {
15 Some(new_expr) => {
16 *expr = new_expr;
17 true
18 }
19 None => false,
20 }
21 }
22
23 fn name(&self) -> &str {
24 "BetaReduce"
25 }
26}
27
28fn try_beta_reduce(expr: &CoreExpr) -> Option<CoreExpr> {
29 try_beta_at(expr, expr.nodes.len() - 1)
31}
32
33fn try_beta_at(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
34 match &expr.nodes[idx] {
35 CoreFrame::App { fun, arg } => {
36 if let CoreFrame::Lam { binder, body } = &expr.nodes[*fun] {
38 let body_tree = expr.extract_subtree(*body);
40 let arg_tree = expr.extract_subtree(*arg);
41 let substituted = tidepool_repr::subst::subst(&body_tree, *binder, &arg_tree);
42 Some(replace_subtree(expr, idx, &substituted))
43 } else {
44 try_beta_at(expr, *fun).or_else(|| try_beta_at(expr, *arg))
46 }
47 }
48 other => {
50 let mut result = None;
51 match other {
55 CoreFrame::Var(_) | CoreFrame::Lit(_) => {}
56 CoreFrame::App { .. } => {
57 unreachable!("App nodes are handled in the outer match")
58 }
59 CoreFrame::Lam { body, .. } => {
60 result = try_beta_at(expr, *body);
61 }
62 CoreFrame::LetNonRec { rhs, body, .. } => {
63 result = try_beta_at(expr, *rhs).or_else(|| try_beta_at(expr, *body));
64 }
65 CoreFrame::LetRec { bindings, body } => {
66 for (_, rhs) in bindings {
67 result = try_beta_at(expr, *rhs);
68 if result.is_some() {
69 break;
70 }
71 }
72 if result.is_none() {
73 result = try_beta_at(expr, *body);
74 }
75 }
76 CoreFrame::Case {
77 scrutinee, alts, ..
78 } => {
79 result = try_beta_at(expr, *scrutinee);
80 if result.is_none() {
81 for alt in alts {
82 result = try_beta_at(expr, alt.body);
83 if result.is_some() {
84 break;
85 }
86 }
87 }
88 }
89 CoreFrame::Con { fields, .. } => {
90 for field in fields {
91 result = try_beta_at(expr, *field);
92 if result.is_some() {
93 break;
94 }
95 }
96 }
97 CoreFrame::Join { rhs, body, .. } => {
98 result = try_beta_at(expr, *rhs).or_else(|| try_beta_at(expr, *body));
99 }
100 CoreFrame::Jump { args, .. } => {
101 for arg in args {
102 result = try_beta_at(expr, *arg);
103 if result.is_some() {
104 break;
105 }
106 }
107 }
108 CoreFrame::PrimOp { args, .. } => {
109 for arg in args {
110 result = try_beta_at(expr, *arg);
111 if result.is_some() {
112 break;
113 }
114 }
115 }
116 }
117 result
118 }
119 }
120}
121
122fn replace_subtree(expr: &CoreExpr, target_idx: usize, replacement: &CoreExpr) -> CoreExpr {
123 let mut new_nodes = Vec::new();
124 let mut old_to_new = HashMap::new();
125
126 fn rebuild(
127 expr: &CoreExpr,
128 idx: usize,
129 target: usize,
130 replacement: &CoreExpr,
131 new_nodes: &mut Vec<CoreFrame<usize>>,
132 old_to_new: &mut HashMap<usize, usize>,
133 ) -> usize {
134 if let Some(&ni) = old_to_new.get(&idx) {
135 return ni;
136 }
137
138 if idx == target {
139 let offset = new_nodes.len();
141 for node in &replacement.nodes {
142 let mapped = node.clone().map_layer(|i| i + offset);
143 new_nodes.push(mapped);
144 }
145 let root = new_nodes.len() - 1;
146 old_to_new.insert(idx, root);
147 return root;
148 }
149
150 let mapped = expr.nodes[idx]
151 .clone()
152 .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
153 let new_idx = new_nodes.len();
154 new_nodes.push(mapped);
155 old_to_new.insert(idx, new_idx);
156 new_idx
157 }
158
159 rebuild(
160 expr,
161 expr.nodes.len() - 1,
162 target_idx,
163 replacement,
164 &mut new_nodes,
165 &mut old_to_new,
166 );
167 CoreExpr { nodes: new_nodes }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use tidepool_eval::{eval, Env, VecHeap};
174 use tidepool_repr::{Literal, VarId};
175
176 #[test]
177 fn test_beta_identity() {
178 let x = VarId(1);
180 let nodes = vec![
181 CoreFrame::Var(x), CoreFrame::Lam { binder: x, body: 0 }, CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::App { fun: 1, arg: 2 }, ];
186 let mut expr = CoreExpr { nodes };
187 let pass = BetaReduce;
188 let changed = pass.run(&mut expr);
189
190 assert!(changed);
191 assert_eq!(expr.nodes.len(), 1);
192 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
193 }
194
195 #[test]
196 fn test_beta_const() {
197 let x = VarId(1);
199 let y = VarId(2);
200 let nodes = vec![
201 CoreFrame::Var(x), CoreFrame::Lam { binder: y, body: 0 }, CoreFrame::Lam { binder: x, body: 1 }, CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::App { fun: 2, arg: 3 }, ];
207 let mut expr = CoreExpr { nodes };
208 let pass = BetaReduce;
209 let changed = pass.run(&mut expr);
210
211 assert!(changed);
212 let root = expr.nodes.len() - 1;
214 if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
215 assert_eq!(*binder, y);
216 if let CoreFrame::Lit(Literal::LitInt(1)) = &expr.nodes[*body] {
217 } else {
219 panic!("Body should be 1, got {:?}", expr.nodes[*body]);
220 }
221 } else {
222 panic!("Result should be Lam, got {:?}", expr.nodes[root]);
223 }
224 }
225
226 #[test]
227 fn test_beta_no_redex() {
228 let x = VarId(1);
230 let nodes = vec![
231 CoreFrame::Var(x), CoreFrame::Lam { binder: x, body: 0 }, ];
234 let mut expr = CoreExpr { nodes };
235 let pass = BetaReduce;
236 let changed = pass.run(&mut expr);
237 assert!(!changed);
238 }
239
240 #[test]
241 fn test_beta_capture_avoiding() {
242 let x = VarId(1);
244 let y = VarId(2);
245 let nodes = vec![
246 CoreFrame::Var(x), CoreFrame::Lam { binder: y, body: 0 }, CoreFrame::Lam { binder: x, body: 1 }, CoreFrame::Var(y), CoreFrame::App { fun: 2, arg: 3 }, ];
252 let mut expr = CoreExpr { nodes };
253 let pass = BetaReduce;
254 let changed = pass.run(&mut expr);
255
256 assert!(changed);
257 let root = expr.nodes.len() - 1;
258 if let CoreFrame::Lam { binder, body } = &expr.nodes[root] {
259 assert_ne!(*binder, y); if let CoreFrame::Var(v) = &expr.nodes[*body] {
261 assert_eq!(*v, y); } else {
263 panic!("Body should be Var(y)");
264 }
265 } else {
266 panic!("Result should be Lam");
267 }
268 }
269
270 #[test]
271 fn test_beta_preserves_eval() {
272 let x = VarId(1);
274 let nodes = vec![
275 CoreFrame::Var(x), CoreFrame::PrimOp {
277 op: tidepool_repr::PrimOpKind::IntAdd,
278 args: vec![0, 0],
279 }, CoreFrame::Lam { binder: x, body: 1 }, CoreFrame::Lit(Literal::LitInt(21)), CoreFrame::App { fun: 2, arg: 3 }, ];
284 let expr_orig = CoreExpr { nodes };
285 let mut expr_reduced = expr_orig.clone();
286 let pass = BetaReduce;
287 pass.run(&mut expr_reduced);
288
289 let mut heap = VecHeap::new();
290 let env = Env::new();
291
292 let val_orig = eval(&expr_orig, &env, &mut heap).expect("Original eval failed");
293 let val_reduced = eval(&expr_reduced, &env, &mut heap).expect("Reduced eval failed");
294
295 if let (tidepool_eval::Value::Lit(l1), tidepool_eval::Value::Lit(l2)) = (&val_orig, &val_reduced) {
296 assert_eq!(l1, l2);
297 } else {
298 panic!(
299 "Expected literal results, got {:?} and {:?}",
300 val_orig, val_reduced
301 );
302 }
303
304 if let tidepool_eval::Value::Lit(Literal::LitInt(n)) = val_orig {
305 assert_eq!(n, 42);
306 } else {
307 panic!("Expected 42");
308 }
309 }
310
311 #[test]
312 fn test_beta_nested() {
313 let x = VarId(1);
315 let y = VarId(2);
316 let nodes = vec![
317 CoreFrame::Var(y), CoreFrame::Lam { binder: y, body: 0 }, CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::App { fun: 1, arg: 2 }, CoreFrame::Var(x), CoreFrame::Lam { binder: x, body: 4 }, CoreFrame::App { fun: 5, arg: 3 }, ];
325 let mut expr = CoreExpr { nodes };
326 let pass = BetaReduce;
327
328 while pass.run(&mut expr) {}
330
331 assert_eq!(expr.nodes.len(), 1);
332 assert_eq!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42)));
333 }
334}