spawns_executor/
lib.rs

1use async_executor::Executor;
2use async_shutdown::ShutdownManager;
3use futures::executor::{block_on as block_current_thread_on, LocalPool, LocalSpawner};
4use futures::task::{FutureObj, Spawn as _};
5use spawns_core::{enter, spawn, Spawn, Task};
6use std::boxed::Box;
7use std::future::Future;
8use std::num::NonZeroUsize;
9use std::sync::Arc;
10use std::thread;
11
12struct Spawner {
13    spawner: LocalSpawner,
14}
15
16impl Spawn for Spawner {
17    fn spawn(&self, task: Task) {
18        let Task { future, .. } = task;
19        self.spawner.spawn_obj(FutureObj::new(future)).unwrap()
20    }
21}
22
23struct ExecutorSpawner<'a> {
24    executor: &'a Executor<'static>,
25}
26
27impl<'a> ExecutorSpawner<'a> {
28    fn new(executor: &'a Executor<'static>) -> Self {
29        Self { executor }
30    }
31}
32
33impl Spawn for ExecutorSpawner<'_> {
34    fn spawn(&self, task: Task) {
35        let Task { future, .. } = task;
36        self.executor.spawn(Box::into_pin(future)).detach();
37    }
38}
39
40/// Executor construct to block future until completion.
41pub struct Blocking {
42    parallelism: usize,
43}
44
45impl Blocking {
46    /// Creates an executor to run future and its assistant tasks.
47    ///
48    /// # Notable behaviors
49    /// * `0` means [thread::available_parallelism].
50    /// * `1` behaves identical to [block_on].
51    pub fn new(parallelism: usize) -> Self {
52        Self { parallelism }
53    }
54
55    fn parallelism(&self) -> usize {
56        match self.parallelism {
57            0 => std::thread::available_parallelism().map_or(2, NonZeroUsize::get),
58            n => n,
59        }
60    }
61
62    fn run_until<T, F>(executor: &Executor<'static>, future: F) -> T
63    where
64        F: Future<Output = T> + Send + 'static,
65    {
66        let spawner = ExecutorSpawner::new(executor);
67        let _scope = enter(&spawner);
68        block_current_thread_on(executor.run(future))
69    }
70
71    /// Blocks current thread and runs given future until completion.
72    ///
73    /// All task will be cancelled sooner or later after return.
74    ///
75    /// Uses [spawn] to spawn assistant tasks.
76    pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
77        self,
78        future: F,
79    ) -> F::Output {
80        let threads = self.parallelism();
81        if threads == 1 {
82            return block_on(future);
83        }
84        let executor = Arc::new(Executor::new());
85        let shutdown = ShutdownManager::new();
86        let shutdown_signal = shutdown.wait_shutdown_triggered();
87        (2..=threads).for_each(|i| {
88            thread::Builder::new()
89                .name(format!("spawns-executor-{i}/{threads}"))
90                .spawn({
91                    let executor = executor.clone();
92                    let shutdown_signal = shutdown_signal.clone();
93                    move || Self::run_until(&executor, shutdown_signal)
94                })
95                .unwrap();
96        });
97        let _shutdown_on_drop = shutdown.trigger_shutdown_token(());
98        Self::run_until(&executor, future)
99    }
100}
101
102/// Blocks current thread and runs given future until completion.
103///
104/// All task will be cancelled after return.
105///
106/// Uses [spawn] to spawn assistant tasks.
107pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(future: F) -> F::Output {
108    let mut pool = LocalPool::new();
109    let spawner = Spawner {
110        spawner: pool.spawner(),
111    };
112    let _scope = enter(&spawner);
113    pool.run_until(spawn(future)).unwrap()
114}
115
116#[cfg(test)]
117mod tests {
118    use super::{block_current_thread_on, block_on, Blocking};
119    use spawns_core as spawns;
120
121    mod echo {
122        // All this module are runtime agnostic.
123        use async_net::*;
124        use futures_lite::io;
125        use futures_lite::prelude::*;
126        use spawns_core::{spawn, TaskHandle};
127
128        async fn echo_stream(stream: TcpStream) {
129            let (reader, writer) = io::split(stream);
130            let _ = io::copy(reader, writer).await;
131        }
132
133        async fn echo_server(listener: TcpListener) {
134            let mut echos = vec![];
135            loop {
136                let (conn, _addr) = listener.accept().await.unwrap();
137                echos.push(spawn(echo_stream(conn)).attach());
138            }
139        }
140
141        async fn start_echo_server() -> (u16, TaskHandle<()>) {
142            let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
143            let port = listener.local_addr().unwrap().port();
144            let handle = spawn(echo_server(listener));
145            (port, handle.attach())
146        }
147
148        pub async fn echo_one(data: &[u8]) -> Vec<u8> {
149            let (port, _server_handle) = start_echo_server().await;
150            let mut stream = TcpStream::connect(format!("127.0.0.1:{port}"))
151                .await
152                .unwrap();
153            stream.write_all(data).await.unwrap();
154            stream.close().await.unwrap();
155            let mut buf = vec![];
156            stream.read_to_end(&mut buf).await.unwrap();
157            buf
158        }
159    }
160
161    #[test]
162    fn block_on_current_thread() {
163        let msg = b"Hello! Current Thread Executor!";
164        let result = block_on(echo::echo_one(msg));
165        assert_eq!(&result[..], msg);
166    }
167
168    #[test]
169    fn block_on_multi_thread() {
170        let msg = b"Hello! Multi-Thread Executor!";
171        let result = Blocking::new(4).block_on(echo::echo_one(msg));
172        assert_eq!(&result[..], msg);
173    }
174
175    #[test]
176    fn task_cancelled_after_main_return_current_thread() {
177        use async_io::Timer;
178        use std::time::Duration;
179        #[allow(clippy::async_yields_async)]
180        let handle = block_on(async {
181            spawns::spawn(async { Timer::after(Duration::from_secs(30)).await })
182        });
183        let err = block_current_thread_on(handle).unwrap_err();
184        assert!(err.is_cancelled());
185    }
186
187    #[test]
188    fn task_cancelled_after_main_return_multi_thread() {
189        use async_io::Timer;
190        use std::time::Duration;
191        #[allow(clippy::async_yields_async)]
192        let handle = Blocking::new(4).block_on(async {
193            spawns::spawn(async { Timer::after(Duration::from_secs(30)).await })
194        });
195        let err = block_current_thread_on(handle).unwrap_err();
196        assert!(err.is_cancelled());
197    }
198}