wit_bindgen/rt/
async_support.rs

1#![deny(missing_docs)]
2
3extern crate std;
4use core::sync::atomic::{AtomicBool, Ordering};
5use std::boxed::Box;
6use std::collections::BTreeMap;
7use std::ffi::c_void;
8use std::future::Future;
9use std::mem;
10use std::pin::Pin;
11use std::ptr;
12use std::sync::Arc;
13use std::task::{Context, Poll, Wake, Waker};
14
15macro_rules! rtdebug {
16    ($($f:tt)*) => {
17        // Change this flag to enable debugging, right now we're not using a
18        // crate like `log` or such to reduce runtime deps. Intended to be used
19        // during development for now.
20        if false {
21            std::eprintln!($($f)*);
22        }
23    }
24
25}
26
27mod abi_buffer;
28mod cabi;
29mod error_context;
30mod future_support;
31mod stream_support;
32mod subtask;
33mod waitable;
34mod waitable_set;
35
36use self::waitable_set::WaitableSet;
37pub use abi_buffer::*;
38pub use error_context::*;
39pub use future_support::*;
40pub use stream_support::*;
41#[doc(hidden)]
42pub use subtask::Subtask;
43
44type BoxFuture<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
45
46#[cfg(feature = "async-spawn")]
47mod spawn;
48#[cfg(feature = "async-spawn")]
49pub use spawn::spawn;
50#[cfg(not(feature = "async-spawn"))]
51mod spawn_disabled;
52#[cfg(not(feature = "async-spawn"))]
53use spawn_disabled as spawn;
54
55/// Represents a task created by either a call to an async-lifted export or a
56/// future run using `block_on` or `start_task`.
57struct FutureState<'a> {
58    /// Remaining work to do (if any) before this task can be considered "done".
59    ///
60    /// Note that we won't tell the host the task is done until this is drained
61    /// and `waitables` is empty.
62    tasks: spawn::Tasks<'a>,
63
64    /// The waitable set containing waitables created by this task, if any.
65    waitable_set: Option<WaitableSet>,
66
67    /// State of all waitables in `waitable_set`, and the ptr/callback they're
68    /// associated with.
69    //
70    // Note that this is a `BTreeMap` rather than a `HashMap` only because, as
71    // of this writing, initializing the default hasher for `HashMap` requires
72    // calling `wasi_snapshot_preview1:random_get`, which requires initializing
73    // the `wasi_snapshot_preview1` adapter when targeting `wasm32-wasip2` and
74    // later, and that's expensive enough that we'd prefer to avoid it for apps
75    // which otherwise make no use of the adapter.
76    waitables: BTreeMap<u32, (*mut c_void, unsafe extern "C" fn(*mut c_void, u32))>,
77
78    /// Raw structure used to pass to `cabi::wasip3_task_set`
79    wasip3_task: cabi::wasip3_task,
80
81    /// Rust-level state for the waker, notably a bool as to whether this has
82    /// been woken.
83    waker: Arc<FutureWaker>,
84
85    /// Clone of `waker` field, but represented as `std::task::Waker`.
86    waker_clone: Waker,
87}
88
89impl FutureState<'_> {
90    fn new(future: BoxFuture<'_>) -> FutureState<'_> {
91        let waker = Arc::new(FutureWaker::default());
92        FutureState {
93            waker_clone: waker.clone().into(),
94            waker,
95            tasks: spawn::Tasks::new(future),
96            waitable_set: None,
97            waitables: BTreeMap::new(),
98            wasip3_task: cabi::wasip3_task {
99                // This pointer is filled in before calling `wasip3_task_set`.
100                ptr: ptr::null_mut(),
101                version: cabi::WASIP3_TASK_V1,
102                waitable_register,
103                waitable_unregister,
104            },
105        }
106    }
107
108    fn get_or_create_waitable_set(&mut self) -> &WaitableSet {
109        self.waitable_set.get_or_insert_with(WaitableSet::new)
110    }
111
112    fn add_waitable(&mut self, waitable: u32) {
113        self.get_or_create_waitable_set().join(waitable)
114    }
115
116    fn remove_waitable(&mut self, waitable: u32) {
117        WaitableSet::remove_waitable_from_all_sets(waitable)
118    }
119
120    fn remaining_work(&self) -> bool {
121        !self.waitables.is_empty()
122    }
123
124    /// Handles the `event{0,1,2}` event codes and returns a corresponding
125    /// return code along with a flag whether this future is "done" or not.
126    fn callback(&mut self, event0: u32, event1: u32, event2: u32) -> (u32, bool) {
127        match event0 {
128            EVENT_NONE => rtdebug!("EVENT_NONE"),
129            EVENT_SUBTASK => rtdebug!("EVENT_SUBTASK({event1:#x}, {event2:#x})"),
130            EVENT_STREAM_READ => rtdebug!("EVENT_STREAM_READ({event1:#x}, {event2:#x})"),
131            EVENT_STREAM_WRITE => rtdebug!("EVENT_STREAM_WRITE({event1:#x}, {event2:#x})"),
132            EVENT_FUTURE_READ => rtdebug!("EVENT_FUTURE_READ({event1:#x}, {event2:#x})"),
133            EVENT_FUTURE_WRITE => rtdebug!("EVENT_FUTURE_WRITE({event1:#x}, {event2:#x})"),
134            EVENT_CANCEL => {
135                rtdebug!("EVENT_CANCEL");
136
137                // Cancellation is mapped to destruction in Rust, so return a
138                // code/bool indicating we're done. The caller will then
139                // appropriately deallocate this `FutureState` which will
140                // transitively run all destructors.
141                return (CALLBACK_CODE_EXIT, true);
142            }
143            _ => unreachable!(),
144        }
145        if event0 != EVENT_NONE {
146            self.deliver_waitable_event(event1, event2)
147        }
148
149        self.poll()
150    }
151
152    /// Deliver the `code` event to the `waitable` store within our map. This
153    /// waitable should be present because it's part of the waitable set which
154    /// is kept in-sync with our map.
155    fn deliver_waitable_event(&mut self, waitable: u32, code: u32) {
156        self.remove_waitable(waitable);
157        let (ptr, callback) = self.waitables.remove(&waitable).unwrap();
158        unsafe {
159            callback(ptr, code);
160        }
161    }
162
163    /// Poll this task until it either completes or can't make immediate
164    /// progress.
165    ///
166    /// Returns the code representing what happened along with a boolean as to
167    /// whether this execution is done.
168    fn poll(&mut self) -> (u32, bool) {
169        self.with_p3_task_set(|me| {
170            let mut context = Context::from_waker(&me.waker_clone);
171
172            loop {
173                // Reset the waker before polling to clear out any pending
174                // notification, if any.
175                me.waker.0.store(false, Ordering::Relaxed);
176
177                // Poll our future, seeing if it was able to make progress.
178                let poll = me.tasks.poll_next(&mut context);
179
180                match poll {
181                    // A future completed, yay! Keep going to see if more have
182                    // completed.
183                    Poll::Ready(Some(())) => (),
184
185                    // The task list is empty, but there might be remaining work
186                    // in terms of waitables through the cabi interface. In this
187                    // situation wait for all waitables to be resolved before
188                    // signaling that our own task is done.
189                    Poll::Ready(None) => {
190                        assert!(me.tasks.is_empty());
191                        if me.remaining_work() {
192                            let waitable = me.waitable_set.as_ref().unwrap().as_raw();
193                            break (CALLBACK_CODE_WAIT | (waitable << 4), false);
194                        } else {
195                            break (CALLBACK_CODE_EXIT, true);
196                        }
197                    }
198
199                    // Some future within `self.tasks` is not ready yet. If our
200                    // `waker` was signaled then that means this is a yield
201                    // operation, otherwise it means we're blocking on
202                    // something.
203                    Poll::Pending => {
204                        assert!(!me.tasks.is_empty());
205                        if me.waker.0.load(Ordering::Relaxed) {
206                            break (CALLBACK_CODE_YIELD, false);
207                        }
208
209                        assert!(me.remaining_work());
210                        let waitable = me.waitable_set.as_ref().unwrap().as_raw();
211                        break (CALLBACK_CODE_WAIT | (waitable << 4), false);
212                    }
213                }
214            }
215        })
216    }
217
218    fn with_p3_task_set<R>(&mut self, f: impl FnOnce(&mut Self) -> R) -> R {
219        // Finish our `wasip3_task` by initializing its self-referential pointer,
220        // and then register it for the duration of this function with
221        // `wasip3_task_set`. The previous value of `wasip3_task_set` will get
222        // restored when this function returns.
223        struct ResetTask(*mut cabi::wasip3_task);
224        impl Drop for ResetTask {
225            fn drop(&mut self) {
226                unsafe {
227                    cabi::wasip3_task_set(self.0);
228                }
229            }
230        }
231        let self_raw = self as *mut FutureState<'_>;
232        self.wasip3_task.ptr = self_raw.cast();
233        let prev = unsafe { cabi::wasip3_task_set(&mut self.wasip3_task) };
234        let _reset = ResetTask(prev);
235
236        f(self)
237    }
238}
239
240impl Drop for FutureState<'_> {
241    fn drop(&mut self) {
242        // If this state has active tasks then they need to be dropped which may
243        // execute arbitrary code. This arbitrary code might require the p3 APIs
244        // for managing waitables, notably around removing them. In this
245        // situation we ensure that the p3 task is set while futures are being
246        // destroyed.
247        if !self.tasks.is_empty() {
248            self.with_p3_task_set(|me| {
249                me.tasks = Default::default();
250            })
251        }
252    }
253}
254
255unsafe extern "C" fn waitable_register(
256    ptr: *mut c_void,
257    waitable: u32,
258    callback: unsafe extern "C" fn(*mut c_void, u32),
259    callback_ptr: *mut c_void,
260) -> *mut c_void {
261    let ptr = ptr.cast::<FutureState<'static>>();
262    assert!(!ptr.is_null());
263    (*ptr).add_waitable(waitable);
264    match (*ptr).waitables.insert(waitable, (callback_ptr, callback)) {
265        Some((prev, _)) => prev,
266        None => ptr::null_mut(),
267    }
268}
269
270unsafe extern "C" fn waitable_unregister(ptr: *mut c_void, waitable: u32) -> *mut c_void {
271    let ptr = ptr.cast::<FutureState<'static>>();
272    assert!(!ptr.is_null());
273    (*ptr).remove_waitable(waitable);
274    match (*ptr).waitables.remove(&waitable) {
275        Some((prev, _)) => prev,
276        None => ptr::null_mut(),
277    }
278}
279
280#[derive(Default)]
281struct FutureWaker(AtomicBool);
282
283impl Wake for FutureWaker {
284    fn wake(self: Arc<Self>) {
285        Self::wake_by_ref(&self)
286    }
287
288    fn wake_by_ref(self: &Arc<Self>) {
289        self.0.store(true, Ordering::Relaxed)
290    }
291}
292
293const EVENT_NONE: u32 = 0;
294const EVENT_SUBTASK: u32 = 1;
295const EVENT_STREAM_READ: u32 = 2;
296const EVENT_STREAM_WRITE: u32 = 3;
297const EVENT_FUTURE_READ: u32 = 4;
298const EVENT_FUTURE_WRITE: u32 = 5;
299const EVENT_CANCEL: u32 = 6;
300
301const CALLBACK_CODE_EXIT: u32 = 0;
302const CALLBACK_CODE_YIELD: u32 = 1;
303const CALLBACK_CODE_WAIT: u32 = 2;
304const _CALLBACK_CODE_POLL: u32 = 3;
305
306const STATUS_STARTING: u32 = 0;
307const STATUS_STARTED: u32 = 1;
308const STATUS_RETURNED: u32 = 2;
309const STATUS_STARTED_CANCELLED: u32 = 3;
310const STATUS_RETURNED_CANCELLED: u32 = 4;
311
312const BLOCKED: u32 = 0xffff_ffff;
313const COMPLETED: u32 = 0x0;
314const DROPPED: u32 = 0x1;
315const CANCELLED: u32 = 0x2;
316
317/// Return code of stream/future operations.
318#[derive(PartialEq, Debug, Copy, Clone)]
319enum ReturnCode {
320    /// The operation is blocked and has not completed.
321    Blocked,
322    /// The operation completed with the specified number of items.
323    Completed(u32),
324    /// The other end is dropped, but before that the specified number of items
325    /// were transferred.
326    Dropped(u32),
327    /// The operation was cancelled, but before that the specified number of
328    /// items were transferred.
329    Cancelled(u32),
330}
331
332impl ReturnCode {
333    fn decode(val: u32) -> ReturnCode {
334        if val == BLOCKED {
335            return ReturnCode::Blocked;
336        }
337        let amt = val >> 4;
338        match val & 0xf {
339            COMPLETED => ReturnCode::Completed(amt),
340            DROPPED => ReturnCode::Dropped(amt),
341            CANCELLED => ReturnCode::Cancelled(amt),
342            _ => panic!("unknown return code {val:#x}"),
343        }
344    }
345}
346
347/// Starts execution of the `task` provided, an asynchronous computation.
348///
349/// This is used for async-lifted exports at their definition site. The
350/// representation of the export is `task` and this function is called from the
351/// entrypoint. The code returned here is the same as the callback associated
352/// with this export, and the callback will be used if this task doesn't exit
353/// immediately with its result.
354#[doc(hidden)]
355pub fn start_task(task: impl Future<Output = ()> + 'static) -> i32 {
356    // Allocate a new `FutureState` which will track all state necessary for
357    // our exported task.
358    let state = Box::into_raw(Box::new(FutureState::new(Box::pin(task))));
359
360    // Store our `FutureState` into our context-local-storage slot and then
361    // pretend we got EVENT_NONE to kick off everything.
362    //
363    // SAFETY: we should own `context.set` as we're the root level exported
364    // task, and then `callback` is only invoked when context-local storage is
365    // valid.
366    unsafe {
367        assert!(context_get().is_null());
368        context_set(state.cast());
369        callback(EVENT_NONE, 0, 0) as i32
370    }
371}
372
373/// Handle a progress notification from the host regarding either a call to an
374/// async-lowered import or a stream/future read/write operation.
375///
376/// # Unsafety
377///
378/// This function assumes that `context_get()` returns a `FutureState`.
379#[doc(hidden)]
380pub unsafe fn callback(event0: u32, event1: u32, event2: u32) -> u32 {
381    // Acquire our context-local state, assert it's not-null, and then reset
382    // the state to null while we're running to help prevent any unintended
383    // usage.
384    let state = context_get().cast::<FutureState<'static>>();
385    assert!(!state.is_null());
386    unsafe {
387        context_set(ptr::null_mut());
388    }
389
390    // Use `state` to run the `callback` function in the context of our event
391    // codes we received. If the callback decides to exit then we're done with
392    // our future so deallocate it. Otherwise put our future back in
393    // context-local storage and forward the code.
394    unsafe {
395        let (rc, done) = (*state).callback(event0, event1, event2);
396        if done {
397            drop(Box::from_raw(state));
398        } else {
399            context_set(state.cast());
400        }
401        rtdebug!(" => (cb) {rc:#x}");
402        rc
403    }
404}
405
406/// Run the specified future to completion, returning the result.
407///
408/// This uses `waitable-set.wait` to poll for progress on any in-progress calls
409/// to async-lowered imports as necessary.
410// TODO: refactor so `'static` bounds aren't necessary
411pub fn block_on<T: 'static>(future: impl Future<Output = T>) -> T {
412    let mut result = None;
413    let mut state = FutureState::new(Box::pin(async {
414        result = Some(future.await);
415    }));
416    let mut event = (EVENT_NONE, 0, 0);
417    loop {
418        match state.callback(event.0, event.1, event.2) {
419            (_, true) => {
420                drop(state);
421                break result.unwrap();
422            }
423            (CALLBACK_CODE_YIELD, false) => event = state.waitable_set.as_ref().unwrap().poll(),
424            _ => event = state.waitable_set.as_ref().unwrap().wait(),
425        }
426    }
427}
428
429/// Call the `yield` canonical built-in function.
430///
431/// This yields control to the host temporarily, allowing other tasks to make
432/// progress. It's a good idea to call this inside a busy loop which does not
433/// otherwise ever yield control the host.
434///
435/// Note that this function is a blocking function, not an `async` function.
436/// That means that this is not an async yield which allows other tasks in this
437/// component to progress, but instead this will block the current function
438/// until the host gets back around to returning from this yield. Asynchronous
439/// functions should probably use [`yield_async`] instead.
440///
441/// # Return Value
442///
443/// This function returns a `bool` which indicates whether execution should
444/// continue after this yield point. A return value of `true` means that the
445/// task was not cancelled and execution should continue. A return value of
446/// `false`, however, means that the task was cancelled while it was suspended
447/// at this yield point. The caller should return back and exit from the task
448/// ASAP in this situation.
449pub fn yield_blocking() -> bool {
450    #[cfg(not(target_arch = "wasm32"))]
451    unsafe fn yield_() -> bool {
452        unreachable!();
453    }
454
455    #[cfg(target_arch = "wasm32")]
456    #[link(wasm_import_module = "$root")]
457    extern "C" {
458        #[link_name = "[thread-yield]"]
459        fn yield_() -> bool;
460    }
461    // Note that the return value from the raw intrinsic is inverted, the
462    // canonical ABI returns "did this task get cancelled" while this function
463    // works as "should work continue going".
464    unsafe { !yield_() }
465}
466
467/// The asynchronous counterpart to [`yield_blocking`].
468///
469/// This function does not block the current task but instead gives the
470/// Rust-level executor a chance to yield control back to the host temporarily.
471/// This means that other Rust-level tasks may also be able to progress during
472/// this yield operation.
473///
474/// # Return Value
475///
476/// Unlike [`yield_blocking`] this function does not return anything. If this
477/// component task is cancelled while paused at this yield point then the future
478/// will be dropped and a Rust-level destructor will take over and clean up the
479/// task. It's not necessary to do anything with the return value of this
480/// function other than ensuring that you `.await` the function call.
481pub async fn yield_async() {
482    #[derive(Default)]
483    struct Yield {
484        yielded: bool,
485    }
486
487    impl Future for Yield {
488        type Output = ();
489
490        fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<()> {
491            if self.yielded {
492                Poll::Ready(())
493            } else {
494                self.yielded = true;
495                context.waker().wake_by_ref();
496                Poll::Pending
497            }
498        }
499    }
500
501    Yield::default().await;
502}
503
504/// Call the `backpressure.set` canonical built-in function.
505///
506/// When `enabled` is `true`, this tells the host to defer any new calls to this
507/// component instance until further notice (i.e. until `backpressure.set` is
508/// called again with `enabled` set to `false`).
509#[deprecated = "use backpressure_{inc,dec} instead"]
510pub fn backpressure_set(enabled: bool) {
511    #[cfg(not(target_arch = "wasm32"))]
512    unsafe fn backpressure_set(_: i32) {
513        unreachable!();
514    }
515
516    #[cfg(target_arch = "wasm32")]
517    #[link(wasm_import_module = "$root")]
518    extern "C" {
519        #[link_name = "[backpressure-set]"]
520        fn backpressure_set(_: i32);
521    }
522
523    unsafe { backpressure_set(if enabled { 1 } else { 0 }) }
524}
525
526/// Call the `backpressure.inc` canonical built-in function.
527pub fn backpressure_inc() {
528    #[cfg(not(target_arch = "wasm32"))]
529    unsafe fn backpressure_inc() {
530        unreachable!();
531    }
532
533    #[cfg(target_arch = "wasm32")]
534    #[link(wasm_import_module = "$root")]
535    extern "C" {
536        #[link_name = "[backpressure-inc]"]
537        fn backpressure_inc();
538    }
539
540    unsafe { backpressure_inc() }
541}
542
543/// Call the `backpressure.dec` canonical built-in function.
544pub fn backpressure_dec() {
545    #[cfg(not(target_arch = "wasm32"))]
546    unsafe fn backpressure_dec() {
547        unreachable!();
548    }
549
550    #[cfg(target_arch = "wasm32")]
551    #[link(wasm_import_module = "$root")]
552    extern "C" {
553        #[link_name = "[backpressure-dec]"]
554        fn backpressure_dec();
555    }
556
557    unsafe { backpressure_dec() }
558}
559
560fn context_get() -> *mut u8 {
561    #[cfg(not(target_arch = "wasm32"))]
562    unsafe fn get() -> *mut u8 {
563        unreachable!()
564    }
565
566    #[cfg(target_arch = "wasm32")]
567    #[link(wasm_import_module = "$root")]
568    extern "C" {
569        #[link_name = "[context-get-0]"]
570        fn get() -> *mut u8;
571    }
572
573    unsafe { get() }
574}
575
576unsafe fn context_set(value: *mut u8) {
577    #[cfg(not(target_arch = "wasm32"))]
578    unsafe fn set(_: *mut u8) {
579        unreachable!()
580    }
581
582    #[cfg(target_arch = "wasm32")]
583    #[link(wasm_import_module = "$root")]
584    extern "C" {
585        #[link_name = "[context-set-0]"]
586        fn set(value: *mut u8);
587    }
588
589    unsafe { set(value) }
590}
591
592#[doc(hidden)]
593pub struct TaskCancelOnDrop {
594    _priv: (),
595}
596
597impl TaskCancelOnDrop {
598    #[doc(hidden)]
599    pub fn new() -> TaskCancelOnDrop {
600        TaskCancelOnDrop { _priv: () }
601    }
602
603    #[doc(hidden)]
604    pub fn forget(self) {
605        mem::forget(self);
606    }
607}
608
609impl Drop for TaskCancelOnDrop {
610    fn drop(&mut self) {
611        #[cfg(not(target_arch = "wasm32"))]
612        unsafe fn cancel() {
613            unreachable!()
614        }
615
616        #[cfg(target_arch = "wasm32")]
617        #[link(wasm_import_module = "[export]$root")]
618        extern "C" {
619            #[link_name = "[task-cancel]"]
620            fn cancel();
621        }
622
623        unsafe { cancel() }
624    }
625}