1#![allow(
3 clippy::cast_possible_truncation,
4 clippy::cast_possible_wrap,
5 clippy::cast_sign_loss
6)]
7
8pub mod adapter_merge;
9pub mod dead_function_elimination;
10pub mod memory_layout;
11pub mod wasm_module;
12
13use std::collections::HashMap;
14
15use crate::pvm::Instruction;
16use crate::{Error, Result, SpiProgram};
17
18pub use wasm_module::WasmModule;
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum ImportAction {
23 Trap,
25 Nop,
27}
28
29#[derive(Debug, Clone)]
32#[allow(clippy::struct_excessive_bools)]
33pub struct OptimizationFlags {
34 pub llvm_passes: bool,
37 pub peephole: bool,
39 pub register_cache: bool,
41 pub icmp_branch_fusion: bool,
43 pub shrink_wrap_callee_saves: bool,
45 pub dead_store_elimination: bool,
47 pub constant_propagation: bool,
49 pub inlining: bool,
51 pub cross_block_cache: bool,
53 pub register_allocation: bool,
55 pub dead_function_elimination: bool,
57 pub fallthrough_jumps: bool,
59}
60
61impl Default for OptimizationFlags {
62 fn default() -> Self {
63 Self {
64 llvm_passes: true,
65 peephole: true,
66 register_cache: true,
67 icmp_branch_fusion: true,
68 shrink_wrap_callee_saves: true,
69 dead_store_elimination: true,
70 constant_propagation: true,
71 inlining: true,
72 cross_block_cache: true,
73 register_allocation: true,
74 dead_function_elimination: true,
75 fallthrough_jumps: true,
76 }
77 }
78}
79
80#[derive(Debug, Clone, Default)]
82pub struct CompileOptions {
83 pub import_map: Option<HashMap<String, ImportAction>>,
87 pub adapter: Option<String>,
90 pub metadata: Vec<u8>,
93 pub optimizations: OptimizationFlags,
95}
96
97pub use crate::abi::{ARGS_LEN_REG, ARGS_PTR_REG, RETURN_ADDR_REG, STACK_PTR_REG};
99
100#[derive(Debug, Clone)]
103pub struct CallFixup {
104 pub return_addr_instr: usize,
105 pub jump_instr: usize,
106 pub target_func: u32,
107}
108
109#[derive(Debug, Clone)]
110pub struct IndirectCallFixup {
111 pub return_addr_instr: usize,
112 pub jump_ind_instr: usize,
114}
115
116const RO_DATA_SIZE: usize = 64 * 1024;
118
119pub fn compile(wasm: &[u8]) -> Result<SpiProgram> {
120 compile_with_options(wasm, &CompileOptions::default())
121}
122
123pub fn compile_with_options(wasm: &[u8], options: &CompileOptions) -> Result<SpiProgram> {
124 const KNOWN_INTRINSICS: &[&str] = &["host_call", "pvm_ptr"];
126 const DEFAULT_MAPPINGS: &[&str] = &["abort"];
128
129 let merged_wasm;
131 let wasm = if let Some(adapter_wat) = &options.adapter {
132 merged_wasm = adapter_merge::merge_adapter(wasm, adapter_wat)?;
133 &merged_wasm
134 } else {
135 wasm
136 };
137
138 let module = WasmModule::parse(wasm)?;
139
140 for name in &module.imported_func_names {
142 if KNOWN_INTRINSICS.contains(&name.as_str()) {
143 continue;
144 }
145 if let Some(import_map) = &options.import_map {
146 if import_map.contains_key(name) {
147 continue;
148 }
149 } else if DEFAULT_MAPPINGS.contains(&name.as_str()) {
150 continue;
151 }
152 return Err(Error::UnresolvedImport(format!(
153 "import '{name}' has no mapping. Provide a mapping via --imports or add it to the import map."
154 )));
155 }
156
157 compile_via_llvm(&module, options)
158}
159
160pub fn compile_via_llvm(module: &WasmModule, options: &CompileOptions) -> Result<SpiProgram> {
161 use crate::llvm_backend::{self, LoweringContext};
162 use crate::llvm_frontend;
163 use inkwell::context::Context;
164
165 let reachable_locals = if options.optimizations.dead_function_elimination {
167 Some(dead_function_elimination::reachable_functions(module)?)
168 } else {
169 None
170 };
171
172 let context = Context::create();
174 let llvm_module = llvm_frontend::translate_wasm_to_llvm(
175 &context,
176 module,
177 options.optimizations.llvm_passes,
178 options.optimizations.inlining,
179 reachable_locals.as_ref(),
180 )?;
181
182 let mut data_segment_offsets = std::collections::HashMap::new();
184 let mut data_segment_lengths = std::collections::HashMap::new();
185 let mut current_ro_offset = if module.function_table.is_empty() {
186 1 } else {
188 module.function_table.len() * 8 };
190
191 let mut data_segment_length_addrs = std::collections::HashMap::new();
192 let mut passive_ordinal = 0usize;
193
194 for (idx, seg) in module.data_segments.iter().enumerate() {
195 if seg.offset.is_none() {
196 if current_ro_offset + seg.data.len() > RO_DATA_SIZE {
198 return Err(Error::Internal(format!(
199 "passive data segment {} (size {}) would overflow RO_DATA region ({} bytes used of {})",
200 idx,
201 seg.data.len(),
202 current_ro_offset,
203 RO_DATA_SIZE
204 )));
205 }
206 data_segment_offsets.insert(idx as u32, current_ro_offset as u32);
207 data_segment_lengths.insert(idx as u32, seg.data.len() as u32);
208 data_segment_length_addrs.insert(
209 idx as u32,
210 memory_layout::data_segment_length_offset(module.globals.len(), passive_ordinal),
211 );
212 current_ro_offset += seg.data.len();
213 passive_ordinal += 1;
214 }
215 }
216
217 let ctx = LoweringContext {
219 wasm_memory_base: module.wasm_memory_base,
220 num_globals: module.globals.len(),
221 function_signatures: module.function_signatures.clone(),
222 type_signatures: module.type_signatures.clone(),
223 function_table: module.function_table.clone(),
224 num_imported_funcs: module.num_imported_funcs as usize,
225 imported_func_names: module.imported_func_names.clone(),
226 initial_memory_pages: module.memory_limits.initial_pages,
227 max_memory_pages: module.max_memory_pages,
228 stack_size: memory_layout::DEFAULT_STACK_SIZE,
229 data_segment_offsets,
230 data_segment_lengths,
231 data_segment_length_addrs,
232 wasm_import_map: options.import_map.clone(),
233 optimizations: options.optimizations.clone(),
234 };
235
236 let mut all_instructions: Vec<Instruction> = Vec::new();
238 let mut all_call_fixups: Vec<(usize, CallFixup)> = Vec::new();
239 let mut all_indirect_call_fixups: Vec<(usize, IndirectCallFixup)> = Vec::new();
240 let mut function_offsets: Vec<usize> = vec![0; module.functions.len()];
241 let mut next_call_return_idx: usize = 0;
242
243 all_instructions.push(Instruction::Jump { offset: 0 });
246 if module.has_secondary_entry {
247 all_instructions.push(Instruction::Jump { offset: 0 });
248 } else {
249 all_instructions.push(Instruction::Trap);
250 }
251
252 let mut emission_order: Vec<usize> = Vec::with_capacity(module.functions.len());
255 emission_order.push(module.main_func_local_idx);
256 if let Some(secondary_idx) = module.secondary_entry_local_idx
257 && secondary_idx != module.main_func_local_idx
258 {
259 emission_order.push(secondary_idx);
260 }
261 for idx in 0..module.functions.len() {
262 if idx != module.main_func_local_idx && module.secondary_entry_local_idx != Some(idx) {
263 emission_order.push(idx);
264 }
265 }
266
267 for &local_func_idx in &emission_order {
268 if reachable_locals
271 .as_ref()
272 .is_some_and(|r| !r.contains(&local_func_idx))
273 {
274 let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
275 function_offsets[local_func_idx] = func_start_offset;
276 all_instructions.push(Instruction::Trap);
277 continue;
278 }
279
280 let global_func_idx = module.num_imported_funcs as usize + local_func_idx;
281 let fn_name = format!("wasm_func_{global_func_idx}");
282 let llvm_func = llvm_module
283 .get_function(&fn_name)
284 .ok_or_else(|| Error::Internal(format!("missing LLVM function: {fn_name}")))?;
285
286 let is_main = local_func_idx == module.main_func_local_idx;
287 let is_secondary = module.secondary_entry_local_idx == Some(local_func_idx);
288 let is_entry = is_main || is_secondary;
289
290 let func_start_offset: usize = all_instructions.iter().map(|i| i.encode().len()).sum();
291 function_offsets[local_func_idx] = func_start_offset;
292
293 if let Some(start_local_idx) = module.start_func_local_idx.filter(|_| is_entry) {
295 all_instructions.push(Instruction::AddImm64 {
297 dst: STACK_PTR_REG,
298 src: STACK_PTR_REG,
299 value: -16,
300 });
301 all_instructions.push(Instruction::StoreIndU64 {
302 base: STACK_PTR_REG,
303 src: ARGS_PTR_REG,
304 offset: 0,
305 });
306 all_instructions.push(Instruction::StoreIndU64 {
307 base: STACK_PTR_REG,
308 src: ARGS_LEN_REG,
309 offset: 8,
310 });
311
312 let call_return_addr = ((next_call_return_idx + 1) * 2) as i32;
314 next_call_return_idx += 1;
315 let current_instr_idx = all_instructions.len();
316 all_instructions.push(Instruction::LoadImmJump {
317 reg: RETURN_ADDR_REG,
318 value: call_return_addr,
319 offset: 0, });
321
322 all_call_fixups.push((
323 current_instr_idx,
324 CallFixup {
325 target_func: start_local_idx as u32,
326 return_addr_instr: 0,
327 jump_instr: 0, },
329 ));
330
331 all_instructions.push(Instruction::LoadIndU64 {
333 dst: ARGS_PTR_REG,
334 base: STACK_PTR_REG,
335 offset: 0,
336 });
337 all_instructions.push(Instruction::LoadIndU64 {
338 dst: ARGS_LEN_REG,
339 base: STACK_PTR_REG,
340 offset: 8,
341 });
342 all_instructions.push(Instruction::AddImm64 {
343 dst: STACK_PTR_REG,
344 src: STACK_PTR_REG,
345 value: 16,
346 });
347 }
348
349 let translation = llvm_backend::lower_function(
350 llvm_func,
351 &ctx,
352 is_entry,
353 global_func_idx,
354 next_call_return_idx,
355 )?;
356 next_call_return_idx += translation.num_call_returns;
357
358 let instr_base = all_instructions.len();
359 for fixup in translation.call_fixups {
360 all_call_fixups.push((
361 instr_base,
362 CallFixup {
363 return_addr_instr: fixup.return_addr_instr,
364 jump_instr: fixup.jump_instr,
365 target_func: fixup.target_func,
366 },
367 ));
368 }
369 for fixup in translation.indirect_call_fixups {
370 all_indirect_call_fixups.push((
371 instr_base,
372 IndirectCallFixup {
373 return_addr_instr: fixup.return_addr_instr,
374 jump_ind_instr: fixup.jump_ind_instr,
375 },
376 ));
377 }
378
379 all_instructions.extend(translation.instructions);
380 }
381
382 let (jump_table, func_entry_jump_table_base) = resolve_call_fixups(
384 &mut all_instructions,
385 &all_call_fixups,
386 &all_indirect_call_fixups,
387 &function_offsets,
388 )?;
389
390 let main_offset = function_offsets[module.main_func_local_idx] as i32;
392 if let Instruction::Jump { offset } = &mut all_instructions[0] {
393 *offset = main_offset;
394 }
395
396 if let Some(secondary_idx) = module.secondary_entry_local_idx {
397 let secondary_offset = function_offsets[secondary_idx] as i32 - 5;
398 if let Instruction::Jump { offset } = &mut all_instructions[1] {
399 *offset = secondary_offset;
400 }
401 }
402
403 let mut ro_data = vec![0u8];
405 if !module.function_table.is_empty() {
406 ro_data.clear();
407 for &func_idx in &module.function_table {
408 if func_idx == u32::MAX || (func_idx as usize) < module.num_imported_funcs as usize {
409 ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
410 ro_data.extend_from_slice(&u32::MAX.to_le_bytes());
411 } else {
412 let local_func_idx = func_idx as usize - module.num_imported_funcs as usize;
413 let jump_ref = 2 * (func_entry_jump_table_base + local_func_idx + 1) as u32;
414 ro_data.extend_from_slice(&jump_ref.to_le_bytes());
415 let type_idx = *module
416 .function_type_indices
417 .get(local_func_idx)
418 .unwrap_or(&u32::MAX);
419 ro_data.extend_from_slice(&type_idx.to_le_bytes());
420 }
421 }
422 }
423
424 for seg in &module.data_segments {
428 if seg.offset.is_none() {
429 ro_data.extend_from_slice(&seg.data);
430 }
431 }
432
433 let blob = crate::pvm::ProgramBlob::new(all_instructions).with_jump_table(jump_table);
434 let rw_data_section = build_rw_data(
435 &module.data_segments,
436 &module.global_init_values,
437 module.memory_limits.initial_pages,
438 module.wasm_memory_base,
439 &ctx.data_segment_length_addrs,
440 &ctx.data_segment_lengths,
441 );
442
443 let heap_pages = calculate_heap_pages(
444 rw_data_section.len(),
445 module.wasm_memory_base,
446 module.memory_limits.initial_pages,
447 module.functions.len(),
448 )?;
449
450 Ok(SpiProgram::new(blob)
451 .with_heap_pages(heap_pages)
452 .with_ro_data(ro_data)
453 .with_rw_data(rw_data_section)
454 .with_metadata(options.metadata.clone()))
455}
456
457fn calculate_heap_pages(
472 rw_data_len: usize,
473 wasm_memory_base: i32,
474 initial_pages: u32,
475 num_functions: usize,
476) -> Result<u16> {
477 use wasm_module::MIN_INITIAL_WASM_PAGES;
478
479 let initial_pages = initial_pages.max(MIN_INITIAL_WASM_PAGES);
480 let wasm_memory_initial_end = wasm_memory_base as usize + (initial_pages as usize) * 64 * 1024;
481
482 let spilled_locals_end = memory_layout::SPILLED_LOCALS_BASE as usize
483 + num_functions * memory_layout::SPILLED_LOCALS_PER_FUNC as usize;
484
485 let end = spilled_locals_end.max(wasm_memory_initial_end);
486 let total_bytes = end - memory_layout::GLOBAL_MEMORY_BASE as usize;
487 let rw_pages = rw_data_len.div_ceil(4096);
488 let total_pages = total_bytes.div_ceil(4096);
489 let heap_pages = total_pages.saturating_sub(rw_pages) + 1;
490
491 u16::try_from(heap_pages).map_err(|_| {
492 Error::Internal(format!(
493 "heap size {heap_pages} pages exceeds u16::MAX ({}) — module too large",
494 u16::MAX
495 ))
496 })
497}
498
499pub(crate) fn build_rw_data(
501 data_segments: &[wasm_module::DataSegment],
502 global_init_values: &[i32],
503 initial_memory_pages: u32,
504 wasm_memory_base: i32,
505 data_segment_length_addrs: &std::collections::HashMap<u32, i32>,
506 data_segment_lengths: &std::collections::HashMap<u32, u32>,
507) -> Vec<u8> {
508 let num_passive_segments = data_segment_length_addrs.len();
511 let globals_end =
512 memory_layout::globals_region_size(global_init_values.len(), num_passive_segments);
513
514 let wasm_to_rw_offset = wasm_memory_base as u32 - 0x30000;
516
517 let data_end = data_segments
518 .iter()
519 .filter_map(|seg| {
520 seg.offset
521 .map(|off| wasm_to_rw_offset + off + seg.data.len() as u32)
522 })
523 .max()
524 .unwrap_or(0) as usize;
525
526 let total_size = globals_end.max(data_end);
527
528 if total_size == 0 {
529 return Vec::new();
530 }
531
532 let mut rw_data = vec![0u8; total_size];
533
534 for (i, &value) in global_init_values.iter().enumerate() {
536 let offset = i * 4;
537 if offset + 4 <= rw_data.len() {
538 rw_data[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
539 }
540 }
541
542 let mem_size_offset = global_init_values.len() * 4;
544 if mem_size_offset + 4 <= rw_data.len() {
545 rw_data[mem_size_offset..mem_size_offset + 4]
546 .copy_from_slice(&initial_memory_pages.to_le_bytes());
547 }
548
549 for (&seg_idx, &addr) in data_segment_length_addrs {
552 if let Some(&length) = data_segment_lengths.get(&seg_idx) {
553 let rw_offset = (addr - memory_layout::GLOBAL_MEMORY_BASE) as usize;
555 if rw_offset + 4 <= rw_data.len() {
556 rw_data[rw_offset..rw_offset + 4].copy_from_slice(&length.to_le_bytes());
557 }
558 }
559 }
560
561 for seg in data_segments {
563 if let Some(offset) = seg.offset {
564 let rw_offset = (wasm_to_rw_offset + offset) as usize;
565 if rw_offset + seg.data.len() <= rw_data.len() {
566 rw_data[rw_offset..rw_offset + seg.data.len()].copy_from_slice(&seg.data);
567 }
568 }
569 }
570
571 if let Some(last_non_zero) = rw_data.iter().rposition(|&b| b != 0) {
574 rw_data.truncate(last_non_zero + 1);
575 } else {
576 rw_data.clear();
577 }
578
579 rw_data
580}
581
582fn return_addr_jump_table_idx(
592 instructions: &[Instruction],
593 return_addr_instr: usize,
594) -> Result<usize> {
595 let value = match instructions.get(return_addr_instr) {
596 Some(
597 Instruction::LoadImmJump { value, .. }
598 | Instruction::LoadImm { value, .. }
599 | Instruction::LoadImmJumpInd { value, .. },
600 ) => Some(*value),
601 _ => None,
602 };
603 match value {
604 Some(v) if v > 0 && v % 2 == 0 => Ok((v as usize / 2) - 1),
605 _ => Err(Error::Internal(format!(
606 "expected LoadImmJump/LoadImm/LoadImmJumpInd((idx+1)*2) at return_addr_instr {return_addr_instr}, got {:?}",
607 instructions.get(return_addr_instr)
608 ))),
609 }
610}
611
612fn resolve_call_fixups(
613 instructions: &mut [Instruction],
614 call_fixups: &[(usize, CallFixup)],
615 indirect_call_fixups: &[(usize, IndirectCallFixup)],
616 function_offsets: &[usize],
617) -> Result<(Vec<u32>, usize)> {
618 let mut num_call_returns: usize = 0;
622
623 for (instr_base, fixup) in call_fixups {
624 let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
625 num_call_returns = num_call_returns.max(idx + 1);
626 }
627 for (instr_base, fixup) in indirect_call_fixups {
628 let idx = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
629 num_call_returns = num_call_returns.max(idx + 1);
630 }
631
632 let mut jump_table: Vec<u32> = vec![0u32; num_call_returns];
633
634 for (instr_base, fixup) in call_fixups {
638 let target_offset = function_offsets
639 .get(fixup.target_func as usize)
640 .ok_or_else(|| {
641 Error::Unsupported(format!("call to unknown function {}", fixup.target_func))
642 })?;
643
644 let jump_idx = instr_base + fixup.jump_instr;
645
646 let return_addr_offset: usize = instructions[..=jump_idx]
648 .iter()
649 .map(|i| i.encode().len())
650 .sum();
651
652 let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
653 jump_table[slot] = return_addr_offset as u32;
654
655 let expected_addr = ((slot + 1) * 2) as i32;
657 debug_assert!(
658 matches!(&instructions[jump_idx], Instruction::LoadImmJump { value, .. } if *value == expected_addr),
659 "pre-assigned jump table address mismatch: expected {expected_addr}, got {:?}",
660 &instructions[jump_idx]
661 );
662
663 let jump_start_offset: usize = instructions[..jump_idx]
665 .iter()
666 .map(|i| i.encode().len())
667 .sum();
668 let relative_offset = (*target_offset as i32) - (jump_start_offset as i32);
669
670 if let Instruction::LoadImmJump { offset, .. } = &mut instructions[jump_idx] {
671 *offset = relative_offset;
672 }
673 }
674
675 for (instr_base, fixup) in indirect_call_fixups {
676 let jump_ind_idx = instr_base + fixup.jump_ind_instr;
677
678 let return_addr_offset: usize = instructions[..=jump_ind_idx]
679 .iter()
680 .map(|i| i.encode().len())
681 .sum();
682
683 let slot = return_addr_jump_table_idx(instructions, instr_base + fixup.return_addr_instr)?;
684 jump_table[slot] = return_addr_offset as u32;
685 }
686
687 let func_entry_base = jump_table.len();
688 for &offset in function_offsets {
689 jump_table.push(offset as u32);
690 }
691
692 Ok((jump_table, func_entry_base))
693}
694
695#[cfg(test)]
696mod tests {
697 use std::collections::HashMap;
698
699 use super::build_rw_data;
700 use super::memory_layout;
701 use super::wasm_module::DataSegment;
702
703 #[test]
704 fn build_rw_data_trims_all_zero_tail_to_empty() {
705 let rw = build_rw_data(&[], &[], 0, 0x30000, &HashMap::new(), &HashMap::new());
706 assert!(rw.is_empty());
707 }
708
709 #[test]
710 fn build_rw_data_preserves_internal_zeros_and_trims_trailing_zeros() {
711 let data_segments = vec![DataSegment {
712 offset: Some(0),
713 data: vec![1, 0, 2, 0, 0],
714 }];
715
716 let rw = build_rw_data(
717 &data_segments,
718 &[],
719 0,
720 0x30000,
721 &HashMap::new(),
722 &HashMap::new(),
723 );
724
725 assert_eq!(rw, vec![1, 0, 2]);
726 }
727
728 #[test]
729 fn build_rw_data_keeps_non_zero_passive_length_bytes() {
730 let mut addrs = HashMap::new();
731 addrs.insert(0u32, memory_layout::GLOBAL_MEMORY_BASE + 4);
732 let mut lengths = HashMap::new();
733 lengths.insert(0u32, 7u32);
734
735 let rw = build_rw_data(&[], &[], 0, 0x30000, &addrs, &lengths);
736
737 assert_eq!(rw, vec![0, 0, 0, 0, 7]);
738 }
739
740 #[test]
743 fn heap_pages_with_empty_rw_data_equals_total_pages_plus_one() {
744 let pages = super::calculate_heap_pages(0, 0x33000, 0, 10).unwrap();
750 assert_eq!(pages, 260);
751 }
752
753 #[test]
754 fn heap_pages_reduced_by_rw_data_pages() {
755 let pages_no_rw = super::calculate_heap_pages(0, 0x33000, 0, 10).unwrap();
757 let pages_with_rw = super::calculate_heap_pages(8192, 0x33000, 0, 10).unwrap();
758 assert_eq!(pages_no_rw - pages_with_rw, 2);
759 }
760
761 #[test]
762 fn heap_pages_saturates_at_one_for_large_rw_data() {
763 let pages = super::calculate_heap_pages(2 * 1024 * 1024, 0x33000, 0, 10).unwrap();
765 assert_eq!(pages, 1);
766 }
767
768 #[test]
769 fn heap_pages_respects_initial_pages() {
770 let pages = super::calculate_heap_pages(0, 0x33000, 32, 10).unwrap();
776 assert_eq!(pages, 516);
777 }
778}