1#![allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
3
4use crate::{Error, Result};
5use wasmparser::{FunctionBody, GlobalType, Parser, Payload};
6
7use super::memory_layout;
8
9#[derive(Debug, Clone, Copy)]
11pub struct MemoryLimits {
12 pub initial_pages: u32,
14 pub max_pages: Option<u32>,
16}
17
18impl Default for MemoryLimits {
19 fn default() -> Self {
20 Self {
21 initial_pages: 1,
22 max_pages: None,
23 }
24 }
25}
26
27pub(crate) const MIN_INITIAL_WASM_PAGES: u32 = 16;
30
31const DEFAULT_MAX_PAGES: u32 = 16;
34
35pub struct DataSegment {
37 pub offset: Option<u32>,
39 pub data: Vec<u8>,
41}
42
43pub struct WasmModule<'a> {
45 pub functions: Vec<FunctionBody<'a>>,
48 pub func_types: Vec<wasmparser::FuncType>,
50 pub function_type_indices: Vec<u32>,
52 pub globals: Vec<GlobalType>,
54 pub global_init_values: Vec<i32>,
56 pub data_segments: Vec<DataSegment>,
58 pub memory_limits: MemoryLimits,
60 pub num_imported_funcs: u32,
62 pub imported_func_type_indices: Vec<u32>,
64 pub imported_func_names: Vec<String>,
66
67 pub main_func_local_idx: usize,
70 pub has_secondary_entry: bool,
72 pub secondary_entry_local_idx: Option<usize>,
74 pub start_func_local_idx: Option<usize>,
76 pub function_signatures: Vec<(usize, bool)>,
78 pub type_signatures: Vec<(usize, usize)>,
80 pub function_table: Vec<u32>,
82 pub exported_wasm_func_indices: Vec<u32>,
84 pub wasm_memory_base: i32,
86 pub max_memory_pages: u32,
88}
89
90impl<'a> WasmModule<'a> {
91 pub fn parse(wasm: &'a [u8]) -> Result<Self> {
93 wasmparser::validate(wasm)
94 .map_err(|e| Error::Internal(format!("WASM validation error: {e}")))?;
95
96 let mut functions = Vec::new();
97 let mut func_types: Vec<wasmparser::FuncType> = Vec::new();
98 let mut function_type_indices = Vec::new();
99 let mut globals: Vec<GlobalType> = Vec::new();
100 let mut global_init_values: Vec<i32> = Vec::new();
101 let mut main_func_idx: Option<u32> = None;
102 let mut secondary_entry_func_idx: Option<u32> = None;
103 let mut start_func_idx: Option<u32> = None;
104 let mut exported_wasm_func_indices: Vec<u32> = Vec::new();
105 let mut tables: Vec<wasmparser::TableType> = Vec::new();
106 let mut table_elements: Vec<(u32, u32, Vec<u32>)> = Vec::new();
107 let mut data_segments: Vec<DataSegment> = Vec::new();
108 let mut memory_limits = MemoryLimits::default();
109 let mut num_imported_funcs: u32 = 0;
110 let mut imported_func_type_indices: Vec<u32> = Vec::new();
111 let mut imported_func_names: Vec<String> = Vec::new();
112
113 for payload in Parser::new(0).parse_all(wasm) {
114 match payload? {
115 Payload::TypeSection(reader) => {
116 for rec_group in reader {
117 for sub_type in rec_group?.into_types() {
118 if let wasmparser::CompositeInnerType::Func(f) =
119 &sub_type.composite_type.inner
120 {
121 func_types.push(f.clone());
122 }
123 }
124 }
125 }
126 Payload::ImportSection(reader) => {
127 for import in reader {
128 let import = import?;
129 if let wasmparser::TypeRef::Func(type_idx) = import.ty {
130 num_imported_funcs += 1;
131 imported_func_type_indices.push(type_idx);
132 imported_func_names.push(import.name.to_string());
133 }
134 }
135 }
136 Payload::FunctionSection(reader) => {
137 for type_idx in reader {
138 function_type_indices.push(type_idx?);
139 }
140 }
141 Payload::GlobalSection(reader) => {
142 for global in reader {
143 let g = global?;
144 globals.push(g.ty);
145 let init_value = eval_const_i32(&g.init_expr)?;
146 global_init_values.push(init_value);
147 }
148 }
149 Payload::StartSection { func, .. } => {
150 start_func_idx = Some(func);
151 }
152 Payload::TableSection(reader) => {
153 for table in reader {
154 tables.push(table?.ty);
155 }
156 }
157 Payload::MemorySection(reader) => {
158 if let Some(memory) = reader.into_iter().next() {
159 let mem = memory?;
160 memory_limits = MemoryLimits {
161 initial_pages: mem.initial as u32,
162 max_pages: mem.maximum.map(|m| m as u32),
163 };
164 }
165 }
166 Payload::ElementSection(reader) => {
167 for element in reader {
168 let element = element?;
169 if let wasmparser::ElementKind::Active {
170 table_index,
171 offset_expr,
172 } = element.kind
173 {
174 let table_idx = table_index.unwrap_or(0);
175 let offset = eval_const_i32(&offset_expr)?;
176 let func_indices: Vec<u32> = match element.items {
177 wasmparser::ElementItems::Functions(reader) => {
178 reader.into_iter().collect::<std::result::Result<_, _>>()?
179 }
180 wasmparser::ElementItems::Expressions(_, reader) => {
181 let mut indices = Vec::new();
182 for expr in reader {
183 let expr = expr?;
184 if let Some(idx) = eval_const_ref(&expr) {
185 indices.push(idx);
186 }
187 }
188 indices
189 }
190 };
191 table_elements.push((table_idx, offset as u32, func_indices));
192 }
193 }
194 }
195 Payload::ExportSection(reader) => {
196 for export in reader {
197 let export = export?;
198 if export.kind == wasmparser::ExternalKind::Func {
199 exported_wasm_func_indices.push(export.index);
200 let is_imported = export.index < num_imported_funcs;
201 let is_main_name = matches!(
202 export.name,
203 "main"
204 | "refine"
205 | "refine_ext"
206 | "is_authorized"
207 | "is_authorized_ext"
208 );
209 let is_secondary_name =
210 matches!(export.name, "main2" | "accumulate" | "accumulate_ext");
211 if is_imported && (is_main_name || is_secondary_name) {
212 return Err(Error::Internal(format!(
213 "Entry export '{}' refers to imported function index {}",
214 export.name, export.index
215 )));
216 }
217 match export.name {
218 "main" => {
219 main_func_idx = Some(export.index);
220 }
221 "refine" | "refine_ext" | "is_authorized" | "is_authorized_ext"
222 if main_func_idx.is_none() =>
223 {
224 main_func_idx = Some(export.index);
225 }
226 "main2" => {
227 secondary_entry_func_idx = Some(export.index);
228 }
229 "accumulate" | "accumulate_ext"
230 if secondary_entry_func_idx.is_none() =>
231 {
232 secondary_entry_func_idx = Some(export.index);
233 }
234 _ => {}
235 }
236 }
237 }
238 }
239 Payload::CodeSectionEntry(body) => {
240 functions.push(body);
241 }
242 Payload::DataSection(reader) => {
243 for data in reader {
244 let data = data?;
245 match data.kind {
246 wasmparser::DataKind::Active {
247 memory_index: _,
248 offset_expr,
249 } => {
250 let offset = eval_const_i32(&offset_expr)? as u32;
251 data_segments.push(DataSegment {
252 offset: Some(offset),
253 data: data.data.to_vec(),
254 });
255 }
256 wasmparser::DataKind::Passive => {
257 data_segments.push(DataSegment {
258 offset: None,
259 data: data.data.to_vec(),
260 });
261 }
262 }
263 }
264 }
265 _ => {}
266 }
267 }
268
269 if functions.is_empty() {
270 return Err(Error::NoExportedFunction);
271 }
272
273 let main_func_local_idx = if let Some(idx) = main_func_idx {
275 idx as usize - num_imported_funcs as usize
276 } else {
277 tracing::warn!("No 'main' export found, defaulting to first local function");
278 0
279 };
280
281 let has_secondary_entry = secondary_entry_func_idx.is_some();
283 let secondary_entry_local_idx = secondary_entry_func_idx.and_then(|idx| {
284 idx.checked_sub(num_imported_funcs)
285 .map(|v| v as usize)
286 .or_else(|| {
287 tracing::warn!(
288 "secondary entry function {idx} is an imported function, ignoring"
289 );
290 None
291 })
292 });
293 let start_func_local_idx = start_func_idx.and_then(|idx| {
295 idx.checked_sub(num_imported_funcs)
296 .map(|v| v as usize)
297 .or_else(|| {
298 tracing::warn!("start function {idx} is an imported function, ignoring");
299 None
300 })
301 });
302
303 let function_signatures: Vec<(usize, bool)> = imported_func_type_indices
305 .iter()
306 .chain(function_type_indices.iter())
307 .map(|&type_idx| {
308 let func_type = func_types.get(type_idx as usize);
309 let num_params = func_type.map_or(0, |f| f.params().len());
310 let has_return = func_type.is_some_and(|f| !f.results().is_empty());
311 (num_params, has_return)
312 })
313 .collect();
314
315 let type_signatures: Vec<(usize, usize)> = func_types
317 .iter()
318 .map(|f| (f.params().len(), f.results().len()))
319 .collect();
320
321 let table_size = tables.first().map_or(0, |t| t.initial as usize);
323 let mut function_table: Vec<u32> = vec![u32::MAX; table_size];
324 for (table_idx, offset, func_indices) in &table_elements {
325 if *table_idx == 0 {
326 for (i, &func_idx) in func_indices.iter().enumerate() {
327 let idx = *offset as usize + i;
328 if idx < function_table.len() {
329 function_table[idx] = func_idx;
330 }
331 }
332 }
333 }
334
335 let num_passive_segments = data_segments
336 .iter()
337 .filter(|seg| seg.offset.is_none())
338 .count();
339 let wasm_memory_base =
341 memory_layout::compute_wasm_memory_base(globals.len(), num_passive_segments);
342
343 let max_memory_pages = memory_limits
347 .max_pages
348 .unwrap_or(DEFAULT_MAX_PAGES)
349 .max(memory_limits.initial_pages);
350
351 Ok(WasmModule {
352 functions,
353 func_types,
354 function_type_indices,
355 globals,
356 global_init_values,
357 data_segments,
358 memory_limits,
359 num_imported_funcs,
360 imported_func_type_indices,
361 imported_func_names,
362 main_func_local_idx,
363 has_secondary_entry,
364 secondary_entry_local_idx,
365 start_func_local_idx,
366 function_signatures,
367 type_signatures,
368 function_table,
369 exported_wasm_func_indices,
370 wasm_memory_base,
371 max_memory_pages,
372 })
373 }
374}
375
376fn eval_const_i32(expr: &wasmparser::ConstExpr) -> Result<i32> {
377 let mut reader = expr.get_binary_reader();
378 while !reader.eof() {
379 match reader.read_operator()? {
380 wasmparser::Operator::I32Const { value } => return Ok(value),
381 wasmparser::Operator::End => break,
382 _ => {}
383 }
384 }
385 Ok(0)
386}
387
388fn eval_const_ref(expr: &wasmparser::ConstExpr) -> Option<u32> {
389 let mut reader = expr.get_binary_reader();
390 while !reader.eof() {
391 if let Ok(op) = reader.read_operator() {
392 match op {
393 wasmparser::Operator::RefFunc { function_index } => return Some(function_index),
394 wasmparser::Operator::End => break,
395 _ => {}
396 }
397 } else {
398 break;
399 }
400 }
401 None
402}
403
404#[cfg(test)]
405mod tests {
406 use super::WasmModule;
407
408 #[test]
409 fn main_export_name_overrides_alias() {
410 let wasm = wat::parse_str(
411 r#"(module
412 (func $canonical_main (export "main"))
413 (func $alias_main (export "refine"))
414 )"#,
415 )
416 .expect("valid WAT");
417 let module = WasmModule::parse(&wasm).expect("valid module");
418
419 assert_eq!(module.main_func_local_idx, 0);
420 }
421
422 #[test]
423 fn secondary_main2_export_name_overrides_alias() {
424 let wasm = wat::parse_str(
425 r#"(module
426 (func $main (export "main"))
427 (func $canonical_secondary (export "main2"))
428 (func $alias_secondary (export "accumulate_ext"))
429 )"#,
430 )
431 .expect("valid WAT");
432 let module = WasmModule::parse(&wasm).expect("valid module");
433
434 assert!(module.has_secondary_entry);
435 assert_eq!(module.secondary_entry_local_idx, Some(1));
436 }
437
438 #[test]
439 fn reverse_main_export_name_overrides_alias() {
440 let wasm = wat::parse_str(
441 r#"(module
442 (func $canonical_main)
443 (func $alias_main)
444 (export "refine" (func $alias_main))
445 (export "main" (func $canonical_main))
446 )"#,
447 )
448 .expect("valid WAT");
449 let module = WasmModule::parse(&wasm).expect("valid module");
450
451 assert_eq!(module.main_func_local_idx, 0);
452 }
453
454 #[test]
455 fn reverse_secondary_main2_export_name_overrides_alias() {
456 let wasm = wat::parse_str(
457 r#"(module
458 (func $main (export "main"))
459 (func $canonical_secondary)
460 (func $alias_secondary)
461 (export "accumulate_ext" (func $alias_secondary))
462 (export "main2" (func $canonical_secondary))
463 )"#,
464 )
465 .expect("valid WAT");
466 let module = WasmModule::parse(&wasm).expect("valid module");
467
468 assert!(module.has_secondary_entry);
469 assert_eq!(module.secondary_entry_local_idx, Some(1));
470 }
471
472 #[test]
473 fn imported_entry_export_returns_error() {
474 let wasm = wat::parse_str(
475 r#"(module
476 (import "env" "main_import" (func $main_import))
477 (func $local_main)
478 (export "main" (func $main_import))
479 )"#,
480 )
481 .expect("valid WAT");
482
483 match WasmModule::parse(&wasm) {
484 Ok(_) => panic!("must reject imported main export"),
485 Err(crate::Error::Internal(msg)) => {
486 assert!(
487 msg.contains("imported function index"),
488 "unexpected error message: {msg}"
489 );
490 }
491 Err(err) => panic!("unexpected error: {err}"),
492 }
493 }
494}