swasm_utils/stack_height/
mod.rs

1//! The pass that tries to make stack overflows deterministic, by introducing
2//! an upper bound of the stack size.
3//!
4//! This pass introduces a global mutable variable to track stack height,
5//! and instruments all calls with preamble and postamble.
6//!
7//! Stack height is increased prior the call. Otherwise, the check would
8//! be made after the stack frame is allocated.
9//!
10//! The preamble is inserted before the call. It increments
11//! the global stack height variable with statically determined "stack cost"
12//! of the callee. If after the increment the stack height exceeds
13//! the limit (specified by the `rules`) then execution traps.
14//! Otherwise, the call is executed.
15//!
16//! The postamble is inserted after the call. The purpose of the postamble is to decrease
17//! the stack height by the "stack cost" of the callee function.
18//!
19//! Note, that we can't instrument all possible ways to return from the function. The simplest
20//! example would be a trap issued by the host function.
21//! That means stack height global won't be equal to zero upon the next execution after such trap.
22//!
23//! # Thunks
24//!
25//! Because stack height is increased prior the call few problems arises:
26//!
27//! - Stack height isn't increased upon an entry to the first function, i.e. exported function.
28//! - Start function is executed externally (similar to exported functions).
29//! - It is statically unknown what function will be invoked in an indirect call.
30//!
31//! The solution for this problems is to generate a intermediate functions, called 'thunks', which
32//! will increase before and decrease the stack height after the call to original function, and
33//! then make exported function and table entries, start section to point to a corresponding thunks.
34//!
35//! # Stack cost
36//!
37//! Stack cost of the function is calculated as a sum of it's locals
38//! and the maximal height of the value stack.
39//!
40//! All values are treated equally, as they have the same size.
41//!
42//! The rationale for this it makes it possible to use this very naive swasm executor, that is:
43//!
44//! - values are implemented by a union, so each value takes a size equal to
45//!   the size of the largest possible value type this union can hold. (In MVP it is 8 bytes)
46//! - each value from the value stack is placed on the native stack.
47//! - each local variable and function argument is placed on the native stack.
48//! - arguments pushed by the caller are copied into callee stack rather than shared
49//!   between the frames.
50//! - upon entry into the function entire stack frame is allocated.
51
52use std::string::String;
53use std::vec::Vec;
54
55use swasm::elements::{self, Type};
56use swasm::builder;
57
58/// Macro to generate preamble and postamble.
59macro_rules! instrument_call {
60	($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
61		use $crate::swasm::elements::Instruction::*;
62		[
63			// stack_height += stack_cost(F)
64			GetGlobal($stack_height_global_idx),
65			I32Const($callee_stack_cost),
66			I32Add,
67			SetGlobal($stack_height_global_idx),
68			// if stack_counter > LIMIT: unreachable
69			GetGlobal($stack_height_global_idx),
70			I32Const($stack_limit as i32),
71			I32GtU,
72			If(elements::BlockType::NoResult),
73			Unreachable,
74			End,
75			// Original call
76			Call($callee_idx),
77			// stack_height -= stack_cost(F)
78			GetGlobal($stack_height_global_idx),
79			I32Const($callee_stack_cost),
80			I32Sub,
81			SetGlobal($stack_height_global_idx),
82		]
83	}};
84}
85
86mod max_height;
87mod thunk;
88
89/// Error that occured during processing the module.
90///
91/// This means that the module is invalid.
92#[derive(Debug)]
93pub struct Error(String);
94
95pub(crate) struct Context {
96	stack_height_global_idx: Option<u32>,
97	func_stack_costs: Option<Vec<u32>>,
98	stack_limit: u32,
99}
100
101impl Context {
102	/// Returns index in a global index space of a stack_height global variable.
103	///
104	/// Panics if it haven't generated yet.
105	fn stack_height_global_idx(&self) -> u32 {
106		self.stack_height_global_idx.expect(
107			"stack_height_global_idx isn't yet generated;
108			Did you call `inject_stack_counter_global`",
109		)
110	}
111
112	/// Returns `stack_cost` for `func_idx`.
113	///
114	/// Panics if stack costs haven't computed yet or `func_idx` is greater
115	/// than the last function index.
116	fn stack_cost(&self, func_idx: u32) -> Option<u32> {
117		self.func_stack_costs
118			.as_ref()
119			.expect(
120				"func_stack_costs isn't yet computed;
121				Did you call `compute_stack_costs`?",
122			)
123			.get(func_idx as usize)
124			.cloned()
125	}
126
127	/// Returns stack limit specified by the rules.
128	fn stack_limit(&self) -> u32 {
129		self.stack_limit
130	}
131}
132
133/// Instrument a module with stack height limiter.
134///
135/// See module-level documentation for more details.
136///
137/// # Errors
138///
139/// Returns `Err` if module is invalid and can't be
140pub fn inject_limiter(
141	mut module: elements::Module,
142	stack_limit: u32,
143) -> Result<elements::Module, Error> {
144	let mut ctx = Context {
145		stack_height_global_idx: None,
146		func_stack_costs: None,
147		stack_limit,
148	};
149
150	generate_stack_height_global(&mut ctx, &mut module);
151	compute_stack_costs(&mut ctx, &module)?;
152	instrument_functions(&mut ctx, &mut module)?;
153	let module = thunk::generate_thunks(&mut ctx, module)?;
154
155	Ok(module)
156}
157
158/// Generate a new global that will be used for tracking current stack height.
159fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module) {
160	let global_entry = builder::global()
161		.value_type()
162		.i32()
163		.mutable()
164		.init_expr(elements::Instruction::I32Const(0))
165		.build();
166
167	// Try to find an existing global section.
168	for section in module.sections_mut() {
169		if let elements::Section::Global(ref mut gs) = *section {
170			gs.entries_mut().push(global_entry);
171
172			let stack_height_global_idx = (gs.entries().len() as u32) - 1;
173			ctx.stack_height_global_idx = Some(stack_height_global_idx);
174			return;
175		}
176	}
177
178	// Existing section not found, create one!
179	module.sections_mut().push(elements::Section::Global(
180		elements::GlobalSection::with_entries(vec![global_entry]),
181	));
182	ctx.stack_height_global_idx = Some(0);
183}
184
185/// Calculate stack costs for all functions.
186///
187/// Returns a vector with a stack cost for each function, including imports.
188fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), Error> {
189	let func_imports = module.import_count(elements::ImportCountType::Function);
190	let mut func_stack_costs = vec![0; module.functions_space()];
191	// TODO: optimize!
192	for (func_idx, func_stack_cost) in func_stack_costs.iter_mut().enumerate() {
193		// We can't calculate stack_cost of the import functions.
194		if func_idx >= func_imports {
195			*func_stack_cost = compute_stack_cost(func_idx as u32, &module)?;
196		}
197	}
198
199	ctx.func_stack_costs = Some(func_stack_costs);
200	Ok(())
201}
202
203/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
204/// number of arguments plus number of local variables) and the maximal stack
205/// height.
206fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
207	// To calculate the cost of a function we need to convert index from
208	// function index space to defined function spaces.
209	let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
210	let defined_func_idx = func_idx.checked_sub(func_imports).ok_or_else(|| {
211		Error("This should be a index of a defined function".into())
212	})?;
213
214	let code_section = module.code_section().ok_or_else(|| {
215		Error("Due to validation code section should exists".into())
216	})?;
217	let body = &code_section
218		.bodies()
219		.get(defined_func_idx as usize)
220		.ok_or_else(|| Error("Function body is out of bounds".into()))?;
221	let locals_count = body.locals().len() as u32;
222
223	let max_stack_height =
224		max_height::compute(
225			defined_func_idx,
226			module
227		)?;
228
229	Ok(locals_count + max_stack_height)
230}
231
232fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
233	for section in module.sections_mut() {
234		if let elements::Section::Code(ref mut code_section) = *section {
235			for func_body in code_section.bodies_mut() {
236				let mut opcodes = func_body.code_mut();
237				instrument_function(ctx, opcodes)?;
238			}
239		}
240	}
241	Ok(())
242}
243
244/// This function searches `call` instructions and wrap each call
245/// with preamble and postamble.
246///
247/// Before:
248///
249/// ```text
250/// get_local 0
251/// get_local 1
252/// call 228
253/// drop
254/// ```
255///
256/// After:
257///
258/// ```text
259/// get_local 0
260/// get_local 1
261///
262/// < ... preamble ... >
263///
264/// call 228
265///
266/// < .. postamble ... >
267///
268/// drop
269/// ```
270fn instrument_function(
271	ctx: &mut Context,
272	instructions: &mut elements::Instructions,
273) -> Result<(), Error> {
274	use swasm::elements::Instruction::*;
275
276	let mut cursor = 0;
277	loop {
278		if cursor >= instructions.elements().len() {
279			break;
280		}
281
282		enum Action {
283			InstrumentCall {
284				callee_idx: u32,
285				callee_stack_cost: u32,
286			},
287			Nop,
288		}
289
290		let action: Action = {
291			let instruction = &instructions.elements()[cursor];
292			match *instruction {
293				Call(ref callee_idx) => {
294					let callee_stack_cost = ctx
295						.stack_cost(*callee_idx)
296						.ok_or_else(||
297							Error(
298								format!("Call to function that out-of-bounds: {}", callee_idx)
299							)
300						)?;
301
302					// Instrument only calls to a functions which stack_cost is
303					// non-zero.
304					if callee_stack_cost > 0 {
305						Action::InstrumentCall {
306							callee_idx: *callee_idx,
307							callee_stack_cost,
308						}
309					} else {
310						Action::Nop
311					}
312				},
313				_ => Action::Nop,
314			}
315		};
316
317		match action {
318			// We need to wrap a `call idx` instruction
319			// with a code that adjusts stack height counter
320			// and then restores it.
321			Action::InstrumentCall { callee_idx, callee_stack_cost } => {
322				let new_seq = instrument_call!(
323					callee_idx,
324					callee_stack_cost as i32,
325					ctx.stack_height_global_idx(),
326					ctx.stack_limit()
327				);
328
329				// Replace the original `call idx` instruction with
330				// a wrapped call sequence.
331				//
332				// To splice actually take a place, we need to consume iterator
333				// splice returns. So we just `count()` it.
334				let _ = instructions
335					.elements_mut()
336					.splice(cursor..(cursor + 1), new_seq.iter().cloned())
337					.count();
338
339				// Advance cursor to be after the inserted sequence.
340				cursor += new_seq.len();
341			}
342			// Do nothing for other instructions.
343			_ => {
344				cursor += 1;
345			}
346		}
347	}
348
349	Ok(())
350}
351
352fn resolve_func_type(
353	func_idx: u32,
354	module: &elements::Module,
355) -> Result<&elements::FunctionType, Error> {
356	let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
357	let functions = module
358		.function_section()
359		.map(|fs| fs.entries())
360		.unwrap_or(&[]);
361
362	let func_imports = module.import_count(elements::ImportCountType::Function);
363	let sig_idx = if func_idx < func_imports as u32 {
364		module
365			.import_section()
366			.expect("function import count is not zero; import section must exists; qed")
367			.entries()
368			.iter()
369			.filter_map(|entry| match *entry.external() {
370				elements::External::Function(ref idx) => Some(*idx),
371				_ => None,
372			})
373			.nth(func_idx as usize)
374			.expect(
375				"func_idx is less than function imports count;
376				nth function import must be `Some`;
377				qed",
378			)
379	} else {
380		functions
381			.get(func_idx as usize - func_imports)
382			.ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
383			.type_ref()
384	};
385	let Type::Function(ref ty) = *types.get(sig_idx as usize).ok_or_else(|| {
386		Error(format!(
387			"Signature {} (specified by func {}) isn't defined",
388			sig_idx, func_idx
389		))
390	})?;
391	Ok(ty)
392}
393
394#[cfg(test)]
395mod tests {
396	extern crate wabt;
397	use swasm::elements;
398	use super::*;
399
400	fn parse_wat(source: &str) -> elements::Module {
401		elements::deserialize_buffer(&wabt::wat2swasm(source).expect("Failed to wat2swasm"))
402			.expect("Failed to deserialize the module")
403	}
404
405	fn validate_module(module: elements::Module) {
406		let binary = elements::serialize(module).expect("Failed to serialize");
407		wabt::Module::read_binary(&binary, &Default::default())
408			.expect("Wabt failed to read final binary")
409			.validate()
410			.expect("Invalid module");
411	}
412
413	#[test]
414	fn test_with_params_and_result() {
415		let module = parse_wat(
416			r#"
417(module
418  (func (export "i32.add") (param i32 i32) (result i32)
419    get_local 0
420	get_local 1
421	i32.add
422  )
423)
424"#,
425		);
426
427		let module = inject_limiter(module, 1024)
428			.expect("Failed to inject stack counter");
429		validate_module(module);
430	}
431}