use std::cmp;
use std::collections::HashMap;
use std::mem;
use failure::{bail, format_err, Error};
use walrus::ir::Value;
use walrus::{DataId, FunctionId, InitExpr, LocalFunction, ValType};
use walrus::{ExportItem, GlobalId, GlobalKind, ImportKind, MemoryId, Module};
const PAGE_SIZE: u32 = 1 << 16;
pub struct Config {
maximum_memory: u32,
thread_stack_size: u32,
}
impl Config {
pub fn new() -> Config {
Config {
maximum_memory: 1 << 30,
thread_stack_size: 1 << 20,
}
}
pub fn maximum_memory(&mut self, max: u32) -> &mut Config {
self.maximum_memory = max;
self
}
pub fn thread_stack_size(&mut self, size: u32) -> &mut Config {
self.thread_stack_size = size;
self
}
pub fn run(&self, module: &mut Module) -> Result<(), Error> {
let memory = update_memory(module, self.maximum_memory)?;
let segments = switch_data_segments_to_passive(module, memory)?;
let stack_pointer = find_stack_pointer(module)?;
let zero = InitExpr::Value(Value::I32(0));
let globals = Globals {
thread_id: module.globals.add_local(ValType::I32, true, zero),
thread_tcb: module.globals.add_local(ValType::I32, true, zero),
};
let addr = inject_thread_id_counter(module, memory)?;
start_with_init_memory(
module,
&segments,
&globals,
addr,
stack_pointer,
self.thread_stack_size,
memory,
);
implement_thread_intrinsics(module, &globals)?;
Ok(())
}
}
struct PassiveSegment {
id: DataId,
offset: InitExpr,
len: u32,
}
fn switch_data_segments_to_passive(
module: &mut Module,
memory: MemoryId,
) -> Result<Vec<PassiveSegment>, Error> {
let mut ret = Vec::new();
let memory = module.memories.get_mut(memory);
let data = mem::replace(&mut memory.data, Default::default());
for (offset, value) in data.into_iter() {
let len = value.len() as u32;
let id = module.data.add(value);
ret.push(PassiveSegment { id, offset, len });
}
Ok(ret)
}
fn update_memory(module: &mut Module, max: u32) -> Result<MemoryId, Error> {
assert!(max % PAGE_SIZE == 0);
let mut memories = module.memories.iter_mut();
let memory = memories
.next()
.ok_or_else(|| format_err!("currently incompatible with no memory modules"))?;
if memories.next().is_some() {
bail!("only one memory is currently supported");
}
if memory.import.is_none() {
let id = module
.imports
.add("env", "memory", ImportKind::Memory(memory.id()));
memory.import = Some(id);
}
if !memory.shared {
memory.shared = true;
if memory.maximum.is_none() {
memory.maximum = Some(max / PAGE_SIZE);
}
}
Ok(memory.id())
}
struct Globals {
thread_id: GlobalId,
thread_tcb: GlobalId,
}
fn inject_thread_id_counter(module: &mut Module, memory: MemoryId) -> Result<u32, Error> {
let heap_base = module
.exports
.iter()
.filter(|e| e.name == "__heap_base")
.filter_map(|e| match e.item {
ExportItem::Global(id) => Some(id),
_ => None,
})
.next();
let heap_base = match heap_base {
Some(idx) => idx,
None => bail!("failed to find `__heap_base` for injecting thread id"),
};
let (address, add_a_page) = {
let global = module.globals.get_mut(heap_base);
if global.ty != ValType::I32 {
bail!("the `__heap_base` global doesn't have the type `i32`");
}
if global.mutable {
bail!("the `__heap_base` global is unexpectedly mutable");
}
let offset = match &mut global.kind {
GlobalKind::Local(InitExpr::Value(Value::I32(n))) => n,
_ => bail!("`__heap_base` not a locally defined `i32`"),
};
let address = (*offset as u32 + 3) & !3;
let add_a_page = (address + 4) / PAGE_SIZE != address / PAGE_SIZE;
*offset = (address + 4) as i32;
(address, add_a_page)
};
if add_a_page {
let memory = module.memories.get_mut(memory);
memory.initial += 1;
memory.maximum = memory.maximum.map(|m| cmp::max(m, memory.initial));
}
Ok(address)
}
fn find_stack_pointer(module: &mut Module) -> Result<Option<GlobalId>, Error> {
let candidates = module
.globals
.iter()
.filter(|g| g.ty == ValType::I32)
.filter(|g| g.mutable)
.filter(|g| match g.kind {
GlobalKind::Local(_) => true,
GlobalKind::Import(_) => false,
})
.collect::<Vec<_>>();
match candidates.len() {
0 => Ok(None),
1 => Ok(Some(candidates[0].id())),
_ => bail!("too many mutable globals to infer the stack pointer"),
}
}
fn start_with_init_memory(
module: &mut Module,
segments: &[PassiveSegment],
globals: &Globals,
addr: u32,
stack_pointer: Option<GlobalId>,
stack_size: u32,
memory: MemoryId,
) {
use walrus::ir::*;
assert!(stack_size % PAGE_SIZE == 0);
let mut builder = walrus::FunctionBuilder::new();
let mut exprs = Vec::new();
let local = module.locals.add(ValType::I32);
let addr = builder.i32_const(addr as i32);
let one = builder.i32_const(1);
let thread_id = builder.atomic_rmw(
memory,
AtomicOp::Add,
AtomicWidth::I32,
MemArg {
align: 4,
offset: 0,
},
addr,
one,
);
let thread_id = builder.local_tee(local, thread_id);
let global_set = builder.global_set(globals.thread_id, thread_id);
exprs.push(global_set);
let thread_id_is_nonzero = builder.local_get(local);
let mut block = builder.if_else_block(Box::new([]), Box::new([]));
if let Some(stack_pointer) = stack_pointer {
let grow_amount = block.i32_const((stack_size / PAGE_SIZE) as i32);
let memory_growth = block.memory_grow(memory, grow_amount);
let set_local = block.local_set(local, memory_growth);
block.expr(set_local);
let if_negative_trap = {
let mut block = block.block(Box::new([]), Box::new([]));
let lhs = block.local_get(local);
let rhs = block.i32_const(-1);
let condition = block.binop(BinaryOp::I32Ne, lhs, rhs);
let id = block.id();
let br_if = block.br_if(condition, id, Box::new([]));
block.expr(br_if);
let unreachable = block.unreachable();
block.expr(unreachable);
id
};
block.expr(if_negative_trap.into());
let get_local = block.local_get(local);
let page_size = block.i32_const(PAGE_SIZE as i32);
let sp_base = block.binop(BinaryOp::I32Mul, get_local, page_size);
let stack_size = block.i32_const(stack_size as i32);
let sp = block.binop(BinaryOp::I32Add, sp_base, stack_size);
let set_stack_pointer = block.global_set(stack_pointer, sp);
block.expr(set_stack_pointer);
}
let if_nonzero_block = block.id();
drop(block);
let if_zero_block = {
let mut block = builder.if_else_block(Box::new([]), Box::new([]));
for segment in segments {
let zero = block.i32_const(0);
let offset = match segment.offset {
InitExpr::Global(id) => block.global_get(id),
InitExpr::Value(v) => block.const_(v),
};
let len = block.i32_const(segment.len as i32);
let init = block.memory_init(memory, segment.id, offset, zero, len);
block.expr(init);
}
block.id()
};
let block = builder.if_else(thread_id_is_nonzero, if_nonzero_block, if_zero_block);
exprs.push(block);
for segment in segments {
exprs.push(builder.data_drop(segment.id));
}
if let Some(id) = module.start.take() {
exprs.push(builder.call(id, Box::new([])));
}
let ty = module.types.add(&[], &[]);
let id = builder.finish(ty, Vec::new(), exprs, module);
module.start = Some(id);
}
fn implement_thread_intrinsics(module: &mut Module, globals: &Globals) -> Result<(), Error> {
use walrus::ir::*;
let mut map = HashMap::new();
enum Intrinsic {
GetThreadId,
GetTcb,
SetTcb,
}
let imports = module
.imports
.iter()
.filter(|i| i.module == "__wbindgen_thread_xform__");
for import in imports {
let function = match import.kind {
ImportKind::Function(id) => module.funcs.get(id),
_ => bail!("non-function import from special module"),
};
let ty = module.types.get(function.ty());
match &import.name[..] {
"__wbindgen_current_id" => {
if !ty.params().is_empty() || ty.results() != &[ValType::I32] {
bail!("`__wbindgen_current_id` intrinsic has the wrong signature");
}
map.insert(function.id(), Intrinsic::GetThreadId);
}
"__wbindgen_tcb_get" => {
if !ty.params().is_empty() || ty.results() != &[ValType::I32] {
bail!("`__wbindgen_tcb_get` intrinsic has the wrong signature");
}
map.insert(function.id(), Intrinsic::GetTcb);
}
"__wbindgen_tcb_set" => {
if !ty.results().is_empty() || ty.params() != &[ValType::I32] {
bail!("`__wbindgen_tcb_set` intrinsic has the wrong signature");
}
map.insert(function.id(), Intrinsic::SetTcb);
}
other => bail!("unknown thread intrinsic: {}", other),
}
}
struct Visitor<'a> {
map: &'a HashMap<FunctionId, Intrinsic>,
globals: &'a Globals,
func: &'a mut LocalFunction,
}
module.funcs.iter_local_mut().for_each(|(_id, func)| {
let mut entry = func.entry_block();
Visitor {
map: &map,
globals,
func,
}
.visit_block_id_mut(&mut entry);
});
impl VisitorMut for Visitor<'_> {
fn local_function_mut(&mut self) -> &mut LocalFunction {
self.func
}
fn visit_expr_mut(&mut self, expr: &mut Expr) {
let call = match expr {
Expr::Call(e) => e,
other => return other.visit_mut(self),
};
match self.map.get(&call.func) {
Some(Intrinsic::GetThreadId) => {
assert!(call.args.is_empty());
*expr = GlobalGet {
global: self.globals.thread_id,
}
.into();
}
Some(Intrinsic::GetTcb) => {
assert!(call.args.is_empty());
*expr = GlobalGet {
global: self.globals.thread_tcb,
}
.into();
}
Some(Intrinsic::SetTcb) => {
assert_eq!(call.args.len(), 1);
call.args[0].visit_mut(self);
*expr = GlobalSet {
global: self.globals.thread_tcb,
value: call.args[0],
}
.into();
}
None => call.visit_mut(self),
}
}
}
Ok(())
}