wit_bindgen/rt/
async_support.rs

1#![deny(missing_docs)]
2
3extern crate std;
4use core::sync::atomic::{AtomicU32, Ordering};
5use std::boxed::Box;
6use std::collections::BTreeMap;
7use std::ffi::c_void;
8use std::future::Future;
9use std::mem;
10use std::pin::Pin;
11use std::ptr;
12use std::sync::Arc;
13use std::task::{Context, Poll, Wake, Waker};
14
15macro_rules! rtdebug {
16    ($($f:tt)*) => {
17        // Change this flag to enable debugging, right now we're not using a
18        // crate like `log` or such to reduce runtime deps. Intended to be used
19        // during development for now.
20        if false {
21            std::eprintln!($($f)*);
22        }
23    }
24}
25
26/// Helper macro to deduplicate foreign definitions of wasm functions.
27///
28/// This automatically imports when on wasm targets and then defines a dummy
29/// panicking shim for native targets to support native compilation but fail at
30/// runtime.
31macro_rules! extern_wasm {
32    (
33        $(#[$extern_attr:meta])*
34        unsafe extern "C" {
35            $(
36                $(#[$func_attr:meta])*
37                $vis:vis fn $func_name:ident ( $($args:tt)* ) $(-> $ret:ty)?;
38            )*
39        }
40    ) => {
41        $(
42            #[cfg(not(target_family = "wasm"))]
43            #[allow(unused)]
44            $vis unsafe fn $func_name($($args)*) $(-> $ret)? {
45                unreachable!();
46            }
47        )*
48
49        #[cfg(target_family = "wasm")]
50        $(#[$extern_attr])*
51        unsafe extern "C" {
52            $(
53                $(#[$func_attr])*
54                $vis fn $func_name($($args)*) $(-> $ret)?;
55            )*
56        }
57    };
58}
59
60mod abi_buffer;
61mod cabi;
62mod error_context;
63mod future_support;
64#[cfg(feature = "inter-task-wakeup")]
65mod inter_task_wakeup;
66mod stream_support;
67mod subtask;
68#[cfg(feature = "inter-task-wakeup")]
69mod unit_stream;
70mod waitable;
71mod waitable_set;
72
73#[cfg(not(feature = "inter-task-wakeup"))]
74use inter_task_wakeup_disabled as inter_task_wakeup;
75#[cfg(not(feature = "inter-task-wakeup"))]
76mod inter_task_wakeup_disabled;
77
78use self::waitable_set::WaitableSet;
79pub use abi_buffer::*;
80pub use error_context::*;
81pub use future_support::*;
82pub use stream_support::*;
83#[doc(hidden)]
84pub use subtask::Subtask;
85#[cfg(feature = "inter-task-wakeup")]
86pub use unit_stream::*;
87
88type BoxFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
89
90#[cfg(feature = "async-spawn")]
91mod spawn;
92#[cfg(feature = "async-spawn")]
93pub use spawn::spawn;
94#[cfg(not(feature = "async-spawn"))]
95mod spawn_disabled;
96#[cfg(not(feature = "async-spawn"))]
97use spawn_disabled as spawn;
98
99/// Represents a task created by either a call to an async-lifted export or a
100/// future run using `block_on` or `start_task`.
101struct FutureState<'a> {
102    /// Remaining work to do (if any) before this task can be considered "done".
103    ///
104    /// Note that we won't tell the host the task is done until this is drained
105    /// and `waitables` is empty.
106    tasks: spawn::Tasks<'a>,
107
108    /// The waitable set containing waitables created by this task, if any.
109    waitable_set: Option<WaitableSet>,
110
111    /// State of all waitables in `waitable_set`, and the ptr/callback they're
112    /// associated with.
113    //
114    // Note that this is a `BTreeMap` rather than a `HashMap` only because, as
115    // of this writing, initializing the default hasher for `HashMap` requires
116    // calling `wasi_snapshot_preview1:random_get`, which requires initializing
117    // the `wasi_snapshot_preview1` adapter when targeting `wasm32-wasip2` and
118    // later, and that's expensive enough that we'd prefer to avoid it for apps
119    // which otherwise make no use of the adapter.
120    waitables: BTreeMap<u32, (*mut c_void, unsafe extern "C" fn(*mut c_void, u32))>,
121
122    /// Raw structure used to pass to `cabi::wasip3_task_set`
123    wasip3_task: cabi::wasip3_task,
124
125    /// Rust-level state for the waker, notably a bool as to whether this has
126    /// been woken.
127    waker: Arc<FutureWaker>,
128
129    /// Clone of `waker` field, but represented as `std::task::Waker`.
130    waker_clone: Waker,
131
132    /// State related to supporting inter-task wakeup scenarios.
133    inter_task_wakeup: inter_task_wakeup::State,
134}
135
136impl FutureState<'_> {
137    fn new(future: BoxFuture<'_>) -> FutureState<'_> {
138        let waker = Arc::new(FutureWaker::default());
139        FutureState {
140            waker_clone: waker.clone().into(),
141            waker,
142            tasks: spawn::Tasks::new(future),
143            waitable_set: None,
144            waitables: BTreeMap::new(),
145            wasip3_task: cabi::wasip3_task {
146                // This pointer is filled in before calling `wasip3_task_set`.
147                ptr: ptr::null_mut(),
148                version: cabi::WASIP3_TASK_V1,
149                waitable_register,
150                waitable_unregister,
151            },
152            inter_task_wakeup: Default::default(),
153        }
154    }
155
156    fn get_or_create_waitable_set(&mut self) -> &WaitableSet {
157        self.waitable_set.get_or_insert_with(WaitableSet::new)
158    }
159
160    fn add_waitable(&mut self, waitable: u32) {
161        self.get_or_create_waitable_set().join(waitable)
162    }
163
164    fn remove_waitable(&mut self, waitable: u32) {
165        WaitableSet::remove_waitable_from_all_sets(waitable)
166    }
167
168    fn remaining_work(&self) -> bool {
169        !self.waitables.is_empty()
170    }
171
172    /// Handles the `event{0,1,2}` event codes and returns a corresponding
173    /// return code along with a flag whether this future is "done" or not.
174    fn callback(&mut self, event0: u32, event1: u32, event2: u32) -> CallbackCode {
175        match event0 {
176            EVENT_NONE => rtdebug!("EVENT_NONE"),
177            EVENT_SUBTASK => rtdebug!("EVENT_SUBTASK({event1:#x}, {event2:#x})"),
178            EVENT_STREAM_READ => rtdebug!("EVENT_STREAM_READ({event1:#x}, {event2:#x})"),
179            EVENT_STREAM_WRITE => rtdebug!("EVENT_STREAM_WRITE({event1:#x}, {event2:#x})"),
180            EVENT_FUTURE_READ => rtdebug!("EVENT_FUTURE_READ({event1:#x}, {event2:#x})"),
181            EVENT_FUTURE_WRITE => rtdebug!("EVENT_FUTURE_WRITE({event1:#x}, {event2:#x})"),
182            EVENT_CANCEL => {
183                rtdebug!("EVENT_CANCEL");
184
185                // Cancellation is mapped to destruction in Rust, so return a
186                // code/bool indicating we're done. The caller will then
187                // appropriately deallocate this `FutureState` which will
188                // transitively run all destructors.
189                return CallbackCode::Exit;
190            }
191            _ => unreachable!(),
192        }
193
194        self.with_p3_task_set(|me| {
195            // Transition our sleep state to ensure that the inter-task stream
196            // isn't used since there's no need to use that here.
197            me.waker
198                .sleep_state
199                .store(SLEEP_STATE_WOKEN, Ordering::Relaxed);
200
201            // With all of our context now configured, deliver the event
202            // notification this callback corresponds to.
203            //
204            // Note that this should happen under the reset of
205            // `waker.sleep_state` above to ensure that if a waker is woken it
206            // won't actually signal our inter-task stream since we're already
207            // in the process of handling the future.
208            if event0 != EVENT_NONE {
209                me.deliver_waitable_event(event1, event2)
210            }
211
212            // If there's still an in-progress read (e.g. `event{1,2}`) wasn't
213            // ourselves getting woken up, then cancel the read since we're
214            // processing the future here anyway.
215            me.cancel_inter_task_stream_read();
216
217            loop {
218                let mut context = Context::from_waker(&me.waker_clone);
219
220                // On each turn of this loop reset the state to "polling"
221                // which clears out any pending wakeup if one was sent. This
222                // in theory helps minimize wakeups from previous iterations
223                // happening in this iteration.
224                me.waker
225                    .sleep_state
226                    .store(SLEEP_STATE_POLLING, Ordering::Relaxed);
227
228                // Poll our future, seeing if it was able to make progress.
229                let poll = me.tasks.poll_next(&mut context);
230
231                match poll {
232                    // A future completed, yay! Keep going to see if more have
233                    // completed.
234                    Poll::Ready(Some(())) => (),
235
236                    // The task list is empty, but there might be remaining work
237                    // in terms of waitables through the cabi interface. In this
238                    // situation wait for all waitables to be resolved before
239                    // signaling that our own task is done.
240                    Poll::Ready(None) => {
241                        assert!(me.tasks.is_empty());
242                        if me.remaining_work() {
243                            let waitable = me.waitable_set.as_ref().unwrap().as_raw();
244                            break CallbackCode::Wait(waitable);
245                        } else {
246                            break CallbackCode::Exit;
247                        }
248                    }
249
250                    // Some future within `self.tasks` is not ready yet. If our
251                    // `waker` was signaled then that means this is a yield
252                    // operation, otherwise it means we're blocking on
253                    // something.
254                    Poll::Pending => {
255                        assert!(!me.tasks.is_empty());
256                        if me.waker.sleep_state.load(Ordering::Relaxed) == SLEEP_STATE_WOKEN {
257                            if me.remaining_work() {
258                                let (event0, event1, event2) =
259                                    me.waitable_set.as_ref().unwrap().poll();
260                                if event0 != EVENT_NONE {
261                                    me.deliver_waitable_event(event1, event2);
262                                    continue;
263                                }
264                            }
265                            break CallbackCode::Yield;
266                        }
267
268                        // Transition our state to "sleeping" so wakeup
269                        // notifications know that they need to signal the
270                        // inter-task stream.
271                        me.waker
272                            .sleep_state
273                            .store(SLEEP_STATE_SLEEPING, Ordering::Relaxed);
274                        me.read_inter_task_stream();
275                        let waitable = me.waitable_set.as_ref().unwrap().as_raw();
276                        break CallbackCode::Wait(waitable);
277                    }
278                }
279            }
280        })
281    }
282
283    /// Deliver the `code` event to the `waitable` store within our map. This
284    /// waitable should be present because it's part of the waitable set which
285    /// is kept in-sync with our map.
286    fn deliver_waitable_event(&mut self, waitable: u32, code: u32) {
287        self.remove_waitable(waitable);
288
289        if self
290            .inter_task_wakeup
291            .consume_waitable_event(waitable, code)
292        {
293            return;
294        }
295
296        let (ptr, callback) = self.waitables.remove(&waitable).unwrap();
297        unsafe {
298            callback(ptr, code);
299        }
300    }
301
302    fn with_p3_task_set<R>(&mut self, f: impl FnOnce(&mut Self) -> R) -> R {
303        // Finish our `wasip3_task` by initializing its self-referential pointer,
304        // and then register it for the duration of this function with
305        // `wasip3_task_set`. The previous value of `wasip3_task_set` will get
306        // restored when this function returns.
307        struct ResetTask(*mut cabi::wasip3_task);
308        impl Drop for ResetTask {
309            fn drop(&mut self) {
310                unsafe {
311                    cabi::wasip3_task_set(self.0);
312                }
313            }
314        }
315        let self_raw = self as *mut FutureState<'_>;
316        self.wasip3_task.ptr = self_raw.cast();
317        let prev = unsafe { cabi::wasip3_task_set(&mut self.wasip3_task) };
318        let _reset = ResetTask(prev);
319
320        f(self)
321    }
322}
323
324impl Drop for FutureState<'_> {
325    fn drop(&mut self) {
326        // If there's an active read of the inter-task stream, go ahead and
327        // cancel it, since we're about to drop the stream anyway.
328        self.cancel_inter_task_stream_read();
329
330        // If this state has active tasks then they need to be dropped which may
331        // execute arbitrary code. This arbitrary code might require the p3 APIs
332        // for managing waitables, notably around removing them. In this
333        // situation we ensure that the p3 task is set while futures are being
334        // destroyed.
335        if !self.tasks.is_empty() {
336            self.with_p3_task_set(|me| {
337                me.tasks = Default::default();
338            })
339        }
340    }
341}
342
343unsafe extern "C" fn waitable_register(
344    ptr: *mut c_void,
345    waitable: u32,
346    callback: unsafe extern "C" fn(*mut c_void, u32),
347    callback_ptr: *mut c_void,
348) -> *mut c_void {
349    let ptr = ptr.cast::<FutureState<'static>>();
350    assert!(!ptr.is_null());
351    unsafe {
352        (*ptr).add_waitable(waitable);
353        match (*ptr).waitables.insert(waitable, (callback_ptr, callback)) {
354            Some((prev, _)) => prev,
355            None => ptr::null_mut(),
356        }
357    }
358}
359
360unsafe extern "C" fn waitable_unregister(ptr: *mut c_void, waitable: u32) -> *mut c_void {
361    let ptr = ptr.cast::<FutureState<'static>>();
362    assert!(!ptr.is_null());
363    unsafe {
364        (*ptr).remove_waitable(waitable);
365        match (*ptr).waitables.remove(&waitable) {
366            Some((prev, _)) => prev,
367            None => ptr::null_mut(),
368        }
369    }
370}
371
372/// Status for "this task is actively being polled"
373const SLEEP_STATE_POLLING: u32 = 0;
374/// Status for "this task has a wakeup scheduled, no more action need be taken".
375const SLEEP_STATE_WOKEN: u32 = 1;
376/// Status for "this task is not being polled and has not been woken"
377///
378/// Wakeups on this status signal the inter-task stream.
379const SLEEP_STATE_SLEEPING: u32 = 2;
380
381#[derive(Default)]
382struct FutureWaker {
383    /// One of `SLEEP_STATE_*` indicating the current status.
384    sleep_state: AtomicU32,
385    inter_task_stream: inter_task_wakeup::WakerState,
386}
387
388impl Wake for FutureWaker {
389    fn wake(self: Arc<Self>) {
390        Self::wake_by_ref(&self)
391    }
392
393    fn wake_by_ref(self: &Arc<Self>) {
394        match self.sleep_state.swap(SLEEP_STATE_WOKEN, Ordering::Relaxed) {
395            // If this future was currently being polled, or if someone else
396            // already woke it up, then there's nothing to do.
397            SLEEP_STATE_POLLING | SLEEP_STATE_WOKEN => {}
398
399            // If this future is sleeping, however, then this is a cross-task
400            // wakeup meaning that we need to write to its wakeup stream.
401            other => {
402                assert_eq!(other, SLEEP_STATE_SLEEPING);
403                self.inter_task_stream.wake();
404            }
405        }
406    }
407}
408
409const EVENT_NONE: u32 = 0;
410const EVENT_SUBTASK: u32 = 1;
411const EVENT_STREAM_READ: u32 = 2;
412const EVENT_STREAM_WRITE: u32 = 3;
413const EVENT_FUTURE_READ: u32 = 4;
414const EVENT_FUTURE_WRITE: u32 = 5;
415const EVENT_CANCEL: u32 = 6;
416
417#[derive(PartialEq, Debug)]
418enum CallbackCode {
419    Exit,
420    Yield,
421    Wait(u32),
422}
423
424impl CallbackCode {
425    fn encode(self) -> u32 {
426        match self {
427            CallbackCode::Exit => 0,
428            CallbackCode::Yield => 1,
429            CallbackCode::Wait(waitable) => 2 | (waitable << 4),
430        }
431    }
432}
433
434const STATUS_STARTING: u32 = 0;
435const STATUS_STARTED: u32 = 1;
436const STATUS_RETURNED: u32 = 2;
437const STATUS_STARTED_CANCELLED: u32 = 3;
438const STATUS_RETURNED_CANCELLED: u32 = 4;
439
440const BLOCKED: u32 = 0xffff_ffff;
441const COMPLETED: u32 = 0x0;
442const DROPPED: u32 = 0x1;
443const CANCELLED: u32 = 0x2;
444
445/// Return code of stream/future operations.
446#[derive(PartialEq, Debug, Copy, Clone)]
447enum ReturnCode {
448    /// The operation is blocked and has not completed.
449    Blocked,
450    /// The operation completed with the specified number of items.
451    Completed(u32),
452    /// The other end is dropped, but before that the specified number of items
453    /// were transferred.
454    Dropped(u32),
455    /// The operation was cancelled, but before that the specified number of
456    /// items were transferred.
457    Cancelled(u32),
458}
459
460impl ReturnCode {
461    fn decode(val: u32) -> ReturnCode {
462        if val == BLOCKED {
463            return ReturnCode::Blocked;
464        }
465        let amt = val >> 4;
466        match val & 0xf {
467            COMPLETED => ReturnCode::Completed(amt),
468            DROPPED => ReturnCode::Dropped(amt),
469            CANCELLED => ReturnCode::Cancelled(amt),
470            _ => panic!("unknown return code {val:#x}"),
471        }
472    }
473}
474
475/// Starts execution of the `task` provided, an asynchronous computation.
476///
477/// This is used for async-lifted exports at their definition site. The
478/// representation of the export is `task` and this function is called from the
479/// entrypoint. The code returned here is the same as the callback associated
480/// with this export, and the callback will be used if this task doesn't exit
481/// immediately with its result.
482#[doc(hidden)]
483pub fn start_task(task: impl Future<Output = ()> + 'static) -> i32 {
484    // Allocate a new `FutureState` which will track all state necessary for
485    // our exported task.
486    let state = Box::into_raw(Box::new(FutureState::new(Box::pin(task))));
487
488    // Store our `FutureState` into our context-local-storage slot and then
489    // pretend we got EVENT_NONE to kick off everything.
490    //
491    // SAFETY: we should own `context.set` as we're the root level exported
492    // task, and then `callback` is only invoked when context-local storage is
493    // valid.
494    unsafe {
495        assert!(context_get().is_null());
496        context_set(state.cast());
497        callback(EVENT_NONE, 0, 0) as i32
498    }
499}
500
501/// Handle a progress notification from the host regarding either a call to an
502/// async-lowered import or a stream/future read/write operation.
503///
504/// # Unsafety
505///
506/// This function assumes that `context_get()` returns a `FutureState`.
507#[doc(hidden)]
508pub unsafe fn callback(event0: u32, event1: u32, event2: u32) -> u32 {
509    // Acquire our context-local state, assert it's not-null, and then reset
510    // the state to null while we're running to help prevent any unintended
511    // usage.
512    let state = context_get().cast::<FutureState<'static>>();
513    assert!(!state.is_null());
514    unsafe {
515        context_set(ptr::null_mut());
516    }
517
518    // Use `state` to run the `callback` function in the context of our event
519    // codes we received. If the callback decides to exit then we're done with
520    // our future so deallocate it. Otherwise put our future back in
521    // context-local storage and forward the code.
522    unsafe {
523        let rc = (*state).callback(event0, event1, event2);
524        if rc == CallbackCode::Exit {
525            drop(Box::from_raw(state));
526        } else {
527            context_set(state.cast());
528        }
529        rtdebug!(" => (cb) {rc:?}");
530        rc.encode()
531    }
532}
533
534/// Run the specified future to completion, returning the result.
535///
536/// This uses `waitable-set.wait` to poll for progress on any in-progress calls
537/// to async-lowered imports as necessary.
538// TODO: refactor so `'static` bounds aren't necessary
539pub fn block_on<T: 'static>(future: impl Future<Output = T>) -> T {
540    let mut result = None;
541    let mut state = FutureState::new(Box::pin(async {
542        result = Some(future.await);
543    }));
544    let mut event = (EVENT_NONE, 0, 0);
545    loop {
546        match state.callback(event.0, event.1, event.2) {
547            CallbackCode::Exit => {
548                drop(state);
549                break result.unwrap();
550            }
551            CallbackCode::Yield => event = state.waitable_set.as_ref().unwrap().poll(),
552            CallbackCode::Wait(_) => event = state.waitable_set.as_ref().unwrap().wait(),
553        }
554    }
555}
556
557/// Call the `yield` canonical built-in function.
558///
559/// This yields control to the host temporarily, allowing other tasks to make
560/// progress. It's a good idea to call this inside a busy loop which does not
561/// otherwise ever yield control the host.
562///
563/// Note that this function is a blocking function, not an `async` function.
564/// That means that this is not an async yield which allows other tasks in this
565/// component to progress, but instead this will block the current function
566/// until the host gets back around to returning from this yield. Asynchronous
567/// functions should probably use [`yield_async`] instead.
568///
569/// # Return Value
570///
571/// This function returns a `bool` which indicates whether execution should
572/// continue after this yield point. A return value of `true` means that the
573/// task was not cancelled and execution should continue. A return value of
574/// `false`, however, means that the task was cancelled while it was suspended
575/// at this yield point. The caller should return back and exit from the task
576/// ASAP in this situation.
577pub fn yield_blocking() -> bool {
578    extern_wasm! {
579        #[link(wasm_import_module = "$root")]
580        unsafe extern "C" {
581            #[link_name = "[thread-yield]"]
582            fn yield_() -> bool;
583        }
584    }
585
586    // Note that the return value from the raw intrinsic is inverted, the
587    // canonical ABI returns "did this task get cancelled" while this function
588    // works as "should work continue going".
589    unsafe { !yield_() }
590}
591
592/// The asynchronous counterpart to [`yield_blocking`].
593///
594/// This function does not block the current task but instead gives the
595/// Rust-level executor a chance to yield control back to the host temporarily.
596/// This means that other Rust-level tasks may also be able to progress during
597/// this yield operation.
598///
599/// # Return Value
600///
601/// Unlike [`yield_blocking`] this function does not return anything. If this
602/// component task is cancelled while paused at this yield point then the future
603/// will be dropped and a Rust-level destructor will take over and clean up the
604/// task. It's not necessary to do anything with the return value of this
605/// function other than ensuring that you `.await` the function call.
606pub async fn yield_async() {
607    #[derive(Default)]
608    struct Yield {
609        yielded: bool,
610    }
611
612    impl Future for Yield {
613        type Output = ();
614
615        fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<()> {
616            if self.yielded {
617                Poll::Ready(())
618            } else {
619                self.yielded = true;
620                context.waker().wake_by_ref();
621                Poll::Pending
622            }
623        }
624    }
625
626    Yield::default().await;
627}
628
629/// Call the `backpressure.inc` canonical built-in function.
630pub fn backpressure_inc() {
631    extern_wasm! {
632        #[link(wasm_import_module = "$root")]
633        unsafe extern "C" {
634            #[link_name = "[backpressure-inc]"]
635            fn backpressure_inc();
636        }
637    }
638
639    unsafe { backpressure_inc() }
640}
641
642/// Call the `backpressure.dec` canonical built-in function.
643pub fn backpressure_dec() {
644    extern_wasm! {
645        #[link(wasm_import_module = "$root")]
646        unsafe extern "C" {
647            #[link_name = "[backpressure-dec]"]
648            fn backpressure_dec();
649        }
650    }
651
652    unsafe { backpressure_dec() }
653}
654
655fn context_get() -> *mut u8 {
656    extern_wasm! {
657        #[link(wasm_import_module = "$root")]
658        unsafe extern "C" {
659            #[link_name = "[context-get-0]"]
660            fn get() -> *mut u8;
661        }
662    }
663
664    unsafe { get() }
665}
666
667unsafe fn context_set(value: *mut u8) {
668    extern_wasm! {
669        #[link(wasm_import_module = "$root")]
670        unsafe extern "C" {
671            #[link_name = "[context-set-0]"]
672            fn set(value: *mut u8);
673        }
674    }
675
676    unsafe { set(value) }
677}
678
679#[doc(hidden)]
680pub struct TaskCancelOnDrop {
681    _priv: (),
682}
683
684impl TaskCancelOnDrop {
685    #[doc(hidden)]
686    pub fn new() -> TaskCancelOnDrop {
687        TaskCancelOnDrop { _priv: () }
688    }
689
690    #[doc(hidden)]
691    pub fn forget(self) {
692        mem::forget(self);
693    }
694}
695
696impl Drop for TaskCancelOnDrop {
697    fn drop(&mut self) {
698        extern_wasm! {
699            #[link(wasm_import_module = "[export]$root")]
700            unsafe extern "C" {
701                #[link_name = "[task-cancel]"]
702                fn cancel();
703            }
704        }
705
706        unsafe { cancel() }
707    }
708}