xwasmi_validation/
lib.rs

1// TODO: Uncomment
2// #![warn(missing_docs)]
3
4#![cfg_attr(not(feature = "std"), no_std)]
5//// alloc is required in no_std
6#![cfg_attr(not(feature = "std"), feature(alloc, alloc_prelude))]
7
8#[cfg(not(feature = "std"))]
9#[macro_use]
10extern crate alloc;
11#[cfg(feature = "std")]
12extern crate std as alloc;
13
14pub mod stack;
15
16/// Index of default linear memory.
17pub const DEFAULT_MEMORY_INDEX: u32 = 0;
18/// Index of default table.
19pub const DEFAULT_TABLE_INDEX: u32 = 0;
20
21/// Maximal number of pages that a xwasm instance supports.
22pub const LINEAR_MEMORY_MAX_PAGES: u32 = 65536;
23
24#[allow(unused_imports)]
25use alloc::prelude::v1::*;
26use core::fmt;
27#[cfg(feature = "std")]
28use std::error;
29
30#[cfg(not(feature = "std"))]
31use hashbrown::HashSet;
32#[cfg(feature = "std")]
33use std::collections::HashSet;
34
35use self::context::ModuleContextBuilder;
36use xwasm::elements::{
37    BlockType, External, FuncBody, GlobalEntry, GlobalType, InitExpr, Instruction, Internal,
38    MemoryType, Module, ResizableLimits, TableType, Type, ValueType,
39};
40
41pub mod context;
42pub mod func;
43pub mod util;
44
45#[cfg(test)]
46mod tests;
47
48// TODO: Consider using a type other than String, because
49// of formatting machinary is not welcomed in higgsfield runtimes.
50#[derive(Debug)]
51pub struct Error(pub String);
52
53impl fmt::Display for Error {
54    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
55        write!(f, "{}", self.0)
56    }
57}
58
59#[cfg(feature = "std")]
60impl error::Error for Error {
61    fn description(&self) -> &str {
62        &self.0
63    }
64}
65
66impl From<stack::Error> for Error {
67    fn from(e: stack::Error) -> Error {
68        Error(format!("Stack: {}", e))
69    }
70}
71
72pub trait Validator {
73    type Output;
74    type FuncValidator: FuncValidator;
75    fn new(module: &Module) -> Self;
76    fn on_function_validated(
77        &mut self,
78        index: u32,
79        output: <<Self as Validator>::FuncValidator as FuncValidator>::Output,
80    );
81    fn finish(self) -> Self::Output;
82}
83
84pub trait FuncValidator {
85    type Output;
86    fn new(ctx: &func::FunctionValidationContext, body: &FuncBody) -> Self;
87    fn next_instruction(
88        &mut self,
89        ctx: &mut func::FunctionValidationContext,
90        instruction: &Instruction,
91    ) -> Result<(), Error>;
92    fn finish(self) -> Self::Output;
93}
94
95/// A module validator that just validates modules and produces no result.
96pub struct PlainValidator;
97
98impl Validator for PlainValidator {
99    type Output = ();
100    type FuncValidator = PlainFuncValidator;
101    fn new(_module: &Module) -> PlainValidator {
102        PlainValidator
103    }
104    fn on_function_validated(
105        &mut self,
106        _index: u32,
107        _output: <<Self as Validator>::FuncValidator as FuncValidator>::Output,
108    ) -> () {
109        ()
110    }
111    fn finish(self) -> () {
112        ()
113    }
114}
115
116/// A function validator that just validates modules and produces no result.
117pub struct PlainFuncValidator;
118
119impl FuncValidator for PlainFuncValidator {
120    type Output = ();
121
122    fn new(_ctx: &func::FunctionValidationContext, _body: &FuncBody) -> PlainFuncValidator {
123        PlainFuncValidator
124    }
125
126    fn next_instruction(
127        &mut self,
128        ctx: &mut func::FunctionValidationContext,
129        instruction: &Instruction,
130    ) -> Result<(), Error> {
131        ctx.step(instruction)
132    }
133
134    fn finish(self) -> () {
135        ()
136    }
137}
138
139pub fn validate_module<V: Validator>(module: &Module) -> Result<V::Output, Error> {
140    let mut context_builder = ModuleContextBuilder::new();
141    let mut imported_globals = Vec::new();
142    let mut validation = V::new(&module);
143
144    // Copy types from module as is.
145    context_builder.set_types(
146        module
147            .type_section()
148            .map(|ts| {
149                ts.types()
150                    .into_iter()
151                    .map(|&Type::Function(ref ty)| ty)
152                    .cloned()
153                    .collect()
154            })
155            .unwrap_or_default(),
156    );
157
158    // Fill elements with imported values.
159    for import_entry in module
160        .import_section()
161        .map(|i| i.entries())
162        .unwrap_or_default()
163    {
164        match *import_entry.external() {
165            External::Function(idx) => context_builder.push_func_type_index(idx),
166            External::Table(ref table) => context_builder.push_table(table.clone()),
167            External::Memory(ref memory) => context_builder.push_memory(memory.clone()),
168            External::Global(ref global) => {
169                context_builder.push_global(global.clone());
170                imported_globals.push(global.clone());
171            }
172        }
173    }
174
175    // Concatenate elements with defined in the module.
176    if let Some(function_section) = module.function_section() {
177        for func_entry in function_section.entries() {
178            context_builder.push_func_type_index(func_entry.type_ref())
179        }
180    }
181    if let Some(table_section) = module.table_section() {
182        for table_entry in table_section.entries() {
183            validate_table_type(table_entry)?;
184            context_builder.push_table(table_entry.clone());
185        }
186    }
187    if let Some(mem_section) = module.memory_section() {
188        for mem_entry in mem_section.entries() {
189            validate_memory_type(mem_entry)?;
190            context_builder.push_memory(mem_entry.clone());
191        }
192    }
193    if let Some(global_section) = module.global_section() {
194        for global_entry in global_section.entries() {
195            validate_global_entry(global_entry, &imported_globals)?;
196            context_builder.push_global(global_entry.global_type().clone());
197        }
198    }
199
200    let context = context_builder.build();
201
202    let function_section_len = module
203        .function_section()
204        .map(|s| s.entries().len())
205        .unwrap_or(0);
206    let code_section_len = module.code_section().map(|s| s.bodies().len()).unwrap_or(0);
207    if function_section_len != code_section_len {
208        return Err(Error(format!(
209            "length of function section is {}, while len of code section is {}",
210            function_section_len, code_section_len
211        )));
212    }
213
214    // validate every function body in user modules
215    if function_section_len != 0 {
216        // tests use invalid code
217        let function_section = module
218            .function_section()
219            .expect("function_section_len != 0; qed");
220        let code_section = module
221            .code_section()
222            .expect("function_section_len != 0; function_section_len == code_section_len; qed");
223        // check every function body
224        for (index, function) in function_section.entries().iter().enumerate() {
225            let function_body = code_section
226                .bodies()
227                .get(index as usize)
228                .ok_or(Error(format!("Missing body for function {}", index)))?;
229
230            let output = func::drive::<V::FuncValidator>(&context, function, function_body)
231                .map_err(|Error(ref msg)| {
232                    Error(format!(
233                        "Function #{} reading/validation error: {}",
234                        index, msg
235                    ))
236                })?;
237            validation.on_function_validated(index as u32, output);
238        }
239    }
240
241    // validate start section
242    if let Some(start_fn_idx) = module.start_section() {
243        let (params, return_ty) = context.require_function(start_fn_idx)?;
244        if return_ty != BlockType::NoResult || params.len() != 0 {
245            return Err(Error(
246                "start function expected to have type [] -> []".into(),
247            ));
248        }
249    }
250
251    // validate export section
252    if let Some(export_section) = module.export_section() {
253        let mut export_names = HashSet::with_capacity(export_section.entries().len());
254        for export in export_section.entries() {
255            // HashSet::insert returns false if item already in set.
256            let duplicate = export_names.insert(export.field()) == false;
257            if duplicate {
258                return Err(Error(format!("duplicate export {}", export.field())));
259            }
260            match *export.internal() {
261                Internal::Function(function_index) => {
262                    context.require_function(function_index)?;
263                }
264                Internal::Global(global_index) => {
265                    context.require_global(global_index, Some(false))?;
266                }
267                Internal::Memory(memory_index) => {
268                    context.require_memory(memory_index)?;
269                }
270                Internal::Table(table_index) => {
271                    context.require_table(table_index)?;
272                }
273            }
274        }
275    }
276
277    // validate import section
278    if let Some(import_section) = module.import_section() {
279        for import in import_section.entries() {
280            match *import.external() {
281                External::Function(function_type_index) => {
282                    context.require_function_type(function_type_index)?;
283                }
284                External::Global(ref global_type) => {
285                    if global_type.is_mutable() {
286                        return Err(Error(format!(
287                            "trying to import mutable global {}",
288                            import.field()
289                        )));
290                    }
291                }
292                External::Memory(ref memory_type) => {
293                    validate_memory_type(memory_type)?;
294                }
295                External::Table(ref table_type) => {
296                    validate_table_type(table_type)?;
297                }
298            }
299        }
300    }
301
302    // there must be no greater than 1 table in tables index space
303    if context.tables().len() > 1 {
304        return Err(Error(format!(
305            "too many tables in index space: {}",
306            context.tables().len()
307        )));
308    }
309
310    // there must be no greater than 1 linear memory in memory index space
311    if context.memories().len() > 1 {
312        return Err(Error(format!(
313            "too many memory regions in index space: {}",
314            context.memories().len()
315        )));
316    }
317
318    // use data section to initialize linear memory regions
319    if let Some(data_section) = module.data_section() {
320        for data_segment in data_section.entries() {
321            context.require_memory(data_segment.index())?;
322            let init_ty = expr_const_type(data_segment.offset(), context.globals())?;
323            if init_ty != ValueType::I32 {
324                return Err(Error("segment offset should return I32".into()));
325            }
326        }
327    }
328
329    // use element section to fill tables
330    if let Some(element_section) = module.elements_section() {
331        for element_segment in element_section.entries() {
332            context.require_table(element_segment.index())?;
333
334            let init_ty = expr_const_type(element_segment.offset(), context.globals())?;
335            if init_ty != ValueType::I32 {
336                return Err(Error("segment offset should return I32".into()));
337            }
338
339            for function_index in element_segment.members() {
340                context.require_function(*function_index)?;
341            }
342        }
343    }
344
345    Ok(validation.finish())
346}
347
348fn validate_limits(limits: &ResizableLimits) -> Result<(), Error> {
349    if let Some(maximum) = limits.maximum() {
350        if limits.initial() > maximum {
351            return Err(Error(format!(
352                "maximum limit {} is less than minimum {}",
353                maximum,
354                limits.initial()
355            )));
356        }
357    }
358    Ok(())
359}
360
361fn validate_memory_type(memory_type: &MemoryType) -> Result<(), Error> {
362    let initial = memory_type.limits().initial();
363    let maximum: Option<u32> = memory_type.limits().maximum();
364    validate_memory(initial, maximum).map_err(Error)
365}
366
367pub fn validate_memory(initial: u32, maximum: Option<u32>) -> Result<(), String> {
368    if initial > LINEAR_MEMORY_MAX_PAGES {
369        return Err(format!(
370            "initial memory size must be at most {} pages",
371            LINEAR_MEMORY_MAX_PAGES
372        ));
373    }
374    if let Some(maximum) = maximum {
375        if initial > maximum {
376            return Err(format!(
377                "maximum limit {} is less than minimum {}",
378                maximum, initial,
379            ));
380        }
381
382        if maximum > LINEAR_MEMORY_MAX_PAGES {
383            return Err(format!(
384                "maximum memory size must be at most {} pages",
385                LINEAR_MEMORY_MAX_PAGES
386            ));
387        }
388    }
389    Ok(())
390}
391
392fn validate_table_type(table_type: &TableType) -> Result<(), Error> {
393    validate_limits(table_type.limits())
394}
395
396fn validate_global_entry(global_entry: &GlobalEntry, globals: &[GlobalType]) -> Result<(), Error> {
397    let init = global_entry.init_expr();
398    let init_expr_ty = expr_const_type(init, globals)?;
399    if init_expr_ty != global_entry.global_type().content_type() {
400        return Err(Error(format!(
401            "Trying to initialize variable of type {:?} with value of type {:?}",
402            global_entry.global_type().content_type(),
403            init_expr_ty
404        )));
405    }
406    Ok(())
407}
408
409/// Returns type of this constant expression.
410fn expr_const_type(init_expr: &InitExpr, globals: &[GlobalType]) -> Result<ValueType, Error> {
411    let code = init_expr.code();
412    if code.len() != 2 {
413        return Err(Error(
414            "Init expression should always be with length 2".into(),
415        ));
416    }
417    let expr_ty: ValueType = match code[0] {
418        Instruction::I32Const(_) => ValueType::I32,
419        Instruction::I64Const(_) => ValueType::I64,
420        Instruction::F32Const(_) => ValueType::F32,
421        Instruction::F64Const(_) => ValueType::F64,
422        Instruction::GetGlobal(idx) => match globals.get(idx as usize) {
423            Some(target_global) => {
424                if target_global.is_mutable() {
425                    return Err(Error(format!("Global {} is mutable", idx)));
426                }
427                target_global.content_type()
428            }
429            None => {
430                return Err(Error(format!(
431                    "Global {} doesn't exists or not yet defined",
432                    idx
433                )));
434            }
435        },
436        _ => return Err(Error("Non constant opcode in init expr".into())),
437    };
438    if code[1] != Instruction::End {
439        return Err(Error("Expression doesn't ends with `end` opcode".into()));
440    }
441    Ok(expr_ty)
442}