Skip to main content

tidepool_codegen/emit/
case.rs

1use crate::pipeline::CodegenPipeline;
2use crate::emit::*;
3use crate::emit::expr::ensure_heap_ptr;
4use tidepool_repr::{VarId, Alt, AltCon, Literal, CoreExpr};
5use cranelift_codegen::ir::{self, types, InstBuilder, MemFlags, Value, condcodes::IntCC, TrapCode};
6use cranelift_frontend::FunctionBuilder;
7
8/// Emit Case dispatch.
9pub fn emit_case(
10    ctx: &mut EmitContext,
11    pipeline: &mut CodegenPipeline,
12    builder: &mut FunctionBuilder,
13    vmctx: Value,
14    gc_sig: ir::SigRef,
15    tree: &CoreExpr,
16    scrutinee_idx: usize,
17    binder: &VarId,
18    alts: &[Alt<usize>],
19) -> Result<SsaVal, EmitError> {
20    // 1. Emit scrutinee
21    let scrut = ctx.emit_node(pipeline, builder, vmctx, gc_sig, tree, scrutinee_idx)?;
22    let scrut_ptr = scrut.value();
23
24    // 2. Bind case binder
25    ctx.env.insert(*binder, scrut);
26
27    // 3. Classify alts
28    let mut data_alts = Vec::new();
29    let mut lit_alts = Vec::new();
30    let mut default_alt = None;
31
32    for alt in alts {
33        match &alt.con {
34            AltCon::DataAlt(_) => data_alts.push(alt),
35            AltCon::LitAlt(_) => lit_alts.push(alt),
36            AltCon::Default => default_alt = Some(alt),
37        }
38    }
39
40    // 4. Create merge block
41    let merge_block = builder.create_block();
42    builder.append_block_param(merge_block, types::I64);
43
44    // 5. Dispatch
45    if !data_alts.is_empty() {
46        emit_data_dispatch(
47            ctx, pipeline, builder, vmctx, gc_sig, tree, scrut_ptr, &data_alts, default_alt, merge_block,
48        )?;
49    } else if !lit_alts.is_empty() {
50        emit_lit_dispatch(
51            ctx, pipeline, builder, vmctx, gc_sig, tree, scrut, &lit_alts, default_alt, merge_block,
52        )?;
53    } else if let Some(alt) = default_alt {
54        // Default only
55        let result = ctx.emit_node(pipeline, builder, vmctx, gc_sig, tree, alt.body)?;
56        let result_ptr = ensure_heap_ptr(builder, vmctx, gc_sig, result);
57        builder.ins().jump(merge_block, &[result_ptr]);
58    } else {
59        // No alts? Trap.
60        builder.ins().trap(TrapCode::unwrap_user(2));
61    }
62
63    // Seal merge block
64    builder.seal_block(merge_block);
65
66    // Switch to merge block
67    builder.switch_to_block(merge_block);
68    let result = builder.block_params(merge_block)[0];
69    builder.declare_value_needs_stack_map(result);
70
71    // 6. Clean up case binder
72    ctx.env.remove(binder);
73
74    Ok(SsaVal::HeapPtr(result))
75}
76
77fn emit_data_dispatch(
78    ctx: &mut EmitContext,
79    pipeline: &mut CodegenPipeline,
80    builder: &mut FunctionBuilder,
81    vmctx: Value,
82    gc_sig: ir::SigRef,
83    tree: &CoreExpr,
84    scrut_ptr: Value,
85    data_alts: &[&Alt<usize>],
86    default_alt: Option<&Alt<usize>>,
87    merge_block: ir::Block,
88) -> Result<(), EmitError> {
89    // Load con_tag as u64 from offset 8
90    let con_tag = builder.ins().load(types::I64, MemFlags::trusted(), scrut_ptr, CON_TAG_OFFSET);
91
92    // Use comparison chain instead of jump table because DataConIds are large
93    // GHC Uniques (arbitrary u64 values), not small sequential integers.
94    for &alt in data_alts {
95        if let AltCon::DataAlt(tag) = &alt.con {
96            let alt_block = builder.create_block();
97            let next_check_block = builder.create_block();
98
99            let tag_val = builder.ins().iconst(types::I64, tag.0 as i64);
100            let eq = builder.ins().icmp(IntCC::Equal, con_tag, tag_val);
101            builder.ins().brif(eq, alt_block, &[], next_check_block, &[]);
102
103            // Emit alt body
104            builder.switch_to_block(alt_block);
105            builder.seal_block(alt_block);
106
107            // Bind pattern variables
108            let mut bound_vars = Vec::new();
109            for (i, &binder) in alt.binders.iter().enumerate() {
110                let offset = CON_FIELDS_START + (8 * i as i32);
111                let field_val = builder.ins().load(types::I64, MemFlags::trusted(), scrut_ptr, offset);
112                builder.declare_value_needs_stack_map(field_val);
113                ctx.env.insert(binder, SsaVal::HeapPtr(field_val));
114                bound_vars.push(binder);
115            }
116
117            let result = ctx.emit_node(pipeline, builder, vmctx, gc_sig, tree, alt.body)?;
118            let result_ptr = ensure_heap_ptr(builder, vmctx, gc_sig, result);
119            builder.ins().jump(merge_block, &[result_ptr]);
120
121            // Clean up
122            for binder in bound_vars {
123                ctx.env.remove(&binder);
124            }
125
126            // Continue to next check
127            builder.switch_to_block(next_check_block);
128            builder.seal_block(next_check_block);
129        }
130    }
131
132    // Default or trap
133    if let Some(alt) = default_alt {
134        let result = ctx.emit_node(pipeline, builder, vmctx, gc_sig, tree, alt.body)?;
135        let result_ptr = ensure_heap_ptr(builder, vmctx, gc_sig, result);
136        builder.ins().jump(merge_block, &[result_ptr]);
137    } else {
138        builder.ins().trap(TrapCode::unwrap_user(2));
139    }
140
141    Ok(())
142}
143
144fn emit_lit_dispatch(
145    ctx: &mut EmitContext,
146    pipeline: &mut CodegenPipeline,
147    builder: &mut FunctionBuilder,
148    vmctx: Value,
149    gc_sig: ir::SigRef,
150    tree: &CoreExpr,
151    scrut: SsaVal,
152    lit_alts: &[&Alt<usize>],
153    default_alt: Option<&Alt<usize>>,
154    merge_block: ir::Block,
155) -> Result<(), EmitError> {
156    // Unbox scrutinee: Raw values are already unboxed, HeapPtr needs LIT_VALUE_OFFSET load
157    let scrut_value = match scrut {
158        SsaVal::Raw(v, _) => v,
159        SsaVal::HeapPtr(ptr) => builder.ins().load(types::I64, MemFlags::trusted(), ptr, LIT_VALUE_OFFSET),
160    };
161
162    for &alt in lit_alts {
163        let alt_block = builder.create_block();
164        let next_check_block = builder.create_block();
165
166        if let AltCon::LitAlt(lit) = &alt.con {
167            match lit {
168                Literal::LitInt(n) => {
169                    let lit_val = builder.ins().iconst(types::I64, *n);
170                    let eq = builder.ins().icmp(IntCC::Equal, scrut_value, lit_val);
171                    builder.ins().brif(eq, alt_block, &[], next_check_block, &[]);
172                }
173                Literal::LitWord(n) => {
174                    let lit_val = builder.ins().iconst(types::I64, *n as i64);
175                    let eq = builder.ins().icmp(IntCC::Equal, scrut_value, lit_val);
176                    builder.ins().brif(eq, alt_block, &[], next_check_block, &[]);
177                }
178                Literal::LitChar(c) => {
179                    let lit_val = builder.ins().iconst(types::I64, *c as i64);
180                    let eq = builder.ins().icmp(IntCC::Equal, scrut_value, lit_val);
181                    builder.ins().brif(eq, alt_block, &[], next_check_block, &[]);
182                }
183                Literal::LitFloat(bits) => {
184                    let scrut_f64 = builder.ins().bitcast(types::F64, MemFlags::new().with_endianness(ir::Endianness::Little), scrut_value);
185                    let lit_val = builder.ins().f64const(f64::from_bits(*bits));
186                    let eq = builder.ins().fcmp(ir::condcodes::FloatCC::Equal, scrut_f64, lit_val);
187                    builder.ins().brif(eq, alt_block, &[], next_check_block, &[]);
188                }
189                Literal::LitDouble(bits) => {
190                    let scrut_f64 = builder.ins().bitcast(types::F64, MemFlags::new().with_endianness(ir::Endianness::Little), scrut_value);
191                    let lit_val = builder.ins().f64const(f64::from_bits(*bits));
192                    let eq = builder.ins().fcmp(ir::condcodes::FloatCC::Equal, scrut_f64, lit_val);
193                    builder.ins().brif(eq, alt_block, &[], next_check_block, &[]);
194                }
195                Literal::LitString(_) => return Err(EmitError::NotYetImplemented("LitString in Case".into())),
196            }
197        }
198
199        // Emit alt body
200        builder.switch_to_block(alt_block);
201        builder.seal_block(alt_block);
202        let result = ctx.emit_node(pipeline, builder, vmctx, gc_sig, tree, alt.body)?;
203        let result_ptr = ensure_heap_ptr(builder, vmctx, gc_sig, result);
204        builder.ins().jump(merge_block, &[result_ptr]);
205
206        // Continue to next check
207        builder.switch_to_block(next_check_block);
208        builder.seal_block(next_check_block);
209    }
210
211    // Default or trap
212    if let Some(alt) = default_alt {
213        let result = ctx.emit_node(pipeline, builder, vmctx, gc_sig, tree, alt.body)?;
214        let result_ptr = ensure_heap_ptr(builder, vmctx, gc_sig, result);
215        builder.ins().jump(merge_block, &[result_ptr]);
216    } else {
217        builder.ins().trap(TrapCode::unwrap_user(2));
218    }
219
220    Ok(())
221}