use std::cmp;
use std::collections::HashMap;
use std::env;
use std::mem;
use anyhow::{anyhow, bail, Error};
use walrus::ir::Value;
use walrus::{DataId, FunctionId, InitExpr, ValType};
use walrus::{ExportItem, GlobalId, GlobalKind, ImportKind, MemoryId, Module};
use wasm_bindgen_wasm_conventions as wasm_conventions;
const PAGE_SIZE: u32 = 1 << 16;
pub struct Config {
maximum_memory: u32,
thread_stack_size: u32,
enabled: bool,
}
impl Config {
pub fn new() -> Config {
Config {
maximum_memory: 1 << 30,
thread_stack_size: 1 << 20,
enabled: env::var("WASM_BINDGEN_THREADS").is_ok(),
}
}
pub fn is_enabled(&self, module: &Module) -> bool {
if self.enabled {
return true;
}
match wasm_conventions::get_memory(module) {
Ok(memory) => module.memories.get(memory).shared,
Err(_) => false,
}
}
pub fn maximum_memory(&mut self, max: u32) -> &mut Config {
self.maximum_memory = max;
self
}
pub fn thread_stack_size(&mut self, size: u32) -> &mut Config {
self.thread_stack_size = size;
self
}
pub fn run(&self, module: &mut Module) -> Result<(), Error> {
if !self.is_enabled(module) {
return Ok(());
}
let memory = wasm_conventions::get_memory(module)?;
let stack_pointer = wasm_conventions::get_shadow_stack_pointer(module)
.ok_or_else(|| anyhow!("failed to find shadow stack pointer"))?;
let addr = allocate_static_data(module, memory, 4, 4)?;
let zero = InitExpr::Value(Value::I32(0));
let globals = Globals {
thread_id: module.globals.add_local(ValType::I32, true, zero),
thread_tcb: module.globals.add_local(ValType::I32, true, zero),
};
let mem = module.memories.get_mut(memory);
let memory_init = if mem.shared {
let prev_max = mem.maximum.unwrap();
assert!(mem.import.is_some());
mem.maximum = Some(cmp::max(self.maximum_memory / PAGE_SIZE, prev_max));
assert!(mem.data_segments.is_empty());
InitMemory::Call {
wasm_init_memory: delete_synthetic_func(module, "__wasm_init_memory")?,
wasm_init_tls: delete_synthetic_func(module, "__wasm_init_tls")?,
tls_size: delete_synthetic_global(module, "__tls_size")?,
}
} else {
update_memory(module, memory, self.maximum_memory)?;
InitMemory::Segments(switch_data_segments_to_passive(module, memory)?)
};
inject_start(
module,
memory_init,
&globals,
addr,
stack_pointer,
self.thread_stack_size,
memory,
)?;
implement_thread_intrinsics(module, &globals)?;
Ok(())
}
}
fn delete_synthetic_func(module: &mut Module, name: &str) -> Result<FunctionId, Error> {
match delete_synthetic_export(module, name)? {
walrus::ExportItem::Function(f) => Ok(f),
_ => bail!("`{}` must be a function", name),
}
}
fn delete_synthetic_global(module: &mut Module, name: &str) -> Result<u32, Error> {
let id = match delete_synthetic_export(module, name)? {
walrus::ExportItem::Global(g) => g,
_ => bail!("`{}` must be a global", name),
};
let g = match module.globals.get(id).kind {
walrus::GlobalKind::Local(g) => g,
walrus::GlobalKind::Import(_) => bail!("`{}` must not be an imported global", name),
};
match g {
InitExpr::Value(Value::I32(v)) => Ok(v as u32),
_ => bail!("`{}` was not an `i32` constant", name),
}
}
fn delete_synthetic_export(module: &mut Module, name: &str) -> Result<ExportItem, Error> {
let item = module
.exports
.iter()
.find(|e| e.name == name)
.ok_or_else(|| anyhow!("failed to find `{}`", name))?;
let ret = item.item;
let id = item.id();
module.exports.delete(id);
Ok(ret)
}
struct PassiveSegment {
id: DataId,
offset: InitExpr,
len: u32,
}
fn switch_data_segments_to_passive(
module: &mut Module,
memory: MemoryId,
) -> Result<Vec<PassiveSegment>, Error> {
let mut ret = Vec::new();
let memory = module.memories.get_mut(memory);
for id in mem::replace(&mut memory.data_segments, Default::default()) {
let data = module.data.get_mut(id);
let kind = match &data.kind {
walrus::DataKind::Active(kind) => kind,
walrus::DataKind::Passive => continue,
};
let offset = match kind.location {
walrus::ActiveDataLocation::Absolute(n) => {
walrus::InitExpr::Value(walrus::ir::Value::I32(n as i32))
}
walrus::ActiveDataLocation::Relative(global) => walrus::InitExpr::Global(global),
};
data.kind = walrus::DataKind::Passive;
ret.push(PassiveSegment {
id,
offset,
len: data.value.len() as u32,
});
}
Ok(ret)
}
fn update_memory(module: &mut Module, memory: MemoryId, max: u32) -> Result<MemoryId, Error> {
assert!(max % PAGE_SIZE == 0);
let memory = module.memories.get_mut(memory);
if memory.import.is_none() {
let id = module
.imports
.add("env", "memory", ImportKind::Memory(memory.id()));
memory.import = Some(id);
}
if !memory.shared {
memory.shared = true;
if memory.maximum.is_none() {
memory.maximum = Some(max / PAGE_SIZE);
}
}
Ok(memory.id())
}
struct Globals {
thread_id: GlobalId,
thread_tcb: GlobalId,
}
fn allocate_static_data(
module: &mut Module,
memory: MemoryId,
size: u32,
align: u32,
) -> Result<u32, Error> {
let heap_base = module
.exports
.iter()
.filter(|e| e.name == "__heap_base")
.filter_map(|e| match e.item {
ExportItem::Global(id) => Some(id),
_ => None,
})
.next();
let heap_base = match heap_base {
Some(idx) => idx,
None => bail!("failed to find `__heap_base` for injecting thread id"),
};
let (address, add_a_page) = {
let global = module.globals.get_mut(heap_base);
if global.ty != ValType::I32 {
bail!("the `__heap_base` global doesn't have the type `i32`");
}
if global.mutable {
bail!("the `__heap_base` global is unexpectedly mutable");
}
let offset = match &mut global.kind {
GlobalKind::Local(InitExpr::Value(Value::I32(n))) => n,
_ => bail!("`__heap_base` not a locally defined `i32`"),
};
let address = (*offset as u32 + (align - 1)) & !(align - 1);
let add_a_page = (address + size) / PAGE_SIZE != address / PAGE_SIZE;
*offset = (address + size) as i32;
(address, add_a_page)
};
if add_a_page {
let memory = module.memories.get_mut(memory);
memory.initial += 1;
memory.maximum = memory.maximum.map(|m| cmp::max(m, memory.initial));
}
Ok(address)
}
enum InitMemory {
Segments(Vec<PassiveSegment>),
Call {
wasm_init_memory: walrus::FunctionId,
wasm_init_tls: walrus::FunctionId,
tls_size: u32,
},
}
fn inject_start(
module: &mut Module,
memory_init: InitMemory,
globals: &Globals,
addr: u32,
stack_pointer: GlobalId,
stack_size: u32,
memory: MemoryId,
) -> Result<(), Error> {
use walrus::ir::*;
assert!(stack_size % PAGE_SIZE == 0);
let mut builder = walrus::FunctionBuilder::new(&mut module.types, &[], &[]);
let local = module.locals.add(ValType::I32);
let mut body = builder.func_body();
body.i32_const(addr as i32)
.i32_const(1)
.atomic_rmw(
memory,
AtomicOp::Add,
AtomicWidth::I32,
MemArg {
align: 4,
offset: 0,
},
)
.local_tee(local)
.global_set(globals.thread_id);
body.local_get(local);
body.if_else(
None,
|body| {
body.i32_const((stack_size / PAGE_SIZE) as i32)
.memory_grow(memory)
.local_set(local);
body.block(None, |body| {
let target = body.id();
body.local_get(local)
.i32_const(-1)
.binop(BinaryOp::I32Ne)
.br_if(target)
.unreachable();
});
body.local_get(local)
.i32_const(PAGE_SIZE as i32)
.binop(BinaryOp::I32Mul)
.i32_const(stack_size as i32)
.binop(BinaryOp::I32Add)
.global_set(stack_pointer);
},
|body| {
match &memory_init {
InitMemory::Segments(segments) => {
for segment in segments {
match segment.offset {
InitExpr::Global(id) => body.global_get(id),
InitExpr::Value(v) => body.const_(v),
};
body.i32_const(0)
.i32_const(segment.len as i32)
.memory_init(memory, segment.id)
.data_drop(segment.id);
}
}
InitMemory::Call {
wasm_init_memory, ..
} => {
body.call(*wasm_init_memory);
}
}
},
);
if let InitMemory::Call {
wasm_init_tls,
tls_size,
..
} = memory_init
{
let malloc = find_wbindgen_malloc(module)?;
body.i32_const(tls_size as i32)
.call(malloc)
.call(wasm_init_tls);
}
if let Some(id) = module.start.take() {
body.call(id);
}
let id = builder.finish(Vec::new(), &mut module.funcs);
module.start = Some(id);
Ok(())
}
fn find_wbindgen_malloc(module: &Module) -> Result<FunctionId, Error> {
let e = module
.exports
.iter()
.find(|e| e.name == "__wbindgen_malloc")
.ok_or_else(|| anyhow!("failed to find `__wbindgen_malloc`"))?;
match e.item {
walrus::ExportItem::Function(f) => Ok(f),
_ => bail!("`__wbindgen_malloc` wasn't a funtion"),
}
}
fn implement_thread_intrinsics(module: &mut Module, globals: &Globals) -> Result<(), Error> {
use walrus::ir::*;
let mut map = HashMap::new();
enum Intrinsic {
GetThreadId,
GetTcb,
SetTcb,
}
let imports = module
.imports
.iter()
.filter(|i| i.module == "__wbindgen_thread_xform__");
for import in imports {
let function = match import.kind {
ImportKind::Function(id) => module.funcs.get(id),
_ => bail!("non-function import from special module"),
};
let ty = module.types.get(function.ty());
match &import.name[..] {
"__wbindgen_current_id" => {
if !ty.params().is_empty() || ty.results() != &[ValType::I32] {
bail!("`__wbindgen_current_id` intrinsic has the wrong signature");
}
map.insert(function.id(), Intrinsic::GetThreadId);
}
"__wbindgen_tcb_get" => {
if !ty.params().is_empty() || ty.results() != &[ValType::I32] {
bail!("`__wbindgen_tcb_get` intrinsic has the wrong signature");
}
map.insert(function.id(), Intrinsic::GetTcb);
}
"__wbindgen_tcb_set" => {
if !ty.results().is_empty() || ty.params() != &[ValType::I32] {
bail!("`__wbindgen_tcb_set` intrinsic has the wrong signature");
}
map.insert(function.id(), Intrinsic::SetTcb);
}
other => bail!("unknown thread intrinsic: {}", other),
}
}
struct Visitor<'a> {
map: &'a HashMap<FunctionId, Intrinsic>,
globals: &'a Globals,
}
module.funcs.iter_local_mut().for_each(|(_id, func)| {
let entry = func.entry_block();
dfs_pre_order_mut(&mut Visitor { map: &map, globals }, func, entry);
});
impl VisitorMut for Visitor<'_> {
fn visit_instr_mut(&mut self, instr: &mut Instr, _loc: &mut InstrLocId) {
let call = match instr {
Instr::Call(e) => e,
_ => return,
};
match self.map.get(&call.func) {
Some(Intrinsic::GetThreadId) => {
*instr = GlobalGet {
global: self.globals.thread_id,
}
.into();
}
Some(Intrinsic::GetTcb) => {
*instr = GlobalGet {
global: self.globals.thread_tcb,
}
.into();
}
Some(Intrinsic::SetTcb) => {
*instr = GlobalSet {
global: self.globals.thread_tcb,
}
.into();
}
None => {}
}
}
}
Ok(())
}