1use tidepool_eval::{Changed, Pass};
2use tidepool_repr::{get_children, replace_subtree, AltCon, CoreExpr, CoreFrame};
3
4pub struct CaseReduce;
6
7impl Pass for CaseReduce {
8 fn run(&self, expr: &mut CoreExpr) -> Changed {
9 if expr.nodes.is_empty() {
10 return false;
11 }
12 match try_case_reduce(expr) {
13 Some(new_expr) => {
14 *expr = new_expr;
15 true
16 }
17 None => false,
18 }
19 }
20
21 fn name(&self) -> &str {
22 "CaseReduce"
23 }
24}
25
26fn try_case_reduce(expr: &CoreExpr) -> Option<CoreExpr> {
27 try_case_reduce_at(expr, expr.nodes.len() - 1)
28}
29
30fn try_case_reduce_at(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
31 match &expr.nodes[idx] {
32 CoreFrame::Case {
33 scrutinee,
34 binder,
35 alts,
36 } => {
37 match &expr.nodes[*scrutinee] {
38 CoreFrame::Con { tag, fields } => {
39 let alt = alts
41 .iter()
42 .find(|a| matches!(&a.con, AltCon::DataAlt(t) if t == tag))
43 .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
44
45 if let Some(alt) = alt {
46 if let AltCon::DataAlt(_) = &alt.con {
49 if alt.binders.len() != fields.len() {
50 return try_children(expr, idx);
51 }
52 }
53
54 let mut body = expr.extract_subtree(alt.body);
55 if let AltCon::DataAlt(_) = &alt.con {
57 for (alt_binder, field_idx) in alt.binders.iter().zip(fields.iter()) {
58 let field_tree = expr.extract_subtree(*field_idx);
59 body = tidepool_repr::subst::subst(&body, *alt_binder, &field_tree);
60 }
61 }
62 let scrut_tree = expr.extract_subtree(*scrutinee);
64 body = tidepool_repr::subst::subst(&body, *binder, &scrut_tree);
65 Some(replace_subtree(expr, idx, &body))
66 } else {
67 try_children(expr, idx)
69 }
70 }
71 CoreFrame::Lit(lit) => {
72 let alt = alts
73 .iter()
74 .find(|a| matches!(&a.con, AltCon::LitAlt(l) if l == lit))
75 .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
76
77 if let Some(alt) = alt {
78 let mut body = expr.extract_subtree(alt.body);
79 let scrut_tree = expr.extract_subtree(*scrutinee);
81 body = tidepool_repr::subst::subst(&body, *binder, &scrut_tree);
82 Some(replace_subtree(expr, idx, &body))
83 } else {
84 try_children(expr, idx)
85 }
86 }
87 _ => try_children(expr, idx),
88 }
89 }
90 _ => try_children(expr, idx),
91 }
92}
93
94fn try_children(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
95 let children = get_children(&expr.nodes[idx]);
96 for child in children {
97 if let Some(result) = try_case_reduce_at(expr, child) {
98 return Some(result);
99 }
100 }
101 None
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use tidepool_eval::env::Env;
108 use tidepool_eval::heap::VecHeap;
109 use tidepool_eval::value::Value;
110 use tidepool_repr::{Alt, DataConId, Literal, PrimOpKind, VarId};
111
112 #[test]
113 fn test_case_known_con() {
114 let nodes = vec![
116 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Con {
118 tag: DataConId(1),
119 fields: vec![0],
120 }, CoreFrame::Var(VarId(3)), CoreFrame::Case {
123 scrutinee: 1,
124 binder: VarId(2), alts: vec![Alt {
126 con: AltCon::DataAlt(DataConId(1)),
127 binders: vec![VarId(3)],
128 body: 2,
129 }],
130 }, ];
132 let mut expr = CoreExpr { nodes };
133 let pass = CaseReduce;
134 let changed = pass.run(&mut expr);
135 assert!(changed);
136 assert_eq!(expr.nodes.len(), 1);
138 assert!(matches!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42))));
139 }
140
141 #[test]
142 fn test_case_known_con_pair() {
143 let nodes = vec![
145 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Lit(Literal::LitInt(2)), CoreFrame::Con {
148 tag: DataConId(1),
149 fields: vec![0, 1],
150 }, CoreFrame::Var(VarId(10)), CoreFrame::Var(VarId(11)), CoreFrame::PrimOp {
154 op: PrimOpKind::IntAdd,
155 args: vec![3, 4],
156 }, CoreFrame::Case {
158 scrutinee: 2,
159 binder: VarId(12),
160 alts: vec![Alt {
161 con: AltCon::DataAlt(DataConId(1)),
162 binders: vec![VarId(10), VarId(11)],
163 body: 5,
164 }],
165 }, ];
167 let mut expr = CoreExpr { nodes };
168 let pass = CaseReduce;
169
170 let mut heap = VecHeap::new();
171 let val_before = tidepool_eval::eval(&expr, &Env::new(), &mut heap).unwrap();
172
173 let changed = pass.run(&mut expr);
174 assert!(changed);
175
176 let mut heap2 = VecHeap::new();
177 let val_after = tidepool_eval::eval(&expr, &Env::new(), &mut heap2).unwrap();
178
179 match (val_before, val_after) {
180 (Value::Lit(l1), Value::Lit(l2)) => {
181 assert_eq!(l1, l2);
182 if let Literal::LitInt(3) = l1 {
183 } else {
185 panic!("Expected 3, got {:?}", l1);
186 }
187 }
188 (v1, v2) => panic!("Value mismatch or not Lit: {:?}, {:?}", v1, v2),
189 }
190 }
191
192 #[test]
193 fn test_case_known_lit() {
194 let nodes = vec![
196 CoreFrame::Lit(Literal::LitInt(3)), CoreFrame::Lit(Literal::LitInt(10)), CoreFrame::Lit(Literal::LitInt(30)), CoreFrame::Lit(Literal::LitInt(99)), CoreFrame::Case {
201 scrutinee: 0,
202 binder: VarId(10),
203 alts: vec![
204 Alt {
205 con: AltCon::LitAlt(Literal::LitInt(1)),
206 binders: vec![],
207 body: 1,
208 },
209 Alt {
210 con: AltCon::LitAlt(Literal::LitInt(3)),
211 binders: vec![],
212 body: 2,
213 },
214 Alt {
215 con: AltCon::Default,
216 binders: vec![],
217 body: 3,
218 },
219 ],
220 }, ];
222 let mut expr = CoreExpr { nodes };
223 let pass = CaseReduce;
224 let changed = pass.run(&mut expr);
225 assert!(changed);
226 assert!(matches!(
228 expr.nodes[expr.nodes.len() - 1],
229 CoreFrame::Lit(Literal::LitInt(30))
230 ));
231 }
232
233 #[test]
234 fn test_case_known_lit_default() {
235 let nodes = vec![
237 CoreFrame::Lit(Literal::LitInt(3)), CoreFrame::Lit(Literal::LitInt(10)), CoreFrame::Lit(Literal::LitInt(99)), CoreFrame::Case {
241 scrutinee: 0,
242 binder: VarId(10),
243 alts: vec![
244 Alt {
245 con: AltCon::LitAlt(Literal::LitInt(1)),
246 binders: vec![],
247 body: 1,
248 },
249 Alt {
250 con: AltCon::Default,
251 binders: vec![],
252 body: 2,
253 },
254 ],
255 }, ];
257 let mut expr = CoreExpr { nodes };
258 let pass = CaseReduce;
259 let changed = pass.run(&mut expr);
260 assert!(changed);
261 assert!(matches!(
263 expr.nodes[expr.nodes.len() - 1],
264 CoreFrame::Lit(Literal::LitInt(99))
265 ));
266 }
267
268 #[test]
269 fn test_case_unknown_untouched() {
270 let nodes = vec![
272 CoreFrame::Var(VarId(1)), CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Case {
275 scrutinee: 0,
276 binder: VarId(2),
277 alts: vec![Alt {
278 con: AltCon::Default,
279 binders: vec![],
280 body: 1,
281 }],
282 }, ];
284 let mut expr = CoreExpr { nodes };
285 let pass = CaseReduce;
286 let changed = pass.run(&mut expr);
287 assert!(!changed);
288 }
289
290 #[test]
291 fn test_case_binder_substituted() {
292 let nodes = vec![
294 CoreFrame::Lit(Literal::LitInt(42)), CoreFrame::Con {
296 tag: DataConId(1),
297 fields: vec![0],
298 }, CoreFrame::Var(VarId(2)), CoreFrame::Case {
301 scrutinee: 1,
302 binder: VarId(2), alts: vec![Alt {
304 con: AltCon::DataAlt(DataConId(1)),
305 binders: vec![VarId(3)],
306 body: 2,
307 }],
308 }, ];
310 let mut expr = CoreExpr { nodes };
311 let pass = CaseReduce;
312 let changed = pass.run(&mut expr);
313 assert!(changed);
314 if let CoreFrame::Con { tag, fields } = &expr.nodes[expr.nodes.len() - 1] {
316 assert_eq!(tag.0, 1);
317 assert_eq!(fields.len(), 1);
318 if let CoreFrame::Lit(Literal::LitInt(42)) = &expr.nodes[fields[0]] {
319 } else {
321 panic!("Expected field to be 42");
322 }
323 } else {
324 panic!("Expected Con, got {:?}", expr.nodes[expr.nodes.len() - 1]);
325 }
326 }
327
328 #[test]
329 fn test_case_reduce_preserves_eval() {
330 let nodes = vec![
332 CoreFrame::Lit(Literal::LitInt(1)), CoreFrame::Lit(Literal::LitInt(2)), CoreFrame::Con {
335 tag: DataConId(1),
336 fields: vec![0, 1],
337 }, CoreFrame::Var(VarId(10)), CoreFrame::Var(VarId(11)), CoreFrame::PrimOp {
341 op: PrimOpKind::IntAdd,
342 args: vec![3, 4],
343 }, CoreFrame::Lit(Literal::LitInt(0)), CoreFrame::Case {
346 scrutinee: 2,
347 binder: VarId(12),
348 alts: vec![
349 Alt {
350 con: AltCon::DataAlt(DataConId(1)),
351 binders: vec![VarId(10), VarId(11)],
352 body: 5,
353 },
354 Alt {
355 con: AltCon::Default,
356 binders: vec![],
357 body: 6,
358 },
359 ],
360 }, ];
362 let mut expr = CoreExpr { nodes };
363 let pass = CaseReduce;
364
365 let mut heap = VecHeap::new();
366 let val_before = tidepool_eval::eval(&expr, &Env::new(), &mut heap).unwrap();
367
368 pass.run(&mut expr);
369
370 let mut heap2 = VecHeap::new();
371 let val_after = tidepool_eval::eval(&expr, &Env::new(), &mut heap2).unwrap();
372
373 match (val_before, val_after) {
374 (Value::Lit(l1), Value::Lit(l2)) => assert_eq!(l1, l2),
375 (Value::Con(t1, f1), Value::Con(t2, f2)) => {
376 assert_eq!(t1, t2);
377 assert_eq!(f1.len(), f2.len());
378 for (v1, v2) in f1.iter().zip(f2.iter()) {
380 if let (Value::Lit(ll1), Value::Lit(ll2)) = (v1, v2) {
381 assert_eq!(ll1, ll2);
382 }
383 }
384 }
385 (v1, v2) => panic!(
386 "Value mismatch or unsupported for eval check: {:?}, {:?}",
387 v1, v2
388 ),
389 }
390 }
391}