use std::convert::TryInto;
use std::fmt;
use std::sync::Mutex;
use wasmer::wasmparser::{Operator, Type as WpType, TypeOrFuncType as WpTypeOrFuncType};
use wasmer::{
ExportIndex, FunctionMiddleware, GlobalInit, GlobalType, Instance, LocalFunctionIndex,
MiddlewareError, MiddlewareReaderState, ModuleMiddleware, Mutability, Type,
};
use wasmer_types::GlobalIndex;
use wasmer_vm::ModuleInfo;
#[derive(Clone)]
struct MeteringGlobalIndexes(GlobalIndex, GlobalIndex);
impl MeteringGlobalIndexes {
fn remaining_points(&self) -> GlobalIndex {
self.0
}
fn points_exhausted(&self) -> GlobalIndex {
self.1
}
}
impl fmt::Debug for MeteringGlobalIndexes {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MeteringGlobalIndexes")
.field("remaining_points", &self.remaining_points())
.field("points_exhausted", &self.points_exhausted())
.finish()
}
}
pub struct Metering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> {
initial_limit: u64,
cost_function: F,
global_indexes: Mutex<Option<MeteringGlobalIndexes>>,
}
pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> {
cost_function: F,
global_indexes: MeteringGlobalIndexes,
accumulated_cost: u64,
}
#[derive(Debug, PartialEq)]
pub enum MeteringPoints {
Remaining(u64),
Exhausted,
}
impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> Metering<F> {
pub fn new(initial_limit: u64, cost_function: F) -> Self {
Self {
initial_limit,
cost_function,
global_indexes: Mutex::new(None),
}
}
}
impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Metering<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Metering")
.field("initial_limit", &self.initial_limit)
.field("cost_function", &"<function>")
.field("global_indexes", &self.global_indexes)
.finish()
}
}
impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddleware
for Metering<F>
{
fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
Box::new(FunctionMetering {
cost_function: self.cost_function,
global_indexes: self.global_indexes.lock().unwrap().clone().unwrap(),
accumulated_cost: 0,
})
}
fn transform_module_info(&self, module_info: &mut ModuleInfo) {
let mut global_indexes = self.global_indexes.lock().unwrap();
if global_indexes.is_some() {
panic!("Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules.");
}
let remaining_points_global_index = module_info
.globals
.push(GlobalType::new(Type::I64, Mutability::Var));
module_info
.global_initializers
.push(GlobalInit::I64Const(self.initial_limit as i64));
module_info.exports.insert(
"wasmer_metering_remaining_points".to_string(),
ExportIndex::Global(remaining_points_global_index),
);
let points_exhausted_global_index = module_info
.globals
.push(GlobalType::new(Type::I32, Mutability::Var));
module_info
.global_initializers
.push(GlobalInit::I32Const(0));
module_info.exports.insert(
"wasmer_metering_points_exhausted".to_string(),
ExportIndex::Global(points_exhausted_global_index),
);
*global_indexes = Some(MeteringGlobalIndexes(
remaining_points_global_index,
points_exhausted_global_index,
))
}
}
impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for FunctionMetering<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FunctionMetering")
.field("cost_function", &"<function>")
.field("global_indexes", &self.global_indexes)
.finish()
}
}
impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
for FunctionMetering<F>
{
fn feed<'a>(
&mut self,
operator: Operator<'a>,
state: &mut MiddlewareReaderState<'a>,
) -> Result<(), MiddlewareError> {
self.accumulated_cost += (self.cost_function)(&operator);
match operator {
Operator::Loop { .. }
| Operator::End
| Operator::Else
| Operator::Br { .. }
| Operator::BrTable { .. }
| Operator::BrIf { .. }
| Operator::Call { .. }
| Operator::CallIndirect { .. }
| Operator::Return
=> {
if self.accumulated_cost > 0 {
state.extend(&[
Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() },
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64LtU,
Operator::If { ty: WpTypeOrFuncType::Type(WpType::EmptyBlockType) },
Operator::I32Const { value: 1 },
Operator::GlobalSet { global_index: self.global_indexes.points_exhausted().as_u32() },
Operator::Unreachable,
Operator::End,
Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() },
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64Sub,
Operator::GlobalSet { global_index: self.global_indexes.remaining_points().as_u32() },
]);
self.accumulated_cost = 0;
}
}
_ => {}
}
state.push_operator(operator);
Ok(())
}
}
pub fn get_remaining_points(instance: &Instance) -> MeteringPoints {
let exhausted: i32 = instance
.exports
.get_global("wasmer_metering_points_exhausted")
.expect("Can't get `wasmer_metering_points_exhausted` from Instance")
.get()
.try_into()
.expect("`wasmer_metering_points_exhausted` from Instance has wrong type");
if exhausted > 0 {
return MeteringPoints::Exhausted;
}
let points = instance
.exports
.get_global("wasmer_metering_remaining_points")
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
.get()
.try_into()
.expect("`wasmer_metering_remaining_points` from Instance has wrong type");
MeteringPoints::Remaining(points)
}
pub fn set_remaining_points(instance: &Instance, points: u64) {
instance
.exports
.get_global("wasmer_metering_remaining_points")
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
.set(points.into())
.expect("Can't set `wasmer_metering_remaining_points` in Instance");
instance
.exports
.get_global("wasmer_metering_points_exhausted")
.expect("Can't get `wasmer_metering_points_exhausted` from Instance")
.set(0i32.into())
.expect("Can't set `wasmer_metering_points_exhausted` in Instance");
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use wasmer::{imports, wat2wasm, CompilerConfig, Cranelift, Module, Store, JIT};
fn cost_function(operator: &Operator) -> u64 {
match operator {
Operator::LocalGet { .. } | Operator::I32Const { .. } => 1,
Operator::I32Add { .. } => 2,
_ => 0,
}
}
fn bytecode() -> Vec<u8> {
wat2wasm(
br#"
(module
(type $add_t (func (param i32) (result i32)))
(func $add_one_f (type $add_t) (param $value i32) (result i32)
local.get $value
i32.const 1
i32.add)
(export "add_one" (func $add_one_f)))
"#,
)
.unwrap()
.into()
}
#[test]
fn get_remaining_points_works() {
let metering = Arc::new(Metering::new(10, cost_function));
let mut compiler_config = Cranelift::default();
compiler_config.push_middleware(metering.clone());
let store = Store::new(&JIT::new(compiler_config).engine());
let module = Module::new(&store, bytecode()).unwrap();
let instance = Instance::new(&module, &imports! {}).unwrap();
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(10)
);
let add_one = instance
.exports
.get_function("add_one")
.unwrap()
.native::<i32, i32>()
.unwrap();
add_one.call(1).unwrap();
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(6)
);
add_one.call(1).unwrap();
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(2)
);
assert!(add_one.call(1).is_err());
assert_eq!(get_remaining_points(&instance), MeteringPoints::Exhausted);
}
#[test]
fn set_remaining_points_works() {
let metering = Arc::new(Metering::new(10, cost_function));
let mut compiler_config = Cranelift::default();
compiler_config.push_middleware(metering.clone());
let store = Store::new(&JIT::new(compiler_config).engine());
let module = Module::new(&store, bytecode()).unwrap();
let instance = Instance::new(&module, &imports! {}).unwrap();
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(10)
);
let add_one = instance
.exports
.get_function("add_one")
.unwrap()
.native::<i32, i32>()
.unwrap();
set_remaining_points(&instance, 12);
add_one.call(1).unwrap();
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(8)
);
add_one.call(1).unwrap();
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(4)
);
add_one.call(1).unwrap();
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(0)
);
assert!(add_one.call(1).is_err());
assert_eq!(get_remaining_points(&instance), MeteringPoints::Exhausted);
set_remaining_points(&instance, 4);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(4)
);
}
}