Skip to main content

tidepool_codegen/
datacon_env.rs

1use tidepool_repr::{CoreExpr, CoreFrame, DataConTable, TreeBuilder, VarId};
2
3/// Wrap a CoreExpr with let-bindings for all data constructors from the table.
4///
5/// For each DataCon with arity N:
6/// - arity 0: `let dc_var = Con(id, []) in ...`
7/// - arity 1: `let dc_var = \v0 -> Con(id, [v0]) in ...`
8/// - arity 2: `let dc_var = \v0 -> \v1 -> Con(id, [v0, v1]) in ...`
9/// - etc.
10///
11/// The binding VarId matches `VarId(dc.id.0)`, which is what the GHC Core translator
12/// uses to reference data constructors as function values.
13pub fn wrap_with_datacon_env(expr: &CoreExpr, table: &DataConTable) -> CoreExpr {
14    let mut b = TreeBuilder::new();
15
16    // First, push all nodes from the original expression
17    let mut src = TreeBuilder::new();
18    for node in &expr.nodes {
19        src.push(node.clone());
20    }
21    let base = b.push_tree(src);
22    let root = base + expr.nodes.len() - 1;
23
24    // Collect datacons sorted by id for deterministic output
25    let mut datacons: Vec<_> = table.iter().collect();
26    datacons.sort_by_key(|dc| dc.id.0);
27
28    let mut body = root;
29
30    for dc in &datacons {
31        let binder = VarId(dc.id.0);
32        let arity = dc.rep_arity as usize;
33
34        if arity == 0 {
35            // Con(id, [])
36            let con = b.push(CoreFrame::Con {
37                tag: dc.id,
38                fields: vec![],
39            });
40            body = b.push(CoreFrame::LetNonRec {
41                binder,
42                rhs: con,
43                body,
44            });
45        } else {
46            // Build curried lambda chain: \v0 -> \v1 -> ... -> Con(id, [v0, v1, ...])
47            // Fresh vars use a hash of the DataConId to avoid collisions
48            let fresh_base = dc.id.0.wrapping_mul(0x517cc1b727220a95).wrapping_add(0xFFFF_0000_0000_0000);
49            let fresh_vars: Vec<VarId> = (0..arity)
50                .map(|i| VarId(fresh_base.wrapping_add(i as u64)))
51                .collect();
52
53            // Build Con(id, [v0, v1, ...]) — innermost
54            let var_indices: Vec<usize> = fresh_vars
55                .iter()
56                .map(|v| b.push(CoreFrame::Var(*v)))
57                .collect();
58            let mut inner = b.push(CoreFrame::Con {
59                tag: dc.id,
60                fields: var_indices,
61            });
62
63            // Wrap in lambdas from inside out
64            for v in fresh_vars.iter().rev() {
65                inner = b.push(CoreFrame::Lam {
66                    binder: *v,
67                    body: inner,
68                });
69            }
70
71            body = b.push(CoreFrame::LetNonRec {
72                binder,
73                rhs: inner,
74                body,
75            });
76        }
77    }
78
79    b.build()
80}