Skip to main content

wasm_pvm/translate/
wasm_module.rs

1// Parsing code uses casts to convert WASM u64 fields to PVM u32/usize types.
2#![allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
3
4use crate::{Error, Result};
5use wasmparser::{FunctionBody, GlobalType, Parser, Payload};
6
7use super::memory_layout;
8
9/// Parsed WASM memory limits.
10#[derive(Debug, Clone, Copy)]
11pub struct MemoryLimits {
12    /// Initial memory size in 64KB pages.
13    pub initial_pages: u32,
14    /// Maximum memory size in pages (None = no explicit limit).
15    pub max_pages: Option<u32>,
16}
17
18impl Default for MemoryLimits {
19    fn default() -> Self {
20        Self {
21            initial_pages: 1,
22            max_pages: None,
23        }
24    }
25}
26
27/// Minimum WASM pages (64KB each) to pre-allocate for programs declaring (memory 0).
28/// 16 pages = 1MB, sufficient for `AssemblyScript` programs compiled with --runtime stub.
29pub(crate) const MIN_INITIAL_WASM_PAGES: u32 = 16;
30
31/// Default `memory.grow` ceiling when the WASM module doesn't declare a maximum.
32/// 16 WASM pages = 1 MB — conservative default aligned with PVM recommendations.
33const DEFAULT_MAX_PAGES: u32 = 16;
34
35/// Represents a data segment parsed from WASM.
36pub struct DataSegment {
37    /// Offset in WASM linear memory (active only). None for passive segments.
38    pub offset: Option<u32>,
39    /// The actual data bytes.
40    pub data: Vec<u8>,
41}
42
43/// Parsed and pre-processed WASM module, usable by both legacy and LLVM pipelines.
44pub struct WasmModule<'a> {
45    // --- Raw parsed section data ---
46    /// Function bodies from the code section.
47    pub functions: Vec<FunctionBody<'a>>,
48    /// All function types declared in the type section.
49    pub func_types: Vec<wasmparser::FuncType>,
50    /// Type index for each local function (parallels `functions`).
51    pub function_type_indices: Vec<u32>,
52    /// Global variable types.
53    pub globals: Vec<GlobalType>,
54    /// Initial values of global variables.
55    pub global_init_values: Vec<i32>,
56    /// Active data segments from the data section.
57    pub data_segments: Vec<DataSegment>,
58    /// Memory limits parsed from the memory section.
59    pub memory_limits: MemoryLimits,
60    /// Number of imported functions (precede local functions in global index space).
61    pub num_imported_funcs: u32,
62    /// Type indices for imported functions.
63    pub imported_func_type_indices: Vec<u32>,
64    /// Names of imported functions.
65    pub imported_func_names: Vec<String>,
66
67    // --- Derived data ---
68    /// Local function index of the main entry point.
69    pub main_func_local_idx: usize,
70    /// Whether the WASM module exports a "main2" secondary entry point.
71    pub has_secondary_entry: bool,
72    /// Local function index of the secondary entry point (None if import or absent).
73    pub secondary_entry_local_idx: Option<usize>,
74    /// Local function index of the start function (None if import or absent).
75    pub start_func_local_idx: Option<usize>,
76    /// (`num_params`, `has_return`) for each function (imports first, then locals).
77    pub function_signatures: Vec<(usize, bool)>,
78    /// (`num_params`, `num_results`) for each type.
79    pub type_signatures: Vec<(usize, usize)>,
80    /// Function table for indirect calls (`u32::MAX` = invalid entry).
81    pub function_table: Vec<u32>,
82    /// WASM global indices of all exported functions.
83    pub exported_wasm_func_indices: Vec<u32>,
84    /// Base address of WASM linear memory in PVM address space.
85    pub wasm_memory_base: i32,
86    /// Maximum WASM memory pages available for memory.grow.
87    pub max_memory_pages: u32,
88}
89
90impl<'a> WasmModule<'a> {
91    /// Parse and validate a WASM binary, producing a `WasmModule` with all derived data.
92    pub fn parse(wasm: &'a [u8]) -> Result<Self> {
93        wasmparser::validate(wasm)
94            .map_err(|e| Error::Internal(format!("WASM validation error: {e}")))?;
95
96        let mut functions = Vec::new();
97        let mut func_types: Vec<wasmparser::FuncType> = Vec::new();
98        let mut function_type_indices = Vec::new();
99        let mut globals: Vec<GlobalType> = Vec::new();
100        let mut global_init_values: Vec<i32> = Vec::new();
101        let mut main_func_idx: Option<u32> = None;
102        let mut secondary_entry_func_idx: Option<u32> = None;
103        let mut start_func_idx: Option<u32> = None;
104        let mut exported_wasm_func_indices: Vec<u32> = Vec::new();
105        let mut tables: Vec<wasmparser::TableType> = Vec::new();
106        let mut table_elements: Vec<(u32, u32, Vec<u32>)> = Vec::new();
107        let mut data_segments: Vec<DataSegment> = Vec::new();
108        let mut memory_limits = MemoryLimits::default();
109        let mut num_imported_funcs: u32 = 0;
110        let mut imported_func_type_indices: Vec<u32> = Vec::new();
111        let mut imported_func_names: Vec<String> = Vec::new();
112
113        for payload in Parser::new(0).parse_all(wasm) {
114            match payload? {
115                Payload::TypeSection(reader) => {
116                    for rec_group in reader {
117                        for sub_type in rec_group?.into_types() {
118                            if let wasmparser::CompositeInnerType::Func(f) =
119                                &sub_type.composite_type.inner
120                            {
121                                func_types.push(f.clone());
122                            }
123                        }
124                    }
125                }
126                Payload::ImportSection(reader) => {
127                    for import in reader {
128                        let import = import?;
129                        if let wasmparser::TypeRef::Func(type_idx) = import.ty {
130                            num_imported_funcs += 1;
131                            imported_func_type_indices.push(type_idx);
132                            imported_func_names.push(import.name.to_string());
133                        }
134                    }
135                }
136                Payload::FunctionSection(reader) => {
137                    for type_idx in reader {
138                        function_type_indices.push(type_idx?);
139                    }
140                }
141                Payload::GlobalSection(reader) => {
142                    for global in reader {
143                        let g = global?;
144                        globals.push(g.ty);
145                        let init_value = eval_const_i32(&g.init_expr)?;
146                        global_init_values.push(init_value);
147                    }
148                }
149                Payload::StartSection { func, .. } => {
150                    start_func_idx = Some(func);
151                }
152                Payload::TableSection(reader) => {
153                    for table in reader {
154                        tables.push(table?.ty);
155                    }
156                }
157                Payload::MemorySection(reader) => {
158                    if let Some(memory) = reader.into_iter().next() {
159                        let mem = memory?;
160                        memory_limits = MemoryLimits {
161                            initial_pages: mem.initial as u32,
162                            max_pages: mem.maximum.map(|m| m as u32),
163                        };
164                    }
165                }
166                Payload::ElementSection(reader) => {
167                    for element in reader {
168                        let element = element?;
169                        if let wasmparser::ElementKind::Active {
170                            table_index,
171                            offset_expr,
172                        } = element.kind
173                        {
174                            let table_idx = table_index.unwrap_or(0);
175                            let offset = eval_const_i32(&offset_expr)?;
176                            let func_indices: Vec<u32> = match element.items {
177                                wasmparser::ElementItems::Functions(reader) => {
178                                    reader.into_iter().collect::<std::result::Result<_, _>>()?
179                                }
180                                wasmparser::ElementItems::Expressions(_, reader) => {
181                                    let mut indices = Vec::new();
182                                    for expr in reader {
183                                        let expr = expr?;
184                                        if let Some(idx) = eval_const_ref(&expr) {
185                                            indices.push(idx);
186                                        }
187                                    }
188                                    indices
189                                }
190                            };
191                            table_elements.push((table_idx, offset as u32, func_indices));
192                        }
193                    }
194                }
195                Payload::ExportSection(reader) => {
196                    for export in reader {
197                        let export = export?;
198                        if export.kind == wasmparser::ExternalKind::Func {
199                            exported_wasm_func_indices.push(export.index);
200                            let is_imported = export.index < num_imported_funcs;
201                            let is_main_name = matches!(
202                                export.name,
203                                "main"
204                                    | "refine"
205                                    | "refine_ext"
206                                    | "is_authorized"
207                                    | "is_authorized_ext"
208                            );
209                            let is_secondary_name =
210                                matches!(export.name, "main2" | "accumulate" | "accumulate_ext");
211                            if is_imported && (is_main_name || is_secondary_name) {
212                                return Err(Error::Internal(format!(
213                                    "Entry export '{}' refers to imported function index {}",
214                                    export.name, export.index
215                                )));
216                            }
217                            match export.name {
218                                "main" => {
219                                    main_func_idx = Some(export.index);
220                                }
221                                "refine" | "refine_ext" | "is_authorized" | "is_authorized_ext"
222                                    if main_func_idx.is_none() =>
223                                {
224                                    main_func_idx = Some(export.index);
225                                }
226                                "main2" => {
227                                    secondary_entry_func_idx = Some(export.index);
228                                }
229                                "accumulate" | "accumulate_ext"
230                                    if secondary_entry_func_idx.is_none() =>
231                                {
232                                    secondary_entry_func_idx = Some(export.index);
233                                }
234                                _ => {}
235                            }
236                        }
237                    }
238                }
239                Payload::CodeSectionEntry(body) => {
240                    functions.push(body);
241                }
242                Payload::DataSection(reader) => {
243                    for data in reader {
244                        let data = data?;
245                        match data.kind {
246                            wasmparser::DataKind::Active {
247                                memory_index: _,
248                                offset_expr,
249                            } => {
250                                let offset = eval_const_i32(&offset_expr)? as u32;
251                                data_segments.push(DataSegment {
252                                    offset: Some(offset),
253                                    data: data.data.to_vec(),
254                                });
255                            }
256                            wasmparser::DataKind::Passive => {
257                                data_segments.push(DataSegment {
258                                    offset: None,
259                                    data: data.data.to_vec(),
260                                });
261                            }
262                        }
263                    }
264                }
265                _ => {}
266            }
267        }
268
269        if functions.is_empty() {
270            return Err(Error::NoExportedFunction);
271        }
272
273        // Convert main_func_idx from global to local function index
274        let main_func_local_idx = if let Some(idx) = main_func_idx {
275            idx as usize - num_imported_funcs as usize
276        } else {
277            tracing::warn!("No 'main' export found, defaulting to first local function");
278            0
279        };
280
281        // Resolve secondary entry from global to local function index
282        let has_secondary_entry = secondary_entry_func_idx.is_some();
283        let secondary_entry_local_idx = secondary_entry_func_idx.and_then(|idx| {
284            idx.checked_sub(num_imported_funcs)
285                .map(|v| v as usize)
286                .or_else(|| {
287                    tracing::warn!(
288                        "secondary entry function {idx} is an imported function, ignoring"
289                    );
290                    None
291                })
292        });
293        // Resolve start function from global to local function index
294        let start_func_local_idx = start_func_idx.and_then(|idx| {
295            idx.checked_sub(num_imported_funcs)
296                .map(|v| v as usize)
297                .or_else(|| {
298                    tracing::warn!("start function {idx} is an imported function, ignoring");
299                    None
300                })
301        });
302
303        // Build function signatures: (num_params, has_return) indexed by global function index
304        let function_signatures: Vec<(usize, bool)> = imported_func_type_indices
305            .iter()
306            .chain(function_type_indices.iter())
307            .map(|&type_idx| {
308                let func_type = func_types.get(type_idx as usize);
309                let num_params = func_type.map_or(0, |f| f.params().len());
310                let has_return = func_type.is_some_and(|f| !f.results().is_empty());
311                (num_params, has_return)
312            })
313            .collect();
314
315        // Build type signatures: (num_params, num_results) for each type
316        let type_signatures: Vec<(usize, usize)> = func_types
317            .iter()
318            .map(|f| (f.params().len(), f.results().len()))
319            .collect();
320
321        // Build function table from element sections
322        let table_size = tables.first().map_or(0, |t| t.initial as usize);
323        let mut function_table: Vec<u32> = vec![u32::MAX; table_size];
324        for (table_idx, offset, func_indices) in &table_elements {
325            if *table_idx == 0 {
326                for (i, &func_idx) in func_indices.iter().enumerate() {
327                    let idx = *offset as usize + i;
328                    if idx < function_table.len() {
329                        function_table[idx] = func_idx;
330                    }
331                }
332            }
333        }
334
335        let num_passive_segments = data_segments
336            .iter()
337            .filter(|seg| seg.offset.is_none())
338            .count();
339        // Compute WASM memory base
340        let wasm_memory_base =
341            memory_layout::compute_wasm_memory_base(globals.len(), num_passive_segments);
342
343        // max_memory_pages is the runtime limit for memory.grow (hardcoded in PVM code).
344        // When the WASM module doesn't declare a max, use DEFAULT_MAX_PAGES (1 MB).
345        // When it does, respect its preference (but warn in CLI output if large).
346        let max_memory_pages = memory_limits
347            .max_pages
348            .unwrap_or(DEFAULT_MAX_PAGES)
349            .max(memory_limits.initial_pages);
350
351        Ok(WasmModule {
352            functions,
353            func_types,
354            function_type_indices,
355            globals,
356            global_init_values,
357            data_segments,
358            memory_limits,
359            num_imported_funcs,
360            imported_func_type_indices,
361            imported_func_names,
362            main_func_local_idx,
363            has_secondary_entry,
364            secondary_entry_local_idx,
365            start_func_local_idx,
366            function_signatures,
367            type_signatures,
368            function_table,
369            exported_wasm_func_indices,
370            wasm_memory_base,
371            max_memory_pages,
372        })
373    }
374}
375
376fn eval_const_i32(expr: &wasmparser::ConstExpr) -> Result<i32> {
377    let mut reader = expr.get_binary_reader();
378    while !reader.eof() {
379        match reader.read_operator()? {
380            wasmparser::Operator::I32Const { value } => return Ok(value),
381            wasmparser::Operator::End => break,
382            _ => {}
383        }
384    }
385    Ok(0)
386}
387
388fn eval_const_ref(expr: &wasmparser::ConstExpr) -> Option<u32> {
389    let mut reader = expr.get_binary_reader();
390    while !reader.eof() {
391        if let Ok(op) = reader.read_operator() {
392            match op {
393                wasmparser::Operator::RefFunc { function_index } => return Some(function_index),
394                wasmparser::Operator::End => break,
395                _ => {}
396            }
397        } else {
398            break;
399        }
400    }
401    None
402}
403
404#[cfg(test)]
405mod tests {
406    use super::WasmModule;
407
408    #[test]
409    fn main_export_name_overrides_alias() {
410        let wasm = wat::parse_str(
411            r#"(module
412                (func $canonical_main (export "main"))
413                (func $alias_main (export "refine"))
414            )"#,
415        )
416        .expect("valid WAT");
417        let module = WasmModule::parse(&wasm).expect("valid module");
418
419        assert_eq!(module.main_func_local_idx, 0);
420    }
421
422    #[test]
423    fn secondary_main2_export_name_overrides_alias() {
424        let wasm = wat::parse_str(
425            r#"(module
426                (func $main (export "main"))
427                (func $canonical_secondary (export "main2"))
428                (func $alias_secondary (export "accumulate_ext"))
429            )"#,
430        )
431        .expect("valid WAT");
432        let module = WasmModule::parse(&wasm).expect("valid module");
433
434        assert!(module.has_secondary_entry);
435        assert_eq!(module.secondary_entry_local_idx, Some(1));
436    }
437
438    #[test]
439    fn reverse_main_export_name_overrides_alias() {
440        let wasm = wat::parse_str(
441            r#"(module
442                (func $canonical_main)
443                (func $alias_main)
444                (export "refine" (func $alias_main))
445                (export "main" (func $canonical_main))
446            )"#,
447        )
448        .expect("valid WAT");
449        let module = WasmModule::parse(&wasm).expect("valid module");
450
451        assert_eq!(module.main_func_local_idx, 0);
452    }
453
454    #[test]
455    fn reverse_secondary_main2_export_name_overrides_alias() {
456        let wasm = wat::parse_str(
457            r#"(module
458                (func $main (export "main"))
459                (func $canonical_secondary)
460                (func $alias_secondary)
461                (export "accumulate_ext" (func $alias_secondary))
462                (export "main2" (func $canonical_secondary))
463            )"#,
464        )
465        .expect("valid WAT");
466        let module = WasmModule::parse(&wasm).expect("valid module");
467
468        assert!(module.has_secondary_entry);
469        assert_eq!(module.secondary_entry_local_idx, Some(1));
470    }
471
472    #[test]
473    fn imported_entry_export_returns_error() {
474        let wasm = wat::parse_str(
475            r#"(module
476                (import "env" "main_import" (func $main_import))
477                (func $local_main)
478                (export "main" (func $main_import))
479            )"#,
480        )
481        .expect("valid WAT");
482
483        match WasmModule::parse(&wasm) {
484            Ok(_) => panic!("must reject imported main export"),
485            Err(crate::Error::Internal(msg)) => {
486                assert!(
487                    msg.contains("imported function index"),
488                    "unexpected error message: {msg}"
489                );
490            }
491            Err(err) => panic!("unexpected error: {err}"),
492        }
493    }
494}