Skip to main content

wasm_pvm/translate/
mod.rs

1// Address calculations and jump offsets often require wrapping/truncation.
2#![allow(
3    clippy::cast_possible_truncation,
4    clippy::cast_possible_wrap,
5    clippy::cast_sign_loss
6)]
7
8pub mod adapter_merge;
9pub mod dead_function_elimination;
10pub use crate::memory_layout;
11pub mod stats;
12pub mod wasm_module;
13
14use std::collections::HashMap;
15
16use crate::pvm::Instruction;
17use crate::{Error, Result, SpiProgram};
18
19pub use wasm_module::WasmModule;
20
21/// Action to take when a WASM import is called.
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum ImportAction {
24    /// Emit a trap (unreachable) instruction.
25    Trap,
26    /// Emit a no-op (return 0 for functions with return values).
27    Nop,
28    /// Emit a PVM `ecalli` instruction with the given index.
29    /// Arguments are loaded into data registers (r7-r12), return value from r7.
30    Ecalli(u32),
31}
32
33/// Flags to enable/disable individual compiler optimizations.
34/// All optimizations are enabled by default.
35#[derive(Debug, Clone)]
36#[allow(clippy::struct_excessive_bools)]
37pub struct OptimizationFlags {
38    /// Run LLVM optimization passes (mem2reg, instcombine, simplifycfg, gvn, dce).
39    /// When false, also disables inlining (all LLVM passes are skipped).
40    pub llvm_passes: bool,
41    /// Run peephole optimizer (fallthrough removal, dead code elimination).
42    pub peephole: bool,
43    /// Enable per-block register cache (store-load forwarding).
44    pub register_cache: bool,
45    /// Fuse `ICmp` + Branch into a single PVM branch instruction.
46    pub icmp_branch_fusion: bool,
47    /// Only save/restore callee-saved registers (r9-r12) that are actually used.
48    pub shrink_wrap_callee_saves: bool,
49    /// Eliminate SP-relative stores whose target offset is never loaded from.
50    pub dead_store_elimination: bool,
51    /// Skip redundant `LoadImm`/`LoadImm64` when the register already holds the constant.
52    pub constant_propagation: bool,
53    /// Inline small functions at the LLVM IR level to eliminate call overhead.
54    pub inlining: bool,
55    /// Propagate register cache across single-predecessor block boundaries.
56    pub cross_block_cache: bool,
57    /// Allocate long-lived SSA values to physical registers (r5, r6) across block boundaries.
58    pub register_allocation: bool,
59    /// Eliminate unreachable functions not called from entry points or the function table.
60    pub dead_function_elimination: bool,
61    /// Eliminate unconditional jumps to the immediately following block (fallthrough).
62    pub fallthrough_jumps: bool,
63    /// Lower the minimum-use threshold for register allocation candidates from 2 to 1.
64    /// Captures more values (e.g. two-branch if-else patterns) at the cost of slightly
65    /// more `MoveReg` traffic in small leaf functions.
66    pub aggressive_register_allocation: bool,
67    /// Allocate r5/r6 (`abi::SCRATCH1`/`SCRATCH2`) in all functions that don't
68    /// clobber them (no bulk memory ops, no funnel shifts). In non-leaf functions,
69    /// spill/reload around calls is handled automatically.
70    pub allocate_scratch_regs: bool,
71    /// Allocate r7/r8 (`RETURN_VALUE_REG`/`ARGS_LEN_REG`) in all functions.
72    /// These are caller-saved and idle after the prologue; in non-leaf functions,
73    /// they are invalidated after calls via arity-aware predicate.
74    pub allocate_caller_saved_regs: bool,
75    /// Skip stack stores at definition for register-allocated values (lazy spill).
76    /// Values are only written to the stack when required (call clobber, return,
77    /// phi reads, eviction). Requires `register_allocation` to be effective.
78    pub lazy_spill: bool,
79    /// Max LLVM IR instructions for a function to be inlineable.
80    /// Functions exceeding this are marked `noinline`. `None` uses LLVM's
81    /// default (225). Default: `Some(5)` — only tiny helpers are inlined.
82    /// Only effective when `inlining` is `true`.
83    pub inline_threshold: Option<u32>,
84}
85
86impl Default for OptimizationFlags {
87    fn default() -> Self {
88        Self {
89            llvm_passes: true,
90            peephole: true,
91            register_cache: true,
92            icmp_branch_fusion: true,
93            shrink_wrap_callee_saves: true,
94            dead_store_elimination: true,
95            constant_propagation: true,
96            inlining: true,
97            cross_block_cache: true,
98            register_allocation: true,
99            dead_function_elimination: true,
100            fallthrough_jumps: true,
101            aggressive_register_allocation: true,
102            allocate_scratch_regs: true,
103            allocate_caller_saved_regs: true,
104            lazy_spill: true,
105            inline_threshold: Some(5),
106        }
107    }
108}
109
110/// Options for compilation.
111#[derive(Debug, Clone, Default)]
112pub struct CompileOptions {
113    /// Mapping from import function names to actions.
114    /// When provided, all imports (except known intrinsics like `host_call_N` and `pvm_ptr`)
115    /// must have a mapping or compilation will fail with `UnresolvedImport`.
116    pub import_map: Option<HashMap<String, ImportAction>>,
117    /// WAT source for an adapter module whose exports replace matching main imports.
118    /// Applied before the text-based import map, so the two compose.
119    pub adapter: Option<String>,
120    /// Metadata blob to prepend to the SPI output.
121    /// Typically contains the source filename and compiler version.
122    pub metadata: Vec<u8>,
123    /// Optimization flags controlling which compiler passes are enabled.
124    pub optimizations: OptimizationFlags,
125    /// Override the maximum memory pages (memory.grow ceiling).
126    /// When set, this takes precedence over both the WASM-declared max and the compiler default.
127    pub max_memory_pages: Option<u32>,
128}
129
130// Re-export register constants from abi module
131pub use crate::abi::{ARGS_LEN_REG, ARGS_PTR_REG, RETURN_ADDR_REG, STACK_PTR_REG};
132
133// ── Call fixup types (shared with LLVM backend) ──
134
135#[derive(Debug, Clone)]
136pub struct CallFixup {
137    pub return_addr_instr: usize,
138    pub jump_instr: usize,
139    pub target_func: u32,
140}
141
142#[derive(Debug, Clone)]
143pub struct IndirectCallFixup {
144    pub return_addr_instr: usize,
145    // For `LoadImmJumpInd`, this equals `return_addr_instr`.
146    pub jump_ind_instr: usize,
147}
148
149/// `RO_DATA` region size is 64KB (0x10000 to 0x1FFFF)
150const RO_DATA_SIZE: usize = 64 * 1024;
151
152/// Check if an import name is a known compiler intrinsic (`host_call_N` variants, `pvm_ptr`).
153fn is_known_intrinsic(name: &str) -> bool {
154    name == "pvm_ptr" || crate::abi::parse_host_call_variant(name).is_some()
155}
156
157/// Default mappings applied when no explicit import map is provided.
158const DEFAULT_MAPPINGS: &[&str] = &["abort"];
159
160pub fn compile(wasm: &[u8]) -> Result<SpiProgram> {
161    compile_with_options(wasm, &CompileOptions::default())
162}
163
164pub fn compile_with_options(wasm: &[u8], options: &CompileOptions) -> Result<SpiProgram> {
165    let (program, _) = compile_with_stats(wasm, options)?;
166    Ok(program)
167}
168
169pub fn compile_with_stats(
170    wasm: &[u8],
171    options: &CompileOptions,
172) -> Result<(SpiProgram, stats::CompileStats)> {
173    // Apply adapter merge if provided (produces a new WASM binary with fewer imports).
174    let merged_wasm;
175    let wasm = if let Some(adapter_wat) = &options.adapter {
176        merged_wasm = adapter_merge::merge_adapter(wasm, adapter_wat)?;
177        &merged_wasm
178    } else {
179        wasm
180    };
181
182    let mut module = WasmModule::parse(wasm)?;
183
184    // Apply max_memory_pages override if provided.
185    if let Some(max_pages) = options.max_memory_pages {
186        module.max_memory_pages = max_pages.max(module.memory_limits.initial_pages);
187    }
188
189    // Validate imports and collect resolutions.
190    let mut import_resolutions = Vec::new();
191    for name in &module.imported_func_names {
192        if is_known_intrinsic(name) {
193            let action = if name == "pvm_ptr" || name == "host_call_r8" {
194                "intrinsic"
195            } else {
196                "ecalli"
197            };
198            import_resolutions.push(stats::ImportResolution {
199                name: name.clone(),
200                action: action.to_string(),
201            });
202            continue;
203        }
204        if let Some(import_map) = &options.import_map {
205            if let Some(action) = import_map.get(name) {
206                let action_str = match action {
207                    ImportAction::Trap => "trap".to_string(),
208                    ImportAction::Nop => "nop".to_string(),
209                    ImportAction::Ecalli(idx) => format!("ecalli:{idx}"),
210                };
211                import_resolutions.push(stats::ImportResolution {
212                    name: name.clone(),
213                    action: action_str,
214                });
215                continue;
216            }
217        } else if DEFAULT_MAPPINGS.contains(&name.as_str()) {
218            import_resolutions.push(stats::ImportResolution {
219                name: name.clone(),
220                action: "trap (default)".to_string(),
221            });
222            continue;
223        }
224        return Err(Error::UnresolvedImport(format!(
225            "import '{name}' has no mapping. Provide a mapping via --imports or add it to the import map."
226        )));
227    }
228
229    let active_data_segments = module
230        .data_segments
231        .iter()
232        .filter(|s| s.offset.is_some())
233        .count();
234    let passive_data_segments = module
235        .data_segments
236        .iter()
237        .filter(|s| s.offset.is_none())
238        .count();
239    let globals_region_bytes =
240        memory_layout::globals_region_size(module.globals.len(), passive_data_segments);
241
242    let result = compile_via_llvm(&module, options)?;
243
244    let spi_blob_bytes = result.program.encode().len();
245
246    let compile_stats = stats::CompileStats {
247        local_functions: module.functions.len(),
248        imported_functions: module.num_imported_funcs as usize,
249        globals: module.globals.len(),
250        active_data_segments,
251        passive_data_segments,
252        function_table_entries: module.function_table.len(),
253        initial_memory_pages: module.memory_limits.initial_pages,
254        max_memory_pages: module.max_memory_pages,
255        wasm_declared_max_pages: module.memory_limits.max_pages,
256        import_resolutions,
257        wasm_memory_base: module.wasm_memory_base,
258        globals_region_bytes,
259        ro_data_bytes: result.program.ro_data().len(),
260        rw_data_bytes: result.program.rw_data().len(),
261        heap_pages: result.program.heap_pages(),
262        stack_size: memory_layout::DEFAULT_STACK_SIZE,
263        pvm_instructions: result.pvm_instructions,
264        code_bytes: result.code_bytes,
265        jump_table_entries: result.jump_table_entries,
266        dead_functions_eliminated: result.dead_functions_eliminated,
267        spi_blob_bytes,
268        functions: result.function_stats,
269    };
270
271    Ok((result.program, compile_stats))
272}
273
274/// Internal result of `compile_via_llvm`, carrying both the program and stats.
275struct CompilationOutput {
276    program: SpiProgram,
277    function_stats: Vec<stats::FunctionStats>,
278    dead_functions_eliminated: usize,
279    pvm_instructions: usize,
280    code_bytes: usize,
281    jump_table_entries: usize,
282}
283
284fn compile_via_llvm(module: &WasmModule, options: &CompileOptions) -> Result<CompilationOutput> {
285    use crate::llvm_backend::{self, LoweringContext};
286    use crate::llvm_frontend;
287    use inkwell::context::Context;
288
289    // Phase 0: Dead function elimination — compute reachable set.
290    let reachable_locals = if options.optimizations.dead_function_elimination {
291        Some(dead_function_elimination::reachable_functions(module)?)
292    } else {
293        None
294    };
295
296    // Phase 1: WASM → LLVM IR
297    let context = Context::create();
298    let llvm_module = llvm_frontend::translate_wasm_to_llvm(
299        &context,
300        module,
301        options.optimizations.llvm_passes,
302        options.optimizations.inlining,
303        options.optimizations.inline_threshold,
304        reachable_locals.as_ref(),
305    )?;
306
307    // Calculate RO_DATA offsets and lengths for passive data segments
308    let mut data_segment_offsets = std::collections::HashMap::new();
309    let mut data_segment_lengths = std::collections::HashMap::new();
310    let mut current_ro_offset = if module.function_table.is_empty() {
311        1 // dummy byte if no function table
312    } else {
313        module.function_table.len() * 8 // jump_ref + type_idx per entry
314    };
315
316    let mut data_segment_length_addrs = std::collections::HashMap::new();
317    let mut passive_ordinal = 0usize;
318
319    for (idx, seg) in module.data_segments.iter().enumerate() {
320        if seg.offset.is_none() {
321            // Check that segment fits within RO_DATA region
322            if current_ro_offset + seg.data.len() > RO_DATA_SIZE {
323                return Err(Error::Internal(format!(
324                    "passive data segment {} (size {}) would overflow RO_DATA region ({} bytes used of {})",
325                    idx,
326                    seg.data.len(),
327                    current_ro_offset,
328                    RO_DATA_SIZE
329                )));
330            }
331            data_segment_offsets.insert(idx as u32, current_ro_offset as u32);
332            data_segment_lengths.insert(idx as u32, seg.data.len() as u32);
333            data_segment_length_addrs.insert(
334                idx as u32,
335                memory_layout::data_segment_length_offset(module.globals.len(), passive_ordinal),
336            );
337            current_ro_offset += seg.data.len();
338            passive_ordinal += 1;
339        }
340    }
341
342    // Phase 2: Build lowering context
343    let param_overflow_base =
344        memory_layout::compute_param_overflow_base(module.globals.len(), passive_ordinal);
345    let ctx = LoweringContext {
346        wasm_memory_base: module.wasm_memory_base,
347        num_globals: module.globals.len(),
348        param_overflow_base,
349        function_signatures: module.function_signatures.clone(),
350        type_signatures: module.type_signatures.clone(),
351        function_table: module.function_table.clone(),
352        num_imported_funcs: module.num_imported_funcs as usize,
353        imported_func_names: module.imported_func_names.clone(),
354        initial_memory_pages: module.memory_limits.initial_pages,
355        max_memory_pages: module.max_memory_pages,
356        stack_size: memory_layout::DEFAULT_STACK_SIZE,
357        data_segment_offsets,
358        data_segment_lengths,
359        data_segment_length_addrs,
360        wasm_import_map: options.import_map.clone(),
361        optimizations: options.optimizations.clone(),
362    };
363
364    // Phase 3: LLVM IR → PVM bytecode for each function
365    let mut all_instructions: Vec<Instruction> = Vec::new();
366    let mut all_call_fixups: Vec<(usize, CallFixup)> = Vec::new();
367    let mut all_indirect_call_fixups: Vec<(usize, IndirectCallFixup)> = Vec::new();
368    let mut function_offsets: Vec<usize> = vec![0; module.functions.len()];
369    let mut next_call_return_idx: usize = 0;
370    let mut function_stats: Vec<stats::FunctionStats> = Vec::with_capacity(module.functions.len());
371    let mut dead_functions_eliminated: usize = 0;
372
373    // Entry header: Jump to main (PC=0) + Trap or secondary Jump (PC=5).
374    // When there's no secondary entry, we omit the Fallthrough padding (6 bytes instead of 10).
375    all_instructions.push(Instruction::Jump { offset: 0 });
376    if module.has_secondary_entry {
377        all_instructions.push(Instruction::Jump { offset: 0 });
378    } else {
379        all_instructions.push(Instruction::Trap);
380    }
381
382    // Build emission order: main first, then secondary (if any), then remaining in index order.
383    // This places main immediately after the entry header, minimizing the entry Jump distance.
384    let mut emission_order: Vec<usize> = Vec::with_capacity(module.functions.len());
385    emission_order.push(module.main_func_local_idx);
386    if let Some(secondary_idx) = module.secondary_entry_local_idx
387        && secondary_idx != module.main_func_local_idx
388    {
389        emission_order.push(secondary_idx);
390    }
391    for idx in 0..module.functions.len() {
392        if idx != module.main_func_local_idx && module.secondary_entry_local_idx != Some(idx) {
393            emission_order.push(idx);
394        }
395    }
396
397    for &local_func_idx in &emission_order {
398        // Dead functions: emit a single Trap as a placeholder.
399        // The function offset is still recorded so dispatch table indices stay valid.
400        if reachable_locals
401            .as_ref()
402            .is_some_and(|r| !r.contains(&local_func_idx))
403        {
404            let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
405            function_offsets[local_func_idx] = func_start_offset;
406            all_instructions.push(Instruction::Trap);
407            dead_functions_eliminated += 1;
408            let global_func_idx = module.num_imported_funcs as usize + local_func_idx;
409            function_stats.push(stats::FunctionStats {
410                name: format!("wasm_func_{global_func_idx}"),
411                index: local_func_idx,
412                instruction_count: 1,
413                frame_size: 0,
414                is_leaf: true,
415                is_entry: false,
416                is_dead: true,
417                regalloc: stats::FunctionRegAllocStats::default(),
418                pre_dse_instructions: 0,
419                pre_peephole_instructions: 0,
420            });
421            continue;
422        }
423
424        let global_func_idx = module.num_imported_funcs as usize + local_func_idx;
425        let fn_name = format!("wasm_func_{global_func_idx}");
426        let llvm_func = llvm_module
427            .get_function(&fn_name)
428            .ok_or_else(|| Error::Internal(format!("missing LLVM function: {fn_name}")))?;
429
430        let is_main = local_func_idx == module.main_func_local_idx;
431        let is_secondary = module.secondary_entry_local_idx == Some(local_func_idx);
432        let is_entry = is_main || is_secondary;
433
434        let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
435        function_offsets[local_func_idx] = func_start_offset;
436
437        // If entry function and there's a start function, call it first.
438        if let Some(start_local_idx) = module.start_func_local_idx.filter(|_| is_entry) {
439            // Save r7 and r8 to stack.
440            all_instructions.push(Instruction::AddImm64 {
441                dst: STACK_PTR_REG,
442                src: STACK_PTR_REG,
443                value: -16,
444            });
445            all_instructions.push(Instruction::StoreIndU64 {
446                base: STACK_PTR_REG,
447                src: ARGS_PTR_REG,
448                offset: 0,
449            });
450            all_instructions.push(Instruction::StoreIndU64 {
451                base: STACK_PTR_REG,
452                src: ARGS_LEN_REG,
453                offset: 8,
454            });
455
456            // Call start function using LoadImmJump (combined load + jump).
457            let call_return_addr = ((next_call_return_idx + 1) * 2) as i32;
458            next_call_return_idx += 1;
459            let current_instr_idx = all_instructions.len();
460            all_instructions.push(Instruction::LoadImmJump {
461                reg: RETURN_ADDR_REG,
462                value: call_return_addr,
463                offset: 0, // patched during fixup resolution
464            });
465
466            all_call_fixups.push((
467                current_instr_idx,
468                CallFixup {
469                    target_func: start_local_idx as u32,
470                    return_addr_instr: 0,
471                    jump_instr: 0, // same instruction for LoadImmJump
472                },
473            ));
474
475            // Restore r7 and r8.
476            all_instructions.push(Instruction::LoadIndU64 {
477                dst: ARGS_PTR_REG,
478                base: STACK_PTR_REG,
479                offset: 0,
480            });
481            all_instructions.push(Instruction::LoadIndU64 {
482                dst: ARGS_LEN_REG,
483                base: STACK_PTR_REG,
484                offset: 8,
485            });
486            all_instructions.push(Instruction::AddImm64 {
487                dst: STACK_PTR_REG,
488                src: STACK_PTR_REG,
489                value: 16,
490            });
491        }
492
493        let translation = llvm_backend::lower_function(
494            llvm_func,
495            &ctx,
496            is_entry,
497            global_func_idx,
498            next_call_return_idx,
499        )?;
500        next_call_return_idx += translation.num_call_returns;
501
502        let instr_base = all_instructions.len();
503        for fixup in translation.call_fixups {
504            all_call_fixups.push((
505                instr_base,
506                CallFixup {
507                    return_addr_instr: fixup.return_addr_instr,
508                    jump_instr: fixup.jump_instr,
509                    target_func: fixup.target_func,
510                },
511            ));
512        }
513        for fixup in translation.indirect_call_fixups {
514            all_indirect_call_fixups.push((
515                instr_base,
516                IndirectCallFixup {
517                    return_addr_instr: fixup.return_addr_instr,
518                    jump_ind_instr: fixup.jump_ind_instr,
519                },
520            ));
521        }
522
523        let ls = &translation.lowering_stats;
524        function_stats.push(stats::FunctionStats {
525            name: fn_name,
526            index: local_func_idx,
527            instruction_count: translation.instructions.len(),
528            frame_size: ls.frame_size,
529            is_leaf: ls.is_leaf,
530            is_entry,
531            is_dead: false,
532            regalloc: stats::FunctionRegAllocStats {
533                total_values: ls.regalloc_total_values,
534                allocated_values: ls.regalloc_allocated_values,
535                registers_used: ls
536                    .regalloc_registers_used
537                    .iter()
538                    .map(|r| format!("r{r}"))
539                    .collect(),
540                skipped_reason: ls.regalloc_skipped_reason.map(String::from),
541                load_hits: ls.regalloc_load_hits,
542                load_reloads: ls.regalloc_load_reloads,
543                load_moves: ls.regalloc_load_moves,
544                store_hits: ls.regalloc_store_hits,
545                store_moves: ls.regalloc_store_moves,
546            },
547            pre_dse_instructions: ls.pre_dse_instructions,
548            pre_peephole_instructions: ls.pre_peephole_instructions,
549        });
550
551        all_instructions.extend(translation.instructions);
552    }
553
554    // Phase 4: Resolve call fixups and build jump table.
555    let (jump_table, func_entry_jump_table_base) = resolve_call_fixups(
556        &mut all_instructions,
557        &all_call_fixups,
558        &all_indirect_call_fixups,
559        &function_offsets,
560    )?;
561
562    // Patch entry header jumps.
563    let main_offset = function_offsets[module.main_func_local_idx] as i32;
564    if let Instruction::Jump { offset } = &mut all_instructions[0] {
565        *offset = main_offset;
566    }
567
568    if let Some(secondary_idx) = module.secondary_entry_local_idx {
569        let secondary_offset = function_offsets[secondary_idx] as i32 - 5;
570        if let Instruction::Jump { offset } = &mut all_instructions[1] {
571            *offset = secondary_offset;
572        }
573    }
574
575    // Phase 5: Build dispatch table for call_indirect.
576    let mut ro_data = vec![0u8];
577    if !module.function_table.is_empty() {
578        ro_data.clear();
579        for &func_idx in &module.function_table {
580            if func_idx == u32::MAX || (func_idx as usize) < module.num_imported_funcs as usize {
581                ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
582                ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
583            } else {
584                let local_func_idx = func_idx as usize - module.num_imported_funcs as usize;
585                let jump_ref = 2 * (func_entry_jump_table_base + local_func_idx + 1) as u32;
586                ro_data.extend_from_slice(&jump_ref.to_le_bytes());
587                let type_idx = *module
588                    .function_type_indices
589                    .get(local_func_idx)
590                    .unwrap_or(&u32::MAX);
591                ro_data.extend_from_slice(&type_idx.to_le_bytes());
592            }
593        }
594    }
595
596    // Append passive data segments to RO_DATA.
597    // NOTE: This loop must iterate data_segments in the same order as the offset
598    // calculation loop above, since data_segment_offsets indices depend on it.
599    for seg in &module.data_segments {
600        if seg.offset.is_none() {
601            ro_data.extend_from_slice(&seg.data);
602        }
603    }
604
605    // Capture stats before moving instructions into the blob.
606    let pvm_instructions = all_instructions.len();
607    let code_bytes: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
608    let jump_table_entries = jump_table.len();
609
610    let blob = crate::pvm::ProgramBlob::new(all_instructions).with_jump_table(jump_table);
611    let rw_data_section = build_rw_data(
612        &module.data_segments,
613        &module.global_init_values,
614        module.memory_limits.initial_pages,
615        module.wasm_memory_base,
616        &ctx.data_segment_length_addrs,
617        &ctx.data_segment_lengths,
618    );
619
620    let heap_pages = calculate_heap_pages(
621        rw_data_section.len(),
622        module.wasm_memory_base,
623        module.memory_limits.initial_pages,
624    )?;
625
626    let program = SpiProgram::new(blob)
627        .with_heap_pages(heap_pages)
628        .with_ro_data(ro_data)
629        .with_rw_data(rw_data_section)
630        .with_metadata(options.metadata.clone());
631
632    Ok(CompilationOutput {
633        program,
634        function_stats,
635        dead_functions_eliminated,
636        pvm_instructions,
637        code_bytes,
638        jump_table_entries,
639    })
640}
641
642/// Calculate the number of 4KB PVM heap pages needed after `rw_data`.
643///
644/// `heap_pages` tells the runtime how many zero-initialized writable pages to allocate
645/// immediately after the `rw_data` blob. This covers the initial WASM linear memory,
646/// globals, and spilled locals that aren't already covered by `rw_data`.
647///
648/// By computing this **after** `build_rw_data()`, we use the actual (trimmed) `rw_data`
649/// length instead of guessing with headroom.
650///
651/// We add 1 extra page beyond the exact initial memory requirement. This ensures that
652/// the first `memory.grow` / sbrk allocation has a pre-allocated page available at the
653/// boundary of the initial WASM memory. Without it, PVM-in-PVM execution fails because
654/// the inner interpreter's page-fault handling at the exact heap boundary doesn't
655/// correctly propagate through the outer PVM.
656fn calculate_heap_pages(
657    rw_data_len: usize,
658    wasm_memory_base: i32,
659    initial_pages: u32,
660) -> Result<u16> {
661    use wasm_module::MIN_INITIAL_WASM_PAGES;
662
663    let initial_pages = initial_pages.max(MIN_INITIAL_WASM_PAGES);
664    let wasm_memory_initial_end = wasm_memory_base as usize + (initial_pages as usize) * 64 * 1024;
665
666    let total_bytes = wasm_memory_initial_end - memory_layout::GLOBAL_MEMORY_BASE as usize;
667    let rw_pages = rw_data_len.div_ceil(4096);
668    let total_pages = total_bytes.div_ceil(4096);
669    let heap_pages = total_pages.saturating_sub(rw_pages) + 1;
670
671    u16::try_from(heap_pages).map_err(|_| {
672        Error::Internal(format!(
673            "heap size {heap_pages} pages exceeds u16::MAX ({}) — module too large",
674            u16::MAX
675        ))
676    })
677}
678
679/// Build the `rw_data` section from WASM data segments and global initializers.
680pub(crate) fn build_rw_data(
681    data_segments: &[wasm_module::DataSegment],
682    global_init_values: &[i32],
683    initial_memory_pages: u32,
684    wasm_memory_base: i32,
685    data_segment_length_addrs: &std::collections::HashMap<u32, i32>,
686    data_segment_lengths: &std::collections::HashMap<u32, u32>,
687) -> Vec<u8> {
688    // Calculate the minimum size needed for globals
689    // +1 for the compiler-managed memory size global, plus passive segment lengths
690    let num_passive_segments = data_segment_length_addrs.len();
691    let globals_end =
692        memory_layout::globals_region_size(global_init_values.len(), num_passive_segments);
693
694    // Calculate the size needed for data segments
695    let wasm_to_rw_offset = wasm_memory_base as u32 - 0x30000;
696
697    let data_end = data_segments
698        .iter()
699        .filter_map(|seg| {
700            seg.offset
701                .map(|off| wasm_to_rw_offset + off + seg.data.len() as u32)
702        })
703        .max()
704        .unwrap_or(0) as usize;
705
706    let total_size = globals_end.max(data_end);
707
708    if total_size == 0 {
709        return Vec::new();
710    }
711
712    let mut rw_data = vec![0u8; total_size];
713
714    // Initialize user globals
715    for (i, &value) in global_init_values.iter().enumerate() {
716        let offset = i * 4;
717        if offset + 4 <= rw_data.len() {
718            rw_data[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
719        }
720    }
721
722    // Initialize compiler-managed memory size global (right after user globals)
723    let mem_size_offset = global_init_values.len() * 4;
724    if mem_size_offset + 4 <= rw_data.len() {
725        rw_data[mem_size_offset..mem_size_offset + 4]
726            .copy_from_slice(&initial_memory_pages.to_le_bytes());
727    }
728
729    // Initialize passive data segment effective lengths (right after memory size global).
730    // These are used by memory.init for bounds checking and zeroed by data.drop.
731    for (&seg_idx, &addr) in data_segment_length_addrs {
732        if let Some(&length) = data_segment_lengths.get(&seg_idx) {
733            // addr is absolute PVM address; convert to rw_data offset
734            let rw_offset = (addr - memory_layout::GLOBAL_MEMORY_BASE) as usize;
735            if rw_offset + 4 <= rw_data.len() {
736                rw_data[rw_offset..rw_offset + 4].copy_from_slice(&length.to_le_bytes());
737            }
738        }
739    }
740
741    // Copy data segments to their WASM memory locations
742    for seg in data_segments {
743        if let Some(offset) = seg.offset {
744            let rw_offset = (wasm_to_rw_offset + offset) as usize;
745            if rw_offset + seg.data.len() <= rw_data.len() {
746                rw_data[rw_offset..rw_offset + seg.data.len()].copy_from_slice(&seg.data);
747            }
748        }
749    }
750
751    // Trim trailing zeros to reduce SPI size. Heap pages are zero-initialized,
752    // so omitted high-address zero bytes are semantically equivalent.
753    if let Some(last_non_zero) = rw_data.iter().rposition(|&b| b != 0) {
754        rw_data.truncate(last_non_zero + 1);
755    } else {
756        rw_data.clear();
757    }
758
759    rw_data
760}
761
762/// Extract the pre-assigned jump-table index from a return-address load instruction.
763///
764/// Call return addresses are pre-assigned as `(idx + 1) * 2` at emission time.
765/// This helper recovers `idx` so that `resolve_call_fixups` can write the byte
766/// offset into the correct jump-table slot instead of appending in list order
767/// (which would desync when a function mixes direct and indirect calls).
768///
769/// Direct calls use `LoadImmJump`, while indirect calls use either `LoadImm` (legacy
770/// two-instruction sequence) or `LoadImmJumpInd` (combined return-addr load + jump).
771fn return_addr_jump_table_idx(
772    instructions: &[Instruction],
773    return_addr_instr: usize,
774) -> Result<usize> {
775    let value = match instructions.get(return_addr_instr) {
776        Some(
777            Instruction::LoadImmJump { value, .. }
778            | Instruction::LoadImm { value, .. }
779            | Instruction::LoadImmJumpInd { value, .. },
780        ) => Some(*value),
781        _ => None,
782    };
783    match value {
784        Some(v) if v > 0 && v % 2 == 0 => Ok((v as usize / 2) - 1),
785        _ => Err(Error::Internal(format!(
786            "expected LoadImmJump/LoadImm/LoadImmJumpInd((idx+1)*2) at return_addr_instr {return_addr_instr}, got {:?}",
787            instructions.get(return_addr_instr)
788        ))),
789    }
790}
791
792fn resolve_call_fixups(
793    instructions: &mut [Instruction],
794    call_fixups: &[(usize, CallFixup)],
795    indirect_call_fixups: &[(usize, IndirectCallFixup)],
796    function_offsets: &[usize],
797) -> Result<(Vec<u32>, usize)> {
798    // Count total call-return entries by finding the maximum pre-assigned index.
799    // Entries are written at their pre-assigned slot so mixed direct/indirect
800    // call ordering within a function is preserved correctly.
801    let mut num_call_returns: usize = 0;
802
803    for (instr_base, fixup) in call_fixups {
804        let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
805        num_call_returns = num_call_returns.max(idx + 1);
806    }
807    for (instr_base, fixup) in indirect_call_fixups {
808        let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
809        num_call_returns = num_call_returns.max(idx + 1);
810    }
811
812    let mut jump_table: Vec<u32> = vec![0u32; num_call_returns];
813
814    // Call return addresses (LoadImmJump/LoadImm/LoadImmJumpInd values) are pre-assigned at emission time,
815    // so we only need to compute byte offsets for the jump table and patch Jump targets.
816    // Write each entry at its pre-assigned index to keep values in sync.
817    for (instr_base, fixup) in call_fixups {
818        let target_offset = function_offsets
819            .get(fixup.target_func as usize)
820            .ok_or_else(|| {
821                Error::Unsupported(format!("call to unknown function {}", fixup.target_func))
822            })?;
823
824        let jump_idx = instr_base + fixup.jump_instr;
825
826        // Return address = byte offset after the LoadImmJump instruction.
827        let return_addr_offset: usize = instructions[..=jump_idx]
828            .iter()
829            .map(|i| i.encode().len())
830            .sum();
831
832        let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
833        jump_table[slot] = return_addr_offset as u32;
834
835        // Verify pre-assigned jump table address matches actual index.
836        let expected_addr = ((slot + 1) * 2) as i32;
837        debug_assert!(
838            matches!(&instructions[jump_idx], Instruction::LoadImmJump { value, .. } if *value == expected_addr),
839            "pre-assigned jump table address mismatch: expected {expected_addr}, got {:?}",
840            &instructions[jump_idx]
841        );
842
843        // Patch the offset field of LoadImmJump.
844        let jump_start_offset: usize = instructions[..jump_idx]
845            .iter()
846            .map(|i| i.encode().len())
847            .sum();
848        let relative_offset = (*target_offset as i32) - (jump_start_offset as i32);
849
850        if let Instruction::LoadImmJump { offset, .. } = &mut instructions[jump_idx] {
851            *offset = relative_offset;
852        }
853    }
854
855    for (instr_base, fixup) in indirect_call_fixups {
856        let jump_ind_idx = instr_base + fixup.jump_ind_instr;
857
858        let return_addr_offset: usize = instructions[..=jump_ind_idx]
859            .iter()
860            .map(|i| i.encode().len())
861            .sum();
862
863        let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
864        jump_table[slot] = return_addr_offset as u32;
865    }
866
867    let func_entry_base = jump_table.len();
868    for &offset in function_offsets {
869        jump_table.push(offset as u32);
870    }
871
872    Ok((jump_table, func_entry_base))
873}
874
875#[cfg(test)]
876mod tests {
877    use std::collections::HashMap;
878
879    use super::build_rw_data;
880    use super::memory_layout;
881    use super::wasm_module::DataSegment;
882
883    #[test]
884    fn build_rw_data_trims_all_zero_tail_to_empty() {
885        let rw = build_rw_data(&[], &[], 0, 0x30000, &HashMap::new(), &HashMap::new());
886        assert!(rw.is_empty());
887    }
888
889    #[test]
890    fn build_rw_data_preserves_internal_zeros_and_trims_trailing_zeros() {
891        let data_segments = vec![DataSegment {
892            offset: Some(0),
893            data: vec![1, 0, 2, 0, 0],
894        }];
895
896        let rw = build_rw_data(
897            &data_segments,
898            &[],
899            0,
900            0x30000,
901            &HashMap::new(),
902            &HashMap::new(),
903        );
904
905        assert_eq!(rw, vec![1, 0, 2]);
906    }
907
908    #[test]
909    fn build_rw_data_keeps_non_zero_passive_length_bytes() {
910        let mut addrs = HashMap::new();
911        addrs.insert(0u32, memory_layout::GLOBAL_MEMORY_BASE + 4);
912        let mut lengths = HashMap::new();
913        lengths.insert(0u32, 7u32);
914
915        let rw = build_rw_data(&[], &[], 0, 0x30000, &addrs, &lengths);
916
917        assert_eq!(rw, vec![0, 0, 0, 0, 7]);
918    }
919
920    // ── calculate_heap_pages tests ──
921
922    #[test]
923    fn heap_pages_with_empty_rw_data_equals_total_pages_plus_one() {
924        // wasm_memory_base = 0x31000 (typical with few globals), initial_pages = 0 (clamped to 16)
925        // end = 0x31000 + 16*64*1024 = 0x31000 + 0x100000 = 0x131000
926        // total_bytes = 0x131000 - 0x30000 = 0x101000 = 1052672
927        // total_pages = ceil(1052672 / 4096) = 257
928        // rw_pages = 0, heap_pages = 257 + 1 = 258
929        let pages = super::calculate_heap_pages(0, 0x31000, 0).unwrap();
930        assert_eq!(pages, 258);
931    }
932
933    #[test]
934    fn heap_pages_reduced_by_rw_data_pages() {
935        // Same scenario but with 8192 bytes of rw_data (2 pages)
936        let pages_no_rw = super::calculate_heap_pages(0, 0x31000, 0).unwrap();
937        let pages_with_rw = super::calculate_heap_pages(8192, 0x31000, 0).unwrap();
938        assert_eq!(pages_no_rw - pages_with_rw, 2);
939    }
940
941    #[test]
942    fn heap_pages_saturates_at_one_for_large_rw_data() {
943        // rw_data that covers more than total_pages still gets +1 headroom
944        let pages = super::calculate_heap_pages(2 * 1024 * 1024, 0x31000, 0).unwrap();
945        assert_eq!(pages, 1);
946    }
947
948    #[test]
949    fn heap_pages_respects_initial_pages() {
950        // initial_pages = 32 (larger than MIN_INITIAL_WASM_PAGES=16)
951        // end = 0x31000 + 32*64*1024 = 0x31000 + 0x200000 = 0x231000
952        // total_bytes = 0x231000 - 0x30000 = 0x201000
953        // total_pages = ceil(0x201000 / 4096) = 513
954        // heap_pages = 513 + 1 = 514
955        let pages = super::calculate_heap_pages(0, 0x31000, 32).unwrap();
956        assert_eq!(pages, 514);
957    }
958}