1use tidepool_eval::{Changed, Pass};
2use tidepool_repr::{AltCon, CoreExpr, CoreFrame, MapLayer};
3use std::collections::HashMap;
4
5pub struct CaseReduce;
7
8impl Pass for CaseReduce {
9 fn run(&self, expr: &mut CoreExpr) -> Changed {
10 if expr.nodes.is_empty() {
11 return false;
12 }
13 match try_case_reduce(expr) {
14 Some(new_expr) => {
15 *expr = new_expr;
16 true
17 }
18 None => false,
19 }
20 }
21
22 fn name(&self) -> &str {
23 "CaseReduce"
24 }
25}
26
27fn try_case_reduce(expr: &CoreExpr) -> Option<CoreExpr> {
28 try_case_reduce_at(expr, expr.nodes.len() - 1)
29}
30
31fn try_case_reduce_at(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
32 match &expr.nodes[idx] {
33 CoreFrame::Case {
34 scrutinee,
35 binder,
36 alts,
37 } => {
38 match &expr.nodes[*scrutinee] {
39 CoreFrame::Con { tag, fields } => {
40 let alt = alts
42 .iter()
43 .find(|a| matches!(&a.con, AltCon::DataAlt(t) if t == tag))
44 .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
45
46 if let Some(alt) = alt {
47 if let AltCon::DataAlt(_) = &alt.con {
50 if alt.binders.len() != fields.len() {
51 return try_children(expr, idx);
52 }
53 }
54
55 let mut body = extract_subtree(expr, alt.body);
56 if let AltCon::DataAlt(_) = &alt.con {
58 for (alt_binder, field_idx) in alt.binders.iter().zip(fields.iter()) {
59 let field_tree = extract_subtree(expr, *field_idx);
60 body = tidepool_repr::subst::subst(&body, *alt_binder, &field_tree);
61 }
62 }
63 let scrut_tree = extract_subtree(expr, *scrutinee);
65 body = tidepool_repr::subst::subst(&body, *binder, &scrut_tree);
66 Some(replace_subtree(expr, idx, &body))
67 } else {
68 try_children(expr, idx)
70 }
71 }
72 CoreFrame::Lit(lit) => {
73 let alt = alts
74 .iter()
75 .find(|a| matches!(&a.con, AltCon::LitAlt(l) if l == lit))
76 .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
77
78 if let Some(alt) = alt {
79 let mut body = extract_subtree(expr, alt.body);
80 let scrut_tree = extract_subtree(expr, *scrutinee);
82 body = tidepool_repr::subst::subst(&body, *binder, &scrut_tree);
83 Some(replace_subtree(expr, idx, &body))
84 } else {
85 try_children(expr, idx)
86 }
87 }
88 _ => try_children(expr, idx),
89 }
90 }
91 _ => try_children(expr, idx),
92 }
93}
94
95fn try_children(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
96 let children = get_children(&expr.nodes[idx]);
97 for child in children {
98 if let Some(result) = try_case_reduce_at(expr, child) {
99 return Some(result);
100 }
101 }
102 None
103}
104
105fn get_children(frame: &CoreFrame<usize>) -> Vec<usize> {
106 match frame {
107 CoreFrame::Var(_) | CoreFrame::Lit(_) => vec![],
108 CoreFrame::App { fun, arg } => vec![*fun, *arg],
109 CoreFrame::Lam { body, .. } => vec![*body],
110 CoreFrame::LetNonRec { rhs, body, .. } => vec![*rhs, *body],
111 CoreFrame::LetRec { bindings, body, .. } => {
112 let mut c: Vec<usize> = bindings.iter().map(|(_, r)| *r).collect();
113 c.push(*body);
114 c
115 }
116 CoreFrame::Case {
117 scrutinee, alts, ..
118 } => {
119 let mut c = vec![*scrutinee];
120 for alt in alts {
121 c.push(alt.body);
122 }
123 c
124 }
125 CoreFrame::Con { fields, .. } => fields.clone(),
126 CoreFrame::Join { rhs, body, .. } => vec![*rhs, *body],
127 CoreFrame::Jump { args, .. } => args.clone(),
128 CoreFrame::PrimOp { args, .. } => args.clone(),
129 }
130}
131
132fn extract_subtree(expr: &CoreExpr, root_idx: usize) -> CoreExpr {
133 let mut new_nodes = Vec::new();
134 let mut old_to_new = HashMap::new();
135 collect(root_idx, expr, &mut new_nodes, &mut old_to_new);
136 CoreExpr { nodes: new_nodes }
137}
138
139fn collect(
140 idx: usize,
141 expr: &CoreExpr,
142 new_nodes: &mut Vec<CoreFrame<usize>>,
143 old_to_new: &mut HashMap<usize, usize>,
144) -> usize {
145 if let Some(&new_idx) = old_to_new.get(&idx) {
146 return new_idx;
147 }
148 let mapped = expr.nodes[idx]
149 .clone()
150 .map_layer(|child| collect(child, expr, new_nodes, old_to_new));
151 let new_idx = new_nodes.len();
152 new_nodes.push(mapped);
153 old_to_new.insert(idx, new_idx);
154 new_idx
155}
156
157fn replace_subtree(expr: &CoreExpr, target_idx: usize, replacement: &CoreExpr) -> CoreExpr {
158 let mut new_nodes = Vec::new();
159 let mut old_to_new = HashMap::new();
160 rebuild(
161 expr,
162 expr.nodes.len() - 1,
163 target_idx,
164 replacement,
165 &mut new_nodes,
166 &mut old_to_new,
167 );
168 CoreExpr { nodes: new_nodes }
169}
170
171fn rebuild(
172 expr: &CoreExpr,
173 idx: usize,
174 target: usize,
175 replacement: &CoreExpr,
176 new_nodes: &mut Vec<CoreFrame<usize>>,
177 old_to_new: &mut HashMap<usize, usize>,
178) -> usize {
179 if let Some(&ni) = old_to_new.get(&idx) {
180 return ni;
181 }
182 if idx == target {
183 let offset = new_nodes.len();
184 for node in &replacement.nodes {
185 new_nodes.push(node.clone().map_layer(|i| i + offset));
186 }
187 let root = new_nodes.len() - 1;
188 old_to_new.insert(idx, root);
189 return root;
190 }
191 let mapped = expr.nodes[idx]
192 .clone()
193 .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
194 let new_idx = new_nodes.len();
195 new_nodes.push(mapped);
196 old_to_new.insert(idx, new_idx);
197 new_idx
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use tidepool_eval::env::Env;
204 use tidepool_eval::heap::VecHeap;
205 use tidepool_eval::value::Value;
206 use tidepool_repr::{Alt, DataConId, Literal, PrimOpKind, VarId};
207
208 #[test]
209 fn test_case_known_con() {
210 let nodes = vec![
212 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Con {
214 tag: DataConId(1),
215 fields: vec![0],
216 }, CoreFrame::Var(VarId(3)), CoreFrame::Case {
219 scrutinee: 1,
220 binder: VarId(2), alts: vec![Alt {
222 con: AltCon::DataAlt(DataConId(1)),
223 binders: vec![VarId(3)],
224 body: 2,
225 }],
226 }, ];
228 let mut expr = CoreExpr { nodes };
229 let pass = CaseReduce;
230 let changed = pass.run(&mut expr);
231 assert!(changed);
232 assert_eq!(expr.nodes.len(), 1);
234 assert!(matches!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42))));
235 }
236
237 #[test]
238 fn test_case_known_con_pair() {
239 let nodes = vec![
241 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Lit(Literal::LitInt(2)), CoreFrame::Con {
244 tag: DataConId(1),
245 fields: vec![0, 1],
246 }, CoreFrame::Var(VarId(10)), CoreFrame::Var(VarId(11)), CoreFrame::PrimOp {
250 op: PrimOpKind::IntAdd,
251 args: vec![3, 4],
252 }, CoreFrame::Case {
254 scrutinee: 2,
255 binder: VarId(12),
256 alts: vec![Alt {
257 con: AltCon::DataAlt(DataConId(1)),
258 binders: vec![VarId(10), VarId(11)],
259 body: 5,
260 }],
261 }, ];
263 let mut expr = CoreExpr { nodes };
264 let pass = CaseReduce;
265
266 let mut heap = VecHeap::new();
267 let val_before = tidepool_eval::eval(&expr, &Env::new(), &mut heap).unwrap();
268
269 let changed = pass.run(&mut expr);
270 assert!(changed);
271
272 let mut heap2 = VecHeap::new();
273 let val_after = tidepool_eval::eval(&expr, &Env::new(), &mut heap2).unwrap();
274
275 match (val_before, val_after) {
276 (Value::Lit(l1), Value::Lit(l2)) => {
277 assert_eq!(l1, l2);
278 if let Literal::LitInt(3) = l1 {
279 } else {
281 panic!("Expected 3, got {:?}", l1);
282 }
283 }
284 (v1, v2) => panic!("Value mismatch or not Lit: {:?}, {:?}", v1, v2),
285 }
286 }
287
288 #[test]
289 fn test_case_known_lit() {
290 let nodes = vec![
292 CoreFrame::Lit(Literal::LitInt(3)), CoreFrame::Lit(Literal::LitInt(10)), CoreFrame::Lit(Literal::LitInt(30)), CoreFrame::Lit(Literal::LitInt(99)), CoreFrame::Case {
297 scrutinee: 0,
298 binder: VarId(10),
299 alts: vec![
300 Alt {
301 con: AltCon::LitAlt(Literal::LitInt(1)),
302 binders: vec![],
303 body: 1,
304 },
305 Alt {
306 con: AltCon::LitAlt(Literal::LitInt(3)),
307 binders: vec![],
308 body: 2,
309 },
310 Alt {
311 con: AltCon::Default,
312 binders: vec![],
313 body: 3,
314 },
315 ],
316 }, ];
318 let mut expr = CoreExpr { nodes };
319 let pass = CaseReduce;
320 let changed = pass.run(&mut expr);
321 assert!(changed);
322 assert!(matches!(
324 expr.nodes[expr.nodes.len() - 1],
325 CoreFrame::Lit(Literal::LitInt(30))
326 ));
327 }
328
329 #[test]
330 fn test_case_known_lit_default() {
331 let nodes = vec![
333 CoreFrame::Lit(Literal::LitInt(3)), CoreFrame::Lit(Literal::LitInt(10)), CoreFrame::Lit(Literal::LitInt(99)), CoreFrame::Case {
337 scrutinee: 0,
338 binder: VarId(10),
339 alts: vec![
340 Alt {
341 con: AltCon::LitAlt(Literal::LitInt(1)),
342 binders: vec![],
343 body: 1,
344 },
345 Alt {
346 con: AltCon::Default,
347 binders: vec![],
348 body: 2,
349 },
350 ],
351 }, ];
353 let mut expr = CoreExpr { nodes };
354 let pass = CaseReduce;
355 let changed = pass.run(&mut expr);
356 assert!(changed);
357 assert!(matches!(
359 expr.nodes[expr.nodes.len() - 1],
360 CoreFrame::Lit(Literal::LitInt(99))
361 ));
362 }
363
364 #[test]
365 fn test_case_unknown_untouched() {
366 let nodes = vec![
368 CoreFrame::Var(VarId(1)), CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Case {
371 scrutinee: 0,
372 binder: VarId(2),
373 alts: vec![Alt {
374 con: AltCon::Default,
375 binders: vec![],
376 body: 1,
377 }],
378 }, ];
380 let mut expr = CoreExpr { nodes };
381 let pass = CaseReduce;
382 let changed = pass.run(&mut expr);
383 assert!(!changed);
384 }
385
386 #[test]
387 fn test_case_binder_substituted() {
388 let nodes = vec![
390 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Con {
392 tag: DataConId(1),
393 fields: vec![0],
394 }, CoreFrame::Var(VarId(2)), CoreFrame::Case {
397 scrutinee: 1,
398 binder: VarId(2), alts: vec![Alt {
400 con: AltCon::DataAlt(DataConId(1)),
401 binders: vec![VarId(3)],
402 body: 2,
403 }],
404 }, ];
406 let mut expr = CoreExpr { nodes };
407 let pass = CaseReduce;
408 let changed = pass.run(&mut expr);
409 assert!(changed);
410 if let CoreFrame::Con { tag, fields } = &expr.nodes[expr.nodes.len() - 1] {
412 assert_eq!(tag.0, 1);
413 assert_eq!(fields.len(), 1);
414 if let CoreFrame::Lit(Literal::LitInt(42)) = &expr.nodes[fields[0]] {
415 } else {
417 panic!("Expected field to be 42");
418 }
419 } else {
420 panic!("Expected Con, got {:?}", expr.nodes[expr.nodes.len() - 1]);
421 }
422 }
423
424 #[test]
425 fn test_case_reduce_preserves_eval() {
426 let nodes = vec![
428 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Lit(Literal::LitInt(2)), CoreFrame::Con {
431 tag: DataConId(1),
432 fields: vec![0, 1],
433 }, CoreFrame::Var(VarId(10)), CoreFrame::Var(VarId(11)), CoreFrame::PrimOp {
437 op: PrimOpKind::IntAdd,
438 args: vec![3, 4],
439 }, CoreFrame::Lit(Literal::LitInt(0)), CoreFrame::Case {
442 scrutinee: 2,
443 binder: VarId(12),
444 alts: vec![
445 Alt {
446 con: AltCon::DataAlt(DataConId(1)),
447 binders: vec![VarId(10), VarId(11)],
448 body: 5,
449 },
450 Alt {
451 con: AltCon::Default,
452 binders: vec![],
453 body: 6,
454 },
455 ],
456 }, ];
458 let mut expr = CoreExpr { nodes };
459 let pass = CaseReduce;
460
461 let mut heap = VecHeap::new();
462 let val_before = tidepool_eval::eval(&expr, &Env::new(), &mut heap).unwrap();
463
464 pass.run(&mut expr);
465
466 let mut heap2 = VecHeap::new();
467 let val_after = tidepool_eval::eval(&expr, &Env::new(), &mut heap2).unwrap();
468
469 match (val_before, val_after) {
470 (Value::Lit(l1), Value::Lit(l2)) => assert_eq!(l1, l2),
471 (Value::Con(t1, f1), Value::Con(t2, f2)) => {
472 assert_eq!(t1, t2);
473 assert_eq!(f1.len(), f2.len());
474 for (v1, v2) in f1.iter().zip(f2.iter()) {
476 if let (Value::Lit(ll1), Value::Lit(ll2)) = (v1, v2) {
477 assert_eq!(ll1, ll2);
478 }
479 }
480 }
481 (v1, v2) => panic!(
482 "Value mismatch or unsupported for eval check: {:?}, {:?}",
483 v1, v2
484 ),
485 }
486 }
487}