1use futures::channel::oneshot;
24use futures::task::{waker_ref, ArcWake};
25#[cfg(feature = "debug")]
26use std::any::{type_name, TypeId};
27use std::cell::UnsafeCell;
28use std::collections::BTreeMap;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33
34type Token = usize;
36
37#[cfg(feature = "debug")]
38#[derive(Clone, Debug)]
39pub struct TypeInfo {
40 type_id: Option<TypeId>,
41 type_name: &'static str,
42}
43
44#[cfg(feature = "debug")]
45impl TypeInfo {
46 fn new<T>() -> Self
47 where
48 T: 'static,
49 {
50 Self {
51 type_name: type_name::<T>(),
52 type_id: Some(TypeId::of::<T>()),
53 }
54 }
55
56 fn new_non_static<T>() -> Self {
57 Self {
58 type_name: type_name::<T>(),
59 type_id: None,
60 }
61 }
62
63 pub fn type_name(&self) -> &'static str {
65 self.type_name
66 }
67
68 pub fn type_id(&self) -> Option<TypeId> {
72 self.type_id
73 }
74}
75
76#[derive(Clone)]
78pub struct Task {
79 token: Token,
80 #[cfg(feature = "debug")]
81 type_info: Arc<TypeInfo>,
82}
83
84impl PartialEq for Task {
85 fn eq(&self, other: &Self) -> bool {
86 self.token == other.token
87 }
88}
89
90impl Eq for Task {}
91
92impl PartialOrd for Task {
93 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
94 self.token.partial_cmp(&other.token)
95 }
96}
97
98impl Ord for Task {
99 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
100 self.token.cmp(&other.token)
101 }
102}
103
104impl Task {
105 #[cfg(feature = "debug")]
106 pub fn type_info(&self) -> &TypeInfo {
107 self.type_info.as_ref()
108 }
109}
110
111pub struct TaskHandle<T> {
115 receiver: oneshot::Receiver<T>,
116 task: Task,
117}
118
119impl<T> TaskHandle<T> {
120 pub fn task(&self) -> Task {
122 self.task.clone()
123 }
124}
125
126#[derive(Debug, Clone)]
128pub enum JoinError {
129 Canceled,
131}
132
133impl<T> Future for TaskHandle<T> {
134 type Output = Result<T, JoinError>;
135 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
136 match self.receiver.try_recv() {
137 Err(oneshot::Canceled) => Poll::Ready(Err(JoinError::Canceled)),
138 Ok(Some(result)) => Poll::Ready(Ok(result)),
139 Ok(None) => {
140 cx.waker().wake_by_ref();
141 Poll::Pending
142 }
143 }
144 }
145}
146
147impl ArcWake for Task {
148 fn wake_by_ref(arc_self: &Arc<Self>) {
149 EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).enqueue(arc_self.clone()))
150 }
151}
152
153struct Executor {
155 counter: Token,
156 futures: BTreeMap<Task, Pin<Box<dyn Future<Output = ()>>>>,
157 queue: Vec<Arc<Task>>,
158}
159
160impl Executor {
161 fn new() -> Self {
162 Self {
163 counter: 0,
164 futures: BTreeMap::new(),
165 queue: vec![],
166 }
167 }
168
169 fn enqueue(&mut self, task: Arc<Task>) {
170 if self.futures.contains_key(&task) {
171 self.queue.insert(0, task);
172 }
173 }
174
175 fn spawn<F, T>(&mut self, fut: F) -> TaskHandle<T>
176 where
177 F: Future<Output = T> + 'static,
178 T: 'static,
179 {
180 let token = self.counter;
181 self.counter = self.counter.wrapping_add(1);
182 let task = Task {
183 token,
184 #[cfg(feature = "debug")]
185 type_info: Arc::new(TypeInfo::new::<F>()),
186 };
187
188 let (sender, receiver) = oneshot::channel();
189
190 self.futures.insert(task.clone(), unsafe {
191 Pin::new_unchecked(Box::new(async move {
192 let _ = sender.send(fut.await);
193 }) as Box<dyn Future<Output = ()>>)
194 });
195 self.queue.push(Arc::new(task.clone()));
196 TaskHandle { receiver, task }
197 }
198
199 fn spawn_non_static<F, T>(&mut self, fut: F) -> TaskHandle<T>
200 where
201 F: Future<Output = T>,
202 {
203 let token = self.counter;
204 self.counter = self.counter.wrapping_add(1);
205 let task = Task {
206 token,
207 #[cfg(feature = "debug")]
208 type_info: Arc::new(TypeInfo::new_non_static::<F>()),
209 };
210
211 let (sender, receiver) = oneshot::channel();
212
213 self.futures.insert(task.clone(), unsafe {
214 Pin::new_unchecked(std::mem::transmute::<_, Box<dyn Future<Output = ()>>>(
215 Box::new(async move {
216 let _ = sender.send(fut.await);
217 }) as Box<dyn Future<Output = ()>>,
218 ))
219 });
220 self.queue.push(Arc::new(task.clone()));
221 TaskHandle { receiver, task }
222 }
223}
224
225thread_local! {
226 static EXECUTOR: UnsafeCell<Executor> = UnsafeCell::new(Executor::new()) ;
227}
228
229thread_local! {
230 static UNTIL: UnsafeCell<Option<Task>> = UnsafeCell::new(None) ;
231}
232
233thread_local! {
234 static UNTIL_SATISFIED: UnsafeCell<bool> = UnsafeCell::new(false) ;
235}
236
237thread_local! {
238 static YIELD: UnsafeCell<bool> = UnsafeCell::new(true) ;
239}
240
241thread_local! {
242 static EXIT_LOOP: UnsafeCell<bool> = UnsafeCell::new(false) ;
243}
244
245pub fn spawn<F, T>(fut: F) -> TaskHandle<T>
247where
248 F: Future<Output = T> + 'static,
249 T: 'static,
250{
251 EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).spawn(fut))
252}
253
254#[cfg(not(feature = "cooperative"))]
258pub fn block_on<F, R>(fut: F) -> R
259where
260 F: Future<Output = R>,
261{
262 let mut handle = EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).spawn_non_static(fut));
265 run(Some(handle.task()));
266 loop {
267 match handle.receiver.try_recv() {
268 Ok(None) => {}
269 Ok(Some(v)) => return v,
270 Err(_) => unreachable!(), }
272 }
273}
274
275#[cfg(feature = "cooperative")]
284pub fn block_on<F, R>(fut: F) -> R
285where
286 F: Future<Output = R>,
287{
288 let mut handle = EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).spawn_non_static(fut));
289 YIELD.with(|cell| unsafe {
290 *cell.get() = false;
291 });
292 run(Some(handle.task()));
293 YIELD.with(|cell| unsafe {
294 *cell.get() = true;
295 });
296 loop {
297 match handle.receiver.try_recv() {
298 Ok(None) => {}
299 Ok(Some(v)) => return v,
300 Err(_) => unreachable!(), }
302 }
303}
304
305pub fn run(until: Option<Task>) {
313 UNTIL.with(|cell| unsafe { *cell.get() = until });
314 UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = false });
315 run_internal();
316}
317
318fn run_internal() -> bool {
323 let until = UNTIL.with(|cell| unsafe { &*cell.get() });
324 let exit_condition_met = UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() });
325 if exit_condition_met {
326 return true;
327 }
328 EXECUTOR.with(|cell| loop {
329 let task = (unsafe { &mut *cell.get() }).queue.pop();
330
331 if let Some(task) = task {
332 let future = (unsafe { &mut *cell.get() }).futures.get_mut(&task);
333 let ready = if let Some(future) = future {
334 let waker = waker_ref(&task);
335 let context = &mut Context::from_waker(&*waker);
336 let ready = matches!(future.as_mut().poll(context), Poll::Ready(_));
337 ready
338 } else {
339 false
340 };
341 if ready {
342 (unsafe { &mut *cell.get() }).futures.remove(&task);
343
344 if let Some(Task { ref token, .. }) = until {
345 if *token == task.token {
346 UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = true });
347 return true;
348 }
349 }
350 }
351 }
352 if until.is_none() && (unsafe { &mut *cell.get() }).futures.is_empty() {
353 UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = true });
354 return true;
355 }
356
357 let exit_requested = EXIT_LOOP.with(|cell| {
358 let v = cell.get();
359 let result = unsafe { *v };
360 unsafe {
362 *v = false;
363 }
364 result
365 }) && YIELD.with(|cell| unsafe { *cell.get() });
366
367 if exit_requested {
368 return false;
369 }
370
371 if (unsafe { &mut *cell.get() }).queue.is_empty()
372 && !(unsafe { &mut *cell.get() }).futures.is_empty()
373 {
374 for task in (unsafe { &mut *cell.get() }).futures.keys() {
376 (unsafe { &mut *cell.get() }).enqueue(Arc::new(task.clone()));
377 }
378 }
379 })
380}
381
382#[cfg(all(
383 feature = "cooperative",
384 target_arch = "wasm32",
385 not(target_os = "wasi")
386))]
387mod cooperative {
388 use super::{run_internal, EXIT_LOOP};
389 use pin_project::pin_project;
390 use std::cell::UnsafeCell;
391 use std::future::Future;
392 use std::pin::Pin;
393 use std::sync::Arc;
394 use std::task::{Context, Poll};
395 use std::time::Duration;
396 use wasm_bindgen::prelude::*;
397
398 #[wasm_bindgen]
399 extern "C" {
400 #[wasm_bindgen(js_name = "setTimeout")]
401 fn set_timeout(_: JsValue, delay: u32);
402
403 #[cfg(feature = "requestIdleCallback")]
404 #[wasm_bindgen(js_name = "requestIdleCallback")]
405 fn request_idle_callback(_: JsValue, options: &JsValue);
406
407 #[cfg(feature = "cooperative-browser")]
408 #[wasm_bindgen(js_name = "requestAnimationFrame")]
409 fn request_animation_frame(_: JsValue);
410
411 }
412
413 #[pin_project]
414 struct TimeoutYield<F, O>
415 where
416 F: Future<Output = O> + 'static,
417 {
418 yielded: bool,
419 duration: Option<Duration>,
420 done: bool,
421 output: Option<O>,
422 #[pin]
423 future: F,
424 ready: Arc<UnsafeCell<bool>>,
425 }
426
427 impl<F, O> Future for TimeoutYield<F, O>
428 where
429 F: Future<Output = O> + 'static,
430 {
431 type Output = O;
432 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
433 if self.done {
434 return Poll::Pending;
435 }
436 if self.yielded && !unsafe { *self.ready.get() } {
437 cx.waker().wake_by_ref();
438 return Poll::Pending;
439 }
440 let should_yield = !self.yielded;
441 let this = self.project();
442 if *this.yielded && unsafe { *this.ready.get() } && this.output.is_some() {
443 let output = this.output.take().unwrap();
445 *this.done = true;
446 return Poll::Ready(output);
447 }
448 match (should_yield, this.future.poll(cx)) {
449 (_, result @ Poll::Pending) | (true, result) => {
450 *this.yielded = true;
451 if cfg!(target_arch = "wasm32") {
452 if let Some(duration) = this.duration {
455 if duration.as_millis() > 0 {
456 set_timeout(
457 Closure::once_into_js(move || {
458 run_internal();
459 }),
460 0,
461 );
462 }
463 }
464
465 if should_yield {
466 let ready = this.ready.clone();
467
468 set_timeout(
469 Closure::once_into_js(move || {
470 unsafe { *ready.get() = true };
471 run_internal();
472 }),
473 this.duration
474 .unwrap_or(Duration::from_millis(0))
475 .as_millis() as u32,
476 );
477 }
478 EXIT_LOOP.with(|cell| unsafe { *cell.get() = true });
479 }
480 if let Poll::Ready(output) = result {
481 this.output.replace(output);
482 }
483 cx.waker().wake_by_ref();
484 Poll::Pending
485 }
486 (false, Poll::Ready(output)) => {
487 *this.done = true;
488 Poll::Ready(output)
489 }
490 }
491 }
492 }
493
494 pub fn yield_timeout(duration: Duration) -> impl Future<Output = ()> {
505 TimeoutYield {
506 future: futures::future::ready(()),
507 duration: Some(duration),
508 output: None,
509 yielded: false,
510 done: false,
511 ready: Arc::new(UnsafeCell::new(false)),
512 }
513 }
514
515 pub fn yield_async<F, O>(future: F) -> impl Future<Output = O>
521 where
522 F: Future<Output = O> + 'static,
523 {
524 TimeoutYield {
525 future,
526 duration: None,
527 output: None,
528 yielded: false,
529 done: false,
530 ready: Arc::new(UnsafeCell::new(false)),
531 }
532 }
533
534 #[cfg(feature = "cooperative-browser")]
535 #[pin_project]
536 struct AnimationFrameYield {
537 yielded: bool,
538 done: bool,
539 output: Arc<UnsafeCell<Option<f64>>>,
540 }
541
542 #[cfg(feature = "cooperative-browser")]
543 impl Future for AnimationFrameYield {
544 type Output = f64;
545 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
546 if self.done {
547 return Poll::Pending;
548 }
549 let should_yield = !self.yielded;
550 let this = self.project();
551 if *this.yielded && unsafe { &*this.output.get() }.is_some() {
552 let output = unsafe { &mut *this.output.get() }.take().unwrap();
554 *this.done = true;
555 return Poll::Ready(output);
556 }
557
558 if should_yield {
559 *this.yielded = true;
560 if cfg!(target_arch = "wasm32") {
561 let output = this.output.clone();
562 request_animation_frame(Closure::once_into_js(move |timestamp| {
563 unsafe { &mut *output.get() }.replace(timestamp);
564 run_internal();
565 }));
566 EXIT_LOOP.with(|cell| unsafe { *cell.get() = true });
567 }
568 }
569
570 cx.waker().wake_by_ref();
571
572 Poll::Pending
573 }
574 }
575
576 #[cfg(feature = "cooperative-browser")]
587 pub fn yield_animation_frame() -> impl Future<Output = f64> {
588 AnimationFrameYield {
589 output: Arc::new(UnsafeCell::new(None)),
590 yielded: false,
591 done: false,
592 }
593 }
594
595 #[cfg(feature = "requestIdleCallback")]
596 #[pin_project]
597 struct UntilIdleYield {
598 timeout: Option<Duration>,
599 yielded: bool,
600 done: bool,
601 output: Arc<UnsafeCell<Option<web_sys::IdleDeadline>>>,
602 }
603
604 #[cfg(feature = "requestIdleCallback")]
605 impl Future for UntilIdleYield {
606 type Output = web_sys::IdleDeadline;
607 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
608 if self.done {
609 return Poll::Pending;
610 }
611 let should_yield = !self.yielded;
612 let this = self.project();
613 if *this.yielded && unsafe { &*this.output.get() }.is_some() {
614 let output = unsafe { &mut *this.output.get() }.take().unwrap();
616 *this.done = true;
617 return Poll::Ready(output);
618 }
619
620 if should_yield {
621 *this.yielded = true;
622 if cfg!(target_arch = "wasm32") {
623 let map = js_sys::Map::new();
624 if let Some(timeout) = this.timeout {
625 map.set(&"timeout".into(), &(timeout.as_millis() as u32).into());
626 }
627 let options =
628 js_sys::Object::from_entries(&map).unwrap_or(js_sys::Object::new());
629 let output = this.output.clone();
630 request_idle_callback(
631 Closure::once_into_js(move |timestamp| {
632 unsafe { &mut *output.get() }.replace(timestamp);
633 run_internal();
634 }),
635 &options.into(),
636 );
637 EXIT_LOOP.with(|cell| unsafe { *cell.get() = true });
638 }
639 }
640
641 cx.waker().wake_by_ref();
642
643 Poll::Pending
644 }
645 }
646
647 #[cfg(feature = "requestIdleCallback")]
657 pub fn yield_until_idle(
658 timeout: Option<Duration>,
659 ) -> impl Future<Output = web_sys::IdleDeadline> {
660 UntilIdleYield {
661 timeout,
662 output: Arc::new(UnsafeCell::new(None)),
663 yielded: false,
664 done: false,
665 }
666 }
667}
668
669#[cfg(all(
670 feature = "cooperative",
671 target_arch = "wasm32",
672 not(target_os = "wasi")
673))]
674pub use cooperative::*;
675
676pub fn tasks_count() -> usize {
678 EXECUTOR.with(|cell| {
679 let executor = unsafe { &mut *cell.get() };
680 executor.futures.len()
681 })
682}
683
684pub fn queued_tasks_count() -> usize {
686 EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).queue.len())
687}
688
689pub fn tasks() -> Vec<Task> {
691 EXECUTOR.with(|cell| {
692 (unsafe { &*cell.get() })
693 .futures
694 .keys()
695 .map(|t| Task::clone(&t))
696 .collect()
697 })
698}
699
700pub fn queued_tasks() -> Vec<Task> {
702 EXECUTOR.with(|cell| {
703 (unsafe { &*cell.get() })
704 .queue
705 .iter()
706 .map(|t| Task::clone(&t))
707 .collect()
708 })
709}
710
711pub fn evict_all() {
717 EXECUTOR.with(|cell| unsafe { *cell.get() = Executor::new() });
718}
719
720#[cfg(test)]
721fn set_counter(counter: usize) {
722 EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).counter = counter);
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728 #[cfg(all(target_arch = "wasm32", target_os = "unknown"))]
729 use wasm_bindgen_test::*;
730
731 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
732 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
733 fn test() {
734 use tokio::sync::*;
735 let (sender, receiver) = oneshot::channel::<()>();
736 let _handle = spawn(async move {
737 let _ = receiver.await;
738 });
739 let _ = sender.send(());
740 run(None);
741 evict_all();
742 }
743
744 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
745 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
746 fn test_until() {
747 use tokio::sync::*;
748 let (_sender1, receiver1) = oneshot::channel::<()>();
749 let _handle1 = spawn(async move {
750 let _ = receiver1.await;
751 });
752 let (sender2, receiver2) = oneshot::channel::<()>();
753 let handle2 = spawn(async move {
754 let _ = receiver2.await;
755 });
756 let _ = sender2.send(());
757 run(Some(handle2.task()));
758 evict_all();
759 }
760
761 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
762 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
763 fn test_counts() {
764 use tokio::sync::*;
765 let (sender, mut receiver) = oneshot::channel();
766 let (sender2, receiver2) = oneshot::channel::<()>();
767 let handle1 = spawn(async move {
768 let _ = receiver2.await;
769 let _ = sender.send((tasks_count(), queued_tasks_count()));
770 });
771 let _handle2 = spawn(async move {
772 let _ = sender2.send(());
773 futures::future::pending::<()>().await });
775 run(Some(handle1.task()));
776 let (tasks_, queued_tasks_) = receiver.try_recv().unwrap();
777 assert_eq!(tasks_, 2);
779 assert_eq!(queued_tasks_, 0);
781 assert_eq!(tasks_count(), 1);
783 assert_eq!(queued_tasks_count(), 0);
785 evict_all();
786 }
787
788 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
789 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
790 fn evicted_tasks_dont_requeue() {
791 use tokio::sync::*;
792 let (_sender, receiver) = oneshot::channel::<()>();
793 let handle = spawn(async move {
794 let _ = receiver.await;
795 });
796 assert_eq!(tasks_count(), 1);
797 evict_all();
798 assert_eq!(tasks_count(), 0);
799 ArcWake::wake_by_ref(&Arc::new(handle.task()));
800 assert_eq!(tasks_count(), 0);
801 assert_eq!(queued_tasks_count(), 0);
802 evict_all();
803 }
804
805 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
806 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
807 fn token_exhaustion() {
808 set_counter(usize::MAX);
809 let handle0 = spawn(async move {});
811 let handle = spawn(async move {});
813 assert!(handle.task().token != handle0.task().token);
815 assert_eq!(handle.task().token, 0);
816 evict_all();
817 }
818
819 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
820 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
821 fn blocking_on() {
822 use tokio::sync::*;
823 let (sender, receiver) = oneshot::channel::<u8>();
824 let _handle = spawn(async move {
825 let _ = sender.send(1);
826 });
827 let result = block_on(async move { receiver.await.unwrap() });
828 assert_eq!(result, 1);
829 evict_all();
830 }
831
832 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
833 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
834 fn starvation() {
835 use tokio::sync::*;
836 let (sender, receiver) = oneshot::channel();
837 let _handle = spawn(async move {
838 tokio::task::yield_now().await;
839 tokio::task::yield_now().await;
840 let _ = sender.send(());
841 });
842 let result = block_on(async move { receiver.await.unwrap() });
843 assert_eq!(result, ());
844 evict_all();
845 }
846
847 #[cfg(feature = "debug")]
848 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
849 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
850 fn task_type_info() {
851 spawn(futures::future::pending::<()>());
852 assert!(tasks()[0]
853 .type_info()
854 .type_name()
855 .contains("future::pending::Pending"));
856 assert_eq!(
857 tasks()[0].type_info().type_id().unwrap(),
858 TypeId::of::<futures::future::Pending<()>>()
859 );
860 evict_all();
861 assert_eq!(tasks().len(), 0);
862 }
863
864 #[cfg_attr(not(all(target_arch = "wasm32", target_os = "unknown")), test)]
865 #[cfg_attr(all(target_arch = "wasm32", target_os = "unknown"), wasm_bindgen_test)]
866 fn joinining() {
867 use tokio::sync::*;
868 let (sender, receiver) = oneshot::channel();
869 let (sender1, mut receiver1) = oneshot::channel();
870 let _handle1 = spawn(async move {
871 let _ = sender.send(());
872 });
873
874 let handle2 = spawn(async move {
875 let _ = receiver.await;
876 100u8
877 });
878
879 let handle3 = spawn(async move {
880 let _ = sender1.send(handle2.await);
881 });
882 run(Some(handle3.task()));
883
884 assert_eq!(receiver1.try_recv().unwrap().unwrap(), 100);
885
886 evict_all();
887 }
888}