twasm_utils/stack_height/
mod.rs

1//! The pass that tries to make stack overflows deterministic, by introducing
2//! an upper bound of the stack size.
3//!
4//! This pass introduces a global mutable variable to track stack height,
5//! and instruments all calls with preamble and postamble.
6//!
7//! Stack height is increased prior the call. Otherwise, the check would
8//! be made after the stack frame is allocated.
9//!
10//! The preamble is inserted before the call. It increments
11//! the global stack height variable with statically determined "stack cost"
12//! of the callee. If after the increment the stack height exceeds
13//! the limit (specified by the `rules`) then execution traps.
14//! Otherwise, the call is executed.
15//!
16//! The postamble is inserted after the call. The purpose of the postamble is to decrease
17//! the stack height by the "stack cost" of the callee function.
18//!
19//! Note, that we can't instrument all possible ways to return from the function. The simplest
20//! example would be a trap issued by the host function.
21//! That means stack height global won't be equal to zero upon the next execution after such trap.
22//!
23//! # Thunks
24//!
25//! Because stack height is increased prior the call few problems arises:
26//!
27//! - Stack height isn't increased upon an entry to the first function, i.e. exported function.
28//! - Start function is executed externally (similar to exported functions).
29//! - It is statically unknown what function will be invoked in an indirect call.
30//!
31//! The solution for this problems is to generate a intermediate functions, called 'thunks', which
32//! will increase before and decrease the stack height after the call to original function, and
33//! then make exported function and table entries, start section to point to a corresponding thunks.
34//!
35//! # Stack cost
36//!
37//! Stack cost of the function is calculated as a sum of it's locals
38//! and the maximal height of the value stack.
39//!
40//! All values are treated equally, as they have the same size.
41//!
42//! The rationale is that this makes it possible to use the following very naive wasm executor:
43//!
44//! - values are implemented by a union, so each value takes a size equal to
45//!   the size of the largest possible value type this union can hold. (In MVP it is 8 bytes)
46//! - each value from the value stack is placed on the native stack.
47//! - each local variable and function argument is placed on the native stack.
48//! - arguments pushed by the caller are copied into callee stack rather than shared
49//!   between the frames.
50//! - upon entry into the function entire stack frame is allocated.
51
52use crate::std::string::String;
53use crate::std::vec::Vec;
54
55use tetsy_wasm::elements::{self, Type};
56use tetsy_wasm::builder;
57
58/// Macro to generate preamble and postamble.
59macro_rules! instrument_call {
60	($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
61		use $crate::tetsy_wasm::elements::Instruction::*;
62		[
63			// stack_height += stack_cost(F)
64			GetGlobal($stack_height_global_idx),
65			I32Const($callee_stack_cost),
66			I32Add,
67			SetGlobal($stack_height_global_idx),
68			// if stack_counter > LIMIT: unreachable
69			GetGlobal($stack_height_global_idx),
70			I32Const($stack_limit as i32),
71			I32GtU,
72			If(elements::BlockType::NoResult),
73			Unreachable,
74			End,
75			// Original call
76			Call($callee_idx),
77			// stack_height -= stack_cost(F)
78			GetGlobal($stack_height_global_idx),
79			I32Const($callee_stack_cost),
80			I32Sub,
81			SetGlobal($stack_height_global_idx),
82		]
83	}};
84}
85
86mod max_height;
87mod thunk;
88
89/// Error that occured during processing the module.
90///
91/// This means that the module is invalid.
92#[derive(Debug)]
93pub struct Error(String);
94
95pub(crate) struct Context {
96	stack_height_global_idx: u32,
97	func_stack_costs: Vec<u32>,
98	stack_limit: u32,
99}
100
101impl Context {
102	/// Returns index in a global index space of a stack_height global variable.
103	fn stack_height_global_idx(&self) -> u32 {
104		self.stack_height_global_idx
105	}
106
107	/// Returns `stack_cost` for `func_idx`.
108	fn stack_cost(&self, func_idx: u32) -> Option<u32> {
109		self.func_stack_costs.get(func_idx as usize).cloned()
110	}
111
112	/// Returns stack limit specified by the rules.
113	fn stack_limit(&self) -> u32 {
114		self.stack_limit
115	}
116}
117
118/// Instrument a module with stack height limiter.
119///
120/// See module-level documentation for more details.
121///
122/// # Errors
123///
124/// Returns `Err` if module is invalid and can't be
125pub fn inject_limiter(
126	mut module: elements::Module,
127	stack_limit: u32,
128) -> Result<elements::Module, Error> {
129	let mut ctx = Context {
130		stack_height_global_idx: generate_stack_height_global(&mut module),
131		func_stack_costs: compute_stack_costs(&module)?,
132		stack_limit,
133	};
134
135	instrument_functions(&mut ctx, &mut module)?;
136	let module = thunk::generate_thunks(&mut ctx, module)?;
137
138	Ok(module)
139}
140
141/// Generate a new global that will be used for tracking current stack height.
142fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
143	let global_entry = builder::global()
144		.value_type()
145		.i32()
146		.mutable()
147		.init_expr(elements::Instruction::I32Const(0))
148		.build();
149
150	// Try to find an existing global section.
151	for section in module.sections_mut() {
152		if let elements::Section::Global(gs) = section {
153			gs.entries_mut().push(global_entry);
154			return (gs.entries().len() as u32) - 1;
155		}
156	}
157
158	// Existing section not found, create one!
159	module.sections_mut().push(elements::Section::Global(
160		elements::GlobalSection::with_entries(vec![global_entry]),
161	));
162	0
163}
164
165/// Calculate stack costs for all functions.
166///
167/// Returns a vector with a stack cost for each function, including imports.
168fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, Error> {
169	let func_imports = module.import_count(elements::ImportCountType::Function);
170
171	// TODO: optimize!
172	(0..module.functions_space())
173		.map(|func_idx| {
174			if func_idx < func_imports {
175				// We can't calculate stack_cost of the import functions.
176				Ok(0)
177			} else {
178				compute_stack_cost(func_idx as u32, &module)
179			}
180		})
181		.collect()
182}
183
184/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
185/// number of arguments plus number of local variables) and the maximal stack
186/// height.
187fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
188	// To calculate the cost of a function we need to convert index from
189	// function index space to defined function spaces.
190	let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
191	let defined_func_idx = func_idx.checked_sub(func_imports).ok_or_else(|| {
192		Error("This should be a index of a defined function".into())
193	})?;
194
195	let code_section = module.code_section().ok_or_else(|| {
196		Error("Due to validation code section should exists".into())
197	})?;
198	let body = &code_section
199		.bodies()
200		.get(defined_func_idx as usize)
201		.ok_or_else(|| Error("Function body is out of bounds".into()))?;
202	let locals_count = body.locals().len() as u32;
203
204	let max_stack_height =
205		max_height::compute(
206			defined_func_idx,
207			module
208		)?;
209
210	Ok(locals_count + max_stack_height)
211}
212
213fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
214	for section in module.sections_mut() {
215		if let elements::Section::Code(code_section) = section {
216			for func_body in code_section.bodies_mut() {
217				let opcodes = func_body.code_mut();
218				instrument_function(ctx, opcodes)?;
219			}
220		}
221	}
222	Ok(())
223}
224
225/// This function searches `call` instructions and wrap each call
226/// with preamble and postamble.
227///
228/// Before:
229///
230/// ```text
231/// get_local 0
232/// get_local 1
233/// call 228
234/// drop
235/// ```
236///
237/// After:
238///
239/// ```text
240/// get_local 0
241/// get_local 1
242///
243/// < ... preamble ... >
244///
245/// call 228
246///
247/// < .. postamble ... >
248///
249/// drop
250/// ```
251fn instrument_function(
252	ctx: &mut Context,
253	instructions: &mut elements::Instructions,
254) -> Result<(), Error> {
255	use tetsy_wasm::elements::Instruction::*;
256
257	let mut cursor = 0;
258	loop {
259		if cursor >= instructions.elements().len() {
260			break;
261		}
262
263		enum Action {
264			InstrumentCall {
265				callee_idx: u32,
266				callee_stack_cost: u32,
267			},
268			Nop,
269		}
270
271		let action: Action = {
272			let instruction = &instructions.elements()[cursor];
273			match instruction {
274				Call(callee_idx) => {
275					let callee_stack_cost = ctx
276						.stack_cost(*callee_idx)
277						.ok_or_else(||
278							Error(
279								format!("Call to function that out-of-bounds: {}", callee_idx)
280							)
281						)?;
282
283					// Instrument only calls to a functions which stack_cost is
284					// non-zero.
285					if callee_stack_cost > 0 {
286						Action::InstrumentCall {
287							callee_idx: *callee_idx,
288							callee_stack_cost,
289						}
290					} else {
291						Action::Nop
292					}
293				},
294				_ => Action::Nop,
295			}
296		};
297
298		match action {
299			// We need to wrap a `call idx` instruction
300			// with a code that adjusts stack height counter
301			// and then restores it.
302			Action::InstrumentCall { callee_idx, callee_stack_cost } => {
303				let new_seq = instrument_call!(
304					callee_idx,
305					callee_stack_cost as i32,
306					ctx.stack_height_global_idx(),
307					ctx.stack_limit()
308				);
309
310				// Replace the original `call idx` instruction with
311				// a wrapped call sequence.
312				//
313				// To splice actually take a place, we need to consume iterator
314				// splice returns. So we just `count()` it.
315				let _ = instructions
316					.elements_mut()
317					.splice(cursor..(cursor + 1), new_seq.iter().cloned())
318					.count();
319
320				// Advance cursor to be after the inserted sequence.
321				cursor += new_seq.len();
322			}
323			// Do nothing for other instructions.
324			_ => {
325				cursor += 1;
326			}
327		}
328	}
329
330	Ok(())
331}
332
333fn resolve_func_type(
334	func_idx: u32,
335	module: &elements::Module,
336) -> Result<&elements::FunctionType, Error> {
337	let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
338	let functions = module
339		.function_section()
340		.map(|fs| fs.entries())
341		.unwrap_or(&[]);
342
343	let func_imports = module.import_count(elements::ImportCountType::Function);
344	let sig_idx = if func_idx < func_imports as u32 {
345		module
346			.import_section()
347			.expect("function import count is not zero; import section must exists; qed")
348			.entries()
349			.iter()
350			.filter_map(|entry| match entry.external() {
351				elements::External::Function(idx) => Some(*idx),
352				_ => None,
353			})
354			.nth(func_idx as usize)
355			.expect(
356				"func_idx is less than function imports count;
357				nth function import must be `Some`;
358				qed",
359			)
360	} else {
361		functions
362			.get(func_idx as usize - func_imports)
363			.ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
364			.type_ref()
365	};
366	let Type::Function(ty) = types.get(sig_idx as usize).ok_or_else(|| {
367		Error(format!(
368			"Signature {} (specified by func {}) isn't defined",
369			sig_idx, func_idx
370		))
371	})?;
372	Ok(ty)
373}
374
375#[cfg(test)]
376mod tests {
377	extern crate wabt;
378	use tetsy_wasm::elements;
379	use super::*;
380
381	fn parse_wat(source: &str) -> elements::Module {
382		elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
383			.expect("Failed to deserialize the module")
384	}
385
386	fn validate_module(module: elements::Module) {
387		let binary = elements::serialize(module).expect("Failed to serialize");
388		wabt::Module::read_binary(&binary, &Default::default())
389			.expect("Wabt failed to read final binary")
390			.validate()
391			.expect("Invalid module");
392	}
393
394	#[test]
395	fn test_with_params_and_result() {
396		let module = parse_wat(
397			r#"
398(module
399	(func (export "i32.add") (param i32 i32) (result i32)
400		get_local 0
401	get_local 1
402	i32.add
403	)
404)
405"#,
406		);
407
408		let module = inject_limiter(module, 1024)
409			.expect("Failed to inject stack counter");
410		validate_module(module);
411	}
412}