yielding_executor/
single_threaded.rs

1//! # Single-threaded executor
2//!
3//! This executor works *strictly* in a single-threaded environment. In order to spawn a task, use
4//! [`spawn`]. To run the executor, use [`run`].
5//!
6//! There is no need to create an instance of the executor, it's automatically provisioned as a
7//! thread-local instance.
8//!
9//! ## Example
10//!
11//! ```
12//! use tokio::sync::*;
13//! use yielding_executor::single_threaded::{spawn, start};
14//! let (sender, receiver) = oneshot::channel::<()>();
15//! let _task = spawn(async move {
16//!    // Complete when something is received
17//!    let _ = receiver.await;
18//! });
19//! // Send data to be received
20//! let _ = sender.send(());
21//! start();
22//! ```
23use futures::channel::oneshot;
24use futures::task::{waker_ref, ArcWake};
25#[cfg(feature = "debug")]
26use std::any::{type_name, TypeId};
27use std::cell::UnsafeCell;
28use std::collections::BTreeMap;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33
34/// Task token
35type Token = usize;
36
37#[cfg(feature = "debug")]
38#[derive(Clone, Debug)]
39#[allow(missing_docs)]
40pub struct TypeInfo {
41    type_id: Option<TypeId>,
42    type_name: &'static str,
43}
44
45#[cfg(feature = "debug")]
46impl TypeInfo {
47    fn new<T>() -> Self
48    where
49        T: 'static,
50    {
51        Self {
52            type_name: type_name::<T>(),
53            type_id: Some(TypeId::of::<T>()),
54        }
55    }
56
57    fn new_non_static<T>() -> Self {
58        Self {
59            type_name: type_name::<T>(),
60            type_id: None,
61        }
62    }
63
64    /// Returns tasks's type name
65    pub fn type_name(&self) -> &'static str {
66        self.type_name
67    }
68
69    /// Returns tasks's [`std::any::TypeId`]
70    ///
71    /// If it's `None` then the type does not have a `'static` lifetime
72    pub fn type_id(&self) -> Option<TypeId> {
73        self.type_id
74    }
75}
76
77/// Task information
78#[derive(Clone)]
79#[must_use]
80pub struct Task {
81    token: Token,
82    #[cfg(feature = "debug")]
83    type_info: Arc<TypeInfo>,
84}
85
86impl PartialEq for Task {
87    fn eq(&self, other: &Self) -> bool {
88        self.token == other.token
89    }
90}
91
92impl Eq for Task {}
93
94impl PartialOrd for Task {
95    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
96        self.token.partial_cmp(&other.token)
97    }
98}
99
100impl Ord for Task {
101    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
102        self.token.cmp(&other.token)
103    }
104}
105
106impl Task {
107    #[cfg(feature = "debug")]
108    #[allow(missing_docs)]
109    pub fn type_info(&self) -> &TypeInfo {
110        self.type_info.as_ref()
111    }
112}
113
114/// Task handle
115///
116/// Implements [`std::future::Future`] to allow for waiting for task completion
117pub struct TaskHandle<T> {
118    receiver: oneshot::Receiver<T>,
119    task: Task,
120}
121
122impl<T> TaskHandle<T> {
123    /// Returns a copy of task information record
124    pub fn task(&self) -> Task {
125        self.task.clone()
126    }
127}
128
129/// Task joining error
130#[derive(Debug, Clone)]
131pub enum JoinError {
132    /// Task was canceled
133    Canceled,
134}
135
136impl<T> Future for TaskHandle<T> {
137    type Output = Result<T, JoinError>;
138    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
139        match self.receiver.try_recv() {
140            Err(oneshot::Canceled) => Poll::Ready(Err(JoinError::Canceled)),
141            Ok(Some(result)) => Poll::Ready(Ok(result)),
142            Ok(None) => {
143                cx.waker().wake_by_ref();
144                Poll::Pending
145            }
146        }
147    }
148}
149
150impl ArcWake for Task {
151    fn wake_by_ref(arc_self: &Arc<Self>) {
152        EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).enqueue(arc_self.clone()));
153    }
154}
155
156/// Single-threaded executor
157struct Executor {
158    counter: Token,
159    futures: BTreeMap<Task, Pin<Box<dyn Future<Output = ()>>>>,
160    queue: Vec<Arc<Task>>,
161}
162
163impl Executor {
164    fn new() -> Self {
165        Self {
166            counter: 0,
167            futures: BTreeMap::new(),
168            queue: vec![],
169        }
170    }
171
172    fn enqueue(&mut self, task: Arc<Task>) {
173        if self.futures.contains_key(&task) {
174            self.queue.insert(0, task);
175        }
176    }
177
178    fn spawn<F, T>(&mut self, fut: F) -> TaskHandle<T>
179    where
180        F: Future<Output = T> + 'static,
181        T: 'static,
182    {
183        let token = self.counter;
184        self.counter = self.counter.wrapping_add(1);
185        let task = Task {
186            token,
187            #[cfg(feature = "debug")]
188            type_info: Arc::new(TypeInfo::new::<F>()),
189        };
190
191        let (sender, receiver) = oneshot::channel();
192
193        self.futures.insert(task.clone(), unsafe {
194            Pin::new_unchecked(Box::new(async move {
195                let _ = sender.send(fut.await);
196            }))
197        });
198        self.queue.push(Arc::new(task.clone()));
199        TaskHandle { receiver, task }
200    }
201
202    fn spawn_non_static<F, T>(&mut self, fut: F) -> TaskHandle<T>
203    where
204        F: Future<Output = T>,
205    {
206        let token = self.counter;
207        self.counter = self.counter.wrapping_add(1);
208        let task = Task {
209            token,
210            #[cfg(feature = "debug")]
211            type_info: Arc::new(TypeInfo::new_non_static::<F>()),
212        };
213
214        let (sender, receiver) = oneshot::channel();
215
216        self.futures.insert(task.clone(), unsafe {
217            Pin::new_unchecked(std::mem::transmute::<_, Box<dyn Future<Output = ()>>>(
218                Box::new(async move {
219                    let _ = sender.send(fut.await);
220                }) as Box<dyn Future<Output = ()>>,
221            ))
222        });
223        self.queue.push(Arc::new(task.clone()));
224        TaskHandle { receiver, task }
225    }
226}
227
228thread_local! {
229  static EXECUTOR: UnsafeCell<Executor> = UnsafeCell::new(Executor::new()) ;
230}
231
232thread_local! {
233  static UNTIL: UnsafeCell<Option<Task>> = UnsafeCell::new(None) ;
234}
235
236thread_local! {
237  static UNTIL_SATISFIED: UnsafeCell<bool> = UnsafeCell::new(false) ;
238}
239
240thread_local! {
241  static WHILE_FN: UnsafeCell<Option<Box<dyn FnMut() -> bool>>> = UnsafeCell::new(None) ;
242}
243
244thread_local! {
245  static YIELD: UnsafeCell<bool> = UnsafeCell::new(true) ;
246}
247
248thread_local! {
249  static EXIT_LOOP: UnsafeCell<bool> = UnsafeCell::new(false) ;
250}
251
252/// Spawn a task
253pub fn spawn<F, T>(fut: F) -> TaskHandle<T>
254where
255    F: Future<Output = T> + 'static,
256    T: 'static,
257{
258    EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).spawn(fut))
259}
260
261/// Run tasks until completion of a future
262///
263/// ## Important
264///
265/// This function will yield to the environment if configured to do so.
266///
267pub fn run<F, R>(fut: F) -> R
268where
269    F: Future<Output = R>,
270{
271    let mut handle = EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).spawn_non_static(fut));
272    YIELD.with(|cell| unsafe {
273        *cell.get() = false;
274    });
275    run_until(handle.task());
276    YIELD.with(|cell| unsafe {
277        *cell.get() = true;
278    });
279    loop {
280        match handle.receiver.try_recv() {
281            Ok(None) => {}
282            Ok(Some(v)) => return v,
283            Err(_) => unreachable!(), // the data was sent at this point
284        }
285    }
286}
287
288/// Run the executor
289///
290/// The `until` promise and `while` function will remain unchanged.
291pub fn start() {
292    run_internal();
293}
294
295/// Reset execution conditions
296///
297/// Unsets the until promise and the while fn as well as their resolution statuses.
298pub fn reset_yield_conditions() {
299    UNTIL.with(|cell| unsafe { *cell.get() = None });
300    UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = false });
301    WHILE_FN.with(|cell| unsafe { *cell.get() = None });
302}
303
304/// Run the executor until a promise resolves
305///
306/// If `until` is `None`, it will run until all tasks have been completed. Otherwise, it'll wait
307/// until passed task is complete, or unless a `cooperative` feature has been enabled and control
308/// has been yielded to the environment. In this case the function will return but the environment
309/// might schedule further execution of this executor in the background after termination of the
310/// function enclosing invocation of this [`run`]
311pub fn run_until(until: Task) {
312    UNTIL.with(|cell| unsafe { *cell.get() = Some(until) });
313    UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = false });
314    run_internal();
315}
316
317/// Run the executor while a function returns true
318///
319/// The function passed as `condition` will run on every loop of the executor. The executor will
320/// yield anytime the `condition` evaluates to `true`. You can restart execution by issuing another
321/// `run` command
322pub fn run_while<F>(condition: F)
323where
324    F: FnMut() -> bool + 'static,
325{
326    WHILE_FN.with(|cell| unsafe { *cell.get() = Some(Box::new(condition)) });
327
328    run_internal();
329}
330
331// Returns `true` if `until` task completed, or there was no `until` task and every task was
332// completed.
333//
334// Returns `false` if loop exit was requested
335fn run_internal() -> bool {
336    let until = UNTIL.with(|cell| unsafe { &*cell.get() });
337    let exit_condition_met = UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() });
338    if exit_condition_met {
339        return true;
340    }
341    EXECUTOR.with(|cell| loop {
342        let task = (unsafe { &mut *cell.get() }).queue.pop();
343
344        if let Some(task) = task {
345            let future = (unsafe { &mut *cell.get() }).futures.get_mut(&task);
346            let ready = future.map_or(false, |future| {
347                let waker = waker_ref(&task);
348                let context = &mut Context::from_waker(&*waker);
349                let ready = matches!(future.as_mut().poll(context), Poll::Ready(_));
350                ready
351            });
352            if ready {
353                (unsafe { &mut *cell.get() }).futures.remove(&task);
354
355                if let Some(Task { ref token, .. }) = until {
356                    if *token == task.token {
357                        UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = true });
358                        return true;
359                    }
360                }
361            }
362        }
363        if until.is_none() && (unsafe { &mut *cell.get() }).futures.is_empty() {
364            UNTIL_SATISFIED.with(|cell| unsafe { *cell.get() = true });
365            return true;
366        }
367
368        let should_continue =
369            WHILE_FN.with(|cell| unsafe { (&mut *cell.get()).as_mut().map_or(true, |f| (f)()) });
370
371        let exit_requested = EXIT_LOOP.with(|cell| {
372            let v = cell.get();
373            let result = unsafe { *v };
374            // Clear the flag
375            unsafe {
376                *v = false;
377            }
378            result
379        }) && YIELD.with(|cell| unsafe { *cell.get() });
380
381        if exit_requested || !should_continue {
382            return false;
383        }
384
385        if (unsafe { &mut *cell.get() }).queue.is_empty()
386            && !(unsafe { &mut *cell.get() }).futures.is_empty()
387        {
388            // the executor is starving
389            for task in (unsafe { &mut *cell.get() }).futures.keys() {
390                (unsafe { &mut *cell.get() }).enqueue(Arc::new(task.clone()));
391            }
392        }
393    })
394}
395
396/// Returns the number of tasks currently registered with the executor
397#[must_use]
398pub fn tasks_count() -> usize {
399    EXECUTOR.with(|cell| {
400        let executor = unsafe { &mut *cell.get() };
401        executor.futures.len()
402    })
403}
404
405/// Returns the number of tasks currently in the queue to execute
406#[must_use]
407pub fn queued_tasks_count() -> usize {
408    EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).queue.len())
409}
410
411/// Returns all tasks that haven't completed yet
412#[must_use]
413pub fn tasks() -> Vec<Task> {
414    EXECUTOR.with(|cell| {
415        (unsafe { &*cell.get() })
416            .futures
417            .keys()
418            .map(Task::clone)
419            .collect()
420    })
421}
422
423/// Returns tokens for queued tasks
424#[must_use]
425pub fn queued_tasks() -> Vec<Task> {
426    EXECUTOR.with(|cell| {
427        (unsafe { &*cell.get() })
428            .queue
429            .iter()
430            .map(|t| Task::clone(t))
431            .collect()
432    })
433}
434
435/// Removes all tasks from the executor
436///
437/// ## Caution
438///
439/// Evicted tasks won't be able to get re-scheduled when they will be woken up.
440pub fn evict_all() {
441    EXECUTOR.with(|cell| unsafe { *cell.get() = Executor::new() });
442}
443
444#[cfg(test)]
445fn set_counter(counter: usize) {
446    EXECUTOR.with(|cell| (unsafe { &mut *cell.get() }).counter = counter);
447}
448
449#[cfg(test)]
450mod tests {
451
452    use super::*;
453    thread_local! {
454      static NUM: UnsafeCell<u32> = UnsafeCell::new(0) ;
455    }
456
457    #[test]
458    fn test() {
459        use tokio::sync::*;
460        let (sender, receiver) = oneshot::channel::<()>();
461        let _handle = spawn(async move {
462            let _ = receiver.await;
463        });
464        let _ = sender.send(());
465        start();
466        reset_yield_conditions();
467        evict_all();
468    }
469
470    #[test]
471    fn test_until() {
472        use tokio::sync::*;
473        let (_sender1, receiver1) = oneshot::channel::<()>();
474        let _handle1 = spawn(async move {
475            let _ = receiver1.await;
476        });
477        let (sender2, receiver2) = oneshot::channel::<()>();
478        let handle2 = spawn(async move {
479            let _ = receiver2.await;
480        });
481        let _ = sender2.send(());
482        run_until(handle2.task());
483        reset_yield_conditions();
484        evict_all();
485    }
486
487    #[test]
488    fn test_while() {
489        use tokio::sync::*;
490        let (_sender1, receiver1) = oneshot::channel::<()>();
491        let _handle1 = spawn(async move {
492            let _ = receiver1.await;
493        });
494        let (sender2, receiver2) = oneshot::channel::<()>();
495        let _handle2 = spawn(async move {
496            let _ = receiver2.await;
497        });
498        let _ = sender2.send(());
499
500        run_while(move || {
501            let num = NUM.with(|cell| unsafe {
502                *cell.get() += 1;
503                *cell.get()
504            });
505            num < 6
506        });
507        let num = NUM.with(|cell| unsafe { *cell.get() });
508
509        assert_eq!(num, 6);
510
511        reset_yield_conditions();
512
513        evict_all();
514    }
515
516    #[test]
517    fn test_counts() {
518        use tokio::sync::oneshot;
519        let (sender, mut receiver) = oneshot::channel();
520        let (sender2, receiver2) = oneshot::channel::<()>();
521        let handle1 = spawn(async move {
522            let _ = receiver2.await;
523            let _ = sender.send((tasks_count(), queued_tasks_count()));
524        });
525        let _handle2 = spawn(async move {
526            let _ = sender2.send(());
527            futures::future::pending::<()>().await; // this will never end
528        });
529        run_until(handle1.task());
530        let (tasks_, queued_tasks_) = receiver.try_recv().unwrap();
531        // handle1 + handle2
532        assert_eq!(tasks_, 2);
533        // handle1 is being executed, handle2 has nothing new
534        assert_eq!(queued_tasks_, 0);
535        // handle1 is gone
536        assert_eq!(tasks_count(), 1);
537        // handle2 still has nothing new
538        assert_eq!(queued_tasks_count(), 0);
539        reset_yield_conditions();
540        evict_all();
541    }
542
543    #[test]
544    fn evicted_tasks_dont_requeue() {
545        use tokio::sync::*;
546        let (_sender, receiver) = oneshot::channel::<()>();
547        let handle = spawn(async move {
548            let _ = receiver.await;
549        });
550        assert_eq!(tasks_count(), 1);
551        evict_all();
552        assert_eq!(tasks_count(), 0);
553        ArcWake::wake_by_ref(&Arc::new(handle.task()));
554        assert_eq!(tasks_count(), 0);
555        assert_eq!(queued_tasks_count(), 0);
556        reset_yield_conditions();
557        evict_all();
558    }
559
560    #[test]
561    fn token_exhaustion() {
562        set_counter(usize::MAX);
563        // this should be fine anyway
564        let handle0 = spawn(async move {});
565        // this should NOT crash
566        let handle = spawn(async move {});
567        // new token should be different and wrap back to the beginning
568        assert!(handle.task().token != handle0.task().token);
569        assert_eq!(handle.task().token, 0);
570        reset_yield_conditions();
571        evict_all();
572    }
573
574    #[test]
575    fn blocking_on() {
576        use tokio::sync::*;
577        let (sender, receiver) = oneshot::channel::<u8>();
578        let _handle = spawn(async move {
579            let _ = sender.send(1);
580        });
581        let result = run(async move { receiver.await.unwrap() });
582        assert_eq!(result, 1);
583        reset_yield_conditions();
584        evict_all();
585    }
586
587    #[test]
588    fn starvation() {
589        use tokio::sync::*;
590        let (sender, receiver) = oneshot::channel();
591        let _handle = spawn(async move {
592            tokio::task::yield_now().await;
593            tokio::task::yield_now().await;
594            let _ = sender.send(());
595        });
596        run(async move { receiver.await.unwrap() });
597        reset_yield_conditions();
598        evict_all();
599    }
600
601    #[cfg(feature = "debug")]
602    #[test]
603    fn task_type_info() {
604        spawn(futures::future::pending::<()>());
605        assert!(tasks()[0]
606            .type_info()
607            .type_name()
608            .contains("future::pending::Pending"));
609        assert_eq!(
610            tasks()[0].type_info().type_id().unwrap(),
611            TypeId::of::<futures::future::Pending<()>>()
612        );
613        reset_yield_conditions();
614        evict_all();
615        assert_eq!(tasks().len(), 0);
616    }
617
618    #[test]
619    fn joining() {
620        use tokio::sync::*;
621        let (sender, receiver) = oneshot::channel();
622        let (sender1, mut receiver1) = oneshot::channel();
623        let _handle1 = spawn(async move {
624            let _ = sender.send(());
625        });
626
627        let handle2 = spawn(async move {
628            let _ = receiver.await;
629            100u8
630        });
631
632        let handle3 = spawn(async move {
633            let _ = sender1.send(handle2.await);
634        });
635        run_until(handle3.task());
636
637        assert_eq!(receiver1.try_recv().unwrap().unwrap(), 100);
638        reset_yield_conditions();
639
640        evict_all();
641    }
642}