wit_bindgen/rt/
async_support.rs

1#![deny(missing_docs)]
2// TODO: Switch to interior mutability (e.g. use Mutexes or thread-local
3// RefCells) and remove this, since even in single-threaded mode `static mut`
4// references can be a hazard due to recursive access.
5#![allow(static_mut_refs)]
6
7extern crate std;
8use core::sync::atomic::{AtomicBool, Ordering};
9use std::boxed::Box;
10use std::collections::BTreeMap;
11use std::ffi::c_void;
12use std::future::Future;
13use std::mem;
14use std::pin::Pin;
15use std::ptr;
16use std::sync::Arc;
17use std::task::{Context, Poll, Wake, Waker};
18use std::vec::Vec;
19
20use futures::channel::oneshot;
21use futures::future::FutureExt;
22use futures::stream::{FuturesUnordered, StreamExt};
23
24macro_rules! rtdebug {
25    ($($f:tt)*) => {
26        // Change this flag to enable debugging, right now we're not using a
27        // crate like `log` or such to reduce runtime deps. Intended to be used
28        // during development for now.
29        if false {
30            std::eprintln!($($f)*);
31        }
32    }
33
34}
35
36mod abi_buffer;
37mod cabi;
38mod error_context;
39mod future_support;
40mod stream_support;
41mod subtask;
42mod waitable;
43mod waitable_set;
44
45use self::waitable_set::WaitableSet;
46pub use abi_buffer::*;
47pub use error_context::*;
48pub use future_support::*;
49pub use stream_support::*;
50#[doc(hidden)]
51pub use subtask::Subtask;
52
53pub use futures;
54
55type BoxFuture = Pin<Box<dyn Future<Output = ()> + 'static>>;
56
57/// Represents a task created by either a call to an async-lifted export or a
58/// future run using `block_on` or `start_task`.
59struct FutureState {
60    /// Remaining work to do (if any) before this task can be considered "done".
61    ///
62    /// Note that we won't tell the host the task is done until this is drained
63    /// and `waitables` is empty.
64    tasks: FuturesUnordered<BoxFuture>,
65
66    /// The waitable set containing waitables created by this task, if any.
67    waitable_set: Option<WaitableSet>,
68
69    /// State of all waitables in `waitable_set`, and the ptr/callback they're
70    /// associated with.
71    //
72    // Note that this is a `BTreeMap` rather than a `HashMap` only because, as
73    // of this writing, initializing the default hasher for `HashMap` requires
74    // calling `wasi_snapshot_preview1:random_get`, which requires initializing
75    // the `wasi_snapshot_preview1` adapter when targeting `wasm32-wasip2` and
76    // later, and that's expensive enough that we'd prefer to avoid it for apps
77    // which otherwise make no use of the adapter.
78    waitables: BTreeMap<u32, (*mut c_void, unsafe extern "C" fn(*mut c_void, u32))>,
79
80    /// Raw structure used to pass to `cabi::wasip3_task_set`
81    wasip3_task: cabi::wasip3_task,
82
83    /// Rust-level state for the waker, notably a bool as to whether this has
84    /// been woken.
85    waker: Arc<FutureWaker>,
86
87    /// Clone of `waker` field, but represented as `std::task::Waker`.
88    waker_clone: Waker,
89}
90
91impl FutureState {
92    fn new(future: BoxFuture) -> FutureState {
93        let waker = Arc::new(FutureWaker::default());
94        FutureState {
95            waker_clone: waker.clone().into(),
96            waker,
97            tasks: [future].into_iter().collect(),
98            waitable_set: None,
99            waitables: BTreeMap::new(),
100            wasip3_task: cabi::wasip3_task {
101                // This pointer is filled in before calling `wasip3_task_set`.
102                ptr: ptr::null_mut(),
103                version: cabi::WASIP3_TASK_V1,
104                waitable_register,
105                waitable_unregister,
106            },
107        }
108    }
109
110    fn get_or_create_waitable_set(&mut self) -> &WaitableSet {
111        self.waitable_set.get_or_insert_with(WaitableSet::new)
112    }
113
114    fn add_waitable(&mut self, waitable: u32) {
115        self.get_or_create_waitable_set().join(waitable)
116    }
117
118    fn remove_waitable(&mut self, waitable: u32) {
119        WaitableSet::remove_waitable_from_all_sets(waitable)
120    }
121
122    fn remaining_work(&self) -> bool {
123        !self.waitables.is_empty()
124    }
125
126    /// Handles the `event{0,1,2}` event codes and returns a corresponding
127    /// return code along with a flag whether this future is "done" or not.
128    fn callback(&mut self, event0: u32, event1: u32, event2: u32) -> (u32, bool) {
129        match event0 {
130            EVENT_NONE => rtdebug!("EVENT_NONE"),
131            EVENT_SUBTASK => rtdebug!("EVENT_SUBTASK({event1:#x}, {event2:#x})"),
132            EVENT_STREAM_READ => rtdebug!("EVENT_STREAM_READ({event1:#x}, {event2:#x})"),
133            EVENT_STREAM_WRITE => rtdebug!("EVENT_STREAM_WRITE({event1:#x}, {event2:#x})"),
134            EVENT_FUTURE_READ => rtdebug!("EVENT_FUTURE_READ({event1:#x}, {event2:#x})"),
135            EVENT_FUTURE_WRITE => rtdebug!("EVENT_FUTURE_WRITE({event1:#x}, {event2:#x})"),
136            EVENT_CANCEL => {
137                rtdebug!("EVENT_CANCEL");
138
139                // Cancellation is mapped to destruction in Rust, so return a
140                // code/bool indicating we're done. The caller will then
141                // appropriately deallocate this `FutureState` which will
142                // transitively run all destructors.
143                return (CALLBACK_CODE_EXIT, true);
144            }
145            _ => unreachable!(),
146        }
147        if event0 != EVENT_NONE {
148            self.deliver_waitable_event(event1, event2)
149        }
150
151        self.poll()
152    }
153
154    /// Deliver the `code` event to the `waitable` store within our map. This
155    /// waitable should be present because it's part of the waitable set which
156    /// is kept in-sync with our map.
157    fn deliver_waitable_event(&mut self, waitable: u32, code: u32) {
158        self.remove_waitable(waitable);
159        let (ptr, callback) = self.waitables.remove(&waitable).unwrap();
160        unsafe {
161            callback(ptr, code);
162        }
163    }
164
165    /// Poll this task until it either completes or can't make immediate
166    /// progress.
167    ///
168    /// Returns the code representing what happened along with a boolean as to
169    /// whether this execution is done.
170    fn poll(&mut self) -> (u32, bool) {
171        self.with_p3_task_set(|me| {
172            let mut context = Context::from_waker(&me.waker_clone);
173
174            loop {
175                // Reset the waker before polling to clear out any pending
176                // notification, if any.
177                me.waker.0.store(false, Ordering::Relaxed);
178
179                // Poll our future, handling `SPAWNED` around this.
180                let poll;
181                unsafe {
182                    poll = me.tasks.poll_next_unpin(&mut context);
183                    if !SPAWNED.is_empty() {
184                        me.tasks.extend(SPAWNED.drain(..));
185                    }
186                }
187
188                match poll {
189                    // A future completed, yay! Keep going to see if more have
190                    // completed.
191                    Poll::Ready(Some(())) => (),
192
193                    // The `FuturesUnordered` list is empty meaning that there's no
194                    // more work left to do, so we're done.
195                    Poll::Ready(None) => {
196                        assert!(!me.remaining_work());
197                        assert!(me.tasks.is_empty());
198                        break (CALLBACK_CODE_EXIT, true);
199                    }
200
201                    // Some future within `FuturesUnordered` is not ready yet. If
202                    // our `waker` was signaled then that means this is a yield
203                    // operation, otherwise it means we're blocking on something.
204                    Poll::Pending => {
205                        assert!(!me.tasks.is_empty());
206                        if me.waker.0.load(Ordering::Relaxed) {
207                            break (CALLBACK_CODE_YIELD, false);
208                        }
209
210                        assert!(me.remaining_work());
211                        let waitable = me.waitable_set.as_ref().unwrap().as_raw();
212                        break (CALLBACK_CODE_WAIT | (waitable << 4), false);
213                    }
214                }
215            }
216        })
217    }
218
219    fn with_p3_task_set<R>(&mut self, f: impl FnOnce(&mut Self) -> R) -> R {
220        // Finish our `wasip3_task` by initializing its self-referential pointer,
221        // and then register it for the duration of this function with
222        // `wasip3_task_set`. The previous value of `wasip3_task_set` will get
223        // restored when this function returns.
224        struct ResetTask(*mut cabi::wasip3_task);
225        impl Drop for ResetTask {
226            fn drop(&mut self) {
227                unsafe {
228                    cabi::wasip3_task_set(self.0);
229                }
230            }
231        }
232        let self_raw = self as *mut FutureState;
233        self.wasip3_task.ptr = self_raw.cast();
234        let prev = unsafe { cabi::wasip3_task_set(&mut self.wasip3_task) };
235        let _reset = ResetTask(prev);
236
237        f(self)
238    }
239}
240
241impl Drop for FutureState {
242    fn drop(&mut self) {
243        // If this state has active tasks then they need to be dropped which may
244        // execute arbitrary code. This arbitrary code might require the p3 APIs
245        // for managing waitables, notably around removing them. In this
246        // situation we ensure that the p3 task is set while futures are being
247        // destroyed.
248        if !self.tasks.is_empty() {
249            self.with_p3_task_set(|me| {
250                me.tasks = Default::default();
251            })
252        }
253    }
254}
255
256unsafe extern "C" fn waitable_register(
257    ptr: *mut c_void,
258    waitable: u32,
259    callback: unsafe extern "C" fn(*mut c_void, u32),
260    callback_ptr: *mut c_void,
261) -> *mut c_void {
262    let ptr = ptr.cast::<FutureState>();
263    assert!(!ptr.is_null());
264    (*ptr).add_waitable(waitable);
265    match (*ptr).waitables.insert(waitable, (callback_ptr, callback)) {
266        Some((prev, _)) => prev,
267        None => ptr::null_mut(),
268    }
269}
270
271unsafe extern "C" fn waitable_unregister(ptr: *mut c_void, waitable: u32) -> *mut c_void {
272    let ptr = ptr.cast::<FutureState>();
273    assert!(!ptr.is_null());
274    (*ptr).remove_waitable(waitable);
275    match (*ptr).waitables.remove(&waitable) {
276        Some((prev, _)) => prev,
277        None => ptr::null_mut(),
278    }
279}
280
281#[derive(Default)]
282struct FutureWaker(AtomicBool);
283
284impl Wake for FutureWaker {
285    fn wake(self: Arc<Self>) {
286        Self::wake_by_ref(&self)
287    }
288
289    fn wake_by_ref(self: &Arc<Self>) {
290        self.0.store(true, Ordering::Relaxed)
291    }
292}
293
294/// Any newly-deferred work queued by calls to the `spawn` function while
295/// polling the current task.
296static mut SPAWNED: Vec<BoxFuture> = Vec::new();
297
298const EVENT_NONE: u32 = 0;
299const EVENT_SUBTASK: u32 = 1;
300const EVENT_STREAM_READ: u32 = 2;
301const EVENT_STREAM_WRITE: u32 = 3;
302const EVENT_FUTURE_READ: u32 = 4;
303const EVENT_FUTURE_WRITE: u32 = 5;
304const EVENT_CANCEL: u32 = 6;
305
306const CALLBACK_CODE_EXIT: u32 = 0;
307const CALLBACK_CODE_YIELD: u32 = 1;
308const CALLBACK_CODE_WAIT: u32 = 2;
309const _CALLBACK_CODE_POLL: u32 = 3;
310
311const STATUS_STARTING: u32 = 0;
312const STATUS_STARTED: u32 = 1;
313const STATUS_RETURNED: u32 = 2;
314const STATUS_STARTED_CANCELLED: u32 = 3;
315const STATUS_RETURNED_CANCELLED: u32 = 4;
316
317const BLOCKED: u32 = 0xffff_ffff;
318const COMPLETED: u32 = 0x0;
319const DROPPED: u32 = 0x1;
320const CANCELLED: u32 = 0x2;
321
322/// Return code of stream/future operations.
323#[derive(PartialEq, Debug, Copy, Clone)]
324enum ReturnCode {
325    /// The operation is blocked and has not completed.
326    Blocked,
327    /// The operation completed with the specified number of items.
328    Completed(u32),
329    /// The other end is dropped, but before that the specified number of items
330    /// were transferred.
331    Dropped(u32),
332    /// The operation was cancelled, but before that the specified number of
333    /// items were transferred.
334    Cancelled(u32),
335}
336
337impl ReturnCode {
338    fn decode(val: u32) -> ReturnCode {
339        if val == BLOCKED {
340            return ReturnCode::Blocked;
341        }
342        let amt = val >> 4;
343        match val & 0xf {
344            COMPLETED => ReturnCode::Completed(amt),
345            DROPPED => ReturnCode::Dropped(amt),
346            CANCELLED => ReturnCode::Cancelled(amt),
347            _ => panic!("unknown return code {val:#x}"),
348        }
349    }
350}
351
352/// Starts execution of the `task` provided, an asynchronous computation.
353///
354/// This is used for async-lifted exports at their definition site. The
355/// representation of the export is `task` and this function is called from the
356/// entrypoint. The code returned here is the same as the callback associated
357/// with this export, and the callback will be used if this task doesn't exit
358/// immediately with its result.
359#[doc(hidden)]
360pub fn start_task(task: impl Future<Output = ()> + 'static) -> i32 {
361    // Allocate a new `FutureState` which will track all state necessary for
362    // our exported task.
363    let state = Box::into_raw(Box::new(FutureState::new(Box::pin(task))));
364
365    // Store our `FutureState` into our context-local-storage slot and then
366    // pretend we got EVENT_NONE to kick off everything.
367    //
368    // SAFETY: we should own `context.set` as we're the root level exported
369    // task, and then `callback` is only invoked when context-local storage is
370    // valid.
371    unsafe {
372        assert!(context_get().is_null());
373        context_set(state.cast());
374        callback(EVENT_NONE, 0, 0) as i32
375    }
376}
377
378/// Handle a progress notification from the host regarding either a call to an
379/// async-lowered import or a stream/future read/write operation.
380///
381/// # Unsafety
382///
383/// This function assumes that `context_get()` returns a `FutureState`.
384#[doc(hidden)]
385pub unsafe fn callback(event0: u32, event1: u32, event2: u32) -> u32 {
386    // Acquire our context-local state, assert it's not-null, and then reset
387    // the state to null while we're running to help prevent any unintended
388    // usage.
389    let state = context_get().cast::<FutureState>();
390    assert!(!state.is_null());
391    unsafe {
392        context_set(ptr::null_mut());
393    }
394
395    // Use `state` to run the `callback` function in the context of our event
396    // codes we received. If the callback decides to exit then we're done with
397    // our future so deallocate it. Otherwise put our future back in
398    // context-local storage and forward the code.
399    unsafe {
400        let (rc, done) = (*state).callback(event0, event1, event2);
401        if done {
402            drop(Box::from_raw(state));
403        } else {
404            context_set(state.cast());
405        }
406        rtdebug!(" => (cb) {rc:#x}");
407        rc
408    }
409}
410
411/// Defer the specified future to be run after the current async-lifted export
412/// task has returned a value.
413///
414/// The task will remain in a running state until all spawned futures have
415/// completed.
416pub fn spawn(future: impl Future<Output = ()> + 'static) {
417    unsafe { SPAWNED.push(Box::pin(future)) }
418}
419
420/// Run the specified future to completion, returning the result.
421///
422/// This uses `waitable-set.wait` to poll for progress on any in-progress calls
423/// to async-lowered imports as necessary.
424// TODO: refactor so `'static` bounds aren't necessary
425pub fn block_on<T: 'static>(future: impl Future<Output = T> + 'static) -> T {
426    let (tx, mut rx) = oneshot::channel();
427    let state = &mut FutureState::new(Box::pin(future.map(move |v| drop(tx.send(v)))) as BoxFuture);
428    let mut event = (EVENT_NONE, 0, 0);
429    loop {
430        match state.callback(event.0, event.1, event.2) {
431            (_, true) => break rx.try_recv().unwrap().unwrap(),
432            (CALLBACK_CODE_YIELD, false) => event = state.waitable_set.as_ref().unwrap().poll(),
433            _ => event = state.waitable_set.as_ref().unwrap().wait(),
434        }
435    }
436}
437
438/// Call the `yield` canonical built-in function.
439///
440/// This yields control to the host temporarily, allowing other tasks to make
441/// progress. It's a good idea to call this inside a busy loop which does not
442/// otherwise ever yield control the host.
443///
444/// Note that this function is a blocking function, not an `async` function.
445/// That means that this is not an async yield which allows other tasks in this
446/// component to progress, but instead this will block the current function
447/// until the host gets back around to returning from this yield. Asynchronous
448/// functions should probably use [`yield_async`] instead.
449///
450/// # Return Value
451///
452/// This function returns a `bool` which indicates whether execution should
453/// continue after this yield point. A return value of `true` means that the
454/// task was not cancelled and execution should continue. A return value of
455/// `false`, however, means that the task was cancelled while it was suspended
456/// at this yield point. The caller should return back and exit from the task
457/// ASAP in this situation.
458pub fn yield_blocking() -> bool {
459    #[cfg(not(target_arch = "wasm32"))]
460    unsafe fn yield_() -> bool {
461        unreachable!();
462    }
463
464    #[cfg(target_arch = "wasm32")]
465    #[link(wasm_import_module = "$root")]
466    extern "C" {
467        #[link_name = "[thread-yield]"]
468        fn yield_() -> bool;
469    }
470    // Note that the return value from the raw intrinsic is inverted, the
471    // canonical ABI returns "did this task get cancelled" while this function
472    // works as "should work continue going".
473    unsafe { !yield_() }
474}
475
476/// The asynchronous counterpart to [`yield_blocking`].
477///
478/// This function does not block the current task but instead gives the
479/// Rust-level executor a chance to yield control back to the host temporarily.
480/// This means that other Rust-level tasks may also be able to progress during
481/// this yield operation.
482///
483/// # Return Value
484///
485/// Unlike [`yield_blocking`] this function does not return anything. If this
486/// component task is cancelled while paused at this yield point then the future
487/// will be dropped and a Rust-level destructor will take over and clean up the
488/// task. It's not necessary to do anything with the return value of this
489/// function other than ensuring that you `.await` the function call.
490pub async fn yield_async() {
491    #[derive(Default)]
492    struct Yield {
493        yielded: bool,
494    }
495
496    impl Future for Yield {
497        type Output = ();
498
499        fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<()> {
500            if self.yielded {
501                Poll::Ready(())
502            } else {
503                self.yielded = true;
504                context.waker().wake_by_ref();
505                Poll::Pending
506            }
507        }
508    }
509
510    Yield::default().await;
511}
512
513/// Call the `backpressure.set` canonical built-in function.
514///
515/// When `enabled` is `true`, this tells the host to defer any new calls to this
516/// component instance until further notice (i.e. until `backpressure.set` is
517/// called again with `enabled` set to `false`).
518#[deprecated = "use backpressure_{inc,dec} instead"]
519pub fn backpressure_set(enabled: bool) {
520    #[cfg(not(target_arch = "wasm32"))]
521    unsafe fn backpressure_set(_: i32) {
522        unreachable!();
523    }
524
525    #[cfg(target_arch = "wasm32")]
526    #[link(wasm_import_module = "$root")]
527    extern "C" {
528        #[link_name = "[backpressure-set]"]
529        fn backpressure_set(_: i32);
530    }
531
532    unsafe { backpressure_set(if enabled { 1 } else { 0 }) }
533}
534
535/// Call the `backpressure.inc` canonical built-in function.
536pub fn backpressure_inc() {
537    #[cfg(not(target_arch = "wasm32"))]
538    unsafe fn backpressure_inc() {
539        unreachable!();
540    }
541
542    #[cfg(target_arch = "wasm32")]
543    #[link(wasm_import_module = "$root")]
544    extern "C" {
545        #[link_name = "[backpressure-inc]"]
546        fn backpressure_inc();
547    }
548
549    unsafe { backpressure_inc() }
550}
551
552/// Call the `backpressure.dec` canonical built-in function.
553pub fn backpressure_dec() {
554    #[cfg(not(target_arch = "wasm32"))]
555    unsafe fn backpressure_dec() {
556        unreachable!();
557    }
558
559    #[cfg(target_arch = "wasm32")]
560    #[link(wasm_import_module = "$root")]
561    extern "C" {
562        #[link_name = "[backpressure-dec]"]
563        fn backpressure_dec();
564    }
565
566    unsafe { backpressure_dec() }
567}
568
569fn context_get() -> *mut u8 {
570    #[cfg(not(target_arch = "wasm32"))]
571    unsafe fn get() -> *mut u8 {
572        unreachable!()
573    }
574
575    #[cfg(target_arch = "wasm32")]
576    #[link(wasm_import_module = "$root")]
577    extern "C" {
578        #[link_name = "[context-get-0]"]
579        fn get() -> *mut u8;
580    }
581
582    unsafe { get() }
583}
584
585unsafe fn context_set(value: *mut u8) {
586    #[cfg(not(target_arch = "wasm32"))]
587    unsafe fn set(_: *mut u8) {
588        unreachable!()
589    }
590
591    #[cfg(target_arch = "wasm32")]
592    #[link(wasm_import_module = "$root")]
593    extern "C" {
594        #[link_name = "[context-set-0]"]
595        fn set(value: *mut u8);
596    }
597
598    unsafe { set(value) }
599}
600
601#[doc(hidden)]
602pub struct TaskCancelOnDrop {
603    _priv: (),
604}
605
606impl TaskCancelOnDrop {
607    #[doc(hidden)]
608    pub fn new() -> TaskCancelOnDrop {
609        TaskCancelOnDrop { _priv: () }
610    }
611
612    #[doc(hidden)]
613    pub fn forget(self) {
614        mem::forget(self);
615    }
616}
617
618impl Drop for TaskCancelOnDrop {
619    fn drop(&mut self) {
620        #[cfg(not(target_arch = "wasm32"))]
621        unsafe fn cancel() {
622            unreachable!()
623        }
624
625        #[cfg(target_arch = "wasm32")]
626        #[link(wasm_import_module = "[export]$root")]
627        extern "C" {
628            #[link_name = "[task-cancel]"]
629            fn cancel();
630        }
631
632        unsafe { cancel() }
633    }
634}