wta_executor/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(clippy::pedantic)]
3#![allow(clippy::missing_errors_doc)]
4#![allow(clippy::missing_panics_doc)]
5
6use std::{
7    cell::RefCell,
8    future::Future,
9    pin::Pin,
10    sync::{Arc, Mutex},
11    task::{Context, Poll, Wake, Waker},
12    thread::Thread,
13};
14
15use crossbeam_queue::SegQueue;
16use futures::{channel::oneshot, FutureExt};
17
18pub type Task = Pin<Box<dyn Future<Output = ()> + Send + Sync + 'static>>;
19
20thread_local! {
21    static EXECUTOR: RefCell<Option<Arc<Executor>>> = RefCell::new(None);
22}
23
24pub(crate) fn context<R>(f: impl FnOnce(&Arc<Executor>) -> R) -> R {
25    EXECUTOR.with(|e| {
26        let e = e.borrow();
27        let e = e
28            .as_ref()
29            .expect("spawn called outside of an executor context");
30        f(e)
31    })
32}
33
34pub fn spawn<F>(fut: F) -> JoinHandle<F::Output>
35where
36    F: Future + Send + Sync + 'static,
37    F::Output: Send,
38{
39    context(|e| e.spawn(fut))
40}
41
42#[derive(Default)]
43pub struct Executor {
44    tasks: SegQueue<Task>,
45    threads: SegQueue<Thread>,
46}
47
48impl Executor {
49    /// register this executor on the current thread
50    pub fn register(self: &Arc<Self>) {
51        EXECUTOR.with(|exec| *exec.borrow_mut() = Some(self.clone()));
52    }
53
54    fn wake(&self, task: Task) {
55        self.tasks.push(task);
56        // self.condvar.notify_one();
57        if let Some(t) = self.threads.pop() {
58            t.unpark();
59        };
60    }
61
62    pub fn poll_once(self: Arc<Self>) {
63        // Take one task from the queue.
64        let mut task = {
65            loop {
66                // try acquire a task from the queue
67                if let Some(task) = self.tasks.pop() {
68                    break task;
69                }
70                // park this thread
71                self.threads.push(std::thread::current());
72                std::thread::park();
73            }
74        };
75
76        let wake = Arc::new(TaskWaker {
77            task: Mutex::new(None),
78            executor: self,
79        });
80        let waker = Waker::from(wake.clone());
81        let mut cx = Context::from_waker(&waker);
82
83        if task.as_mut().poll(&mut cx).is_pending() {
84            wake.task.lock().unwrap().replace(task);
85        }
86    }
87
88    pub fn spawn<F>(&self, fut: F) -> JoinHandle<F::Output>
89    where
90        F: Future + Send + Sync + 'static,
91        F::Output: Send,
92    {
93        let (sender, handle) = JoinHandle::new();
94
95        // Pin the future. Also wrap it s.t. it sends it's output over the channel
96        let fut = Box::pin(fut.map(|out| sender.send(out).unwrap_or_default()));
97        // insert the task into the runtime and signal that it is ready for processing
98        self.wake(fut);
99
100        // return the handle to the spawner so that it can be `await`ed with it's output value
101        handle
102    }
103}
104
105struct TaskWaker {
106    executor: Arc<Executor>,
107    task: Mutex<Option<Task>>,
108}
109
110impl Wake for TaskWaker {
111    fn wake(self: Arc<Self>) {
112        self.wake_by_ref();
113    }
114    fn wake_by_ref(self: &Arc<Self>) {
115        if let Some(task) = self.task.lock().unwrap().take() {
116            self.executor.wake(task);
117        }
118    }
119}
120
121pub struct JoinHandle<R>(oneshot::Receiver<R>);
122
123impl<R> Unpin for JoinHandle<R> {}
124
125impl<R> JoinHandle<R> {
126    #[must_use]
127    pub fn new() -> (oneshot::Sender<R>, Self) {
128        let (sender, receiver) = oneshot::channel();
129        (sender, Self(receiver))
130    }
131}
132
133impl<R> Future for JoinHandle<R> {
134    type Output = R;
135
136    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137        // poll the inner channel for the spawned future's result
138        self.0.poll_unpin(cx).map(Result::unwrap)
139    }
140}