twasm_utils/stack_height/
mod.rs1use crate::std::string::String;
53use crate::std::vec::Vec;
54
55use tetsy_wasm::elements::{self, Type};
56use tetsy_wasm::builder;
57
58macro_rules! instrument_call {
60 ($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
61 use $crate::tetsy_wasm::elements::Instruction::*;
62 [
63 GetGlobal($stack_height_global_idx),
65 I32Const($callee_stack_cost),
66 I32Add,
67 SetGlobal($stack_height_global_idx),
68 GetGlobal($stack_height_global_idx),
70 I32Const($stack_limit as i32),
71 I32GtU,
72 If(elements::BlockType::NoResult),
73 Unreachable,
74 End,
75 Call($callee_idx),
77 GetGlobal($stack_height_global_idx),
79 I32Const($callee_stack_cost),
80 I32Sub,
81 SetGlobal($stack_height_global_idx),
82 ]
83 }};
84}
85
86mod max_height;
87mod thunk;
88
89#[derive(Debug)]
93pub struct Error(String);
94
95pub(crate) struct Context {
96 stack_height_global_idx: u32,
97 func_stack_costs: Vec<u32>,
98 stack_limit: u32,
99}
100
101impl Context {
102 fn stack_height_global_idx(&self) -> u32 {
104 self.stack_height_global_idx
105 }
106
107 fn stack_cost(&self, func_idx: u32) -> Option<u32> {
109 self.func_stack_costs.get(func_idx as usize).cloned()
110 }
111
112 fn stack_limit(&self) -> u32 {
114 self.stack_limit
115 }
116}
117
118pub fn inject_limiter(
126 mut module: elements::Module,
127 stack_limit: u32,
128) -> Result<elements::Module, Error> {
129 let mut ctx = Context {
130 stack_height_global_idx: generate_stack_height_global(&mut module),
131 func_stack_costs: compute_stack_costs(&module)?,
132 stack_limit,
133 };
134
135 instrument_functions(&mut ctx, &mut module)?;
136 let module = thunk::generate_thunks(&mut ctx, module)?;
137
138 Ok(module)
139}
140
141fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
143 let global_entry = builder::global()
144 .value_type()
145 .i32()
146 .mutable()
147 .init_expr(elements::Instruction::I32Const(0))
148 .build();
149
150 for section in module.sections_mut() {
152 if let elements::Section::Global(gs) = section {
153 gs.entries_mut().push(global_entry);
154 return (gs.entries().len() as u32) - 1;
155 }
156 }
157
158 module.sections_mut().push(elements::Section::Global(
160 elements::GlobalSection::with_entries(vec![global_entry]),
161 ));
162 0
163}
164
165fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, Error> {
169 let func_imports = module.import_count(elements::ImportCountType::Function);
170
171 (0..module.functions_space())
173 .map(|func_idx| {
174 if func_idx < func_imports {
175 Ok(0)
177 } else {
178 compute_stack_cost(func_idx as u32, &module)
179 }
180 })
181 .collect()
182}
183
184fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, Error> {
188 let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
191 let defined_func_idx = func_idx.checked_sub(func_imports).ok_or_else(|| {
192 Error("This should be a index of a defined function".into())
193 })?;
194
195 let code_section = module.code_section().ok_or_else(|| {
196 Error("Due to validation code section should exists".into())
197 })?;
198 let body = &code_section
199 .bodies()
200 .get(defined_func_idx as usize)
201 .ok_or_else(|| Error("Function body is out of bounds".into()))?;
202 let locals_count = body.locals().len() as u32;
203
204 let max_stack_height =
205 max_height::compute(
206 defined_func_idx,
207 module
208 )?;
209
210 Ok(locals_count + max_stack_height)
211}
212
213fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Result<(), Error> {
214 for section in module.sections_mut() {
215 if let elements::Section::Code(code_section) = section {
216 for func_body in code_section.bodies_mut() {
217 let opcodes = func_body.code_mut();
218 instrument_function(ctx, opcodes)?;
219 }
220 }
221 }
222 Ok(())
223}
224
225fn instrument_function(
252 ctx: &mut Context,
253 instructions: &mut elements::Instructions,
254) -> Result<(), Error> {
255 use tetsy_wasm::elements::Instruction::*;
256
257 let mut cursor = 0;
258 loop {
259 if cursor >= instructions.elements().len() {
260 break;
261 }
262
263 enum Action {
264 InstrumentCall {
265 callee_idx: u32,
266 callee_stack_cost: u32,
267 },
268 Nop,
269 }
270
271 let action: Action = {
272 let instruction = &instructions.elements()[cursor];
273 match instruction {
274 Call(callee_idx) => {
275 let callee_stack_cost = ctx
276 .stack_cost(*callee_idx)
277 .ok_or_else(||
278 Error(
279 format!("Call to function that out-of-bounds: {}", callee_idx)
280 )
281 )?;
282
283 if callee_stack_cost > 0 {
286 Action::InstrumentCall {
287 callee_idx: *callee_idx,
288 callee_stack_cost,
289 }
290 } else {
291 Action::Nop
292 }
293 },
294 _ => Action::Nop,
295 }
296 };
297
298 match action {
299 Action::InstrumentCall { callee_idx, callee_stack_cost } => {
303 let new_seq = instrument_call!(
304 callee_idx,
305 callee_stack_cost as i32,
306 ctx.stack_height_global_idx(),
307 ctx.stack_limit()
308 );
309
310 let _ = instructions
316 .elements_mut()
317 .splice(cursor..(cursor + 1), new_seq.iter().cloned())
318 .count();
319
320 cursor += new_seq.len();
322 }
323 _ => {
325 cursor += 1;
326 }
327 }
328 }
329
330 Ok(())
331}
332
333fn resolve_func_type(
334 func_idx: u32,
335 module: &elements::Module,
336) -> Result<&elements::FunctionType, Error> {
337 let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
338 let functions = module
339 .function_section()
340 .map(|fs| fs.entries())
341 .unwrap_or(&[]);
342
343 let func_imports = module.import_count(elements::ImportCountType::Function);
344 let sig_idx = if func_idx < func_imports as u32 {
345 module
346 .import_section()
347 .expect("function import count is not zero; import section must exists; qed")
348 .entries()
349 .iter()
350 .filter_map(|entry| match entry.external() {
351 elements::External::Function(idx) => Some(*idx),
352 _ => None,
353 })
354 .nth(func_idx as usize)
355 .expect(
356 "func_idx is less than function imports count;
357 nth function import must be `Some`;
358 qed",
359 )
360 } else {
361 functions
362 .get(func_idx as usize - func_imports)
363 .ok_or_else(|| Error(format!("Function at index {} is not defined", func_idx)))?
364 .type_ref()
365 };
366 let Type::Function(ty) = types.get(sig_idx as usize).ok_or_else(|| {
367 Error(format!(
368 "Signature {} (specified by func {}) isn't defined",
369 sig_idx, func_idx
370 ))
371 })?;
372 Ok(ty)
373}
374
375#[cfg(test)]
376mod tests {
377 extern crate wabt;
378 use tetsy_wasm::elements;
379 use super::*;
380
381 fn parse_wat(source: &str) -> elements::Module {
382 elements::deserialize_buffer(&wabt::wat2wasm(source).expect("Failed to wat2wasm"))
383 .expect("Failed to deserialize the module")
384 }
385
386 fn validate_module(module: elements::Module) {
387 let binary = elements::serialize(module).expect("Failed to serialize");
388 wabt::Module::read_binary(&binary, &Default::default())
389 .expect("Wabt failed to read final binary")
390 .validate()
391 .expect("Invalid module");
392 }
393
394 #[test]
395 fn test_with_params_and_result() {
396 let module = parse_wat(
397 r#"
398(module
399 (func (export "i32.add") (param i32 i32) (result i32)
400 get_local 0
401 get_local 1
402 i32.add
403 )
404)
405"#,
406 );
407
408 let module = inject_limiter(module, 1024)
409 .expect("Failed to inject stack counter");
410 validate_module(module);
411 }
412}