Skip to main content

shape_jit/translator/
osr_compiler.rs

1//! OSR (On-Stack Replacement) Loop Compilation
2//!
3//! Compiles hot loop bodies to native code via Cranelift IR for mid-execution
4//! transfer from the bytecode interpreter to JIT-compiled code.
5//!
6//! # OSR ABI
7//! `extern "C" fn(ctx_ptr: *mut u8, _unused: *const u8) -> u64`
8//! - Returns 0 on normal loop exit (locals written back to ctx).
9//! - Returns `u64::MAX` on deoptimization (locals partially written back).
10//!
11//! # Escape Analysis / Scalar Replacement
12//! The escape analysis pass (Phase 5) identifies small non-escaping arrays
13//! for scalar replacement in the whole-function JIT compiler. OSR compilation
14//! does NOT support NewArray/GetProp/SetLocalIndex opcodes (they fail the
15//! preflight check in `is_osr_supported_opcode`), so scalar replacement does
16//! not apply to OSR-compiled loop bodies. If OSR support for array opcodes is
17//! added in the future, deopt materialization must reconstruct scalar-replaced
18//! arrays from their SSA variable elements before writing locals back to ctx.
19
20use std::collections::{HashMap, HashSet};
21
22use cranelift::prelude::*;
23use cranelift_module::{Linkage, Module};
24
25use shape_vm::bytecode::{DeoptInfo, Instruction, OpCode, Operand, OsrEntryPoint};
26use shape_vm::type_tracking::{FrameDescriptor, SlotKind};
27
28use super::loop_analysis::LoopInfo;
29
30/// Result of compiling a loop body for OSR entry.
31#[derive(Debug)]
32pub struct OsrCompilationResult {
33    /// Native code pointer for the compiled loop body.
34    pub native_code: *const u8,
35    /// OSR entry point metadata (live locals, kinds, bytecode IPs).
36    pub entry_point: OsrEntryPoint,
37    /// Deopt info for all guard points within the compiled loop.
38    pub deopt_points: Vec<DeoptInfo>,
39}
40
41// SAFETY: native_code pointer is valid for the lifetime of the JIT compilation
42// and is only used within the VM execution context.
43unsafe impl Send for OsrCompilationResult {}
44
45/// Maximum number of locals the JIT context buffer can hold.
46/// The locals area spans u64 indices 8..264 (256 slots).
47const JIT_LOCALS_CAP: usize = 256;
48
49/// Byte offset where locals begin in the JIT context buffer.
50const LOCALS_BYTE_OFFSET: i32 = 64; // 8 * 8
51
52/// Check whether an opcode is in the supported MVP set for OSR compilation.
53fn is_osr_supported_opcode(opcode: OpCode, operand: &Option<Operand>) -> bool {
54    use shape_vm::bytecode::BuiltinFunction as BF;
55    match opcode {
56        // Stack
57        OpCode::PushConst | OpCode::PushNull | OpCode::Pop | OpCode::Dup | OpCode::Swap => true,
58        // Variables
59        OpCode::LoadLocal
60        | OpCode::LoadLocalTrusted
61        | OpCode::StoreLocal
62        | OpCode::StoreLocalTyped => true,
63        OpCode::LoadModuleBinding
64        | OpCode::StoreModuleBinding
65        | OpCode::StoreModuleBindingTyped => true,
66        // Arithmetic (Int)
67        OpCode::AddInt
68        | OpCode::SubInt
69        | OpCode::MulInt
70        | OpCode::DivInt
71        | OpCode::ModInt
72        | OpCode::PowInt => true,
73        // Arithmetic (Number)
74        OpCode::AddNumber
75        | OpCode::SubNumber
76        | OpCode::MulNumber
77        | OpCode::DivNumber
78        | OpCode::ModNumber
79        | OpCode::PowNumber => true,
80        // Neg
81        OpCode::Neg => true,
82        // Comparison (Int)
83        OpCode::GtInt
84        | OpCode::LtInt
85        | OpCode::GteInt
86        | OpCode::LteInt
87        | OpCode::EqInt
88        | OpCode::NeqInt => true,
89        // Comparison (Number)
90        OpCode::GtNumber
91        | OpCode::LtNumber
92        | OpCode::GteNumber
93        | OpCode::LteNumber
94        | OpCode::EqNumber
95        | OpCode::NeqNumber => true,
96        // Logic
97        OpCode::And | OpCode::Or | OpCode::Not => true,
98        // Control
99        OpCode::Jump
100        | OpCode::JumpIfFalse
101        | OpCode::JumpIfFalseTrusted
102        | OpCode::JumpIfTrue
103        | OpCode::LoopStart
104        | OpCode::LoopEnd
105        | OpCode::Break
106        | OpCode::Continue => true,
107        // Coercion / width casts
108        OpCode::IntToNumber | OpCode::NumberToInt | OpCode::CastWidth => true,
109        // Return (mapped to loop exit)
110        OpCode::Return | OpCode::ReturnValue => true,
111        // Misc
112        OpCode::Nop | OpCode::Halt | OpCode::Debug => true,
113        // BuiltinCall: only selected math builtins
114        OpCode::BuiltinCall => {
115            if let Some(Operand::Builtin(bf)) = operand {
116                matches!(
117                    bf,
118                    BF::Abs
119                        | BF::Sqrt
120                        | BF::Min
121                        | BF::Max
122                        | BF::Floor
123                        | BF::Ceil
124                        | BF::Round
125                        | BF::Pow
126                )
127            } else {
128                false
129            }
130        }
131        _ => false,
132    }
133}
134
135/// Compile a loop body for OSR (On-Stack Replacement) entry.
136///
137/// Emits Cranelift IR for the loop body with the OSR ABI:
138/// `extern "C" fn(ctx_ptr: *mut u8, _: *const u8) -> u64`
139///
140/// - Returns 0 on normal loop exit (locals written back to ctx).
141/// - Returns `u64::MAX` on deoptimization (locals partially written back).
142///
143/// # Arguments
144/// * `jit` - The JIT compiler instance (owns the Cranelift module).
145/// * `function` - The function containing the target loop.
146/// * `instructions` - The full instruction stream of the function.
147/// * `loop_info` - Analysis results for the target loop (from `analyze_loops`).
148/// * `frame_descriptor` - Typed frame layout for slot marshaling.
149pub fn compile_osr_loop(
150    jit: &mut crate::compiler::JITCompiler,
151    function: &shape_vm::bytecode::Function,
152    instructions: &[Instruction],
153    loop_info: &LoopInfo,
154    frame_descriptor: &FrameDescriptor,
155) -> Result<OsrCompilationResult, String> {
156    // Validate the loop bounds are within the instruction stream.
157    if loop_info.header_idx >= instructions.len() {
158        return Err(format!(
159            "OSR loop header {} is out of bounds (instruction count: {})",
160            loop_info.header_idx,
161            instructions.len()
162        ));
163    }
164    if loop_info.end_idx >= instructions.len() {
165        return Err(format!(
166            "OSR loop end {} is out of bounds (instruction count: {})",
167            loop_info.end_idx,
168            instructions.len()
169        ));
170    }
171
172    // Preflight: reject loops containing unsupported opcodes.
173    for idx in loop_info.header_idx..=loop_info.end_idx {
174        let instr = &instructions[idx];
175        if !is_osr_supported_opcode(instr.opcode, &instr.operand) {
176            return Err(format!(
177                "OSR unsupported opcode {:?} at instruction {}",
178                instr.opcode, idx
179            ));
180        }
181    }
182
183    // Compute live locals: union of read and written sets.
184    let mut live_locals: Vec<u16> = loop_info
185        .body_locals_read
186        .union(&loop_info.body_locals_written)
187        .copied()
188        .collect();
189    live_locals.sort_unstable();
190
191    // Check all locals fit within JIT locals capacity.
192    for &local_idx in &live_locals {
193        if local_idx as usize >= JIT_LOCALS_CAP {
194            return Err(format!(
195                "OSR local index {} exceeds JIT_LOCALS_CAP ({})",
196                local_idx, JIT_LOCALS_CAP
197            ));
198        }
199    }
200
201    // Map each live local to its SlotKind from the frame descriptor.
202    let local_kinds: Vec<SlotKind> = live_locals
203        .iter()
204        .map(|&slot| {
205            frame_descriptor
206                .slots
207                .get(slot as usize)
208                .copied()
209                .unwrap_or(SlotKind::Unknown)
210        })
211        .collect();
212
213    let entry_point = OsrEntryPoint {
214        bytecode_ip: loop_info.header_idx,
215        live_locals: live_locals.clone(),
216        local_kinds: local_kinds.clone(),
217        exit_ip: loop_info.end_idx + 1,
218    };
219
220    // Body locals written — used for epilogue (only write back modified locals).
221    let body_locals_written: HashSet<u16> = loop_info.body_locals_written.clone();
222
223    // --- Cranelift compilation ---
224
225    // Declare the OSR function: (i64, i64) -> i64
226    let func_name = format!("osr_loop_f{}_ip{}", function.arity, loop_info.header_idx);
227    let mut sig = jit.module_mut().make_signature();
228    sig.params.push(AbiParam::new(types::I64)); // ctx_ptr
229    sig.params.push(AbiParam::new(types::I64)); // unused
230    sig.returns.push(AbiParam::new(types::I64)); // result (0 or u64::MAX)
231
232    let func_id = jit
233        .module_mut()
234        .declare_function(&func_name, Linkage::Export, &sig)
235        .map_err(|e| format!("Failed to declare OSR function: {}", e))?;
236
237    let mut ctx = cranelift::codegen::Context::new();
238    ctx.func.signature = sig;
239
240    {
241        let mut builder = FunctionBuilder::new(&mut ctx.func, jit.builder_context_mut());
242
243        // Create blocks
244        let entry_block = builder.create_block();
245        let exit_block = builder.create_block();
246        let deopt_block = builder.create_block();
247
248        // Pre-scan for jump targets inside the loop body to create blocks
249        let mut block_map: HashMap<usize, Block> = HashMap::new();
250        // The loop header gets its own block (this is the main loop block)
251        let header_block = builder.create_block();
252        block_map.insert(loop_info.header_idx, header_block);
253
254        for idx in loop_info.header_idx..=loop_info.end_idx {
255            let instr = &instructions[idx];
256            match instr.opcode {
257                OpCode::Jump
258                | OpCode::JumpIfFalse
259                | OpCode::JumpIfFalseTrusted
260                | OpCode::JumpIfTrue => {
261                    if let Some(Operand::Offset(off)) = instr.operand {
262                        let target = (idx as i64 + off as i64 + 1) as usize;
263                        if target >= loop_info.header_idx
264                            && target <= loop_info.end_idx + 1
265                            && !block_map.contains_key(&target)
266                        {
267                            let blk = builder.create_block();
268                            block_map.insert(target, blk);
269                        }
270                    }
271                }
272                _ => {}
273            }
274            // Also create a block for the instruction after a conditional branch
275            // (fall-through target)
276            match instr.opcode {
277                OpCode::JumpIfFalse | OpCode::JumpIfFalseTrusted | OpCode::JumpIfTrue => {
278                    let fall_through = idx + 1;
279                    if fall_through >= loop_info.header_idx
280                        && fall_through <= loop_info.end_idx
281                        && !block_map.contains_key(&fall_through)
282                    {
283                        let blk = builder.create_block();
284                        block_map.insert(fall_through, blk);
285                    }
286                }
287                _ => {}
288            }
289        }
290
291        // Declare Cranelift variables for all live locals
292        let max_local = live_locals.iter().copied().max().unwrap_or(0) as usize;
293        for local_idx in 0..=max_local {
294            builder.declare_var(Variable::new(local_idx), types::I64);
295        }
296        // Declare compile-time stack variables (generous upper bound)
297        let stack_var_base = JIT_LOCALS_CAP;
298        let max_stack_depth = 32usize;
299        for s in 0..max_stack_depth {
300            builder.declare_var(Variable::new(stack_var_base + s), types::I64);
301        }
302
303        // ---- Entry block: load live locals from JIT context buffer ----
304        builder.append_block_params_for_function_params(entry_block);
305        builder.switch_to_block(entry_block);
306        builder.seal_block(entry_block);
307
308        let ctx_ptr = builder.block_params(entry_block)[0];
309
310        // Load live locals from context buffer
311        for &local_idx in &live_locals {
312            let offset = LOCALS_BYTE_OFFSET + (local_idx as i32) * 8;
313            let val = builder
314                .ins()
315                .load(types::I64, MemFlags::trusted(), ctx_ptr, offset);
316            builder.def_var(Variable::new(local_idx as usize), val);
317        }
318
319        // Jump to loop header block
320        builder.ins().jump(header_block, &[]);
321
322        // ---- Compile loop body instructions ----
323        // Compile-time operand stack depth tracker
324        let mut stack_depth: usize = 0;
325        // Manual block termination tracking (replaces builder.is_filled())
326        let mut block_terminated: bool = false;
327
328        macro_rules! stack_push {
329            ($builder:expr, $val:expr, $depth:expr) => {{
330                let var = Variable::new(stack_var_base + $depth);
331                $builder.def_var(var, $val);
332                $depth += 1;
333            }};
334        }
335        macro_rules! stack_pop {
336            ($builder:expr, $depth:expr) => {{
337                $depth -= 1;
338                let var = Variable::new(stack_var_base + $depth);
339                $builder.use_var(var)
340            }};
341        }
342
343        for idx in loop_info.header_idx..=loop_info.end_idx {
344            // Switch to the block for this instruction if one exists
345            if let Some(&blk) = block_map.get(&idx) {
346                if idx != loop_info.header_idx || block_terminated {
347                    if !block_terminated {
348                        builder.ins().jump(blk, &[]);
349                    }
350                }
351                builder.switch_to_block(blk);
352                block_terminated = false;
353                // Don't seal loop header yet (it has a back-edge)
354                if idx != loop_info.header_idx {
355                    builder.seal_block(blk);
356                }
357            }
358
359            // Skip instruction emission if block already terminated
360            if block_terminated {
361                continue;
362            }
363
364            let instr = &instructions[idx];
365            match instr.opcode {
366                OpCode::Nop | OpCode::Debug | OpCode::LoopStart => {
367                    // No-ops in JIT
368                }
369
370                OpCode::LoopEnd => {
371                    // Back-edge: jump to header
372                    builder.ins().jump(header_block, &[]);
373                    block_terminated = true;
374                }
375
376                OpCode::PushNull => {
377                    let null = builder
378                        .ins()
379                        .iconst(types::I64, crate::nan_boxing::TAG_NULL as i64);
380                    stack_push!(builder, null, stack_depth);
381                }
382
383                OpCode::PushConst => {
384                    if let Some(Operand::Const(_const_idx)) = instr.operand {
385                        // For OSR MVP, we deopt on constants we can't resolve inline.
386                        // The JitCompilationBackend will provide constant resolution
387                        // in a future pass.
388                        let null = builder
389                            .ins()
390                            .iconst(types::I64, crate::nan_boxing::TAG_NULL as i64);
391                        stack_push!(builder, null, stack_depth);
392                    }
393                }
394
395                OpCode::Pop => {
396                    if stack_depth > 0 {
397                        let _ = stack_pop!(builder, stack_depth);
398                    }
399                }
400
401                OpCode::Dup => {
402                    if stack_depth > 0 {
403                        let var = Variable::new(stack_var_base + stack_depth - 1);
404                        let val = builder.use_var(var);
405                        stack_push!(builder, val, stack_depth);
406                    }
407                }
408
409                OpCode::Swap => {
410                    if stack_depth >= 2 {
411                        let var_a = Variable::new(stack_var_base + stack_depth - 1);
412                        let var_b = Variable::new(stack_var_base + stack_depth - 2);
413                        let a = builder.use_var(var_a);
414                        let b = builder.use_var(var_b);
415                        builder.def_var(var_a, b);
416                        builder.def_var(var_b, a);
417                    }
418                }
419
420                OpCode::LoadLocal | OpCode::LoadLocalTrusted => {
421                    if let Some(Operand::Local(local_idx)) = instr.operand {
422                        let val = builder.use_var(Variable::new(local_idx as usize));
423                        stack_push!(builder, val, stack_depth);
424                    }
425                }
426
427                OpCode::StoreLocal => {
428                    if let Some(Operand::Local(local_idx)) = instr.operand {
429                        if stack_depth > 0 {
430                            let val = stack_pop!(builder, stack_depth);
431                            builder.def_var(Variable::new(local_idx as usize), val);
432                        }
433                    }
434                }
435
436                OpCode::StoreLocalTyped => {
437                    if let Some(Operand::TypedLocal(local_idx, _width)) = instr.operand {
438                        if stack_depth > 0 {
439                            let val = stack_pop!(builder, stack_depth);
440                            // OSR MVP: store without truncation (width enforcement
441                            // is done by the interpreter; JIT uses same raw i64).
442                            builder.def_var(Variable::new(local_idx as usize), val);
443                        }
444                    }
445                }
446
447                // Integer arithmetic: values in JIT context are raw i64 for Int64 slots.
448                OpCode::AddInt => {
449                    if stack_depth >= 2 {
450                        let b = stack_pop!(builder, stack_depth);
451                        let a = stack_pop!(builder, stack_depth);
452                        let result = builder.ins().iadd(a, b);
453                        stack_push!(builder, result, stack_depth);
454                    }
455                }
456                OpCode::SubInt => {
457                    if stack_depth >= 2 {
458                        let b = stack_pop!(builder, stack_depth);
459                        let a = stack_pop!(builder, stack_depth);
460                        let result = builder.ins().isub(a, b);
461                        stack_push!(builder, result, stack_depth);
462                    }
463                }
464                OpCode::MulInt => {
465                    if stack_depth >= 2 {
466                        let b = stack_pop!(builder, stack_depth);
467                        let a = stack_pop!(builder, stack_depth);
468                        let result = builder.ins().imul(a, b);
469                        stack_push!(builder, result, stack_depth);
470                    }
471                }
472                OpCode::DivInt => {
473                    if stack_depth >= 2 {
474                        let b = stack_pop!(builder, stack_depth);
475                        let a = stack_pop!(builder, stack_depth);
476                        let result = builder.ins().sdiv(a, b);
477                        stack_push!(builder, result, stack_depth);
478                    }
479                }
480                OpCode::ModInt => {
481                    if stack_depth >= 2 {
482                        let b = stack_pop!(builder, stack_depth);
483                        let a = stack_pop!(builder, stack_depth);
484                        let result = builder.ins().srem(a, b);
485                        stack_push!(builder, result, stack_depth);
486                    }
487                }
488                OpCode::PowInt => {
489                    // Power is complex — deopt for now
490                    builder.ins().jump(deopt_block, &[]);
491                    block_terminated = true;
492                }
493
494                // Float arithmetic: values are NaN-boxed f64 bit patterns.
495                // Bitcast to f64, operate, bitcast back.
496                OpCode::AddNumber => {
497                    if stack_depth >= 2 {
498                        let b = stack_pop!(builder, stack_depth);
499                        let a = stack_pop!(builder, stack_depth);
500                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
501                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
502                        let r_f = builder.ins().fadd(a_f, b_f);
503                        let result = builder.ins().bitcast(types::I64, MemFlags::new(), r_f);
504                        stack_push!(builder, result, stack_depth);
505                    }
506                }
507                OpCode::SubNumber => {
508                    if stack_depth >= 2 {
509                        let b = stack_pop!(builder, stack_depth);
510                        let a = stack_pop!(builder, stack_depth);
511                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
512                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
513                        let r_f = builder.ins().fsub(a_f, b_f);
514                        let result = builder.ins().bitcast(types::I64, MemFlags::new(), r_f);
515                        stack_push!(builder, result, stack_depth);
516                    }
517                }
518                OpCode::MulNumber => {
519                    if stack_depth >= 2 {
520                        let b = stack_pop!(builder, stack_depth);
521                        let a = stack_pop!(builder, stack_depth);
522                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
523                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
524                        let r_f = builder.ins().fmul(a_f, b_f);
525                        let result = builder.ins().bitcast(types::I64, MemFlags::new(), r_f);
526                        stack_push!(builder, result, stack_depth);
527                    }
528                }
529                OpCode::DivNumber => {
530                    if stack_depth >= 2 {
531                        let b = stack_pop!(builder, stack_depth);
532                        let a = stack_pop!(builder, stack_depth);
533                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
534                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
535                        let r_f = builder.ins().fdiv(a_f, b_f);
536                        let result = builder.ins().bitcast(types::I64, MemFlags::new(), r_f);
537                        stack_push!(builder, result, stack_depth);
538                    }
539                }
540                OpCode::ModNumber => {
541                    if stack_depth >= 2 {
542                        let b = stack_pop!(builder, stack_depth);
543                        let a = stack_pop!(builder, stack_depth);
544                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
545                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
546                        // fmod: a - trunc(a/b) * b
547                        let div = builder.ins().fdiv(a_f, b_f);
548                        let trunced = builder.ins().trunc(div);
549                        let prod = builder.ins().fmul(trunced, b_f);
550                        let r_f = builder.ins().fsub(a_f, prod);
551                        let result = builder.ins().bitcast(types::I64, MemFlags::new(), r_f);
552                        stack_push!(builder, result, stack_depth);
553                    }
554                }
555                OpCode::PowNumber => {
556                    // Power is complex — deopt
557                    builder.ins().jump(deopt_block, &[]);
558                    block_terminated = true;
559                }
560
561                OpCode::Neg => {
562                    if stack_depth >= 1 {
563                        let val = stack_pop!(builder, stack_depth);
564                        let result = builder.ins().ineg(val);
565                        stack_push!(builder, result, stack_depth);
566                    }
567                }
568
569                // Integer comparisons: compare raw i64, produce i64 (0 or 1)
570                OpCode::LtInt => {
571                    if stack_depth >= 2 {
572                        let b = stack_pop!(builder, stack_depth);
573                        let a = stack_pop!(builder, stack_depth);
574                        let cmp = builder.ins().icmp(IntCC::SignedLessThan, a, b);
575                        let result = builder.ins().uextend(types::I64, cmp);
576                        stack_push!(builder, result, stack_depth);
577                    }
578                }
579                OpCode::GtInt => {
580                    if stack_depth >= 2 {
581                        let b = stack_pop!(builder, stack_depth);
582                        let a = stack_pop!(builder, stack_depth);
583                        let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, a, b);
584                        let result = builder.ins().uextend(types::I64, cmp);
585                        stack_push!(builder, result, stack_depth);
586                    }
587                }
588                OpCode::LteInt => {
589                    if stack_depth >= 2 {
590                        let b = stack_pop!(builder, stack_depth);
591                        let a = stack_pop!(builder, stack_depth);
592                        let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, a, b);
593                        let result = builder.ins().uextend(types::I64, cmp);
594                        stack_push!(builder, result, stack_depth);
595                    }
596                }
597                OpCode::GteInt => {
598                    if stack_depth >= 2 {
599                        let b = stack_pop!(builder, stack_depth);
600                        let a = stack_pop!(builder, stack_depth);
601                        let cmp = builder.ins().icmp(IntCC::SignedGreaterThanOrEqual, a, b);
602                        let result = builder.ins().uextend(types::I64, cmp);
603                        stack_push!(builder, result, stack_depth);
604                    }
605                }
606                OpCode::EqInt => {
607                    if stack_depth >= 2 {
608                        let b = stack_pop!(builder, stack_depth);
609                        let a = stack_pop!(builder, stack_depth);
610                        let cmp = builder.ins().icmp(IntCC::Equal, a, b);
611                        let result = builder.ins().uextend(types::I64, cmp);
612                        stack_push!(builder, result, stack_depth);
613                    }
614                }
615                OpCode::NeqInt => {
616                    if stack_depth >= 2 {
617                        let b = stack_pop!(builder, stack_depth);
618                        let a = stack_pop!(builder, stack_depth);
619                        let cmp = builder.ins().icmp(IntCC::NotEqual, a, b);
620                        let result = builder.ins().uextend(types::I64, cmp);
621                        stack_push!(builder, result, stack_depth);
622                    }
623                }
624
625                // Float comparisons: bitcast to f64, compare, produce i64
626                OpCode::LtNumber => {
627                    if stack_depth >= 2 {
628                        let b = stack_pop!(builder, stack_depth);
629                        let a = stack_pop!(builder, stack_depth);
630                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
631                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
632                        let cmp = builder.ins().fcmp(FloatCC::LessThan, a_f, b_f);
633                        let result = builder.ins().uextend(types::I64, cmp);
634                        stack_push!(builder, result, stack_depth);
635                    }
636                }
637                OpCode::GtNumber => {
638                    if stack_depth >= 2 {
639                        let b = stack_pop!(builder, stack_depth);
640                        let a = stack_pop!(builder, stack_depth);
641                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
642                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
643                        let cmp = builder.ins().fcmp(FloatCC::GreaterThan, a_f, b_f);
644                        let result = builder.ins().uextend(types::I64, cmp);
645                        stack_push!(builder, result, stack_depth);
646                    }
647                }
648                OpCode::LteNumber => {
649                    if stack_depth >= 2 {
650                        let b = stack_pop!(builder, stack_depth);
651                        let a = stack_pop!(builder, stack_depth);
652                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
653                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
654                        let cmp = builder.ins().fcmp(FloatCC::LessThanOrEqual, a_f, b_f);
655                        let result = builder.ins().uextend(types::I64, cmp);
656                        stack_push!(builder, result, stack_depth);
657                    }
658                }
659                OpCode::GteNumber => {
660                    if stack_depth >= 2 {
661                        let b = stack_pop!(builder, stack_depth);
662                        let a = stack_pop!(builder, stack_depth);
663                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
664                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
665                        let cmp = builder.ins().fcmp(FloatCC::GreaterThanOrEqual, a_f, b_f);
666                        let result = builder.ins().uextend(types::I64, cmp);
667                        stack_push!(builder, result, stack_depth);
668                    }
669                }
670                OpCode::EqNumber => {
671                    if stack_depth >= 2 {
672                        let b = stack_pop!(builder, stack_depth);
673                        let a = stack_pop!(builder, stack_depth);
674                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
675                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
676                        let cmp = builder.ins().fcmp(FloatCC::Equal, a_f, b_f);
677                        let result = builder.ins().uextend(types::I64, cmp);
678                        stack_push!(builder, result, stack_depth);
679                    }
680                }
681                OpCode::NeqNumber => {
682                    if stack_depth >= 2 {
683                        let b = stack_pop!(builder, stack_depth);
684                        let a = stack_pop!(builder, stack_depth);
685                        let a_f = builder.ins().bitcast(types::F64, MemFlags::new(), a);
686                        let b_f = builder.ins().bitcast(types::F64, MemFlags::new(), b);
687                        let cmp = builder.ins().fcmp(FloatCC::NotEqual, a_f, b_f);
688                        let result = builder.ins().uextend(types::I64, cmp);
689                        stack_push!(builder, result, stack_depth);
690                    }
691                }
692
693                // Logic: operands are i64 (0 = false, nonzero = true)
694                OpCode::And => {
695                    if stack_depth >= 2 {
696                        let b = stack_pop!(builder, stack_depth);
697                        let a = stack_pop!(builder, stack_depth);
698                        let result = builder.ins().band(a, b);
699                        stack_push!(builder, result, stack_depth);
700                    }
701                }
702                OpCode::Or => {
703                    if stack_depth >= 2 {
704                        let b = stack_pop!(builder, stack_depth);
705                        let a = stack_pop!(builder, stack_depth);
706                        let result = builder.ins().bor(a, b);
707                        stack_push!(builder, result, stack_depth);
708                    }
709                }
710                OpCode::Not => {
711                    if stack_depth >= 1 {
712                        let val = stack_pop!(builder, stack_depth);
713                        let zero = builder.ins().iconst(types::I64, 0);
714                        let cmp = builder.ins().icmp(IntCC::Equal, val, zero);
715                        let result = builder.ins().uextend(types::I64, cmp);
716                        stack_push!(builder, result, stack_depth);
717                    }
718                }
719
720                // Coercion
721                OpCode::IntToNumber => {
722                    if stack_depth >= 1 {
723                        let val = stack_pop!(builder, stack_depth);
724                        // Raw i64 → f64 → bitcast to i64 (NaN-boxed)
725                        let f = builder.ins().fcvt_from_sint(types::F64, val);
726                        let result = builder.ins().bitcast(types::I64, MemFlags::new(), f);
727                        stack_push!(builder, result, stack_depth);
728                    }
729                }
730                OpCode::NumberToInt => {
731                    if stack_depth >= 1 {
732                        let val = stack_pop!(builder, stack_depth);
733                        // NaN-boxed f64 → f64 → truncate to i64
734                        let f = builder.ins().bitcast(types::F64, MemFlags::new(), val);
735                        let result = builder.ins().fcvt_to_sint_sat(types::I64, f);
736                        stack_push!(builder, result, stack_depth);
737                    }
738                }
739
740                OpCode::CastWidth => {
741                    if stack_depth >= 1 {
742                        if let Some(Operand::Width(width)) = &instr.operand {
743                            if let Some(int_w) = width.to_int_width() {
744                                let val = stack_pop!(builder, stack_depth);
745                                let mask = int_w.mask() as i64;
746                                let mask_val = builder.ins().iconst(types::I64, mask);
747                                let truncated = builder.ins().band(val, mask_val);
748                                let result = if int_w.is_signed() {
749                                    let bits = int_w.bits() as i64;
750                                    let shift = 64 - bits;
751                                    let shift_val = builder.ins().iconst(types::I64, shift);
752                                    let shifted = builder.ins().ishl(truncated, shift_val);
753                                    builder.ins().sshr(shifted, shift_val)
754                                } else {
755                                    truncated
756                                };
757                                stack_push!(builder, result, stack_depth);
758                            }
759                        }
760                    }
761                }
762
763                // Control flow
764                OpCode::Jump => {
765                    if let Some(Operand::Offset(off)) = instr.operand {
766                        let target = (idx as i64 + off as i64 + 1) as usize;
767                        if target > loop_info.end_idx {
768                            builder.ins().jump(exit_block, &[]);
769                        } else if let Some(&blk) = block_map.get(&target) {
770                            builder.ins().jump(blk, &[]);
771                        } else {
772                            builder.ins().jump(deopt_block, &[]);
773                        }
774                        block_terminated = true;
775                    }
776                }
777
778                OpCode::JumpIfFalse | OpCode::JumpIfFalseTrusted => {
779                    if let Some(Operand::Offset(off)) = instr.operand {
780                        let target = (idx as i64 + off as i64 + 1) as usize;
781                        if stack_depth > 0 {
782                            let cond = stack_pop!(builder, stack_depth);
783                            let zero = builder.ins().iconst(types::I64, 0);
784                            let is_false = builder.ins().icmp(IntCC::Equal, cond, zero);
785
786                            let target_block = if target > loop_info.end_idx {
787                                exit_block
788                            } else {
789                                block_map.get(&target).copied().unwrap_or(deopt_block)
790                            };
791                            let fall_through =
792                                block_map.get(&(idx + 1)).copied().unwrap_or(deopt_block);
793
794                            builder
795                                .ins()
796                                .brif(is_false, target_block, &[], fall_through, &[]);
797                            block_terminated = true;
798                        }
799                    }
800                }
801
802                OpCode::JumpIfTrue => {
803                    if let Some(Operand::Offset(off)) = instr.operand {
804                        let target = (idx as i64 + off as i64 + 1) as usize;
805                        if stack_depth > 0 {
806                            let cond = stack_pop!(builder, stack_depth);
807                            let zero = builder.ins().iconst(types::I64, 0);
808                            let is_true = builder.ins().icmp(IntCC::NotEqual, cond, zero);
809
810                            let target_block = if target > loop_info.end_idx {
811                                exit_block
812                            } else {
813                                block_map.get(&target).copied().unwrap_or(deopt_block)
814                            };
815                            let fall_through =
816                                block_map.get(&(idx + 1)).copied().unwrap_or(deopt_block);
817
818                            builder
819                                .ins()
820                                .brif(is_true, target_block, &[], fall_through, &[]);
821                            block_terminated = true;
822                        }
823                    }
824                }
825
826                OpCode::Break => {
827                    builder.ins().jump(exit_block, &[]);
828                    block_terminated = true;
829                }
830
831                OpCode::Continue => {
832                    builder.ins().jump(header_block, &[]);
833                    block_terminated = true;
834                }
835
836                OpCode::Return | OpCode::ReturnValue => {
837                    builder.ins().jump(exit_block, &[]);
838                    block_terminated = true;
839                }
840
841                OpCode::Halt => {
842                    builder.ins().jump(exit_block, &[]);
843                    block_terminated = true;
844                }
845
846                // Module bindings: not in JIT context buffer. Deopt if encountered.
847                OpCode::LoadModuleBinding
848                | OpCode::StoreModuleBinding
849                | OpCode::StoreModuleBindingTyped => {
850                    builder.ins().jump(deopt_block, &[]);
851                    block_terminated = true;
852                }
853
854                // Builtin calls: math functions. Deopt for MVP.
855                OpCode::BuiltinCall => {
856                    builder.ins().jump(deopt_block, &[]);
857                    block_terminated = true;
858                }
859
860                _ => {
861                    // Unsupported opcode — should have been caught by preflight
862                    builder.ins().jump(deopt_block, &[]);
863                    block_terminated = true;
864                }
865            }
866        }
867
868        // Seal the loop header block (all predecessors are now known)
869        builder.seal_block(header_block);
870
871        // ---- Exit block: store modified locals back, return 0 ----
872        builder.switch_to_block(exit_block);
873        builder.seal_block(exit_block);
874
875        for &local_idx in &live_locals {
876            if body_locals_written.contains(&local_idx) {
877                let val = builder.use_var(Variable::new(local_idx as usize));
878                let offset = LOCALS_BYTE_OFFSET + (local_idx as i32) * 8;
879                builder
880                    .ins()
881                    .store(MemFlags::trusted(), val, ctx_ptr, offset);
882            }
883        }
884        let zero_ret = builder.ins().iconst(types::I64, 0);
885        builder.ins().return_(&[zero_ret]);
886
887        // ---- Deopt block: store ALL live locals back, return u64::MAX ----
888        builder.switch_to_block(deopt_block);
889        builder.seal_block(deopt_block);
890
891        for &local_idx in &live_locals {
892            let val = builder.use_var(Variable::new(local_idx as usize));
893            let offset = LOCALS_BYTE_OFFSET + (local_idx as i32) * 8;
894            builder
895                .ins()
896                .store(MemFlags::trusted(), val, ctx_ptr, offset);
897        }
898        let deopt_sentinel = builder.ins().iconst(types::I64, u64::MAX as i64);
899        builder.ins().return_(&[deopt_sentinel]);
900
901        builder.finalize();
902    }
903
904    // Compile and define the function
905    jit.module_mut()
906        .define_function(func_id, &mut ctx)
907        .map_err(|e| format!("Failed to define OSR function: {}", e))?;
908    jit.module_mut().clear_context(&mut ctx);
909    jit.module_mut()
910        .finalize_definitions()
911        .map_err(|e| format!("Failed to finalize OSR function: {}", e))?;
912
913    let code_ptr = jit.module_mut().get_finalized_function(func_id);
914
915    Ok(OsrCompilationResult {
916        native_code: code_ptr,
917        entry_point,
918        deopt_points: Vec::new(),
919    })
920}