pulsar/
executor.rs

1//! executor abstraction
2//!
3//! this crate is compatible with Tokio and async-std, by assembling  them
4//! under the [Executor] trait
5use futures::{Future, Stream};
6use std::{ops::Deref, pin::Pin, sync::Arc};
7
8/// indicates which executor is used
9pub enum ExecutorKind {
10    /// Tokio executor
11    Tokio,
12    /// async-std executor
13    AsyncStd,
14}
15
16/// Wrapper trait abstracting the Tokio and async-std executors
17pub trait Executor: Clone + Send + Sync + 'static {
18    /// spawns a new task
19    #[allow(clippy::clippy::result_unit_err)]
20    fn spawn(&self, f: Pin<Box<dyn Future<Output = ()> + Send>>) -> Result<(), ()>;
21    /// spawns a new blocking task
22    fn spawn_blocking<F, Res>(&self, f: F) -> JoinHandle<Res>
23    where
24        F: FnOnce() -> Res + Send + 'static,
25        Res: Send + 'static;
26
27    /// returns a Stream that will produce at regular intervals
28    fn interval(&self, duration: std::time::Duration) -> Interval;
29    /// waits for a configurable time
30    fn delay(&self, duration: std::time::Duration) -> Delay;
31
32    /// returns which executor is currently used
33    // test at runtime and manually choose the implementation
34    // because we cannot (yet) have async trait methods,
35    // so we cannot move the TCP connection here
36    fn kind(&self) -> ExecutorKind;
37}
38
39/// Wrapper for the Tokio executor
40#[cfg(feature = "tokio-runtime")]
41#[derive(Clone, Debug)]
42pub struct TokioExecutor;
43
44#[cfg(feature = "tokio-runtime")]
45impl Executor for TokioExecutor {
46    fn spawn(&self, f: Pin<Box<dyn Future<Output = ()> + Send>>) -> Result<(), ()> {
47        tokio::task::spawn(f);
48        Ok(())
49    }
50
51    fn spawn_blocking<F, Res>(&self, f: F) -> JoinHandle<Res>
52    where
53        F: FnOnce() -> Res + Send + 'static,
54        Res: Send + 'static,
55    {
56        JoinHandle::Tokio(tokio::task::spawn_blocking(f))
57    }
58
59    fn interval(&self, duration: std::time::Duration) -> Interval {
60        Interval::Tokio(tokio::time::interval(duration))
61    }
62
63    fn delay(&self, duration: std::time::Duration) -> Delay {
64        Delay::Tokio(tokio::time::sleep(duration))
65    }
66
67    fn kind(&self) -> ExecutorKind {
68        ExecutorKind::Tokio
69    }
70}
71
72/// Wrapper for the async-std executor
73#[cfg(feature = "async-std-runtime")]
74#[derive(Clone, Debug)]
75pub struct AsyncStdExecutor;
76
77#[cfg(feature = "async-std-runtime")]
78impl Executor for AsyncStdExecutor {
79    fn spawn(&self, f: Pin<Box<dyn Future<Output = ()> + Send>>) -> Result<(), ()> {
80        async_std::task::spawn(f);
81        Ok(())
82    }
83
84    fn spawn_blocking<F, Res>(&self, f: F) -> JoinHandle<Res>
85    where
86        F: FnOnce() -> Res + Send + 'static,
87        Res: Send + 'static,
88    {
89        JoinHandle::AsyncStd(async_std::task::spawn_blocking(f))
90    }
91
92    fn interval(&self, duration: std::time::Duration) -> Interval {
93        Interval::AsyncStd(async_std::stream::interval(duration))
94    }
95
96    fn delay(&self, duration: std::time::Duration) -> Delay {
97        use async_std::prelude::FutureExt;
98        Delay::AsyncStd(Box::pin(async_std::future::ready(()).delay(duration)))
99    }
100
101    fn kind(&self) -> ExecutorKind {
102        ExecutorKind::AsyncStd
103    }
104}
105
106impl<Exe: Executor> Executor for Arc<Exe> {
107    fn spawn(&self, f: Pin<Box<dyn Future<Output = ()> + Send>>) -> Result<(), ()> {
108        self.deref().spawn(f)
109    }
110
111    fn spawn_blocking<F, Res>(&self, f: F) -> JoinHandle<Res>
112    where
113        F: FnOnce() -> Res + Send + 'static,
114        Res: Send + 'static,
115    {
116        self.deref().spawn_blocking(f)
117    }
118
119    fn interval(&self, duration: std::time::Duration) -> Interval {
120        self.deref().interval(duration)
121    }
122
123    fn delay(&self, duration: std::time::Duration) -> Delay {
124        self.deref().delay(duration)
125    }
126
127    fn kind(&self) -> ExecutorKind {
128        self.deref().kind()
129    }
130}
131
132/// future returned by [Executor::spawn_blocking] to await on the task's result
133pub enum JoinHandle<T> {
134    /// wrapper for tokio's `JoinHandle`
135    #[cfg(feature = "tokio-runtime")]
136    Tokio(tokio::task::JoinHandle<T>),
137    /// wrapper for async-std's `JoinHandle`
138    #[cfg(feature = "async-std-runtime")]
139    AsyncStd(async_std::task::JoinHandle<T>),
140    // here to avoid a compilation error since T is not used
141    #[cfg(all(not(feature = "tokio-runtime"), not(feature = "async-std-runtime")))]
142    PlaceHolder(T),
143}
144
145use std::task::Poll;
146impl<T> Future for JoinHandle<T> {
147    type Output = Option<T>;
148
149    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> {
150        match self.get_mut() {
151            #[cfg(feature = "tokio-runtime")]
152            JoinHandle::Tokio(j) => match Pin::new(j).poll(cx) {
153                Poll::Pending => Poll::Pending,
154                Poll::Ready(v) => Poll::Ready(v.ok()),
155            },
156            #[cfg(feature = "async-std-runtime")]
157            JoinHandle::AsyncStd(j) => match Pin::new(j).poll(cx) {
158                Poll::Pending => Poll::Pending,
159                Poll::Ready(v) => Poll::Ready(Some(v)),
160            },
161            #[cfg(all(not(feature = "tokio-runtime"), not(feature = "async-std-runtime")))]
162            JoinHandle::PlaceHolder(t) => {
163                unimplemented!("please activate one of the following cargo features: tokio-runtime, async-std-runtime")
164            }
165        }
166    }
167}
168
169/// a `Stream` producing a `()` at rgular time intervals
170pub enum Interval {
171    /// wrapper for tokio's interval
172    #[cfg(feature = "tokio-runtime")]
173    Tokio(tokio::time::Interval),
174    /// wrapper for async-std's interval
175    #[cfg(feature = "async-std-runtime")]
176    AsyncStd(async_std::stream::Interval),
177    #[cfg(all(not(feature = "tokio-runtime"), not(feature = "async-std-runtime")))]
178    PlaceHolder,
179}
180
181impl Stream for Interval {
182    type Item = ();
183
184    fn poll_next(
185        self: Pin<&mut Self>,
186        cx: &mut std::task::Context,
187    ) -> std::task::Poll<Option<Self::Item>> {
188        unsafe {
189            match Pin::get_unchecked_mut(self) {
190                #[cfg(feature = "tokio-runtime")]
191                Interval::Tokio(j) => match Pin::new_unchecked(j).poll_tick(cx) {
192                    Poll::Pending => Poll::Pending,
193                    Poll::Ready(_) => Poll::Ready(Some(())),
194                },
195                #[cfg(feature = "async-std-runtime")]
196                Interval::AsyncStd(j) => match Pin::new_unchecked(j).poll_next(cx) {
197                    Poll::Pending => Poll::Pending,
198                    Poll::Ready(v) => Poll::Ready(v),
199                },
200                #[cfg(all(not(feature = "tokio-runtime"), not(feature = "async-std-runtime")))]
201                Interval::PlaceHolder => {
202                    unimplemented!("please activate one of the following cargo features: tokio-runtime, async-std-runtime")
203                }
204            }
205        }
206    }
207}
208
209/// a future producing a `()` after some time
210pub enum Delay {
211    /// wrapper around tokio's `Sleep`
212    #[cfg(feature = "tokio-runtime")]
213    Tokio(tokio::time::Sleep),
214    /// wrapper around async-std's `Delay`
215    #[cfg(feature = "async-std-runtime")]
216    AsyncStd(Pin<Box<dyn Future<Output = ()> + Send>>),
217}
218
219impl Future for Delay {
220    type Output = ();
221
222    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll<Self::Output> {
223        unsafe {
224            match Pin::get_unchecked_mut(self) {
225                #[cfg(feature = "tokio-runtime")]
226                Delay::Tokio(d) => match Pin::new_unchecked(d).poll(cx) {
227                    Poll::Pending => Poll::Pending,
228                    Poll::Ready(_) => Poll::Ready(()),
229                },
230                #[cfg(feature = "async-std-runtime")]
231                Delay::AsyncStd(j) => match Pin::new_unchecked(j).poll(cx) {
232                    Poll::Pending => Poll::Pending,
233                    Poll::Ready(_) => Poll::Ready(()),
234                },
235            }
236        }
237    }
238}