1use futures::channel::oneshot;
24use futures::task::{waker_ref, ArcWake};
25#[cfg(feature = "debug")]
26use std::any::{type_name, TypeId};
27use std::cell::UnsafeCell;
28use std::collections::BTreeMap;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33
34type Token = usize;
36
37#[cfg(feature = "debug")]
38#[derive(Clone, Debug)]
39#[allow(missing_docs)]
40pub struct TypeInfo {
41 type_id: Option<TypeId>,
42 type_name: &'static str,
43}
44
45#[cfg(feature = "debug")]
46impl TypeInfo {
47 fn new<T>() -> Self
48 where
49 T: 'static,
50 {
51 Self {
52 type_name: type_name::<T>(),
53 type_id: Some(TypeId::of::<T>()),
54 }
55 }
56
57 fn new_non_static<T>() -> Self {
58 Self {
59 type_name: type_name::<T>(),
60 type_id: None,
61 }
62 }
63
64 pub fn type_name(&self) -> &'static str {
66 self.type_name
67 }
68
69 pub fn type_id(&self) -> Option<TypeId> {
73 self.type_id
74 }
75}
76
77#[derive(Clone)]
79#[must_use]
80pub struct Task {
81 token: Token,
82 #[cfg(feature = "debug")]
83 type_info: Arc<TypeInfo>,
84}
85
86impl PartialEq for Task {
87 fn eq(&self, other: &Self) -> bool {
88 self.token == other.token
89 }
90}
91
92impl Eq for Task {}
93
94impl PartialOrd for Task {
95 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
96 self.token.partial_cmp(&other.token)
97 }
98}
99
100impl Ord for Task {
101 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
102 self.token.cmp(&other.token)
103 }
104}
105
106impl Task {
107 #[cfg(feature = "debug")]
108 #[allow(missing_docs)]
109 pub fn type_info(&self) -> &TypeInfo {
110 self.type_info.as_ref()
111 }
112}
113
114pub struct TaskHandle<T> {
118 receiver: oneshot::Receiver<T>,
119 task: Task,
120}
121
122impl<T> TaskHandle<T> {
123 pub fn task(&self) -> Task {
125 self.task.clone()
126 }
127}
128
129#[derive(Debug, Clone)]
131pub enum JoinError {
132 Canceled,
134}
135
136impl<T> Future for TaskHandle<T> {
137 type Output = Result<T, JoinError>;
138 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139 match self.receiver.try_recv() {
140 Err(oneshot::Canceled) => Poll::Ready(Err(JoinError::Canceled)),
141 Ok(Some(result)) => Poll::Ready(Ok(result)),
142 Ok(None) => {
143 cx.waker().wake_by_ref();
144 Poll::Pending
145 }
146 }
147 }
148}
149
150impl ArcWake for Task {
151 fn wake_by_ref(arc_self: &Arc<Self>) {
152 EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).enqueue(arc_self.clone()));
153 }
154}
155
156struct Executor {
158 counter: Token,
159 futures: BTreeMap<Task, Pin<Box<dyn Future<Output = ()>>>>,
160 queue: Vec<Arc<Task>>,
161}
162
163impl Executor {
164 fn new() -> Self {
165 Self {
166 counter: 0,
167 futures: BTreeMap::new(),
168 queue: vec![],
169 }
170 }
171
172 fn enqueue(&mut self, task: Arc<Task>) {
173 if self.futures.contains_key(&task) {
174 self.queue.insert(0, task);
175 }
176 }
177
178 fn spawn<F, T>(&mut self, fut: F) -> TaskHandle<T>
179 where
180 F: Future<Output = T> + 'static,
181 T: 'static,
182 {
183 let token = self.counter;
184 self.counter = self.counter.wrapping_add(1);
185 let task = Task {
186 token,
187 #[cfg(feature = "debug")]
188 type_info: Arc::new(TypeInfo::new::<F>()),
189 };
190
191 let (sender, receiver) = oneshot::channel();
192
193 self.futures.insert(task.clone(), unsafe {
194 Pin::new_unchecked(Box::new(async move {
195 let _ = sender.send(fut.await);
196 }))
197 });
198 self.queue.push(Arc::new(task.clone()));
199 TaskHandle { receiver, task }
200 }
201
202 fn spawn_non_static<F, T>(&mut self, fut: F) -> TaskHandle<T>
203 where
204 F: Future<Output = T>,
205 {
206 let token = self.counter;
207 self.counter = self.counter.wrapping_add(1);
208 let task = Task {
209 token,
210 #[cfg(feature = "debug")]
211 type_info: Arc::new(TypeInfo::new_non_static::<F>()),
212 };
213
214 let (sender, receiver) = oneshot::channel();
215
216 self.futures.insert(task.clone(), unsafe {
217 Pin::new_unchecked(std::mem::transmute::<_, Box<dyn Future<Output = ()>>>(
218 Box::new(async move {
219 let _ = sender.send(fut.await);
220 }) as Box<dyn Future<Output = ()>>,
221 ))
222 });
223 self.queue.push(Arc::new(task.clone()));
224 TaskHandle { receiver, task }
225 }
226}
227
228thread_local! {
229 static EXECUTOR: UnsafeCell<Executor> = UnsafeCell::new(Executor::new()) ;
230}
231
232thread_local! {
233 static UNTIL: UnsafeCell<Option<Task>> = UnsafeCell::new(None) ;
234}
235
236thread_local! {
237 static UNTIL_SATISFIED: UnsafeCell<bool> = UnsafeCell::new(false) ;
238}
239
240thread_local! {
241 static WHILE_FN: UnsafeCell<Option<Box<dyn FnMut() -> bool>>> = UnsafeCell::new(None) ;
242}
243
244thread_local! {
245 static YIELD: UnsafeCell<bool> = UnsafeCell::new(true) ;
246}
247
248thread_local! {
249 static EXIT_LOOP: UnsafeCell<bool> = UnsafeCell::new(false) ;
250}
251
252pub fn spawn<F, T>(fut: F) -> TaskHandle<T>
254where
255 F: Future<Output = T> + 'static,
256 T: 'static,
257{
258 EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).spawn(fut))
259}
260
261pub fn run<F, R>(fut: F) -> R
268where
269 F: Future<Output = R>,
270{
271 let mut handle = EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).spawn_non_static(fut));
272 YIELD.with(|cell| unsafe {
273 *cell.get() = false;
274 });
275 run_until(handle.task());
276 YIELD.with(|cell| unsafe {
277 *cell.get() = true;
278 });
279 loop {
280 match handle.receiver.try_recv() {
281 Ok(None) => {}
282 Ok(Some(v)) => return v,
283 Err(_) => unreachable!(), }
285 }
286}
287
288pub fn start() {
292 run_internal();
293}
294
295pub fn reset_yield_conditions() {
299 UNTIL.with(|cell| unsafe { *cell.get() = None });
300 UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = false });
301 WHILE_FN.with(|cell| unsafe { *cell.get() = None });
302}
303
304pub fn run_until(until: Task) {
312 UNTIL.with(|cell| unsafe { *cell.get() = Some(until) });
313 UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = false });
314 run_internal();
315}
316
317pub fn run_while<F>(condition: F)
323where
324 F: FnMut() -> bool + 'static,
325{
326 WHILE_FN.with(|cell| unsafe { *cell.get() = Some(Box::new(condition)) });
327
328 run_internal();
329}
330
331fn run_internal() -> bool {
336 let until = UNTIL.with(|cell| unsafe { &*cell.get() });
337 let exit_condition_met = UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() });
338 if exit_condition_met {
339 return true;
340 }
341 EXECUTOR.with(|cell| loop {
342 let task = (unsafe { &mut *cell.get() }).queue.pop();
343
344 if let Some(task) = task {
345 let future = (unsafe { &mut *cell.get() }).futures.get_mut(&task);
346 let ready = future.map_or(false, |future| {
347 let waker = waker_ref(&task);
348 let context = &mut Context::from_waker(&*waker);
349 let ready = matches!(future.as_mut().poll(context), Poll::Ready(_));
350 ready
351 });
352 if ready {
353 (unsafe { &mut *cell.get() }).futures.remove(&task);
354
355 if let Some(Task { ref token, .. }) = until {
356 if *token == task.token {
357 UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = true });
358 return true;
359 }
360 }
361 }
362 }
363 if until.is_none() && (unsafe { &mut *cell.get() }).futures.is_empty() {
364 UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = true });
365 return true;
366 }
367
368 let should_continue =
369 WHILE_FN.with(|cell| unsafe { (&mut *cell.get()).as_mut().map_or(true, |f| (f)()) });
370
371 let exit_requested = EXIT_LOOP.with(|cell| {
372 let v = cell.get();
373 let result = unsafe { *v };
374 unsafe {
376 *v = false;
377 }
378 result
379 }) && YIELD.with(|cell| unsafe { *cell.get() });
380
381 if exit_requested || !should_continue {
382 return false;
383 }
384
385 if (unsafe { &mut *cell.get() }).queue.is_empty()
386 && !(unsafe { &mut *cell.get() }).futures.is_empty()
387 {
388 for task in (unsafe { &mut *cell.get() }).futures.keys() {
390 (unsafe { &mut *cell.get() }).enqueue(Arc::new(task.clone()));
391 }
392 }
393 })
394}
395
396#[must_use]
398pub fn tasks_count() -> usize {
399 EXECUTOR.with(|cell| {
400 let executor = unsafe { &mut *cell.get() };
401 executor.futures.len()
402 })
403}
404
405#[must_use]
407pub fn queued_tasks_count() -> usize {
408 EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).queue.len())
409}
410
411#[must_use]
413pub fn tasks() -> Vec<Task> {
414 EXECUTOR.with(|cell| {
415 (unsafe { &*cell.get() })
416 .futures
417 .keys()
418 .map(Task::clone)
419 .collect()
420 })
421}
422
423#[must_use]
425pub fn queued_tasks() -> Vec<Task> {
426 EXECUTOR.with(|cell| {
427 (unsafe { &*cell.get() })
428 .queue
429 .iter()
430 .map(|t| Task::clone(t))
431 .collect()
432 })
433}
434
435pub fn evict_all() {
441 EXECUTOR.with(|cell| unsafe { *cell.get() = Executor::new() });
442}
443
444#[cfg(test)]
445fn set_counter(counter: usize) {
446 EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).counter = counter);
447}
448
449#[cfg(test)]
450mod tests {
451
452 use super::*;
453 thread_local! {
454 static NUM: UnsafeCell<u32> = UnsafeCell::new(0) ;
455 }
456
457 #[test]
458 fn test() {
459 use tokio::sync::*;
460 let (sender, receiver) = oneshot::channel::<()>();
461 let _handle = spawn(async move {
462 let _ = receiver.await;
463 });
464 let _ = sender.send(());
465 start();
466 reset_yield_conditions();
467 evict_all();
468 }
469
470 #[test]
471 fn test_until() {
472 use tokio::sync::*;
473 let (_sender1, receiver1) = oneshot::channel::<()>();
474 let _handle1 = spawn(async move {
475 let _ = receiver1.await;
476 });
477 let (sender2, receiver2) = oneshot::channel::<()>();
478 let handle2 = spawn(async move {
479 let _ = receiver2.await;
480 });
481 let _ = sender2.send(());
482 run_until(handle2.task());
483 reset_yield_conditions();
484 evict_all();
485 }
486
487 #[test]
488 fn test_while() {
489 use tokio::sync::*;
490 let (_sender1, receiver1) = oneshot::channel::<()>();
491 let _handle1 = spawn(async move {
492 let _ = receiver1.await;
493 });
494 let (sender2, receiver2) = oneshot::channel::<()>();
495 let _handle2 = spawn(async move {
496 let _ = receiver2.await;
497 });
498 let _ = sender2.send(());
499
500 run_while(move || {
501 let num = NUM.with(|cell| unsafe {
502 *cell.get() += 1;
503 *cell.get()
504 });
505 num < 6
506 });
507 let num = NUM.with(|cell| unsafe { *cell.get() });
508
509 assert_eq!(num, 6);
510
511 reset_yield_conditions();
512
513 evict_all();
514 }
515
516 #[test]
517 fn test_counts() {
518 use tokio::sync::oneshot;
519 let (sender, mut receiver) = oneshot::channel();
520 let (sender2, receiver2) = oneshot::channel::<()>();
521 let handle1 = spawn(async move {
522 let _ = receiver2.await;
523 let _ = sender.send((tasks_count(), queued_tasks_count()));
524 });
525 let _handle2 = spawn(async move {
526 let _ = sender2.send(());
527 futures::future::pending::<()>().await; });
529 run_until(handle1.task());
530 let (tasks_, queued_tasks_) = receiver.try_recv().unwrap();
531 assert_eq!(tasks_, 2);
533 assert_eq!(queued_tasks_, 0);
535 assert_eq!(tasks_count(), 1);
537 assert_eq!(queued_tasks_count(), 0);
539 reset_yield_conditions();
540 evict_all();
541 }
542
543 #[test]
544 fn evicted_tasks_dont_requeue() {
545 use tokio::sync::*;
546 let (_sender, receiver) = oneshot::channel::<()>();
547 let handle = spawn(async move {
548 let _ = receiver.await;
549 });
550 assert_eq!(tasks_count(), 1);
551 evict_all();
552 assert_eq!(tasks_count(), 0);
553 ArcWake::wake_by_ref(&Arc::new(handle.task()));
554 assert_eq!(tasks_count(), 0);
555 assert_eq!(queued_tasks_count(), 0);
556 reset_yield_conditions();
557 evict_all();
558 }
559
560 #[test]
561 fn token_exhaustion() {
562 set_counter(usize::MAX);
563 let handle0 = spawn(async move {});
565 let handle = spawn(async move {});
567 assert!(handle.task().token != handle0.task().token);
569 assert_eq!(handle.task().token, 0);
570 reset_yield_conditions();
571 evict_all();
572 }
573
574 #[test]
575 fn blocking_on() {
576 use tokio::sync::*;
577 let (sender, receiver) = oneshot::channel::<u8>();
578 let _handle = spawn(async move {
579 let _ = sender.send(1);
580 });
581 let result = run(async move { receiver.await.unwrap() });
582 assert_eq!(result, 1);
583 reset_yield_conditions();
584 evict_all();
585 }
586
587 #[test]
588 fn starvation() {
589 use tokio::sync::*;
590 let (sender, receiver) = oneshot::channel();
591 let _handle = spawn(async move {
592 tokio::task::yield_now().await;
593 tokio::task::yield_now().await;
594 let _ = sender.send(());
595 });
596 run(async move { receiver.await.unwrap() });
597 reset_yield_conditions();
598 evict_all();
599 }
600
601 #[cfg(feature = "debug")]
602 #[test]
603 fn task_type_info() {
604 spawn(futures::future::pending::<()>());
605 assert!(tasks()[0]
606 .type_info()
607 .type_name()
608 .contains("future::pending::Pending"));
609 assert_eq!(
610 tasks()[0].type_info().type_id().unwrap(),
611 TypeId::of::<futures::future::Pending<()>>()
612 );
613 reset_yield_conditions();
614 evict_all();
615 assert_eq!(tasks().len(), 0);
616 }
617
618 #[test]
619 fn joining() {
620 use tokio::sync::*;
621 let (sender, receiver) = oneshot::channel();
622 let (sender1, mut receiver1) = oneshot::channel();
623 let _handle1 = spawn(async move {
624 let _ = sender.send(());
625 });
626
627 let handle2 = spawn(async move {
628 let _ = receiver.await;
629 100u8
630 });
631
632 let handle3 = spawn(async move {
633 let _ = sender1.send(handle2.await);
634 });
635 run_until(handle3.task());
636
637 assert_eq!(receiver1.try_recv().unwrap().unwrap(), 100);
638 reset_yield_conditions();
639
640 evict_all();
641 }
642}