Skip to main content

tidepool_optimize/
case_reduce.rs

1use tidepool_eval::{Changed, Pass};
2use tidepool_repr::{AltCon, CoreExpr, CoreFrame, MapLayer};
3use std::collections::HashMap;
4
5/// A pass that performs case-of-known-constructor and case-of-known-literal reductions.
6pub struct CaseReduce;
7
8impl Pass for CaseReduce {
9    fn run(&self, expr: &mut CoreExpr) -> Changed {
10        if expr.nodes.is_empty() {
11            return false;
12        }
13        match try_case_reduce(expr) {
14            Some(new_expr) => {
15                *expr = new_expr;
16                true
17            }
18            None => false,
19        }
20    }
21
22    fn name(&self) -> &str {
23        "CaseReduce"
24    }
25}
26
27fn try_case_reduce(expr: &CoreExpr) -> Option<CoreExpr> {
28    try_case_reduce_at(expr, expr.nodes.len() - 1)
29}
30
31fn try_case_reduce_at(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
32    match &expr.nodes[idx] {
33        CoreFrame::Case {
34            scrutinee,
35            binder,
36            alts,
37        } => {
38            match &expr.nodes[*scrutinee] {
39                CoreFrame::Con { tag, fields } => {
40                    // Find matching DataAlt or Default
41                    let alt = alts
42                        .iter()
43                        .find(|a| matches!(&a.con, AltCon::DataAlt(t) if t == tag))
44                        .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
45
46                    if let Some(alt) = alt {
47                        // Arity check for DataAlt: binders must match fields.
48                        // If mismatch, skip this reduction (malformed IR).
49                        if let AltCon::DataAlt(_) = &alt.con {
50                            if alt.binders.len() != fields.len() {
51                                return try_children(expr, idx);
52                            }
53                        }
54
55                        let mut body = extract_subtree(expr, alt.body);
56                        // Bind fields to alt binders
57                        if let AltCon::DataAlt(_) = &alt.con {
58                            for (alt_binder, field_idx) in alt.binders.iter().zip(fields.iter()) {
59                                let field_tree = extract_subtree(expr, *field_idx);
60                                body = tidepool_repr::subst::subst(&body, *alt_binder, &field_tree);
61                            }
62                        }
63                        // Substitute case binder with scrutinee
64                        let scrut_tree = extract_subtree(expr, *scrutinee);
65                        body = tidepool_repr::subst::subst(&body, *binder, &scrut_tree);
66                        Some(replace_subtree(expr, idx, &body))
67                    } else {
68                        // No matching alt — try children
69                        try_children(expr, idx)
70                    }
71                }
72                CoreFrame::Lit(lit) => {
73                    let alt = alts
74                        .iter()
75                        .find(|a| matches!(&a.con, AltCon::LitAlt(l) if l == lit))
76                        .or_else(|| alts.iter().find(|a| matches!(&a.con, AltCon::Default)));
77
78                    if let Some(alt) = alt {
79                        let mut body = extract_subtree(expr, alt.body);
80                        // Substitute case binder with scrutinee literal
81                        let scrut_tree = extract_subtree(expr, *scrutinee);
82                        body = tidepool_repr::subst::subst(&body, *binder, &scrut_tree);
83                        Some(replace_subtree(expr, idx, &body))
84                    } else {
85                        try_children(expr, idx)
86                    }
87                }
88                _ => try_children(expr, idx),
89            }
90        }
91        _ => try_children(expr, idx),
92    }
93}
94
95fn try_children(expr: &CoreExpr, idx: usize) -> Option<CoreExpr> {
96    let children = get_children(&expr.nodes[idx]);
97    for child in children {
98        if let Some(result) = try_case_reduce_at(expr, child) {
99            return Some(result);
100        }
101    }
102    None
103}
104
105fn get_children(frame: &CoreFrame<usize>) -> Vec<usize> {
106    match frame {
107        CoreFrame::Var(_) | CoreFrame::Lit(_) => vec![],
108        CoreFrame::App { fun, arg } => vec![*fun, *arg],
109        CoreFrame::Lam { body, .. } => vec![*body],
110        CoreFrame::LetNonRec { rhs, body, .. } => vec![*rhs, *body],
111        CoreFrame::LetRec { bindings, body, .. } => {
112            let mut c: Vec<usize> = bindings.iter().map(|(_, r)| *r).collect();
113            c.push(*body);
114            c
115        }
116        CoreFrame::Case {
117            scrutinee, alts, ..
118        } => {
119            let mut c = vec![*scrutinee];
120            for alt in alts {
121                c.push(alt.body);
122            }
123            c
124        }
125        CoreFrame::Con { fields, .. } => fields.clone(),
126        CoreFrame::Join { rhs, body, .. } => vec![*rhs, *body],
127        CoreFrame::Jump { args, .. } => args.clone(),
128        CoreFrame::PrimOp { args, .. } => args.clone(),
129    }
130}
131
132fn extract_subtree(expr: &CoreExpr, root_idx: usize) -> CoreExpr {
133    let mut new_nodes = Vec::new();
134    let mut old_to_new = HashMap::new();
135    collect(root_idx, expr, &mut new_nodes, &mut old_to_new);
136    CoreExpr { nodes: new_nodes }
137}
138
139fn collect(
140    idx: usize,
141    expr: &CoreExpr,
142    new_nodes: &mut Vec<CoreFrame<usize>>,
143    old_to_new: &mut HashMap<usize, usize>,
144) -> usize {
145    if let Some(&new_idx) = old_to_new.get(&idx) {
146        return new_idx;
147    }
148    let mapped = expr.nodes[idx]
149        .clone()
150        .map_layer(|child| collect(child, expr, new_nodes, old_to_new));
151    let new_idx = new_nodes.len();
152    new_nodes.push(mapped);
153    old_to_new.insert(idx, new_idx);
154    new_idx
155}
156
157fn replace_subtree(expr: &CoreExpr, target_idx: usize, replacement: &CoreExpr) -> CoreExpr {
158    let mut new_nodes = Vec::new();
159    let mut old_to_new = HashMap::new();
160    rebuild(
161        expr,
162        expr.nodes.len() - 1,
163        target_idx,
164        replacement,
165        &mut new_nodes,
166        &mut old_to_new,
167    );
168    CoreExpr { nodes: new_nodes }
169}
170
171fn rebuild(
172    expr: &CoreExpr,
173    idx: usize,
174    target: usize,
175    replacement: &CoreExpr,
176    new_nodes: &mut Vec<CoreFrame<usize>>,
177    old_to_new: &mut HashMap<usize, usize>,
178) -> usize {
179    if let Some(&ni) = old_to_new.get(&idx) {
180        return ni;
181    }
182    if idx == target {
183        let offset = new_nodes.len();
184        for node in &replacement.nodes {
185            new_nodes.push(node.clone().map_layer(|i| i + offset));
186        }
187        let root = new_nodes.len() - 1;
188        old_to_new.insert(idx, root);
189        return root;
190    }
191    let mapped = expr.nodes[idx]
192        .clone()
193        .map_layer(|child| rebuild(expr, child, target, replacement, new_nodes, old_to_new));
194    let new_idx = new_nodes.len();
195    new_nodes.push(mapped);
196    old_to_new.insert(idx, new_idx);
197    new_idx
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use tidepool_eval::env::Env;
204    use tidepool_eval::heap::VecHeap;
205    use tidepool_eval::value::Value;
206    use tidepool_repr::{Alt, DataConId, Literal, PrimOpKind, VarId};
207
208    #[test]
209    fn test_case_known_con() {
210        // case Con(tag=1, [42]) of w { DataAlt(1) [y] -> y }
211        let nodes = vec![
212            CoreFrame::Lit(Literal::LitInt(42)), // 0
213            CoreFrame::Con {
214                tag: DataConId(1),
215                fields: vec![0],
216            }, // 1
217            CoreFrame::Var(VarId(3)),            // 2: y
218            CoreFrame::Case {
219                scrutinee: 1,
220                binder: VarId(2), // w
221                alts: vec![Alt {
222                    con: AltCon::DataAlt(DataConId(1)),
223                    binders: vec![VarId(3)],
224                    body: 2,
225                }],
226            }, // 3
227        ];
228        let mut expr = CoreExpr { nodes };
229        let pass = CaseReduce;
230        let changed = pass.run(&mut expr);
231        assert!(changed);
232        // Result should be Lit(42)
233        assert_eq!(expr.nodes.len(), 1);
234        assert!(matches!(expr.nodes[0], CoreFrame::Lit(Literal::LitInt(42))));
235    }
236
237    #[test]
238    fn test_case_known_con_pair() {
239        // case Con(tag=1, [1, 2]) of w { DataAlt(1) [a, b] -> PrimOp(IntAdd, [a, b]) }
240        let nodes = vec![
241            CoreFrame::Lit(Literal::LitInt(1)), // 0
242            CoreFrame::Lit(Literal::LitInt(2)), // 1
243            CoreFrame::Con {
244                tag: DataConId(1),
245                fields: vec![0, 1],
246            }, // 2
247            CoreFrame::Var(VarId(10)),          // 3: a
248            CoreFrame::Var(VarId(11)),          // 4: b
249            CoreFrame::PrimOp {
250                op: PrimOpKind::IntAdd,
251                args: vec![3, 4],
252            }, // 5
253            CoreFrame::Case {
254                scrutinee: 2,
255                binder: VarId(12),
256                alts: vec![Alt {
257                    con: AltCon::DataAlt(DataConId(1)),
258                    binders: vec![VarId(10), VarId(11)],
259                    body: 5,
260                }],
261            }, // 6
262        ];
263        let mut expr = CoreExpr { nodes };
264        let pass = CaseReduce;
265
266        let mut heap = VecHeap::new();
267        let val_before = tidepool_eval::eval(&expr, &Env::new(), &mut heap).unwrap();
268
269        let changed = pass.run(&mut expr);
270        assert!(changed);
271
272        let mut heap2 = VecHeap::new();
273        let val_after = tidepool_eval::eval(&expr, &Env::new(), &mut heap2).unwrap();
274
275        match (val_before, val_after) {
276            (Value::Lit(l1), Value::Lit(l2)) => {
277                assert_eq!(l1, l2);
278                if let Literal::LitInt(3) = l1 {
279                    // OK
280                } else {
281                    panic!("Expected 3, got {:?}", l1);
282                }
283            }
284            (v1, v2) => panic!("Value mismatch or not Lit: {:?}, {:?}", v1, v2),
285        }
286    }
287
288    #[test]
289    fn test_case_known_lit() {
290        // case 3 of w { LitAlt(1) -> 10; LitAlt(3) -> 30; Default -> 99 }
291        let nodes = vec![
292            CoreFrame::Lit(Literal::LitInt(3)),  // 0
293            CoreFrame::Lit(Literal::LitInt(10)), // 1
294            CoreFrame::Lit(Literal::LitInt(30)), // 2
295            CoreFrame::Lit(Literal::LitInt(99)), // 3
296            CoreFrame::Case {
297                scrutinee: 0,
298                binder: VarId(10),
299                alts: vec![
300                    Alt {
301                        con: AltCon::LitAlt(Literal::LitInt(1)),
302                        binders: vec![],
303                        body: 1,
304                    },
305                    Alt {
306                        con: AltCon::LitAlt(Literal::LitInt(3)),
307                        binders: vec![],
308                        body: 2,
309                    },
310                    Alt {
311                        con: AltCon::Default,
312                        binders: vec![],
313                        body: 3,
314                    },
315                ],
316            }, // 4
317        ];
318        let mut expr = CoreExpr { nodes };
319        let pass = CaseReduce;
320        let changed = pass.run(&mut expr);
321        assert!(changed);
322        // Result should be 30
323        assert!(matches!(
324            expr.nodes[expr.nodes.len() - 1],
325            CoreFrame::Lit(Literal::LitInt(30))
326        ));
327    }
328
329    #[test]
330    fn test_case_known_lit_default() {
331        // case 3 of w { LitAlt(1) -> 10; Default -> 99 }
332        let nodes = vec![
333            CoreFrame::Lit(Literal::LitInt(3)),  // 0
334            CoreFrame::Lit(Literal::LitInt(10)), // 1
335            CoreFrame::Lit(Literal::LitInt(99)), // 2
336            CoreFrame::Case {
337                scrutinee: 0,
338                binder: VarId(10),
339                alts: vec![
340                    Alt {
341                        con: AltCon::LitAlt(Literal::LitInt(1)),
342                        binders: vec![],
343                        body: 1,
344                    },
345                    Alt {
346                        con: AltCon::Default,
347                        binders: vec![],
348                        body: 2,
349                    },
350                ],
351            }, // 3
352        ];
353        let mut expr = CoreExpr { nodes };
354        let pass = CaseReduce;
355        let changed = pass.run(&mut expr);
356        assert!(changed);
357        // Result should be 99
358        assert!(matches!(
359            expr.nodes[expr.nodes.len() - 1],
360            CoreFrame::Lit(Literal::LitInt(99))
361        ));
362    }
363
364    #[test]
365    fn test_case_unknown_untouched() {
366        // case Var(x) of w { Default -> 42 }
367        let nodes = vec![
368            CoreFrame::Var(VarId(1)),            // 0: x
369            CoreFrame::Lit(Literal::LitInt(42)), // 1
370            CoreFrame::Case {
371                scrutinee: 0,
372                binder: VarId(2),
373                alts: vec![Alt {
374                    con: AltCon::Default,
375                    binders: vec![],
376                    body: 1,
377                }],
378            }, // 2
379        ];
380        let mut expr = CoreExpr { nodes };
381        let pass = CaseReduce;
382        let changed = pass.run(&mut expr);
383        assert!(!changed);
384    }
385
386    #[test]
387    fn test_case_binder_substituted() {
388        // case Con(tag=1, [42]) of w { DataAlt(1) [y] -> w }
389        let nodes = vec![
390            CoreFrame::Lit(Literal::LitInt(42)), // 0
391            CoreFrame::Con {
392                tag: DataConId(1),
393                fields: vec![0],
394            }, // 1
395            CoreFrame::Var(VarId(2)),            // 2: w
396            CoreFrame::Case {
397                scrutinee: 1,
398                binder: VarId(2), // w
399                alts: vec![Alt {
400                    con: AltCon::DataAlt(DataConId(1)),
401                    binders: vec![VarId(3)],
402                    body: 2,
403                }],
404            }, // 3
405        ];
406        let mut expr = CoreExpr { nodes };
407        let pass = CaseReduce;
408        let changed = pass.run(&mut expr);
409        assert!(changed);
410        // Result should be Con(tag=1, [42])
411        if let CoreFrame::Con { tag, fields } = &expr.nodes[expr.nodes.len() - 1] {
412            assert_eq!(tag.0, 1);
413            assert_eq!(fields.len(), 1);
414            if let CoreFrame::Lit(Literal::LitInt(42)) = &expr.nodes[fields[0]] {
415                // OK
416            } else {
417                panic!("Expected field to be 42");
418            }
419        } else {
420            panic!("Expected Con, got {:?}", expr.nodes[expr.nodes.len() - 1]);
421        }
422    }
423
424    #[test]
425    fn test_case_reduce_preserves_eval() {
426        // case Con(tag=1, [1, 2]) of w { DataAlt(1) [a, b] -> a + b; Default -> 0 }
427        let nodes = vec![
428            CoreFrame::Lit(Literal::LitInt(1)), // 0
429            CoreFrame::Lit(Literal::LitInt(2)), // 1
430            CoreFrame::Con {
431                tag: DataConId(1),
432                fields: vec![0, 1],
433            }, // 2
434            CoreFrame::Var(VarId(10)),          // 3: a
435            CoreFrame::Var(VarId(11)),          // 4: b
436            CoreFrame::PrimOp {
437                op: PrimOpKind::IntAdd,
438                args: vec![3, 4],
439            }, // 5
440            CoreFrame::Lit(Literal::LitInt(0)), // 6
441            CoreFrame::Case {
442                scrutinee: 2,
443                binder: VarId(12),
444                alts: vec![
445                    Alt {
446                        con: AltCon::DataAlt(DataConId(1)),
447                        binders: vec![VarId(10), VarId(11)],
448                        body: 5,
449                    },
450                    Alt {
451                        con: AltCon::Default,
452                        binders: vec![],
453                        body: 6,
454                    },
455                ],
456            }, // 7
457        ];
458        let mut expr = CoreExpr { nodes };
459        let pass = CaseReduce;
460
461        let mut heap = VecHeap::new();
462        let val_before = tidepool_eval::eval(&expr, &Env::new(), &mut heap).unwrap();
463
464        pass.run(&mut expr);
465
466        let mut heap2 = VecHeap::new();
467        let val_after = tidepool_eval::eval(&expr, &Env::new(), &mut heap2).unwrap();
468
469        match (val_before, val_after) {
470            (Value::Lit(l1), Value::Lit(l2)) => assert_eq!(l1, l2),
471            (Value::Con(t1, f1), Value::Con(t2, f2)) => {
472                assert_eq!(t1, t2);
473                assert_eq!(f1.len(), f2.len());
474                // Simple check for literals in fields
475                for (v1, v2) in f1.iter().zip(f2.iter()) {
476                    if let (Value::Lit(ll1), Value::Lit(ll2)) = (v1, v2) {
477                        assert_eq!(ll1, ll2);
478                    }
479                }
480            }
481            (v1, v2) => panic!(
482                "Value mismatch or unsupported for eval check: {:?}, {:?}",
483                v1, v2
484            ),
485        }
486    }
487}