wasm_bindgen_threads_xform/
lib.rs

1use anyhow::{anyhow, bail, Error};
2use std::cmp;
3use walrus::ir::Value;
4use walrus::FunctionBuilder;
5use walrus::{
6    ir::MemArg, ConstExpr, ExportItem, FunctionId, GlobalId, GlobalKind, InstrSeqBuilder, MemoryId,
7    Module, ValType,
8};
9use wasm_bindgen_wasm_conventions as wasm_conventions;
10
11pub const PAGE_SIZE: u32 = 1 << 16;
12const DEFAULT_THREAD_STACK_SIZE: u32 = 1 << 21; // 2MB
13const ATOMIC_MEM_ARG: MemArg = MemArg {
14    align: 4,
15    offset: 0,
16};
17
18#[derive(Clone, Copy)]
19pub struct ThreadCount(walrus::LocalId);
20
21/// Is threaded Wasm enabled?
22pub fn is_enabled(module: &Module) -> bool {
23    // Compatibility with older LLVM outputs. Newer LLVM outputs, when
24    // atomics are enabled, emit a shared memory. That's a good indicator
25    // that we have work to do. If shared memory isn't enabled, though then
26    // this isn't an atomic module so there's nothing to do. We still allow,
27    // though, an environment variable to force us to go down this path to
28    // remain compatible with older LLVM outputs.
29    match wasm_conventions::get_memory(module) {
30        Ok(memory) => module.memories.get(memory).shared,
31        Err(_) => false,
32    }
33}
34
35/// Execute the transformation on the parsed Wasm module specified.
36///
37/// This function will prepare `Module` to be run on multiple threads,
38/// performing steps such as:
39///
40/// * All data segments are switched to "passive" data segments to ensure
41///   they're only initialized once (coming later)
42/// * If memory is exported from this module, it is instead switched to
43///   being imported (with the same parameters).
44/// * The imported memory is required to be `shared`, ensuring it's backed
45///   by a `SharedArrayBuffer` on the web.
46/// * A `global` for a thread ID is injected.
47/// * Four bytes in linear memory are reserved for the counter of thread
48///   IDs.
49/// * A `start` function is injected (or prepended if one already exists)
50///   which initializes memory for the first thread and otherwise allocates
51///   thread ids for all threads.
52/// * Some stack space is prepared for each thread after the first one.
53///
54/// More and/or less may happen here over time, stay tuned!
55pub fn run(module: &mut Module) -> Result<Option<ThreadCount>, Error> {
56    if !is_enabled(module) {
57        return Ok(None);
58    }
59
60    let memory = wasm_conventions::get_memory(module)?;
61
62    // Now we need to allocate extra static memory for:
63    // - A thread id counter.
64    // - A temporary stack for calls to `malloc()` and `free()`.
65    // - A lock to synchronize usage of the above stack.
66    // For this, we allocate 1 extra page of memory (should be enough as temporary
67    // stack) and grab the first 2 _aligned_ i32 words to use as counter and lock.
68    let static_data_align = 4;
69    let static_data_pages = 1;
70    let (base, addr) = allocate_static_data(module, memory, static_data_pages, static_data_align)?;
71
72    let mem = module.memories.get(memory);
73    assert!(mem.shared);
74    assert!(mem.import.is_some());
75    assert!(mem.data_segments.is_empty());
76
77    let tls = Tls {
78        init: delete_synthetic_func(module, "__wasm_init_tls")?,
79        size: delete_synthetic_global(module, "__tls_size")?,
80        align: delete_synthetic_global(module, "__tls_align")?,
81        base: wasm_conventions::get_tls_base(module)
82            .ok_or_else(|| anyhow!("failed to find tls base"))?,
83    };
84
85    let thread_counter_addr = addr as i32;
86
87    let stack_alloc =
88        module
89            .globals
90            .add_local(ValType::I32, true, false, ConstExpr::Value(Value::I32(0)));
91
92    // Make sure the temporary stack is aligned down
93    let temp_stack = (base + static_data_pages * PAGE_SIZE) & !(static_data_align - 1);
94
95    const _: () = assert!(DEFAULT_THREAD_STACK_SIZE % PAGE_SIZE == 0);
96
97    let stack = Stack {
98        pointer: wasm_conventions::get_stack_pointer(module)
99            .ok_or_else(|| anyhow!("failed to find stack pointer"))?,
100        temp: temp_stack as i32,
101        temp_lock: thread_counter_addr + 4,
102        alloc: stack_alloc,
103        size: module.globals.add_local(
104            ValType::I32,
105            true,
106            false,
107            ConstExpr::Value(Value::I32(DEFAULT_THREAD_STACK_SIZE as i32)),
108        ),
109    };
110
111    let _ = module.exports.add("__stack_alloc", stack.alloc);
112
113    let thread_count = inject_start(module, &tls, &stack, thread_counter_addr, memory)?;
114
115    // we expose a `__wbindgen_thread_destroy()` helper function that deallocates stack space.
116    //
117    // ## Safety
118    // After calling this function in a given agent, the instance should be considered
119    // "destroyed" and any further invocations into it will trigger UB. This function
120    // should not be called from an agent that cannot block (e.g. the main document thread).
121    //
122    // You can also call it from a "leader" agent, passing appropriate values, if said leader
123    // is in charge of cleaning up after a "follower" agent. In that case:
124    // - The "appropriate values" are the values of the `__tls_base` and `__stack_alloc` globals
125    //   and the stack size from the follower thread, after initialization.
126    // - The leader does _not_ need to block.
127    // - Similar restrictions apply: the follower thread should be considered unusable afterwards,
128    //   the leader should not call this function with the same set of parameters twice.
129    // - Moreover, concurrent calls can lead to UB: the follower could be in the middle of a
130    //   call while the leader is destroying its stack! You should make sure that this cannot happen.
131    inject_destroy(module, &tls, &stack, memory)?;
132
133    Ok(Some(thread_count))
134}
135
136impl ThreadCount {
137    pub fn wrap_start(self, builder: &mut FunctionBuilder, start: FunctionId) {
138        // We only want to call the start function if we are in the first thread.
139        // The thread counter should be 0 for the first thread.
140        builder.func_body().local_get(self.0).if_else(
141            None,
142            |_| {},
143            |body| {
144                body.call(start);
145            },
146        );
147    }
148}
149
150fn delete_synthetic_func(module: &mut Module, name: &str) -> Result<FunctionId, Error> {
151    match delete_synthetic_export(module, name)? {
152        walrus::ExportItem::Function(f) => Ok(f),
153        _ => bail!("`{}` must be a function", name),
154    }
155}
156
157fn delete_synthetic_global(module: &mut Module, name: &str) -> Result<u32, Error> {
158    let id = match delete_synthetic_export(module, name)? {
159        walrus::ExportItem::Global(g) => g,
160        _ => bail!("`{}` must be a global", name),
161    };
162    let g = match module.globals.get(id).kind {
163        walrus::GlobalKind::Local(g) => g,
164        walrus::GlobalKind::Import(_) => bail!("`{}` must not be an imported global", name),
165    };
166    match g {
167        ConstExpr::Value(Value::I32(v)) => Ok(v as u32),
168        _ => bail!("`{}` was not an `i32` constant", name),
169    }
170}
171
172fn delete_synthetic_export(module: &mut Module, name: &str) -> Result<ExportItem, Error> {
173    let item = module
174        .exports
175        .iter()
176        .find(|e| e.name == name)
177        .ok_or_else(|| anyhow!("failed to find `{}`", name))?;
178    let ret = item.item;
179    let id = item.id();
180    module.exports.delete(id);
181    Ok(ret)
182}
183
184/// Allocates extra space for static data. Returns `(addr, base)`, where:
185/// - `base` is the starting address of the extra `pages`.
186/// - `addr` is the _first_ address in that chunk that is aligned to `align`.
187fn allocate_static_data(
188    module: &mut Module,
189    memory: MemoryId,
190    pages: u32,
191    align: u32,
192) -> Result<(u32, u32), Error> {
193    // First up, look for a `__heap_base` export which is injected by LLD as
194    // part of the linking process. Note that `__heap_base` should in theory be
195    // *after* the stack and data, which means it's at the very end of the
196    // address space and should be safe for us to inject extra pages of data at.
197    let heap_base = module
198        .exports
199        .iter()
200        .filter(|e| e.name == "__heap_base")
201        .find_map(|e| match e.item {
202            ExportItem::Global(id) => Some(id),
203            _ => None,
204        });
205    let heap_base = match heap_base {
206        Some(idx) => idx,
207        None => bail!("failed to find `__heap_base` for injecting thread id"),
208    };
209
210    // Now we need to bump up `__heap_base` by a few pages. Do lots of validation
211    // here to make sure that `__heap_base` is an non-mutable integer, and then do
212    // some logic to ensure that the return the correct, aligned `address` as specified
213    // by `align`.
214    let (base, address) = {
215        let global = module.globals.get_mut(heap_base);
216        if global.ty != ValType::I32 {
217            bail!("the `__heap_base` global doesn't have the type `i32`");
218        }
219        if global.mutable {
220            bail!("the `__heap_base` global is unexpectedly mutable");
221        }
222        let offset = match &mut global.kind {
223            GlobalKind::Local(ConstExpr::Value(Value::I32(n))) => n,
224            _ => bail!("`__heap_base` not a locally defined `i32`"),
225        };
226
227        let address = (*offset as u32 + (align - 1)) & !(align - 1); // align up
228        let base = *offset;
229
230        *offset += (pages * PAGE_SIZE) as i32;
231
232        (base, address)
233    };
234
235    let memory = module.memories.get_mut(memory);
236    memory.initial += u64::from(pages);
237    memory.maximum = memory.maximum.map(|m| cmp::max(m, memory.initial));
238
239    Ok((base as u32, address))
240}
241
242struct Tls {
243    init: walrus::FunctionId,
244    size: u32,
245    align: u32,
246    base: GlobalId,
247}
248
249struct Stack {
250    /// The stack pointer global
251    pointer: GlobalId,
252    /// The address of a small, "scratch-space" stack
253    temp: i32,
254    /// The address of a lock for the temporary stack
255    temp_lock: i32,
256    /// A global to store allocated stack
257    alloc: GlobalId,
258    /// The size of the stack
259    size: GlobalId,
260}
261
262fn inject_start(
263    module: &mut Module,
264    tls: &Tls,
265    stack: &Stack,
266    thread_counter_addr: i32,
267    memory: MemoryId,
268) -> Result<ThreadCount, Error> {
269    use walrus::ir::*;
270
271    let local = module.locals.add(ValType::I32);
272    let thread_count = module.locals.add(ValType::I32);
273    let stack_size = module.locals.add(ValType::I32);
274
275    let malloc = find_function(module, "__wbindgen_malloc")?;
276
277    let prev_start = wasm_bindgen_wasm_conventions::get_start(module);
278    let mut builder = FunctionBuilder::new(&mut module.types, &[ValType::I32], &[]);
279
280    if let Ok(prev_start) | Err(Some(prev_start)) = prev_start {
281        builder.func_body().call(prev_start);
282    }
283
284    let mut body = builder.func_body();
285
286    // Perform an if/else based on whether we're the first thread or not. Our
287    // thread ID will be zero if we're the first thread, otherwise it'll be
288    // nonzero (assuming we don't overflow...)
289    body.i32_const(thread_counter_addr)
290        .i32_const(1)
291        .atomic_rmw(memory, AtomicOp::Add, AtomicWidth::I32, ATOMIC_MEM_ARG)
292        .local_tee(thread_count)
293        .if_else(
294            None,
295            // If our thread id is nonzero then we're the second or greater thread, so
296            // we give ourselves a stack and we update our stack
297            // pointer as the default stack pointer is surely wrong for us.
298            |body| {
299                body.local_get(stack_size).if_else(
300                    None,
301                    |body| {
302                        body.local_get(stack_size).global_set(stack.size);
303                    },
304                    |_| (),
305                );
306
307                // local = malloc(stack.size, align) [aka base]
308                with_temp_stack(body, memory, stack, |body| {
309                    body.global_get(stack.size)
310                        .i32_const(16)
311                        .call(malloc)
312                        .local_tee(local);
313                });
314
315                // stack.alloc = base
316                body.global_set(stack.alloc);
317
318                // stack_pointer = base + stack.size
319                body.global_get(stack.alloc)
320                    .global_get(stack.size)
321                    .binop(BinaryOp::I32Add)
322                    .global_set(stack.pointer);
323            },
324            // If the thread id is zero then the default stack pointer works for
325            // us.
326            |_| {},
327        );
328
329    // Afterwards we need to initialize our thread-local state.
330    body.i32_const(tls.size as i32)
331        .i32_const(tls.align as i32)
332        .call(malloc)
333        .global_set(tls.base)
334        .global_get(tls.base)
335        .call(tls.init);
336
337    let id = builder.finish(vec![stack_size], &mut module.funcs);
338    module.start = Some(id);
339
340    Ok(ThreadCount(thread_count))
341}
342
343fn inject_destroy(
344    module: &mut Module,
345    tls: &Tls,
346    stack: &Stack,
347    memory: MemoryId,
348) -> Result<(), Error> {
349    let free = find_function(module, "__wbindgen_free")?;
350
351    let mut builder = FunctionBuilder::new(
352        &mut module.types,
353        &[ValType::I32, ValType::I32, ValType::I32],
354        &[],
355    );
356
357    builder.name("__wbindgen_thread_destroy".into());
358
359    let mut body = builder.func_body();
360
361    // if no explicit parameters are passed (i.e. their value is 0) then we assume
362    // we're being called from the agent that must be destroyed and rely on its globals
363    let tls_base = module.locals.add(ValType::I32);
364    let stack_alloc = module.locals.add(ValType::I32);
365    let stack_size = module.locals.add(ValType::I32);
366
367    // Ideally, at this point, we would destroy the values stored in TLS.
368    // We can't really do that without help from the standard library.
369    // See https://github.com/rustwasm/wasm-bindgen/pull/2769#issuecomment-1015775467.
370
371    body.local_get(tls_base).if_else(
372        None,
373        |body| {
374            body.local_get(tls_base)
375                .i32_const(tls.size as i32)
376                .i32_const(tls.align as i32)
377                .call(free);
378        },
379        |body| {
380            body.global_get(tls.base)
381                .i32_const(tls.size as i32)
382                .i32_const(tls.align as i32)
383                .call(free);
384
385            // set tls.base = i32::MIN to trigger invalid memory
386            body.i32_const(i32::MIN).global_set(tls.base);
387        },
388    );
389
390    // free the stack calling `__wbindgen_free(stack.alloc, stack.size)`
391    body.local_get(stack_alloc).if_else(
392        None,
393        |body| {
394            // we're destroying somebody else's stack, so we can use our own
395            body.local_get(stack_alloc)
396                .local_get(stack_size)
397                .i32_const(DEFAULT_THREAD_STACK_SIZE as i32)
398                .local_get(stack_size)
399                .select(None)
400                .i32_const(16)
401                .call(free);
402        },
403        |body| {
404            with_temp_stack(body, memory, stack, |body| {
405                body.global_get(stack.alloc)
406                    .global_get(stack.size)
407                    .i32_const(16)
408                    .call(free);
409            });
410
411            // set stack.alloc = 0 to trigger invalid memory
412            body.i32_const(0).global_set(stack.alloc);
413        },
414    );
415
416    let destroy_id = builder.finish(vec![tls_base, stack_alloc, stack_size], &mut module.funcs);
417
418    module.exports.add("__wbindgen_thread_destroy", destroy_id);
419
420    Ok(())
421}
422
423fn find_function(module: &Module, name: &str) -> Result<FunctionId, Error> {
424    let e = module
425        .exports
426        .iter()
427        .find(|e| e.name == name)
428        .ok_or_else(|| anyhow!("failed to find `{}`", name))?;
429    match e.item {
430        walrus::ExportItem::Function(f) => Ok(f),
431        _ => bail!("`{}` wasn't a function", name),
432    }
433}
434
435/// Wraps the instructions fed by `block()` so that they can assume that the temporary, scratch
436/// stack is usable. Clobbers `stack.pointer`.
437fn with_temp_stack(
438    body: &mut InstrSeqBuilder<'_>,
439    memory: MemoryId,
440    stack: &Stack,
441    block: impl Fn(&mut InstrSeqBuilder<'_>),
442) {
443    use walrus::ir::*;
444
445    body.i32_const(stack.temp).global_set(stack.pointer);
446
447    body.loop_(None, |loop_| {
448        let loop_id = loop_.id();
449
450        loop_
451            .i32_const(stack.temp_lock)
452            .i32_const(0)
453            .i32_const(1)
454            .cmpxchg(memory, AtomicWidth::I32, ATOMIC_MEM_ARG)
455            .if_else(
456                None,
457                |body| {
458                    body.i32_const(stack.temp_lock)
459                        .i32_const(1)
460                        .i64_const(-1)
461                        .atomic_wait(memory, ATOMIC_MEM_ARG, false)
462                        .drop()
463                        .br(loop_id);
464                },
465                |_| {},
466            );
467    });
468
469    block(body);
470
471    body.i32_const(stack.temp_lock)
472        .i32_const(0)
473        .store(memory, StoreKind::I32 { atomic: true }, ATOMIC_MEM_ARG)
474        .i32_const(stack.temp_lock)
475        .i32_const(1)
476        .atomic_notify(memory, ATOMIC_MEM_ARG)
477        .drop();
478}