swasm_utils/stack_height/
mod.rs1use std::string::String;
53use std::vec::Vec;
54
55use swasm::elements::{self, Type};
56use swasm::builder;
57
58macro_rules! instrument_call {
60 ($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
61 use $crate::swasm::elements::Instruction::*;
62 [
63 GetGlobal($stack_height_global_idx),
65 I32Const($callee_stack_cost),
66 I32Add,
67 SetGlobal($stack_height_global_idx),
68 GetGlobal($stack_height_global_idx),
70 I32Const($stack_limit as i32),
71 I32GtU,
72 If(elements::BlockType::NoResult),
73 Unreachable,
74 End,
75 Call($callee_idx),
77 GetGlobal($stack_height_global_idx),
79 I32Const($callee_stack_cost),
80 I32Sub,
81 SetGlobal($stack_height_global_idx),
82 ]
83 }};
84}
85
86mod max_height;
87mod thunk;
88
89#[derive(Debug)]
93pub struct Error(String);
94
95pub(crate) struct Context {
96 stack_height_global_idx: Option<u32>,
97 func_stack_costs: Option<Vec<u32>>,
98 stack_limit: u32,
99}
100
101impl Context {
102 fn stack_height_global_idx(&self) -> u32 {
106 self.stack_height_global_idx.expect(
107 "stack_height_global_idx isn't yet generated;
108 Did you call `inject_stack_counter_global`",
109 )
110 }
111
112 fn stack_cost(&self, func_idx: u32) -> Option<u32> {
117 self.func_stack_costs
118 .as_ref()
119 .expect(
120 "func_stack_costs isn't yet computed;
121 Did you call `compute_stack_costs`?",
122 )
123 .get(func_idx as usize)
124 .cloned()
125 }
126
127 fn stack_limit(&self) -> u32 {
129 self.stack_limit
130 }
131}
132
133pub fn inject_limiter(
141 mut module: elements::Module,
142 stack_limit: u32,
143) -> Result<elements::Module, Error> {
144 let mut ctx = Context {
145 stack_height_global_idx: None,
146 func_stack_costs: None,
147 stack_limit,
148 };
149
150 generate_stack_height_global(&mut ctx, &mut module);
151 compute_stack_costs(&mut ctx, &module)?;
152 instrument_functions(&mut ctx, &mut module)?;
153 let module = thunk::generate_thunks(&mut ctx, module)?;
154
155 Ok(module)
156}
157
158fn generate_stack_height_global(ctx: &mut Context, module: &mut elements::Module) {
160 let global_entry = builder::global()
161 .value_type()
162 .i32()
163 .mutable()
164 .init_expr(elements::Instruction::I32Const(0))
165 .build();
166
167 for section in module.sections_mut() {
169 if let elements::Section::Global(ref mut gs) = *section {
170 gs.entries_mut().push(global_entry);
171
172 let stack_height_global_idx = (gs.entries().len() as u32) - 1;
173 ctx.stack_height_global_idx = Some(stack_height_global_idx);
174 return;
175 }
176 }
177
178 module.sections_mut().push(elements::Section::Global(
180 elements::GlobalSection::with_entries(vec![global_entry]),
181 ));
182 ctx.stack_height_global_idx = Some(0);
183}
184
185fn compute_stack_costs(ctx: &mut Context, module: &elements::Module) -> Result<(), Error> {
189 let func_imports = module.import_count(elements::ImportCountType::Function);
190 let mut func_stack_costs = vec![0; module.functions_space()];
191 for (func_idx, func_stack_cost) in func_stack_costs.iter_mut().enumerate() {
193 if func_idx >= func_imports {
195 *func_stack_cost = compute_stack_cost(func_idx as u32, &module)?;
196 }
197 }
198
199 ctx.func_stack_costs = Some(func_stack_costs);
200 Ok(())
201}
202
203fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
207 let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
210 let defined_func_idx = func_idx.checked_sub(func_imports).ok_or_else(|| {
211 Error("This should be a index of a defined function".into())
212 })?;
213
214 let code_section = module.code_section().ok_or_else(|| {
215 Error("Due to validation code section should exists".into())
216 })?;
217 let body = &code_section
218 .bodies()
219 .get(defined_func_idx as usize)
220 .ok_or_else(|| Error("Function body is out of bounds".into()))?;
221 let locals_count = body.locals().len() as u32;
222
223 let max_stack_height =
224 max_height::compute(
225 defined_func_idx,
226 module
227 )?;
228
229 Ok(locals_count + max_stack_height)
230}
231
232fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
233 for section in module.sections_mut() {
234 if let elements::Section::Code(ref mut code_section) = *section {
235 for func_body in code_section.bodies_mut() {
236 let mut opcodes = func_body.code_mut();
237 instrument_function(ctx, opcodes)?;
238 }
239 }
240 }
241 Ok(())
242}
243
244fn instrument_function(
271 ctx: &mut Context,
272 instructions: &mut elements::Instructions,
273) -> Result<(), Error> {
274 use swasm::elements::Instruction::*;
275
276 let mut cursor = 0;
277 loop {
278 if cursor >= instructions.elements().len() {
279 break;
280 }
281
282 enum Action {
283 InstrumentCall {
284 callee_idx: u32,
285 callee_stack_cost: u32,
286 },
287 Nop,
288 }
289
290 let action: Action = {
291 let instruction = &instructions.elements()[cursor];
292 match *instruction {
293 Call(ref callee_idx) => {
294 let callee_stack_cost = ctx
295 .stack_cost(*callee_idx)
296 .ok_or_else(||
297 Error(
298 format!("Call to function that out-of-bounds: {}", callee_idx)
299 )
300 )?;
301
302 if callee_stack_cost > 0 {
305 Action::InstrumentCall {
306 callee_idx: *callee_idx,
307 callee_stack_cost,
308 }
309 } else {
310 Action::Nop
311 }
312 },
313 _ => Action::Nop,
314 }
315 };
316
317 match action {
318 Action::InstrumentCall { callee_idx, callee_stack_cost } => {
322 let new_seq = instrument_call!(
323 callee_idx,
324 callee_stack_cost as i32,
325 ctx.stack_height_global_idx(),
326 ctx.stack_limit()
327 );
328
329 let _ = instructions
335 .elements_mut()
336 .splice(cursor..(cursor + 1), new_seq.iter().cloned())
337 .count();
338
339 cursor += new_seq.len();
341 }
342 _ => {
344 cursor += 1;
345 }
346 }
347 }
348
349 Ok(())
350}
351
352fn resolve_func_type(
353 func_idx: u32,
354 module: &elements::Module,
355) -> Result<&elements::FunctionType, Error> {
356 let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
357 let functions = module
358 .function_section()
359 .map(|fs| fs.entries())
360 .unwrap_or(&[]);
361
362 let func_imports = module.import_count(elements::ImportCountType::Function);
363 let sig_idx = if func_idx < func_imports as u32 {
364 module
365 .import_section()
366 .expect("function import count is not zero; import section must exists; qed")
367 .entries()
368 .iter()
369 .filter_map(|entry| match *entry.external() {
370 elements::External::Function(ref idx) => Some(*idx),
371 _ => None,
372 })
373 .nth(func_idx as usize)
374 .expect(
375 "func_idx is less than function imports count;
376 nth function import must be `Some`;
377 qed",
378 )
379 } else {
380 functions
381 .get(func_idx as usize - func_imports)
382 .ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
383 .type_ref()
384 };
385 let Type::Function(ref ty) = *types.get(sig_idx as usize).ok_or_else(|| {
386 Error(format!(
387 "Signature {} (specified by func {}) isn't defined",
388 sig_idx, func_idx
389 ))
390 })?;
391 Ok(ty)
392}
393
394#[cfg(test)]
395mod tests {
396 extern crate wabt;
397 use swasm::elements;
398 use super::*;
399
400 fn parse_wat(source: &str) -> elements::Module {
401 elements::deserialize_buffer(&wabt::wat2swasm(source).expect("Failed to wat2swasm"))
402 .expect("Failed to deserialize the module")
403 }
404
405 fn validate_module(module: elements::Module) {
406 let binary = elements::serialize(module).expect("Failed to serialize");
407 wabt::Module::read_binary(&binary, &Default::default())
408 .expect("Wabt failed to read final binary")
409 .validate()
410 .expect("Invalid module");
411 }
412
413 #[test]
414 fn test_with_params_and_result() {
415 let module = parse_wat(
416 r#"
417(module
418 (func (export "i32.add") (param i32 i32) (result i32)
419 get_local 0
420 get_local 1
421 i32.add
422 )
423)
424"#,
425 );
426
427 let module = inject_limiter(module, 1024)
428 .expect("Failed to inject stack counter");
429 validate_module(module);
430 }
431}