use std::{
future::Future,
mem::MaybeUninit,
ptr::{self, addr_of},
sync::atomic::{AtomicPtr, Ordering},
};
use js_sys::{
wasm_bindgen::{closure::Closure, JsValue},
Promise,
};
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen]
fn queueMicrotask(closure: &Closure<dyn FnMut(JsValue)>);
type Global;
#[wasm_bindgen(method, getter, js_name = queueMicrotask)]
fn hasQueueMicrotask(this: &Global) -> JsValue;
}
use crate::arena::{Index, TaskArena};
pub(crate) enum Scheduler {
Microtask {
poll_enqueued: Closure<dyn FnMut(JsValue)>,
},
Promise {
promise: Promise,
poll_enqueued: Closure<dyn FnMut(JsValue)>,
},
}
impl Scheduler {
fn new() -> Self {
let has_queue_microtask = js_sys::global()
.unchecked_into::<Global>()
.hasQueueMicrotask()
.is_function();
let poll_enqueued = Closure::new(|_| unsafe { RUNTIME.poll_enqueued() });
if has_queue_microtask {
Scheduler::Microtask { poll_enqueued }
} else {
Scheduler::Promise {
promise: Promise::resolve(&JsValue::undefined()),
poll_enqueued,
}
}
}
pub(crate) unsafe fn schedule_polling(head: *mut ()) {
thread_local! {
static SCHEDULER: Scheduler = Scheduler::new();
}
RUNTIME.poll_head.store(head, Ordering::Release);
SCHEDULER.with(|scheduler| match scheduler {
Scheduler::Microtask { poll_enqueued } => {
queueMicrotask(poll_enqueued);
}
Scheduler::Promise {
promise,
poll_enqueued,
} => {
let _ = promise.then(poll_enqueued);
}
});
}
}
pub(crate) struct Runtime {
poll_head: AtomicPtr<()>,
pub(crate) poll_tail: AtomicPtr<()>,
pub(crate) free_tail: AtomicPtr<()>,
pub(crate) next: AtomicPtr<()>,
arena: TaskArena,
}
unsafe impl Sync for Runtime {}
pub(crate) static RUNTIME: Runtime = Runtime::new();
impl Runtime {
const fn new() -> Self {
Runtime {
poll_head: AtomicPtr::new(ptr::null_mut()),
poll_tail: AtomicPtr::new(ptr::null_mut()),
free_tail: AtomicPtr::new(ptr::null_mut()),
next: AtomicPtr::new(ptr::null_mut()),
arena: TaskArena::new(),
}
}
unsafe fn poll_enqueued(&'static self) {
let mut head = self.poll_head.swap(ptr::null_mut(), Ordering::AcqRel);
self.poll_tail.store(ptr::null_mut(), Ordering::Release);
while !head.is_null() {
head = Index::from_raw(head).poll();
}
}
}
pub fn spawn_local<F>(future: F)
where
F: Future<Output = ()> + 'static,
{
let (ptr, index) = {
let ptr = RUNTIME.next.load(Ordering::Acquire);
if ptr.is_null() {
(
addr_of!(RUNTIME.arena) as *mut (),
Index::new(&RUNTIME.arena, 0),
)
} else {
unsafe { (ptr, Index::from_raw(ptr)) }
}
};
let occupancy = index.set_as_occupied();
RUNTIME
.next
.store(index.next_index(occupancy), Ordering::Release);
unsafe {
*index.handle().task.get() = MaybeUninit::new(Box::pin(future));
}
let tail = RUNTIME.poll_tail.swap(ptr, Ordering::AcqRel);
if !tail.is_null() {
unsafe {
Index::from_raw(tail)
.handle()
.next_enqueued
.store(ptr, Ordering::Release);
}
} else {
unsafe {
Scheduler::schedule_polling(ptr);
}
}
}
#[cfg(all(test, target_arch = "wasm32"))]
mod tests {
use std::ops::FnMut;
use futures_channel::oneshot;
use js_sys::{
wasm_bindgen::{closure::Closure, JsValue},
Promise,
};
use wasm_bindgen_futures::JsFuture;
use wasm_bindgen_test::*;
use super::spawn_local;
#[wasm_bindgen_test]
async fn spawn_local_runs() {
let (tx, rx) = oneshot::channel::<u32>();
spawn_local(async {
tx.send(42).unwrap();
});
assert_eq!(rx.await.unwrap(), 42);
}
#[wasm_bindgen_test]
async fn spawn_local_nested() {
let (ta, mut ra) = oneshot::channel::<u32>();
let (ts, rs) = oneshot::channel::<u32>();
let (tx, rx) = oneshot::channel::<u32>();
let promise = Promise::resolve(&JsValue::null());
spawn_local(async move {
let inbetween = Closure::wrap(Box::new(move |_| {
assert_eq!(
ra.try_recv().unwrap(),
None,
"Nested task should not have run yet"
);
}) as Box<dyn FnMut(JsValue)>);
let inbetween = promise.then(&inbetween);
spawn_local(async {
ta.send(0xdead).unwrap();
ts.send(0xbeaf).unwrap();
});
JsFuture::from(inbetween).await.unwrap();
assert_eq!(
rs.await.unwrap(),
0xbeaf,
"Nested task should run eventually"
);
tx.send(42).unwrap();
});
assert_eq!(rx.await.unwrap(), 42);
}
#[wasm_bindgen_test]
async fn spawn_local_err_no_exception() {
let (tx, rx) = oneshot::channel::<u32>();
spawn_local(async {});
spawn_local(async {
tx.send(42).unwrap();
});
let val = rx.await.unwrap();
assert_eq!(val, 42);
}
}