1use anyhow::{anyhow, bail, Error};
2use std::cmp;
3use walrus::ir::Value;
4use walrus::FunctionBuilder;
5use walrus::{
6 ir::MemArg, ConstExpr, ExportItem, FunctionId, GlobalId, GlobalKind, InstrSeqBuilder, MemoryId,
7 Module, ValType,
8};
9use wasm_bindgen_wasm_conventions as wasm_conventions;
10
11pub const PAGE_SIZE: u32 = 1 << 16;
12const DEFAULT_THREAD_STACK_SIZE: u32 = 1 << 21; const ATOMIC_MEM_ARG: MemArg = MemArg {
14 align: 4,
15 offset: 0,
16};
17
18#[derive(Clone, Copy)]
19pub struct ThreadCount(walrus::LocalId);
20
21pub fn is_enabled(module: &Module) -> bool {
23 match wasm_conventions::get_memory(module) {
30 Ok(memory) => module.memories.get(memory).shared,
31 Err(_) => false,
32 }
33}
34
35pub fn run(module: &mut Module) -> Result<Option<ThreadCount>, Error> {
56 if !is_enabled(module) {
57 return Ok(None);
58 }
59
60 let memory = wasm_conventions::get_memory(module)?;
61
62 let static_data_align = 4;
69 let static_data_pages = 1;
70 let (base, addr) = allocate_static_data(module, memory, static_data_pages, static_data_align)?;
71
72 let mem = module.memories.get(memory);
73 assert!(mem.shared);
74 assert!(mem.import.is_some());
75 assert!(mem.data_segments.is_empty());
76
77 let tls = Tls {
78 init: delete_synthetic_func(module, "__wasm_init_tls")?,
79 size: delete_synthetic_global(module, "__tls_size")?,
80 align: delete_synthetic_global(module, "__tls_align")?,
81 base: wasm_conventions::get_tls_base(module)
82 .ok_or_else(|| anyhow!("failed to find tls base"))?,
83 };
84
85 let thread_counter_addr = addr as i32;
86
87 let stack_alloc =
88 module
89 .globals
90 .add_local(ValType::I32, true, false, ConstExpr::Value(Value::I32(0)));
91
92 let temp_stack = (base + static_data_pages * PAGE_SIZE) & !(static_data_align - 1);
94
95 const _: () = assert!(DEFAULT_THREAD_STACK_SIZE % PAGE_SIZE == 0);
96
97 let stack = Stack {
98 pointer: wasm_conventions::get_stack_pointer(module)
99 .ok_or_else(|| anyhow!("failed to find stack pointer"))?,
100 temp: temp_stack as i32,
101 temp_lock: thread_counter_addr + 4,
102 alloc: stack_alloc,
103 size: module.globals.add_local(
104 ValType::I32,
105 true,
106 false,
107 ConstExpr::Value(Value::I32(DEFAULT_THREAD_STACK_SIZE as i32)),
108 ),
109 };
110
111 let _ = module.exports.add("__stack_alloc", stack.alloc);
112
113 let thread_count = inject_start(module, &tls, &stack, thread_counter_addr, memory)?;
114
115 inject_destroy(module, &tls, &stack, memory)?;
132
133 Ok(Some(thread_count))
134}
135
136impl ThreadCount {
137 pub fn wrap_start(self, builder: &mut FunctionBuilder, start: FunctionId) {
138 builder.func_body().local_get(self.0).if_else(
141 None,
142 |_| {},
143 |body| {
144 body.call(start);
145 },
146 );
147 }
148}
149
150fn delete_synthetic_func(module: &mut Module, name: &str) -> Result<FunctionId, Error> {
151 match delete_synthetic_export(module, name)? {
152 walrus::ExportItem::Function(f) => Ok(f),
153 _ => bail!("`{}` must be a function", name),
154 }
155}
156
157fn delete_synthetic_global(module: &mut Module, name: &str) -> Result<u32, Error> {
158 let id = match delete_synthetic_export(module, name)? {
159 walrus::ExportItem::Global(g) => g,
160 _ => bail!("`{}` must be a global", name),
161 };
162 let g = match module.globals.get(id).kind {
163 walrus::GlobalKind::Local(g) => g,
164 walrus::GlobalKind::Import(_) => bail!("`{}` must not be an imported global", name),
165 };
166 match g {
167 ConstExpr::Value(Value::I32(v)) => Ok(v as u32),
168 _ => bail!("`{}` was not an `i32` constant", name),
169 }
170}
171
172fn delete_synthetic_export(module: &mut Module, name: &str) -> Result<ExportItem, Error> {
173 let item = module
174 .exports
175 .iter()
176 .find(|e| e.name == name)
177 .ok_or_else(|| anyhow!("failed to find `{}`", name))?;
178 let ret = item.item;
179 let id = item.id();
180 module.exports.delete(id);
181 Ok(ret)
182}
183
184fn allocate_static_data(
188 module: &mut Module,
189 memory: MemoryId,
190 pages: u32,
191 align: u32,
192) -> Result<(u32, u32), Error> {
193 let heap_base = module
198 .exports
199 .iter()
200 .filter(|e| e.name == "__heap_base")
201 .find_map(|e| match e.item {
202 ExportItem::Global(id) => Some(id),
203 _ => None,
204 });
205 let heap_base = match heap_base {
206 Some(idx) => idx,
207 None => bail!("failed to find `__heap_base` for injecting thread id"),
208 };
209
210 let (base, address) = {
215 let global = module.globals.get_mut(heap_base);
216 if global.ty != ValType::I32 {
217 bail!("the `__heap_base` global doesn't have the type `i32`");
218 }
219 if global.mutable {
220 bail!("the `__heap_base` global is unexpectedly mutable");
221 }
222 let offset = match &mut global.kind {
223 GlobalKind::Local(ConstExpr::Value(Value::I32(n))) => n,
224 _ => bail!("`__heap_base` not a locally defined `i32`"),
225 };
226
227 let address = (*offset as u32 + (align - 1)) & !(align - 1); let base = *offset;
229
230 *offset += (pages * PAGE_SIZE) as i32;
231
232 (base, address)
233 };
234
235 let memory = module.memories.get_mut(memory);
236 memory.initial += u64::from(pages);
237 memory.maximum = memory.maximum.map(|m| cmp::max(m, memory.initial));
238
239 Ok((base as u32, address))
240}
241
242struct Tls {
243 init: walrus::FunctionId,
244 size: u32,
245 align: u32,
246 base: GlobalId,
247}
248
249struct Stack {
250 pointer: GlobalId,
252 temp: i32,
254 temp_lock: i32,
256 alloc: GlobalId,
258 size: GlobalId,
260}
261
262fn inject_start(
263 module: &mut Module,
264 tls: &Tls,
265 stack: &Stack,
266 thread_counter_addr: i32,
267 memory: MemoryId,
268) -> Result<ThreadCount, Error> {
269 use walrus::ir::*;
270
271 let local = module.locals.add(ValType::I32);
272 let thread_count = module.locals.add(ValType::I32);
273 let stack_size = module.locals.add(ValType::I32);
274
275 let malloc = find_function(module, "__wbindgen_malloc")?;
276
277 let prev_start = wasm_bindgen_wasm_conventions::get_start(module);
278 let mut builder = FunctionBuilder::new(&mut module.types, &[ValType::I32], &[]);
279
280 if let Ok(prev_start) | Err(Some(prev_start)) = prev_start {
281 builder.func_body().call(prev_start);
282 }
283
284 let mut body = builder.func_body();
285
286 body.i32_const(thread_counter_addr)
290 .i32_const(1)
291 .atomic_rmw(memory, AtomicOp::Add, AtomicWidth::I32, ATOMIC_MEM_ARG)
292 .local_tee(thread_count)
293 .if_else(
294 None,
295 |body| {
299 body.local_get(stack_size).if_else(
300 None,
301 |body| {
302 body.local_get(stack_size).global_set(stack.size);
303 },
304 |_| (),
305 );
306
307 with_temp_stack(body, memory, stack, |body| {
309 body.global_get(stack.size)
310 .i32_const(16)
311 .call(malloc)
312 .local_tee(local);
313 });
314
315 body.global_set(stack.alloc);
317
318 body.global_get(stack.alloc)
320 .global_get(stack.size)
321 .binop(BinaryOp::I32Add)
322 .global_set(stack.pointer);
323 },
324 |_| {},
327 );
328
329 body.i32_const(tls.size as i32)
331 .i32_const(tls.align as i32)
332 .call(malloc)
333 .global_set(tls.base)
334 .global_get(tls.base)
335 .call(tls.init);
336
337 let id = builder.finish(vec![stack_size], &mut module.funcs);
338 module.start = Some(id);
339
340 Ok(ThreadCount(thread_count))
341}
342
343fn inject_destroy(
344 module: &mut Module,
345 tls: &Tls,
346 stack: &Stack,
347 memory: MemoryId,
348) -> Result<(), Error> {
349 let free = find_function(module, "__wbindgen_free")?;
350
351 let mut builder = FunctionBuilder::new(
352 &mut module.types,
353 &[ValType::I32, ValType::I32, ValType::I32],
354 &[],
355 );
356
357 builder.name("__wbindgen_thread_destroy".into());
358
359 let mut body = builder.func_body();
360
361 let tls_base = module.locals.add(ValType::I32);
364 let stack_alloc = module.locals.add(ValType::I32);
365 let stack_size = module.locals.add(ValType::I32);
366
367 body.local_get(tls_base).if_else(
372 None,
373 |body| {
374 body.local_get(tls_base)
375 .i32_const(tls.size as i32)
376 .i32_const(tls.align as i32)
377 .call(free);
378 },
379 |body| {
380 body.global_get(tls.base)
381 .i32_const(tls.size as i32)
382 .i32_const(tls.align as i32)
383 .call(free);
384
385 body.i32_const(i32::MIN).global_set(tls.base);
387 },
388 );
389
390 body.local_get(stack_alloc).if_else(
392 None,
393 |body| {
394 body.local_get(stack_alloc)
396 .local_get(stack_size)
397 .i32_const(DEFAULT_THREAD_STACK_SIZE as i32)
398 .local_get(stack_size)
399 .select(None)
400 .i32_const(16)
401 .call(free);
402 },
403 |body| {
404 with_temp_stack(body, memory, stack, |body| {
405 body.global_get(stack.alloc)
406 .global_get(stack.size)
407 .i32_const(16)
408 .call(free);
409 });
410
411 body.i32_const(0).global_set(stack.alloc);
413 },
414 );
415
416 let destroy_id = builder.finish(vec![tls_base, stack_alloc, stack_size], &mut module.funcs);
417
418 module.exports.add("__wbindgen_thread_destroy", destroy_id);
419
420 Ok(())
421}
422
423fn find_function(module: &Module, name: &str) -> Result<FunctionId, Error> {
424 let e = module
425 .exports
426 .iter()
427 .find(|e| e.name == name)
428 .ok_or_else(|| anyhow!("failed to find `{}`", name))?;
429 match e.item {
430 walrus::ExportItem::Function(f) => Ok(f),
431 _ => bail!("`{}` wasn't a function", name),
432 }
433}
434
435fn with_temp_stack(
438 body: &mut InstrSeqBuilder<'_>,
439 memory: MemoryId,
440 stack: &Stack,
441 block: impl Fn(&mut InstrSeqBuilder<'_>),
442) {
443 use walrus::ir::*;
444
445 body.i32_const(stack.temp).global_set(stack.pointer);
446
447 body.loop_(None, |loop_| {
448 let loop_id = loop_.id();
449
450 loop_
451 .i32_const(stack.temp_lock)
452 .i32_const(0)
453 .i32_const(1)
454 .cmpxchg(memory, AtomicWidth::I32, ATOMIC_MEM_ARG)
455 .if_else(
456 None,
457 |body| {
458 body.i32_const(stack.temp_lock)
459 .i32_const(1)
460 .i64_const(-1)
461 .atomic_wait(memory, ATOMIC_MEM_ARG, false)
462 .drop()
463 .br(loop_id);
464 },
465 |_| {},
466 );
467 });
468
469 block(body);
470
471 body.i32_const(stack.temp_lock)
472 .i32_const(0)
473 .store(memory, StoreKind::I32 { atomic: true }, ATOMIC_MEM_ARG)
474 .i32_const(stack.temp_lock)
475 .i32_const(1)
476 .atomic_notify(memory, ATOMIC_MEM_ARG)
477 .drop();
478}