twasmi_validation/
lib.rs

1// TODO: Uncomment
2// #![warn(missing_docs)]
3
4#![cfg_attr(not(feature = "std"), no_std)]
5
6#[cfg(not(feature = "std"))]
7#[macro_use]
8extern crate alloc;
9#[cfg(feature = "std")]
10extern crate std as alloc;
11
12pub mod stack;
13
14/// Index of default linear memory.
15pub const DEFAULT_MEMORY_INDEX: u32 = 0;
16/// Index of default table.
17pub const DEFAULT_TABLE_INDEX: u32 = 0;
18
19/// Maximal number of pages that a wasm instance supports.
20pub const LINEAR_MEMORY_MAX_PAGES: u32 = 65536;
21
22use alloc::{string::String, vec::Vec};
23use core::fmt;
24#[cfg(feature = "std")]
25use std::error;
26
27use self::context::ModuleContextBuilder;
28use tetsy_wasm::elements::{
29    BlockType, ExportEntry, External, FuncBody, GlobalEntry, GlobalType, InitExpr, Instruction,
30    Internal, MemoryType, Module, ResizableLimits, TableType, Type, ValueType,
31};
32
33pub mod context;
34pub mod func;
35pub mod util;
36
37#[cfg(test)]
38mod tests;
39
40// TODO: Consider using a type other than String, because
41// of formatting machinary is not welcomed in substrate runtimes.
42#[derive(Debug)]
43pub struct Error(pub String);
44
45impl fmt::Display for Error {
46    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
47        write!(f, "{}", self.0)
48    }
49}
50
51#[cfg(feature = "std")]
52impl error::Error for Error {
53    fn description(&self) -> &str {
54        &self.0
55    }
56}
57
58impl From<stack::Error> for Error {
59    fn from(e: stack::Error) -> Error {
60        Error(format!("Stack: {}", e))
61    }
62}
63
64pub trait Validator {
65    type Output;
66    type FuncValidator: FuncValidator;
67    fn new(module: &Module) -> Self;
68    fn on_function_validated(
69        &mut self,
70        index: u32,
71        output: <<Self as Validator>::FuncValidator as FuncValidator>::Output,
72    );
73    fn finish(self) -> Self::Output;
74}
75
76pub trait FuncValidator {
77    type Output;
78    fn new(ctx: &func::FunctionValidationContext, body: &FuncBody) -> Self;
79    fn next_instruction(
80        &mut self,
81        ctx: &mut func::FunctionValidationContext,
82        instruction: &Instruction,
83    ) -> Result<(), Error>;
84    fn finish(self) -> Self::Output;
85}
86
87/// A module validator that just validates modules and produces no result.
88pub struct PlainValidator;
89
90impl Validator for PlainValidator {
91    type Output = ();
92    type FuncValidator = PlainFuncValidator;
93    fn new(_module: &Module) -> PlainValidator {
94        PlainValidator
95    }
96    fn on_function_validated(
97        &mut self,
98        _index: u32,
99        _output: <<Self as Validator>::FuncValidator as FuncValidator>::Output,
100    ) -> () {
101        ()
102    }
103    fn finish(self) -> () {
104        ()
105    }
106}
107
108/// A function validator that just validates modules and produces no result.
109pub struct PlainFuncValidator;
110
111impl FuncValidator for PlainFuncValidator {
112    type Output = ();
113
114    fn new(_ctx: &func::FunctionValidationContext, _body: &FuncBody) -> PlainFuncValidator {
115        PlainFuncValidator
116    }
117
118    fn next_instruction(
119        &mut self,
120        ctx: &mut func::FunctionValidationContext,
121        instruction: &Instruction,
122    ) -> Result<(), Error> {
123        ctx.step(instruction)
124    }
125
126    fn finish(self) -> () {
127        ()
128    }
129}
130
131pub fn validate_module<V: Validator>(module: &Module) -> Result<V::Output, Error> {
132    let mut context_builder = ModuleContextBuilder::new();
133    let mut imported_globals = Vec::new();
134    let mut validation = V::new(&module);
135
136    // Copy types from module as is.
137    context_builder.set_types(
138        module
139            .type_section()
140            .map(|ts| {
141                ts.types()
142                    .into_iter()
143                    .map(|&Type::Function(ref ty)| ty)
144                    .cloned()
145                    .collect()
146            })
147            .unwrap_or_default(),
148    );
149
150    // Fill elements with imported values.
151    for import_entry in module
152        .import_section()
153        .map(|i| i.entries())
154        .unwrap_or_default()
155    {
156        match *import_entry.external() {
157            External::Function(idx) => context_builder.push_func_type_index(idx),
158            External::Table(ref table) => context_builder.push_table(table.clone()),
159            External::Memory(ref memory) => context_builder.push_memory(memory.clone()),
160            External::Global(ref global) => {
161                context_builder.push_global(global.clone());
162                imported_globals.push(global.clone());
163            }
164        }
165    }
166
167    // Concatenate elements with defined in the module.
168    if let Some(function_section) = module.function_section() {
169        for func_entry in function_section.entries() {
170            context_builder.push_func_type_index(func_entry.type_ref())
171        }
172    }
173    if let Some(table_section) = module.table_section() {
174        for table_entry in table_section.entries() {
175            validate_table_type(table_entry)?;
176            context_builder.push_table(table_entry.clone());
177        }
178    }
179    if let Some(mem_section) = module.memory_section() {
180        for mem_entry in mem_section.entries() {
181            validate_memory_type(mem_entry)?;
182            context_builder.push_memory(mem_entry.clone());
183        }
184    }
185    if let Some(global_section) = module.global_section() {
186        for global_entry in global_section.entries() {
187            validate_global_entry(global_entry, &imported_globals)?;
188            context_builder.push_global(global_entry.global_type().clone());
189        }
190    }
191
192    let context = context_builder.build();
193
194    let function_section_len = module
195        .function_section()
196        .map(|s| s.entries().len())
197        .unwrap_or(0);
198    let code_section_len = module.code_section().map(|s| s.bodies().len()).unwrap_or(0);
199    if function_section_len != code_section_len {
200        return Err(Error(format!(
201            "length of function section is {}, while len of code section is {}",
202            function_section_len, code_section_len
203        )));
204    }
205
206    // validate every function body in user modules
207    if function_section_len != 0 {
208        // tests use invalid code
209        let function_section = module
210            .function_section()
211            .expect("function_section_len != 0; qed");
212        let code_section = module
213            .code_section()
214            .expect("function_section_len != 0; function_section_len == code_section_len; qed");
215        // check every function body
216        for (index, function) in function_section.entries().iter().enumerate() {
217            let function_body = code_section
218                .bodies()
219                .get(index as usize)
220                .ok_or(Error(format!("Missing body for function {}", index)))?;
221
222            let output = func::drive::<V::FuncValidator>(&context, function, function_body)
223                .map_err(|Error(ref msg)| {
224                    Error(format!(
225                        "Function #{} reading/validation error: {}",
226                        index, msg
227                    ))
228                })?;
229            validation.on_function_validated(index as u32, output);
230        }
231    }
232
233    // validate start section
234    if let Some(start_fn_idx) = module.start_section() {
235        let (params, return_ty) = context.require_function(start_fn_idx)?;
236        if return_ty != BlockType::NoResult || params.len() != 0 {
237            return Err(Error(
238                "start function expected to have type [] -> []".into(),
239            ));
240        }
241    }
242
243    // validate export section
244    if let Some(export_section) = module.export_section() {
245        let mut export_names = export_section
246            .entries()
247            .iter()
248            .map(ExportEntry::field)
249            .collect::<Vec<_>>();
250
251        export_names.sort_unstable();
252
253        for (fst, snd) in export_names.iter().zip(export_names.iter().skip(1)) {
254            if fst == snd {
255                return Err(Error(format!("duplicate export {}", fst)));
256            }
257        }
258
259        for export in export_section.entries() {
260            match *export.internal() {
261                Internal::Function(function_index) => {
262                    context.require_function(function_index)?;
263                }
264                Internal::Global(global_index) => {
265                    context.require_global(global_index, None)?;
266                }
267                Internal::Memory(memory_index) => {
268                    context.require_memory(memory_index)?;
269                }
270                Internal::Table(table_index) => {
271                    context.require_table(table_index)?;
272                }
273            }
274        }
275    }
276
277    // validate import section
278    if let Some(import_section) = module.import_section() {
279        for import in import_section.entries() {
280            match *import.external() {
281                External::Function(function_type_index) => {
282                    context.require_function_type(function_type_index)?;
283                }
284                External::Global(_) => {}
285                External::Memory(ref memory_type) => {
286                    validate_memory_type(memory_type)?;
287                }
288                External::Table(ref table_type) => {
289                    validate_table_type(table_type)?;
290                }
291            }
292        }
293    }
294
295    // there must be no greater than 1 table in tables index space
296    if context.tables().len() > 1 {
297        return Err(Error(format!(
298            "too many tables in index space: {}",
299            context.tables().len()
300        )));
301    }
302
303    // there must be no greater than 1 linear memory in memory index space
304    if context.memories().len() > 1 {
305        return Err(Error(format!(
306            "too many memory regions in index space: {}",
307            context.memories().len()
308        )));
309    }
310
311    // use data section to initialize linear memory regions
312    if let Some(data_section) = module.data_section() {
313        for data_segment in data_section.entries() {
314            context.require_memory(data_segment.index())?;
315            let offset = data_segment
316                .offset()
317                .as_ref()
318                .ok_or_else(|| Error("passive memory segments are not supported".into()))?;
319            let init_ty = expr_const_type(&offset, context.globals())?;
320            if init_ty != ValueType::I32 {
321                return Err(Error("segment offset should return I32".into()));
322            }
323        }
324    }
325
326    // use element section to fill tables
327    if let Some(element_section) = module.elements_section() {
328        for element_segment in element_section.entries() {
329            context.require_table(element_segment.index())?;
330            let offset = element_segment
331                .offset()
332                .as_ref()
333                .ok_or_else(|| Error("passive element segments are not supported".into()))?;
334            let init_ty = expr_const_type(&offset, context.globals())?;
335            if init_ty != ValueType::I32 {
336                return Err(Error("segment offset should return I32".into()));
337            }
338
339            for function_index in element_segment.members() {
340                context.require_function(*function_index)?;
341            }
342        }
343    }
344
345    Ok(validation.finish())
346}
347
348fn validate_limits(limits: &ResizableLimits) -> Result<(), Error> {
349    if let Some(maximum) = limits.maximum() {
350        if limits.initial() > maximum {
351            return Err(Error(format!(
352                "maximum limit {} is less than minimum {}",
353                maximum,
354                limits.initial()
355            )));
356        }
357    }
358    Ok(())
359}
360
361fn validate_memory_type(memory_type: &MemoryType) -> Result<(), Error> {
362    let initial = memory_type.limits().initial();
363    let maximum: Option<u32> = memory_type.limits().maximum();
364    validate_memory(initial, maximum).map_err(Error)
365}
366
367pub fn validate_memory(initial: u32, maximum: Option<u32>) -> Result<(), String> {
368    if initial > LINEAR_MEMORY_MAX_PAGES {
369        return Err(format!(
370            "initial memory size must be at most {} pages",
371            LINEAR_MEMORY_MAX_PAGES
372        ));
373    }
374    if let Some(maximum) = maximum {
375        if initial > maximum {
376            return Err(format!(
377                "maximum limit {} is less than minimum {}",
378                maximum, initial,
379            ));
380        }
381
382        if maximum > LINEAR_MEMORY_MAX_PAGES {
383            return Err(format!(
384                "maximum memory size must be at most {} pages",
385                LINEAR_MEMORY_MAX_PAGES
386            ));
387        }
388    }
389    Ok(())
390}
391
392fn validate_table_type(table_type: &TableType) -> Result<(), Error> {
393    validate_limits(table_type.limits())
394}
395
396fn validate_global_entry(global_entry: &GlobalEntry, globals: &[GlobalType]) -> Result<(), Error> {
397    let init = global_entry.init_expr();
398    let init_expr_ty = expr_const_type(init, globals)?;
399    if init_expr_ty != global_entry.global_type().content_type() {
400        return Err(Error(format!(
401            "Trying to initialize variable of type {:?} with value of type {:?}",
402            global_entry.global_type().content_type(),
403            init_expr_ty
404        )));
405    }
406    Ok(())
407}
408
409/// Returns type of this constant expression.
410fn expr_const_type(init_expr: &InitExpr, globals: &[GlobalType]) -> Result<ValueType, Error> {
411    let code = init_expr.code();
412    if code.len() != 2 {
413        return Err(Error(
414            "Init expression should always be with length 2".into(),
415        ));
416    }
417    let expr_ty: ValueType = match code[0] {
418        Instruction::I32Const(_) => ValueType::I32,
419        Instruction::I64Const(_) => ValueType::I64,
420        Instruction::F32Const(_) => ValueType::F32,
421        Instruction::F64Const(_) => ValueType::F64,
422        Instruction::GetGlobal(idx) => match globals.get(idx as usize) {
423            Some(target_global) => {
424                if target_global.is_mutable() {
425                    return Err(Error(format!("Global {} is mutable", idx)));
426                }
427                target_global.content_type()
428            }
429            None => {
430                return Err(Error(format!(
431                    "Global {} doesn't exists or not yet defined",
432                    idx
433                )));
434            }
435        },
436        _ => return Err(Error("Non constant opcode in init expr".into())),
437    };
438    if code[1] != Instruction::End {
439        return Err(Error("Expression doesn't ends with `end` opcode".into()));
440    }
441    Ok(expr_ty)
442}