use anyhow::{anyhow, bail, Error};
use std::cmp;
use std::collections::{BTreeMap, HashMap, HashSet};
use std::mem;
use walrus::ir::*;
use walrus::{ElementId, ExportId, ImportId, InstrLocId, TypeId};
use walrus::{FunctionId, GlobalId, InitExpr, Module, TableId, ValType};
const DEFAULT_MIN: u32 = 32;
#[derive(Default)]
pub struct Context {
imports: HashMap<ImportId, Function>,
exports: HashMap<ExportId, Function>,
new_elements: Vec<(u32, Function)>,
new_element_offset: u32,
elements: BTreeMap<u32, ElementId>,
table: Option<TableId>,
}
pub struct Meta {
pub table: TableId,
pub alloc: Option<FunctionId>,
pub drop_slice: Option<FunctionId>,
pub live_count: Option<FunctionId>,
}
struct Transform<'a> {
cx: &'a mut Context,
intrinsic_map: HashMap<FunctionId, Intrinsic>,
import_map: HashMap<FunctionId, FunctionId>,
shims: HashSet<FunctionId>,
table: TableId,
clone_ref: Option<FunctionId>,
heap_alloc: Option<FunctionId>,
heap_dealloc: Option<FunctionId>,
stack_pointer: GlobalId,
}
struct Function {
args: HashMap<usize, bool>,
ret_externref: bool,
}
enum Intrinsic {
TableGrow,
TableSetNull,
DropRef,
CloneRef,
}
impl Context {
pub fn prepare(&mut self, module: &mut Module) -> Result<(), Error> {
if let Some(t) = module.tables.main_function_table()? {
let t = module.tables.get(t);
for id in t.elem_segments.iter() {
let elem = module.elements.get(*id);
let offset = match &elem.kind {
walrus::ElementKind::Active { offset, .. } => offset,
_ => continue,
};
let offset = match offset {
walrus::InitExpr::Value(Value::I32(n)) => *n as u32,
other => bail!("invalid offset for segment of function table {:?}", other),
};
let max = offset + elem.members.len() as u32;
self.new_element_offset = cmp::max(self.new_element_offset, max);
self.elements.insert(offset, *id);
}
}
self.table = Some(
module
.tables
.add_local(DEFAULT_MIN, None, ValType::Externref),
);
Ok(())
}
pub fn import_xform(
&mut self,
id: ImportId,
externref: &[(usize, bool)],
ret_externref: bool,
) -> &mut Self {
if let Some(f) = self.function(externref, ret_externref) {
self.imports.insert(id, f);
}
self
}
pub fn export_xform(
&mut self,
id: ExportId,
externref: &[(usize, bool)],
ret_externref: bool,
) -> &mut Self {
if let Some(f) = self.function(externref, ret_externref) {
self.exports.insert(id, f);
}
self
}
pub fn table_element_xform(
&mut self,
idx: u32,
externref: &[(usize, bool)],
ret_externref: bool,
) -> Option<u32> {
self.function(externref, ret_externref).map(|f| {
self.new_elements.push((idx, f));
self.new_elements.len() as u32 + self.new_element_offset - 1
})
}
fn function(&self, externref: &[(usize, bool)], ret_externref: bool) -> Option<Function> {
if !ret_externref && externref.len() == 0 {
return None;
}
Some(Function {
args: externref.iter().cloned().collect(),
ret_externref,
})
}
pub fn run(&mut self, module: &mut Module) -> Result<Meta, Error> {
let table = self.table.unwrap();
let init = InitExpr::Value(Value::I32(DEFAULT_MIN as i32));
let stack_pointer = module.globals.add_local(ValType::I32, true, init);
let mut heap_alloc = None;
let mut heap_dealloc = None;
let mut drop_slice = None;
let mut live_count = None;
let mut to_delete = Vec::new();
for export in module.exports.iter() {
let f = match export.item {
walrus::ExportItem::Function(f) => f,
_ => continue,
};
match export.name.as_str() {
"__externref_table_alloc" => heap_alloc = Some(f),
"__externref_table_dealloc" => heap_dealloc = Some(f),
"__externref_drop_slice" => drop_slice = Some(f),
"__externref_heap_live_count" => live_count = Some(f),
_ => continue,
}
to_delete.push(export.id());
}
for id in to_delete {
module.exports.delete(id);
}
let mut clone_ref = None;
if let Some(heap_alloc) = heap_alloc {
let mut builder =
walrus::FunctionBuilder::new(&mut module.types, &[ValType::I32], &[ValType::I32]);
let arg = module.locals.add(ValType::I32);
let local = module.locals.add(ValType::I32);
let mut body = builder.func_body();
body.call(heap_alloc)
.local_tee(local)
.local_get(arg)
.table_get(table)
.table_set(table)
.local_get(local);
let func = builder.finish(vec![arg], &mut module.funcs);
let name = "__wbindgen_object_clone_ref".to_string();
module.funcs.get_mut(func).name = Some(name);
clone_ref = Some(func);
}
Transform {
cx: self,
intrinsic_map: HashMap::new(),
import_map: HashMap::new(),
shims: HashSet::new(),
table,
clone_ref,
heap_alloc,
heap_dealloc,
stack_pointer,
}
.run(module)?;
Ok(Meta {
table,
alloc: heap_alloc,
drop_slice,
live_count,
})
}
}
impl Transform<'_> {
fn run(&mut self, module: &mut Module) -> Result<(), Error> {
self.find_intrinsics(module)?;
self.process_imports(module)?;
assert!(self.cx.imports.is_empty());
self.process_exports(module)?;
assert!(self.cx.exports.is_empty());
self.process_elements(module)?;
assert!(self.cx.new_elements.is_empty());
if self.shims.is_empty() {
return Ok(());
}
self.rewrite_calls(module)?;
Ok(())
}
fn find_intrinsics(&mut self, module: &mut Module) -> Result<(), Error> {
for import in module.imports.iter_mut() {
let f = match import.kind {
walrus::ImportKind::Function(f) => f,
_ => continue,
};
if import.module == "__wbindgen_externref_xform__" {
match import.name.as_str() {
"__wbindgen_externref_table_grow" => {
self.intrinsic_map.insert(f, Intrinsic::TableGrow);
}
"__wbindgen_externref_table_set_null" => {
self.intrinsic_map.insert(f, Intrinsic::TableSetNull);
}
n => bail!("unknown intrinsic: {}", n),
}
} else if import.module == "__wbindgen_placeholder__" {
match import.name.as_str() {
"__wbindgen_object_drop_ref" => {
self.intrinsic_map.insert(f, Intrinsic::DropRef);
}
"__wbindgen_object_clone_ref" => {
self.intrinsic_map.insert(f, Intrinsic::CloneRef);
}
_ => continue,
}
} else {
continue;
}
import.name = format!("{}_unused", import.name);
}
Ok(())
}
fn heap_alloc(&self) -> Result<FunctionId, Error> {
self.heap_alloc.ok_or_else(|| {
anyhow!("failed to find the `__wbindgen_externref_table_alloc` function")
})
}
fn clone_ref(&self) -> Result<FunctionId, Error> {
self.clone_ref
.ok_or_else(|| anyhow!("failed to find intrinsics to enable `clone_ref` function"))
}
fn heap_dealloc(&self) -> Result<FunctionId, Error> {
self.heap_dealloc.ok_or_else(|| {
anyhow!("failed to find the `__wbindgen_externref_table_dealloc` function")
})
}
fn process_imports(&mut self, module: &mut Module) -> Result<(), Error> {
for import in module.imports.iter_mut() {
let f = match import.kind {
walrus::ImportKind::Function(f) => f,
_ => continue,
};
let func = match self.cx.imports.remove(&import.id()) {
Some(s) => s,
None => continue,
};
let (shim, externref_ty) = self.append_shim(
f,
&import.name,
func,
&mut module.types,
&mut module.funcs,
&mut module.locals,
)?;
self.import_map.insert(f, shim);
match &mut module.funcs.get_mut(f).kind {
walrus::FunctionKind::Import(f) => f.ty = externref_ty,
_ => unreachable!(),
}
}
Ok(())
}
fn process_exports(&mut self, module: &mut Module) -> Result<(), Error> {
for export in module.exports.iter_mut() {
let f = match export.item {
walrus::ExportItem::Function(f) => f,
_ => continue,
};
let function = match self.cx.exports.remove(&export.id()) {
Some(s) => s,
None => continue,
};
let (shim, _externref_ty) = self.append_shim(
f,
&export.name,
function,
&mut module.types,
&mut module.funcs,
&mut module.locals,
)?;
export.item = shim.into();
}
Ok(())
}
fn process_elements(&mut self, module: &mut Module) -> Result<(), Error> {
let table = match module.tables.main_function_table()? {
Some(t) => t,
None => return Ok(()),
};
let table = module.tables.get_mut(table);
let mut new_segment = Vec::new();
for (idx, function) in mem::replace(&mut self.cx.new_elements, Vec::new()) {
let (&offset, &orig_element) = self
.cx
.elements
.range(..=idx)
.next_back()
.ok_or(anyhow!("failed to find segment defining index {}", idx))?;
let target = module.elements.get(orig_element).members[(idx - offset) as usize].ok_or(
anyhow!("function index {} not present in element segment", idx),
)?;
let (shim, _externref_ty) = self.append_shim(
target,
&format!("closure{}", idx),
function,
&mut module.types,
&mut module.funcs,
&mut module.locals,
)?;
new_segment.push(Some(shim));
}
let new_max = self.cx.new_element_offset + new_segment.len() as u32;
table.initial = cmp::max(table.initial, new_max);
if let Some(max) = table.maximum {
table.maximum = Some(cmp::max(max, new_max));
}
let kind = walrus::ElementKind::Active {
table: table.id(),
offset: InitExpr::Value(Value::I32(self.cx.new_element_offset as i32)),
};
let segment = module.elements.add(kind, ValType::Funcref, new_segment);
table.elem_segments.insert(segment);
Ok(())
}
fn append_shim(
&mut self,
shim_target: FunctionId,
name: &str,
mut func: Function,
types: &mut walrus::ModuleTypes,
funcs: &mut walrus::ModuleFunctions,
locals: &mut walrus::ModuleLocals,
) -> Result<(FunctionId, TypeId), Error> {
let target = funcs.get_mut(shim_target);
let (is_export, ty) = match &target.kind {
walrus::FunctionKind::Import(f) => (false, f.ty),
walrus::FunctionKind::Local(f) => (true, f.ty()),
_ => unreachable!(),
};
let target_ty = types.get(ty);
let target_ty_params = target_ty.params().to_vec();
let target_ty_results = target_ty.results().to_vec();
enum Convert {
None,
Store { owned: bool },
Load { owned: bool },
}
let mut param_tys = Vec::new();
let mut param_convert = Vec::new();
let mut externref_stack = 0;
for (i, old_ty) in target_ty.params().iter().enumerate() {
let is_owned = func.args.remove(&i);
let new_ty = is_owned
.map(|_which| ValType::Externref)
.unwrap_or(old_ty.clone());
param_tys.push(new_ty.clone());
if new_ty == *old_ty {
param_convert.push(Convert::None);
} else if is_export {
param_convert.push(Convert::Store {
owned: is_owned.unwrap(),
});
if is_owned == Some(false) {
externref_stack += 1;
}
} else {
param_convert.push(Convert::Load {
owned: is_owned.unwrap(),
});
}
}
let new_ret = if func.ret_externref {
assert_eq!(target_ty.results(), &[ValType::I32]);
vec![ValType::Externref]
} else {
target_ty.results().to_vec()
};
let externref_ty = types.add(¶m_tys, &new_ret);
let shim_ty = if is_export { externref_ty } else { ty };
let mut builder = walrus::FunctionBuilder::new(
types,
if is_export {
¶m_tys
} else {
&target_ty_params
},
if is_export {
&new_ret
} else {
&target_ty_results
},
);
let mut body = builder.func_body();
let params = types
.get(shim_ty)
.params()
.iter()
.cloned()
.map(|ty| locals.add(ty))
.collect::<Vec<_>>();
let fp = locals.add(ValType::I32);
let scratch_i32 = locals.add(ValType::I32);
let scratch_externref = locals.add(ValType::Externref);
if externref_stack > 0 {
body.global_get(self.stack_pointer)
.const_(Value::I32(externref_stack))
.binop(BinaryOp::I32Sub)
.local_tee(fp)
.global_set(self.stack_pointer);
}
let mut next_stack_offset = 0;
for (i, convert) in param_convert.iter().enumerate() {
match *convert {
Convert::None => {
body.local_get(params[i]);
}
Convert::Load { owned: true } => {
body.local_get(params[i])
.table_get(self.table)
.local_get(params[i])
.call(self.heap_dealloc()?);
}
Convert::Load { owned: false } => {
body.local_get(params[i]).table_get(self.table);
}
Convert::Store { owned: true } => {
body.call(self.heap_alloc()?)
.local_tee(scratch_i32)
.local_get(params[i])
.table_set(self.table)
.local_get(scratch_i32);
}
Convert::Store { owned: false } => {
body.local_get(fp);
let idx_local = if next_stack_offset == 0 {
fp
} else {
body.i32_const(next_stack_offset)
.binop(BinaryOp::I32Add)
.local_tee(scratch_i32);
scratch_i32
};
next_stack_offset += 1;
body.local_get(params[i])
.table_set(self.table)
.local_get(idx_local);
}
}
}
body.call(shim_target);
if func.ret_externref {
if is_export {
body.local_tee(scratch_i32)
.table_get(self.table)
.local_get(scratch_i32)
.call(self.heap_dealloc()?);
} else {
body.local_set(scratch_externref)
.call(self.heap_alloc()?)
.local_tee(scratch_i32)
.local_get(scratch_externref)
.table_set(self.table)
.local_get(scratch_i32);
}
}
if externref_stack > 0 {
for i in 0..externref_stack {
body.local_get(fp);
if i > 0 {
body.i32_const(i).binop(BinaryOp::I32Add);
}
body.ref_null(ValType::Externref);
body.table_set(self.table);
}
body.local_get(fp)
.i32_const(externref_stack)
.binop(BinaryOp::I32Add)
.global_set(self.stack_pointer);
}
let id = builder.finish(params, funcs);
let name = format!("{} externref shim", name);
funcs.get_mut(id).name = Some(name);
self.shims.insert(id);
Ok((id, externref_ty))
}
fn rewrite_calls(&mut self, module: &mut Module) -> Result<(), Error> {
for (id, func) in module.funcs.iter_local_mut() {
if self.shims.contains(&id) {
continue;
}
let entry = func.entry_block();
let scratch_i32 = module.locals.add(ValType::I32);
dfs_pre_order_mut(
&mut Rewrite {
clone_ref: self.clone_ref()?,
heap_dealloc: self.heap_dealloc()?,
xform: self,
scratch_i32,
},
func,
entry,
);
}
return Ok(());
struct Rewrite<'a, 'b> {
xform: &'a Transform<'b>,
clone_ref: FunctionId,
heap_dealloc: FunctionId,
scratch_i32: LocalId,
}
impl VisitorMut for Rewrite<'_, '_> {
fn start_instr_seq_mut(&mut self, seq: &mut InstrSeq) {
for i in (0..seq.instrs.len()).rev() {
let call = match &mut seq.instrs[i].0 {
Instr::Call(call) => call,
_ => continue,
};
let intrinsic = match self.xform.intrinsic_map.get(&call.func) {
Some(f) => f,
None => {
if let Some(f) = self.xform.import_map.get(&call.func) {
call.func = *f;
}
continue;
}
};
let ty = ValType::Externref;
match intrinsic {
Intrinsic::TableGrow => {
seq.instrs[i].0 = TableGrow {
table: self.xform.table,
}
.into();
let loc = seq.instrs[i].1;
let local = self.scratch_i32;
seq.instrs.insert(i, (LocalGet { local }.into(), loc));
seq.instrs.insert(i, (RefNull { ty }.into(), loc));
seq.instrs.insert(i, (LocalSet { local }.into(), loc));
}
Intrinsic::TableSetNull => {
seq.instrs[i].0 = TableSet {
table: self.xform.table,
}
.into();
seq.instrs
.insert(i, (RefNull { ty }.into(), InstrLocId::default()));
}
Intrinsic::DropRef => call.func = self.heap_dealloc,
Intrinsic::CloneRef => call.func = self.clone_ref,
}
}
}
}
}
}