Skip to main content

wasm_pvm/translate/
mod.rs

1// Address calculations and jump offsets often require wrapping/truncation.
2#![allow(
3    clippy::cast_possible_truncation,
4    clippy::cast_possible_wrap,
5    clippy::cast_sign_loss
6)]
7
8pub mod adapter_merge;
9pub mod dead_function_elimination;
10pub mod memory_layout;
11pub mod wasm_module;
12
13use std::collections::HashMap;
14
15use crate::pvm::Instruction;
16use crate::{Error, Result, SpiProgram};
17
18pub use wasm_module::WasmModule;
19
20/// Action to take when a WASM import is called.
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum ImportAction {
23    /// Emit a trap (unreachable) instruction.
24    Trap,
25    /// Emit a no-op (return 0 for functions with return values).
26    Nop,
27}
28
29/// Flags to enable/disable individual compiler optimizations.
30/// All optimizations are enabled by default.
31#[derive(Debug, Clone)]
32#[allow(clippy::struct_excessive_bools)]
33pub struct OptimizationFlags {
34    /// Run LLVM optimization passes (mem2reg, instcombine, simplifycfg, gvn, dce).
35    /// When false, also disables inlining (all LLVM passes are skipped).
36    pub llvm_passes: bool,
37    /// Run peephole optimizer (fallthrough removal, dead code elimination).
38    pub peephole: bool,
39    /// Enable per-block register cache (store-load forwarding).
40    pub register_cache: bool,
41    /// Fuse `ICmp` + Branch into a single PVM branch instruction.
42    pub icmp_branch_fusion: bool,
43    /// Only save/restore callee-saved registers (r9-r12) that are actually used.
44    pub shrink_wrap_callee_saves: bool,
45    /// Eliminate SP-relative stores whose target offset is never loaded from.
46    pub dead_store_elimination: bool,
47    /// Skip redundant `LoadImm`/`LoadImm64` when the register already holds the constant.
48    pub constant_propagation: bool,
49    /// Inline small functions at the LLVM IR level to eliminate call overhead.
50    pub inlining: bool,
51    /// Propagate register cache across single-predecessor block boundaries.
52    pub cross_block_cache: bool,
53    /// Allocate long-lived SSA values to physical registers (r5, r6) across block boundaries.
54    pub register_allocation: bool,
55    /// Eliminate unreachable functions not called from entry points or the function table.
56    pub dead_function_elimination: bool,
57    /// Eliminate unconditional jumps to the immediately following block (fallthrough).
58    pub fallthrough_jumps: bool,
59}
60
61impl Default for OptimizationFlags {
62    fn default() -> Self {
63        Self {
64            llvm_passes: true,
65            peephole: true,
66            register_cache: true,
67            icmp_branch_fusion: true,
68            shrink_wrap_callee_saves: true,
69            dead_store_elimination: true,
70            constant_propagation: true,
71            inlining: true,
72            cross_block_cache: true,
73            register_allocation: true,
74            dead_function_elimination: true,
75            fallthrough_jumps: true,
76        }
77    }
78}
79
80/// Options for compilation.
81#[derive(Debug, Clone, Default)]
82pub struct CompileOptions {
83    /// Mapping from import function names to actions.
84    /// When provided, all imports (except known intrinsics like `host_call` and `pvm_ptr`)
85    /// must have a mapping or compilation will fail with `UnresolvedImport`.
86    pub import_map: Option<HashMap<String, ImportAction>>,
87    /// WAT source for an adapter module whose exports replace matching main imports.
88    /// Applied before the text-based import map, so the two compose.
89    pub adapter: Option<String>,
90    /// Metadata blob to prepend to the SPI output.
91    /// Typically contains the source filename and compiler version.
92    pub metadata: Vec<u8>,
93    /// Optimization flags controlling which compiler passes are enabled.
94    pub optimizations: OptimizationFlags,
95}
96
97// Re-export register constants from abi module
98pub use crate::abi::{ARGS_LEN_REG, ARGS_PTR_REG, RETURN_ADDR_REG, STACK_PTR_REG};
99
100// ── Call fixup types (shared with LLVM backend) ──
101
102#[derive(Debug, Clone)]
103pub struct CallFixup {
104    pub return_addr_instr: usize,
105    pub jump_instr: usize,
106    pub target_func: u32,
107}
108
109#[derive(Debug, Clone)]
110pub struct IndirectCallFixup {
111    pub return_addr_instr: usize,
112    // For `LoadImmJumpInd`, this equals `return_addr_instr`.
113    pub jump_ind_instr: usize,
114}
115
116/// `RO_DATA` region size is 64KB (0x10000 to 0x1FFFF)
117const RO_DATA_SIZE: usize = 64 * 1024;
118
119pub fn compile(wasm: &[u8]) -> Result<SpiProgram> {
120    compile_with_options(wasm, &CompileOptions::default())
121}
122
123pub fn compile_with_options(wasm: &[u8], options: &CompileOptions) -> Result<SpiProgram> {
124    // Known intrinsics that don't need import mappings.
125    const KNOWN_INTRINSICS: &[&str] = &["host_call", "pvm_ptr"];
126    // Default mappings applied when no explicit import map is provided.
127    const DEFAULT_MAPPINGS: &[&str] = &["abort"];
128
129    // Apply adapter merge if provided (produces a new WASM binary with fewer imports).
130    let merged_wasm;
131    let wasm = if let Some(adapter_wat) = &options.adapter {
132        merged_wasm = adapter_merge::merge_adapter(wasm, adapter_wat)?;
133        &merged_wasm
134    } else {
135        wasm
136    };
137
138    let module = WasmModule::parse(wasm)?;
139
140    // Validate that all imports are resolved.
141    for name in &module.imported_func_names {
142        if KNOWN_INTRINSICS.contains(&name.as_str()) {
143            continue;
144        }
145        if let Some(import_map) = &options.import_map {
146            if import_map.contains_key(name) {
147                continue;
148            }
149        } else if DEFAULT_MAPPINGS.contains(&name.as_str()) {
150            continue;
151        }
152        return Err(Error::UnresolvedImport(format!(
153            "import '{name}' has no mapping. Provide a mapping via --imports or add it to the import map."
154        )));
155    }
156
157    compile_via_llvm(&module, options)
158}
159
160pub fn compile_via_llvm(module: &WasmModule, options: &CompileOptions) -> Result<SpiProgram> {
161    use crate::llvm_backend::{self, LoweringContext};
162    use crate::llvm_frontend;
163    use inkwell::context::Context;
164
165    // Phase 0: Dead function elimination — compute reachable set.
166    let reachable_locals = if options.optimizations.dead_function_elimination {
167        Some(dead_function_elimination::reachable_functions(module)?)
168    } else {
169        None
170    };
171
172    // Phase 1: WASM → LLVM IR
173    let context = Context::create();
174    let llvm_module = llvm_frontend::translate_wasm_to_llvm(
175        &context,
176        module,
177        options.optimizations.llvm_passes,
178        options.optimizations.inlining,
179        reachable_locals.as_ref(),
180    )?;
181
182    // Calculate RO_DATA offsets and lengths for passive data segments
183    let mut data_segment_offsets = std::collections::HashMap::new();
184    let mut data_segment_lengths = std::collections::HashMap::new();
185    let mut current_ro_offset = if module.function_table.is_empty() {
186        1 // dummy byte if no function table
187    } else {
188        module.function_table.len() * 8 // jump_ref + type_idx per entry
189    };
190
191    let mut data_segment_length_addrs = std::collections::HashMap::new();
192    let mut passive_ordinal = 0usize;
193
194    for (idx, seg) in module.data_segments.iter().enumerate() {
195        if seg.offset.is_none() {
196            // Check that segment fits within RO_DATA region
197            if current_ro_offset + seg.data.len() > RO_DATA_SIZE {
198                return Err(Error::Internal(format!(
199                    "passive data segment {} (size {}) would overflow RO_DATA region ({} bytes used of {})",
200                    idx,
201                    seg.data.len(),
202                    current_ro_offset,
203                    RO_DATA_SIZE
204                )));
205            }
206            data_segment_offsets.insert(idx as u32, current_ro_offset as u32);
207            data_segment_lengths.insert(idx as u32, seg.data.len() as u32);
208            data_segment_length_addrs.insert(
209                idx as u32,
210                memory_layout::data_segment_length_offset(module.globals.len(), passive_ordinal),
211            );
212            current_ro_offset += seg.data.len();
213            passive_ordinal += 1;
214        }
215    }
216
217    // Phase 2: Build lowering context
218    let ctx = LoweringContext {
219        wasm_memory_base: module.wasm_memory_base,
220        num_globals: module.globals.len(),
221        function_signatures: module.function_signatures.clone(),
222        type_signatures: module.type_signatures.clone(),
223        function_table: module.function_table.clone(),
224        num_imported_funcs: module.num_imported_funcs as usize,
225        imported_func_names: module.imported_func_names.clone(),
226        initial_memory_pages: module.memory_limits.initial_pages,
227        max_memory_pages: module.max_memory_pages,
228        stack_size: memory_layout::DEFAULT_STACK_SIZE,
229        data_segment_offsets,
230        data_segment_lengths,
231        data_segment_length_addrs,
232        wasm_import_map: options.import_map.clone(),
233        optimizations: options.optimizations.clone(),
234    };
235
236    // Phase 3: LLVM IR → PVM bytecode for each function
237    let mut all_instructions: Vec<Instruction> = Vec::new();
238    let mut all_call_fixups: Vec<(usize, CallFixup)> = Vec::new();
239    let mut all_indirect_call_fixups: Vec<(usize, IndirectCallFixup)> = Vec::new();
240    let mut function_offsets: Vec<usize> = vec![0; module.functions.len()];
241    let mut next_call_return_idx: usize = 0;
242
243    // Entry header: Jump to main (PC=0) + Trap or secondary Jump (PC=5).
244    // When there's no secondary entry, we omit the Fallthrough padding (6 bytes instead of 10).
245    all_instructions.push(Instruction::Jump { offset: 0 });
246    if module.has_secondary_entry {
247        all_instructions.push(Instruction::Jump { offset: 0 });
248    } else {
249        all_instructions.push(Instruction::Trap);
250    }
251
252    // Build emission order: main first, then secondary (if any), then remaining in index order.
253    // This places main immediately after the entry header, minimizing the entry Jump distance.
254    let mut emission_order: Vec<usize> = Vec::with_capacity(module.functions.len());
255    emission_order.push(module.main_func_local_idx);
256    if let Some(secondary_idx) = module.secondary_entry_local_idx
257        && secondary_idx != module.main_func_local_idx
258    {
259        emission_order.push(secondary_idx);
260    }
261    for idx in 0..module.functions.len() {
262        if idx != module.main_func_local_idx && module.secondary_entry_local_idx != Some(idx) {
263            emission_order.push(idx);
264        }
265    }
266
267    for &local_func_idx in &emission_order {
268        // Dead functions: emit a single Trap as a placeholder.
269        // The function offset is still recorded so dispatch table indices stay valid.
270        if reachable_locals
271            .as_ref()
272            .is_some_and(|r| !r.contains(&local_func_idx))
273        {
274            let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
275            function_offsets[local_func_idx] = func_start_offset;
276            all_instructions.push(Instruction::Trap);
277            continue;
278        }
279
280        let global_func_idx = module.num_imported_funcs as usize + local_func_idx;
281        let fn_name = format!("wasm_func_{global_func_idx}");
282        let llvm_func = llvm_module
283            .get_function(&fn_name)
284            .ok_or_else(|| Error::Internal(format!("missing LLVM function: {fn_name}")))?;
285
286        let is_main = local_func_idx == module.main_func_local_idx;
287        let is_secondary = module.secondary_entry_local_idx == Some(local_func_idx);
288        let is_entry = is_main || is_secondary;
289
290        let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
291        function_offsets[local_func_idx] = func_start_offset;
292
293        // If entry function and there's a start function, call it first.
294        if let Some(start_local_idx) = module.start_func_local_idx.filter(|_| is_entry) {
295            // Save r7 and r8 to stack.
296            all_instructions.push(Instruction::AddImm64 {
297                dst: STACK_PTR_REG,
298                src: STACK_PTR_REG,
299                value: -16,
300            });
301            all_instructions.push(Instruction::StoreIndU64 {
302                base: STACK_PTR_REG,
303                src: ARGS_PTR_REG,
304                offset: 0,
305            });
306            all_instructions.push(Instruction::StoreIndU64 {
307                base: STACK_PTR_REG,
308                src: ARGS_LEN_REG,
309                offset: 8,
310            });
311
312            // Call start function using LoadImmJump (combined load + jump).
313            let call_return_addr = ((next_call_return_idx + 1) * 2) as i32;
314            next_call_return_idx += 1;
315            let current_instr_idx = all_instructions.len();
316            all_instructions.push(Instruction::LoadImmJump {
317                reg: RETURN_ADDR_REG,
318                value: call_return_addr,
319                offset: 0, // patched during fixup resolution
320            });
321
322            all_call_fixups.push((
323                current_instr_idx,
324                CallFixup {
325                    target_func: start_local_idx as u32,
326                    return_addr_instr: 0,
327                    jump_instr: 0, // same instruction for LoadImmJump
328                },
329            ));
330
331            // Restore r7 and r8.
332            all_instructions.push(Instruction::LoadIndU64 {
333                dst: ARGS_PTR_REG,
334                base: STACK_PTR_REG,
335                offset: 0,
336            });
337            all_instructions.push(Instruction::LoadIndU64 {
338                dst: ARGS_LEN_REG,
339                base: STACK_PTR_REG,
340                offset: 8,
341            });
342            all_instructions.push(Instruction::AddImm64 {
343                dst: STACK_PTR_REG,
344                src: STACK_PTR_REG,
345                value: 16,
346            });
347        }
348
349        let translation = llvm_backend::lower_function(
350            llvm_func,
351            &ctx,
352            is_entry,
353            global_func_idx,
354            next_call_return_idx,
355        )?;
356        next_call_return_idx += translation.num_call_returns;
357
358        let instr_base = all_instructions.len();
359        for fixup in translation.call_fixups {
360            all_call_fixups.push((
361                instr_base,
362                CallFixup {
363                    return_addr_instr: fixup.return_addr_instr,
364                    jump_instr: fixup.jump_instr,
365                    target_func: fixup.target_func,
366                },
367            ));
368        }
369        for fixup in translation.indirect_call_fixups {
370            all_indirect_call_fixups.push((
371                instr_base,
372                IndirectCallFixup {
373                    return_addr_instr: fixup.return_addr_instr,
374                    jump_ind_instr: fixup.jump_ind_instr,
375                },
376            ));
377        }
378
379        all_instructions.extend(translation.instructions);
380    }
381
382    // Phase 4: Resolve call fixups and build jump table.
383    let (jump_table, func_entry_jump_table_base) = resolve_call_fixups(
384        &mut all_instructions,
385        &all_call_fixups,
386        &all_indirect_call_fixups,
387        &function_offsets,
388    )?;
389
390    // Patch entry header jumps.
391    let main_offset = function_offsets[module.main_func_local_idx] as i32;
392    if let Instruction::Jump { offset } = &mut all_instructions[0] {
393        *offset = main_offset;
394    }
395
396    if let Some(secondary_idx) = module.secondary_entry_local_idx {
397        let secondary_offset = function_offsets[secondary_idx] as i32 - 5;
398        if let Instruction::Jump { offset } = &mut all_instructions[1] {
399            *offset = secondary_offset;
400        }
401    }
402
403    // Phase 5: Build dispatch table for call_indirect.
404    let mut ro_data = vec![0u8];
405    if !module.function_table.is_empty() {
406        ro_data.clear();
407        for &func_idx in &module.function_table {
408            if func_idx == u32::MAX || (func_idx as usize) < module.num_imported_funcs as usize {
409                ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
410                ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
411            } else {
412                let local_func_idx = func_idx as usize - module.num_imported_funcs as usize;
413                let jump_ref = 2 * (func_entry_jump_table_base + local_func_idx + 1) as u32;
414                ro_data.extend_from_slice(&jump_ref.to_le_bytes());
415                let type_idx = *module
416                    .function_type_indices
417                    .get(local_func_idx)
418                    .unwrap_or(&u32::MAX);
419                ro_data.extend_from_slice(&type_idx.to_le_bytes());
420            }
421        }
422    }
423
424    // Append passive data segments to RO_DATA.
425    // NOTE: This loop must iterate data_segments in the same order as the offset
426    // calculation loop above, since data_segment_offsets indices depend on it.
427    for seg in &module.data_segments {
428        if seg.offset.is_none() {
429            ro_data.extend_from_slice(&seg.data);
430        }
431    }
432
433    let blob = crate::pvm::ProgramBlob::new(all_instructions).with_jump_table(jump_table);
434    let rw_data_section = build_rw_data(
435        &module.data_segments,
436        &module.global_init_values,
437        module.memory_limits.initial_pages,
438        module.wasm_memory_base,
439        &ctx.data_segment_length_addrs,
440        &ctx.data_segment_lengths,
441    );
442
443    let heap_pages = calculate_heap_pages(
444        rw_data_section.len(),
445        module.wasm_memory_base,
446        module.memory_limits.initial_pages,
447        module.functions.len(),
448    )?;
449
450    Ok(SpiProgram::new(blob)
451        .with_heap_pages(heap_pages)
452        .with_ro_data(ro_data)
453        .with_rw_data(rw_data_section)
454        .with_metadata(options.metadata.clone()))
455}
456
457/// Calculate the number of 4KB PVM heap pages needed after `rw_data`.
458///
459/// `heap_pages` tells the runtime how many zero-initialized writable pages to allocate
460/// immediately after the `rw_data` blob. This covers the initial WASM linear memory,
461/// globals, and spilled locals that aren't already covered by `rw_data`.
462///
463/// By computing this **after** `build_rw_data()`, we use the actual (trimmed) `rw_data`
464/// length instead of guessing with headroom.
465///
466/// We add 1 extra page beyond the exact initial memory requirement. This ensures that
467/// the first `memory.grow` / sbrk allocation has a pre-allocated page available at the
468/// boundary of the initial WASM memory. Without it, PVM-in-PVM execution fails because
469/// the inner interpreter's page-fault handling at the exact heap boundary doesn't
470/// correctly propagate through the outer PVM.
471fn calculate_heap_pages(
472    rw_data_len: usize,
473    wasm_memory_base: i32,
474    initial_pages: u32,
475    num_functions: usize,
476) -> Result<u16> {
477    use wasm_module::MIN_INITIAL_WASM_PAGES;
478
479    let initial_pages = initial_pages.max(MIN_INITIAL_WASM_PAGES);
480    let wasm_memory_initial_end = wasm_memory_base as usize + (initial_pages as usize) * 64 * 1024;
481
482    let spilled_locals_end = memory_layout::SPILLED_LOCALS_BASE as usize
483        + num_functions * memory_layout::SPILLED_LOCALS_PER_FUNC as usize;
484
485    let end = spilled_locals_end.max(wasm_memory_initial_end);
486    let total_bytes = end - memory_layout::GLOBAL_MEMORY_BASE as usize;
487    let rw_pages = rw_data_len.div_ceil(4096);
488    let total_pages = total_bytes.div_ceil(4096);
489    let heap_pages = total_pages.saturating_sub(rw_pages) + 1;
490
491    u16::try_from(heap_pages).map_err(|_| {
492        Error::Internal(format!(
493            "heap size {heap_pages} pages exceeds u16::MAX ({}) — module too large",
494            u16::MAX
495        ))
496    })
497}
498
499/// Build the `rw_data` section from WASM data segments and global initializers.
500pub(crate) fn build_rw_data(
501    data_segments: &[wasm_module::DataSegment],
502    global_init_values: &[i32],
503    initial_memory_pages: u32,
504    wasm_memory_base: i32,
505    data_segment_length_addrs: &std::collections::HashMap<u32, i32>,
506    data_segment_lengths: &std::collections::HashMap<u32, u32>,
507) -> Vec<u8> {
508    // Calculate the minimum size needed for globals
509    // +1 for the compiler-managed memory size global, plus passive segment lengths
510    let num_passive_segments = data_segment_length_addrs.len();
511    let globals_end =
512        memory_layout::globals_region_size(global_init_values.len(), num_passive_segments);
513
514    // Calculate the size needed for data segments
515    let wasm_to_rw_offset = wasm_memory_base as u32 - 0x30000;
516
517    let data_end = data_segments
518        .iter()
519        .filter_map(|seg| {
520            seg.offset
521                .map(|off| wasm_to_rw_offset + off + seg.data.len() as u32)
522        })
523        .max()
524        .unwrap_or(0) as usize;
525
526    let total_size = globals_end.max(data_end);
527
528    if total_size == 0 {
529        return Vec::new();
530    }
531
532    let mut rw_data = vec![0u8; total_size];
533
534    // Initialize user globals
535    for (i, &value) in global_init_values.iter().enumerate() {
536        let offset = i * 4;
537        if offset + 4 <= rw_data.len() {
538            rw_data[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
539        }
540    }
541
542    // Initialize compiler-managed memory size global (right after user globals)
543    let mem_size_offset = global_init_values.len() * 4;
544    if mem_size_offset + 4 <= rw_data.len() {
545        rw_data[mem_size_offset..mem_size_offset + 4]
546            .copy_from_slice(&initial_memory_pages.to_le_bytes());
547    }
548
549    // Initialize passive data segment effective lengths (right after memory size global).
550    // These are used by memory.init for bounds checking and zeroed by data.drop.
551    for (&seg_idx, &addr) in data_segment_length_addrs {
552        if let Some(&length) = data_segment_lengths.get(&seg_idx) {
553            // addr is absolute PVM address; convert to rw_data offset
554            let rw_offset = (addr - memory_layout::GLOBAL_MEMORY_BASE) as usize;
555            if rw_offset + 4 <= rw_data.len() {
556                rw_data[rw_offset..rw_offset + 4].copy_from_slice(&length.to_le_bytes());
557            }
558        }
559    }
560
561    // Copy data segments to their WASM memory locations
562    for seg in data_segments {
563        if let Some(offset) = seg.offset {
564            let rw_offset = (wasm_to_rw_offset + offset) as usize;
565            if rw_offset + seg.data.len() <= rw_data.len() {
566                rw_data[rw_offset..rw_offset + seg.data.len()].copy_from_slice(&seg.data);
567            }
568        }
569    }
570
571    // Trim trailing zeros to reduce SPI size. Heap pages are zero-initialized,
572    // so omitted high-address zero bytes are semantically equivalent.
573    if let Some(last_non_zero) = rw_data.iter().rposition(|&b| b != 0) {
574        rw_data.truncate(last_non_zero + 1);
575    } else {
576        rw_data.clear();
577    }
578
579    rw_data
580}
581
582/// Extract the pre-assigned jump-table index from a return-address load instruction.
583///
584/// Call return addresses are pre-assigned as `(idx + 1) * 2` at emission time.
585/// This helper recovers `idx` so that `resolve_call_fixups` can write the byte
586/// offset into the correct jump-table slot instead of appending in list order
587/// (which would desync when a function mixes direct and indirect calls).
588///
589/// Direct calls use `LoadImmJump`, while indirect calls use either `LoadImm` (legacy
590/// two-instruction sequence) or `LoadImmJumpInd` (combined return-addr load + jump).
591fn return_addr_jump_table_idx(
592    instructions: &[Instruction],
593    return_addr_instr: usize,
594) -> Result<usize> {
595    let value = match instructions.get(return_addr_instr) {
596        Some(
597            Instruction::LoadImmJump { value, .. }
598            | Instruction::LoadImm { value, .. }
599            | Instruction::LoadImmJumpInd { value, .. },
600        ) => Some(*value),
601        _ => None,
602    };
603    match value {
604        Some(v) if v > 0 && v % 2 == 0 => Ok((v as usize / 2) - 1),
605        _ => Err(Error::Internal(format!(
606            "expected LoadImmJump/LoadImm/LoadImmJumpInd((idx+1)*2) at return_addr_instr {return_addr_instr}, got {:?}",
607            instructions.get(return_addr_instr)
608        ))),
609    }
610}
611
612fn resolve_call_fixups(
613    instructions: &mut [Instruction],
614    call_fixups: &[(usize, CallFixup)],
615    indirect_call_fixups: &[(usize, IndirectCallFixup)],
616    function_offsets: &[usize],
617) -> Result<(Vec<u32>, usize)> {
618    // Count total call-return entries by finding the maximum pre-assigned index.
619    // Entries are written at their pre-assigned slot so mixed direct/indirect
620    // call ordering within a function is preserved correctly.
621    let mut num_call_returns: usize = 0;
622
623    for (instr_base, fixup) in call_fixups {
624        let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
625        num_call_returns = num_call_returns.max(idx + 1);
626    }
627    for (instr_base, fixup) in indirect_call_fixups {
628        let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
629        num_call_returns = num_call_returns.max(idx + 1);
630    }
631
632    let mut jump_table: Vec<u32> = vec![0u32; num_call_returns];
633
634    // Call return addresses (LoadImmJump/LoadImm/LoadImmJumpInd values) are pre-assigned at emission time,
635    // so we only need to compute byte offsets for the jump table and patch Jump targets.
636    // Write each entry at its pre-assigned index to keep values in sync.
637    for (instr_base, fixup) in call_fixups {
638        let target_offset = function_offsets
639            .get(fixup.target_func as usize)
640            .ok_or_else(|| {
641                Error::Unsupported(format!("call to unknown function {}", fixup.target_func))
642            })?;
643
644        let jump_idx = instr_base + fixup.jump_instr;
645
646        // Return address = byte offset after the LoadImmJump instruction.
647        let return_addr_offset: usize = instructions[..=jump_idx]
648            .iter()
649            .map(|i| i.encode().len())
650            .sum();
651
652        let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
653        jump_table[slot] = return_addr_offset as u32;
654
655        // Verify pre-assigned jump table address matches actual index.
656        let expected_addr = ((slot + 1) * 2) as i32;
657        debug_assert!(
658            matches!(&instructions[jump_idx], Instruction::LoadImmJump { value, .. } if *value == expected_addr),
659            "pre-assigned jump table address mismatch: expected {expected_addr}, got {:?}",
660            &instructions[jump_idx]
661        );
662
663        // Patch the offset field of LoadImmJump.
664        let jump_start_offset: usize = instructions[..jump_idx]
665            .iter()
666            .map(|i| i.encode().len())
667            .sum();
668        let relative_offset = (*target_offset as i32) - (jump_start_offset as i32);
669
670        if let Instruction::LoadImmJump { offset, .. } = &mut instructions[jump_idx] {
671            *offset = relative_offset;
672        }
673    }
674
675    for (instr_base, fixup) in indirect_call_fixups {
676        let jump_ind_idx = instr_base + fixup.jump_ind_instr;
677
678        let return_addr_offset: usize = instructions[..=jump_ind_idx]
679            .iter()
680            .map(|i| i.encode().len())
681            .sum();
682
683        let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
684        jump_table[slot] = return_addr_offset as u32;
685    }
686
687    let func_entry_base = jump_table.len();
688    for &offset in function_offsets {
689        jump_table.push(offset as u32);
690    }
691
692    Ok((jump_table, func_entry_base))
693}
694
695#[cfg(test)]
696mod tests {
697    use std::collections::HashMap;
698
699    use super::build_rw_data;
700    use super::memory_layout;
701    use super::wasm_module::DataSegment;
702
703    #[test]
704    fn build_rw_data_trims_all_zero_tail_to_empty() {
705        let rw = build_rw_data(&[], &[], 0, 0x30000, &HashMap::new(), &HashMap::new());
706        assert!(rw.is_empty());
707    }
708
709    #[test]
710    fn build_rw_data_preserves_internal_zeros_and_trims_trailing_zeros() {
711        let data_segments = vec![DataSegment {
712            offset: Some(0),
713            data: vec![1, 0, 2, 0, 0],
714        }];
715
716        let rw = build_rw_data(
717            &data_segments,
718            &[],
719            0,
720            0x30000,
721            &HashMap::new(),
722            &HashMap::new(),
723        );
724
725        assert_eq!(rw, vec![1, 0, 2]);
726    }
727
728    #[test]
729    fn build_rw_data_keeps_non_zero_passive_length_bytes() {
730        let mut addrs = HashMap::new();
731        addrs.insert(0u32, memory_layout::GLOBAL_MEMORY_BASE + 4);
732        let mut lengths = HashMap::new();
733        lengths.insert(0u32, 7u32);
734
735        let rw = build_rw_data(&[], &[], 0, 0x30000, &addrs, &lengths);
736
737        assert_eq!(rw, vec![0, 0, 0, 0, 7]);
738    }
739
740    // ── calculate_heap_pages tests ──
741
742    #[test]
743    fn heap_pages_with_empty_rw_data_equals_total_pages_plus_one() {
744        // wasm_memory_base = 0x33000 (typical), initial_pages = 0 (clamped to 16)
745        // end = 0x33000 + 16*64*1024 = 0x33000 + 0x100000 = 0x133000
746        // total_bytes = 0x133000 - 0x30000 = 0x103000 = 1060864
747        // total_pages = ceil(1060864 / 4096) = 259
748        // rw_pages = 0, heap_pages = 259 + 1 = 260
749        let pages = super::calculate_heap_pages(0, 0x33000, 0, 10).unwrap();
750        assert_eq!(pages, 260);
751    }
752
753    #[test]
754    fn heap_pages_reduced_by_rw_data_pages() {
755        // Same scenario but with 8192 bytes of rw_data (2 pages)
756        let pages_no_rw = super::calculate_heap_pages(0, 0x33000, 0, 10).unwrap();
757        let pages_with_rw = super::calculate_heap_pages(8192, 0x33000, 0, 10).unwrap();
758        assert_eq!(pages_no_rw - pages_with_rw, 2);
759    }
760
761    #[test]
762    fn heap_pages_saturates_at_one_for_large_rw_data() {
763        // rw_data that covers more than total_pages still gets +1 headroom
764        let pages = super::calculate_heap_pages(2 * 1024 * 1024, 0x33000, 0, 10).unwrap();
765        assert_eq!(pages, 1);
766    }
767
768    #[test]
769    fn heap_pages_respects_initial_pages() {
770        // initial_pages = 32 (larger than MIN_INITIAL_WASM_PAGES=16)
771        // end = 0x33000 + 32*64*1024 = 0x33000 + 0x200000 = 0x233000
772        // total_bytes = 0x233000 - 0x30000 = 0x203000
773        // total_pages = ceil(0x203000 / 4096) = 515
774        // heap_pages = 515 + 1 = 516
775        let pages = super::calculate_heap_pages(0, 0x33000, 32, 10).unwrap();
776        assert_eq!(pages, 516);
777    }
778}