Skip to main content

tidepool_codegen/emit/
case.rs

1use crate::emit::expr::{ensure_heap_ptr, force_thunk_ssaval};
2use crate::emit::*;
3use cranelift_codegen::ir::{
4    self, condcodes::IntCC, types, AbiParam, BlockArg, InstBuilder, MemFlags, Signature, Value,
5};
6use cranelift_frontend::FunctionBuilder;
7use cranelift_module::{Linkage, Module};
8use tidepool_repr::{Alt, AltCon, Literal, VarId};
9
10/// Emit Case dispatch. The scrutinee has already been evaluated (stack-safe).
11#[allow(clippy::too_many_arguments)]
12pub fn emit_case(
13    ctx: &mut EmitContext,
14    sess: &mut EmitSession,
15    builder: &mut FunctionBuilder,
16    scrut: SsaVal,
17    binder: &VarId,
18    alts: &[Alt<usize>],
19    tail: TailCtx,
20) -> Result<SsaVal, EmitError> {
21    // 1. Scrutinee already evaluated
22    let scrut_ptr = scrut.value();
23
24    // 2. Bind case binder (save old value for restore)
25    // NOTE: EnvGuard cannot be used here because it would borrow ctx.env mutably,
26    // preventing the use of ctx in subsequent emit_* calls.
27    let old_case_binder = ctx.env.insert(*binder, scrut);
28
29    // 3. Classify alts
30    let mut data_alts = Vec::new();
31    let mut lit_alts = Vec::new();
32    let mut default_alt = None;
33
34    for alt in alts {
35        match &alt.con {
36            AltCon::DataAlt(_) => data_alts.push(alt),
37            AltCon::LitAlt(_) => lit_alts.push(alt),
38            AltCon::Default => default_alt = Some(alt),
39        }
40    }
41
42    // 4. Create merge block
43    let merge_block = builder.create_block();
44    builder.append_block_param(merge_block, types::I64);
45
46    // 5. Dispatch
47    if !data_alts.is_empty() {
48        emit_data_dispatch(
49            ctx,
50            sess,
51            builder,
52            scrut_ptr,
53            &data_alts,
54            default_alt,
55            merge_block,
56            tail,
57        )?;
58    } else if !lit_alts.is_empty() {
59        emit_lit_dispatch(
60            ctx,
61            sess,
62            builder,
63            scrut,
64            &lit_alts,
65            default_alt,
66            merge_block,
67            tail,
68        )?;
69    } else if let Some(alt) = default_alt {
70        // Default only
71        let result = ctx.emit_node(sess, builder, alt.body, tail)?;
72        let result_ptr = ensure_heap_ptr(builder, sess.vmctx, sess.gc_sig, sess.oom_func, result);
73        builder
74            .ins()
75            .jump(merge_block, &[BlockArg::Value(result_ptr)]);
76    } else {
77        // No alts? Call runtime_case_trap to handle pending errors gracefully.
78        emit_case_trap(sess, builder, scrut_ptr, &[], merge_block)?;
79    }
80
81    // Seal merge block
82    builder.seal_block(merge_block);
83
84    // Switch to merge block
85    builder.switch_to_block(merge_block);
86    let result = builder.block_params(merge_block)[0];
87    builder.declare_value_needs_stack_map(result);
88    ctx.declare_env(builder);
89
90    // 6. Restore case binder
91    ctx.env.restore(*binder, old_case_binder);
92
93    Ok(SsaVal::HeapPtr(result))
94}
95
96#[allow(clippy::too_many_arguments)]
97fn emit_data_dispatch(
98    ctx: &mut EmitContext,
99    sess: &mut EmitSession,
100    builder: &mut FunctionBuilder,
101    initial_scrut_ptr: Value,
102    data_alts: &[&Alt<usize>],
103    default_alt: Option<&Alt<usize>>,
104    merge_block: ir::Block,
105    tail: TailCtx,
106) -> Result<(), EmitError> {
107    // 1. Force if needed (tag < 2: Closure or Thunk)
108    let tag = builder
109        .ins()
110        .load(types::I8, MemFlags::trusted(), initial_scrut_ptr, 0);
111    let needs_force = builder.ins().icmp_imm(IntCC::UnsignedLessThan, tag, 2);
112
113    let force_block = builder.create_block();
114    let dispatch_block = builder.create_block();
115    builder.append_block_param(dispatch_block, types::I64);
116
117    builder.ins().brif(
118        needs_force,
119        force_block,
120        &[],
121        dispatch_block,
122        &[BlockArg::Value(initial_scrut_ptr)],
123    );
124
125    // Force block: call host_fns::heap_force
126    builder.switch_to_block(force_block);
127    builder.seal_block(force_block);
128
129    let force_fn = sess
130        .pipeline
131        .module
132        .declare_function("heap_force", Linkage::Import, &{
133            let mut sig = Signature::new(sess.pipeline.isa.default_call_conv());
134            sig.params.push(AbiParam::new(types::I64)); // vmctx
135            sig.params.push(AbiParam::new(types::I64)); // thunk
136            sig.returns.push(AbiParam::new(types::I64)); // result
137            sig
138        })
139        .map_err(|e| EmitError::CraneliftError(e.to_string()))?;
140    let force_ref = sess
141        .pipeline
142        .module
143        .declare_func_in_func(force_fn, builder.func);
144
145    let call = builder
146        .ins()
147        .call(force_ref, &[sess.vmctx, initial_scrut_ptr]);
148    let force_result = builder.inst_results(call)[0];
149    builder.declare_value_needs_stack_map(force_result);
150    builder
151        .ins()
152        .jump(dispatch_block, &[BlockArg::Value(force_result)]);
153
154    // Dispatch block: actual pattern matching starts here
155    builder.switch_to_block(dispatch_block);
156    builder.seal_block(dispatch_block);
157    let scrut_ptr = builder.block_params(dispatch_block)[0];
158    builder.declare_value_needs_stack_map(scrut_ptr);
159
160    // Load con_tag as u64 from offset 8
161    let con_tag = builder
162        .ins()
163        .load(types::I64, MemFlags::trusted(), scrut_ptr, CON_TAG_OFFSET);
164
165    // Use comparison chain instead of jump table because DataConIds are large
166    // GHC Uniques (arbitrary u64 values), not small sequential integers.
167    for &alt in data_alts {
168        if let AltCon::DataAlt(tag) = &alt.con {
169            let alt_block = builder.create_block();
170            let next_check_block = builder.create_block();
171
172            let tag_val = builder.ins().iconst(types::I64, tag.0 as i64);
173            let eq = builder.ins().icmp(IntCC::Equal, con_tag, tag_val);
174            builder
175                .ins()
176                .brif(eq, alt_block, &[], next_check_block, &[]);
177
178            // Emit alt body
179            builder.switch_to_block(alt_block);
180            builder.seal_block(alt_block);
181            ctx.declare_env(builder);
182
183            // Bind pattern variables — do NOT force thunked fields.
184            // In Haskell, case alt binders are lazy. Thunked Con fields
185            // remain as thunks until used in a strict context (case scrutiny,
186            // primop args, etc.). Forcing here causes infinite loops for
187            // self-referencing structures like `xs = 1 : map (+1) xs`.
188            //
189            // INVARIANT: All strict consumers must force thunked values before
190            // reading heap layout. The forcing points are:
191            //   - emit_lit_dispatch: force_thunk_ssaval on scrutinee
192            //   - emit_data_dispatch: tag < 2 check → heap_force on scrutinee
193            //   - PrimOp collapse: force_thunk_ssaval on all args
194            //   - App collapse: tag check → heap_force on fun position
195            //   - unbox_int/unbox_double/unbox_float: defensive trap on TAG_THUNK
196            // See force_thunk_ssaval in expr.rs.
197            let mut scope = EnvScope::new();
198            // NOTE: EnvGuard cannot be used here because it would borrow ctx.env
199            // mutably, preventing the use of ctx in emit_node.
200            for (i, &binder) in alt.binders.iter().enumerate() {
201                let offset = CON_FIELDS_OFFSET + (8 * i as i32);
202                let field_val =
203                    builder
204                        .ins()
205                        .load(types::I64, MemFlags::trusted(), scrut_ptr, offset);
206                builder.declare_value_needs_stack_map(field_val);
207                ctx.env
208                    .insert_scoped(&mut scope, binder, SsaVal::HeapPtr(field_val));
209            }
210
211            let result = ctx.emit_node(sess, builder, alt.body, tail)?;
212            let result_ptr =
213                ensure_heap_ptr(builder, sess.vmctx, sess.gc_sig, sess.oom_func, result);
214            builder
215                .ins()
216                .jump(merge_block, &[BlockArg::Value(result_ptr)]);
217
218            // Restore pattern variable bindings
219            ctx.env.restore_scope(scope);
220
221            // Continue to next check
222            builder.switch_to_block(next_check_block);
223            builder.seal_block(next_check_block);
224        }
225    }
226
227    // Default or trap
228    if let Some(alt) = default_alt {
229        ctx.declare_env(builder);
230        let result = ctx.emit_node(sess, builder, alt.body, tail)?;
231        let result_ptr = ensure_heap_ptr(builder, sess.vmctx, sess.gc_sig, sess.oom_func, result);
232        builder
233            .ins()
234            .jump(merge_block, &[BlockArg::Value(result_ptr)]);
235    } else {
236        emit_case_trap(sess, builder, scrut_ptr, data_alts, merge_block)?;
237    }
238
239    Ok(())
240}
241
242/// Emit a call to `runtime_case_trap` instead of a bare `trap user2`.
243/// Passes the scrutinee pointer and expected alt tags for diagnostic output.
244fn emit_case_trap(
245    sess: &mut EmitSession,
246    builder: &mut FunctionBuilder,
247    scrut_ptr: Value,
248    data_alts: &[&Alt<usize>],
249    merge_block: ir::Block,
250) -> Result<(), EmitError> {
251    // Collect expected tags
252    let tags: Vec<u64> = data_alts
253        .iter()
254        .filter_map(|alt| {
255            if let AltCon::DataAlt(tag) = &alt.con {
256                Some(tag.0)
257            } else {
258                None
259            }
260        })
261        .collect();
262
263    // Store tags on stack
264    let num_alts = tags.len();
265    let ss = builder.create_sized_stack_slot(ir::StackSlotData::new(
266        ir::StackSlotKind::ExplicitSlot,
267        (num_alts * 8) as u32,
268        3, // align 8
269    ));
270    for (i, &tag) in tags.iter().enumerate() {
271        let tag_val = builder.ins().iconst(types::I64, tag as i64);
272        builder.ins().stack_store(tag_val, ss, (i * 8) as i32);
273    }
274    let tags_addr = builder.ins().stack_addr(types::I64, ss, 0);
275
276    let trap_fn = sess
277        .pipeline
278        .module
279        .declare_function("runtime_case_trap", Linkage::Import, &{
280            let mut sig = Signature::new(sess.pipeline.isa.default_call_conv());
281            sig.params.push(AbiParam::new(types::I64)); // scrut_ptr
282            sig.params.push(AbiParam::new(types::I64)); // num_alts
283            sig.params.push(AbiParam::new(types::I64)); // alt_tags
284            sig.returns.push(AbiParam::new(types::I64)); // returns poison ptr
285            sig
286        })
287        .map_err(|e| EmitError::CraneliftError(e.to_string()))?;
288    let trap_ref = sess
289        .pipeline
290        .module
291        .declare_func_in_func(trap_fn, builder.func);
292    let num_alts_val = builder.ins().iconst(types::I64, num_alts as i64);
293    let call = builder
294        .ins()
295        .call(trap_ref, &[scrut_ptr, num_alts_val, tags_addr]);
296    let result = builder.inst_results(call)[0];
297    builder.ins().jump(merge_block, &[BlockArg::Value(result)]);
298    Ok(())
299}
300
301#[allow(clippy::too_many_arguments)]
302fn emit_lit_dispatch(
303    ctx: &mut EmitContext,
304    sess: &mut EmitSession,
305    builder: &mut FunctionBuilder,
306    scrut: SsaVal,
307    lit_alts: &[&Alt<usize>],
308    default_alt: Option<&Alt<usize>>,
309    merge_block: ir::Block,
310    tail: TailCtx,
311) -> Result<(), EmitError> {
312    // Force thunked scrutinees: literal case dispatch is strict —
313    // ThunkCon fields extracted by data alt matching may still be thunks.
314    let scrut = force_thunk_ssaval(sess.pipeline, builder, sess.vmctx, scrut)?;
315
316    // Unbox scrutinee: Raw values are already unboxed, HeapPtr needs LIT_VALUE_OFFSET load
317    let scrut_value = match scrut {
318        SsaVal::Raw(v, _) => v,
319        SsaVal::HeapPtr(ptr) => {
320            builder
321                .ins()
322                .load(types::I64, MemFlags::trusted(), ptr, LIT_VALUE_OFFSET)
323        }
324    };
325
326    for &alt in lit_alts {
327        let alt_block = builder.create_block();
328        let next_check_block = builder.create_block();
329
330        if let AltCon::LitAlt(lit) = &alt.con {
331            match lit {
332                Literal::LitInt(n) => {
333                    let lit_val = builder.ins().iconst(types::I64, *n);
334                    let eq = builder.ins().icmp(IntCC::Equal, scrut_value, lit_val);
335                    builder
336                        .ins()
337                        .brif(eq, alt_block, &[], next_check_block, &[]);
338                }
339                Literal::LitWord(n) => {
340                    let lit_val = builder.ins().iconst(types::I64, *n as i64);
341                    let eq = builder.ins().icmp(IntCC::Equal, scrut_value, lit_val);
342                    builder
343                        .ins()
344                        .brif(eq, alt_block, &[], next_check_block, &[]);
345                }
346                Literal::LitChar(c) => {
347                    let lit_val = builder.ins().iconst(types::I64, *c as i64);
348                    let eq = builder.ins().icmp(IntCC::Equal, scrut_value, lit_val);
349                    builder
350                        .ins()
351                        .brif(eq, alt_block, &[], next_check_block, &[]);
352                }
353                Literal::LitFloat(bits) => {
354                    let scrut_f64 = builder.ins().bitcast(
355                        types::F64,
356                        MemFlags::new().with_endianness(ir::Endianness::Little),
357                        scrut_value,
358                    );
359                    let lit_val = builder.ins().f64const(f64::from_bits(*bits));
360                    let eq = builder
361                        .ins()
362                        .fcmp(ir::condcodes::FloatCC::Equal, scrut_f64, lit_val);
363                    builder
364                        .ins()
365                        .brif(eq, alt_block, &[], next_check_block, &[]);
366                }
367                Literal::LitDouble(bits) => {
368                    let scrut_f64 = builder.ins().bitcast(
369                        types::F64,
370                        MemFlags::new().with_endianness(ir::Endianness::Little),
371                        scrut_value,
372                    );
373                    let lit_val = builder.ins().f64const(f64::from_bits(*bits));
374                    let eq = builder
375                        .ins()
376                        .fcmp(ir::condcodes::FloatCC::Equal, scrut_f64, lit_val);
377                    builder
378                        .ins()
379                        .brif(eq, alt_block, &[], next_check_block, &[]);
380                }
381                Literal::LitString(_) => {
382                    return Err(EmitError::NotYetImplemented("LitString in Case".into()))
383                }
384            }
385        }
386
387        // Emit alt body
388        builder.switch_to_block(alt_block);
389        builder.seal_block(alt_block);
390        ctx.declare_env(builder);
391        let result = ctx.emit_node(sess, builder, alt.body, tail)?;
392        let result_ptr = ensure_heap_ptr(builder, sess.vmctx, sess.gc_sig, sess.oom_func, result);
393        builder
394            .ins()
395            .jump(merge_block, &[BlockArg::Value(result_ptr)]);
396
397        // Continue to next check
398        builder.switch_to_block(next_check_block);
399        builder.seal_block(next_check_block);
400    }
401
402    // Default or trap
403    if let Some(alt) = default_alt {
404        ctx.declare_env(builder);
405        let result = ctx.emit_node(sess, builder, alt.body, tail)?;
406        let result_ptr = ensure_heap_ptr(builder, sess.vmctx, sess.gc_sig, sess.oom_func, result);
407        builder
408            .ins()
409            .jump(merge_block, &[BlockArg::Value(result_ptr)]);
410    } else {
411        // No alts matched.
412        // We pass empty data_alts since these are lit alts.
413        emit_case_trap(sess, builder, scrut_value, &[], merge_block)?;
414    }
415
416    Ok(())
417}