1use async_executor::Executor;
2use async_shutdown::ShutdownManager;
3use futures::executor::{block_on as block_current_thread_on, LocalPool, LocalSpawner};
4use futures::task::{FutureObj, Spawn as _};
5use spawns_core::{enter, spawn, Spawn, Task};
6use std::boxed::Box;
7use std::future::Future;
8use std::num::NonZeroUsize;
9use std::sync::Arc;
10use std::thread;
11
12struct Spawner {
13 spawner: LocalSpawner,
14}
15
16impl Spawn for Spawner {
17 fn spawn(&self, task: Task) {
18 let Task { future, .. } = task;
19 self.spawner.spawn_obj(FutureObj::new(future)).unwrap()
20 }
21}
22
23struct ExecutorSpawner<'a> {
24 executor: &'a Executor<'static>,
25}
26
27impl<'a> ExecutorSpawner<'a> {
28 fn new(executor: &'a Executor<'static>) -> Self {
29 Self { executor }
30 }
31}
32
33impl Spawn for ExecutorSpawner<'_> {
34 fn spawn(&self, task: Task) {
35 let Task { future, .. } = task;
36 self.executor.spawn(Box::into_pin(future)).detach();
37 }
38}
39
40pub struct Blocking {
42 parallelism: usize,
43}
44
45impl Blocking {
46 pub fn new(parallelism: usize) -> Self {
52 Self { parallelism }
53 }
54
55 fn parallelism(&self) -> usize {
56 match self.parallelism {
57 0 => std::thread::available_parallelism().map_or(2, NonZeroUsize::get),
58 n => n,
59 }
60 }
61
62 fn run_until<T, F>(executor: &Executor<'static>, future: F) -> T
63 where
64 F: Future<Output = T> + Send + 'static,
65 {
66 let spawner = ExecutorSpawner::new(executor);
67 let _scope = enter(&spawner);
68 block_current_thread_on(executor.run(future))
69 }
70
71 pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
77 self,
78 future: F,
79 ) -> F::Output {
80 let threads = self.parallelism();
81 if threads == 1 {
82 return block_on(future);
83 }
84 let executor = Arc::new(Executor::new());
85 let shutdown = ShutdownManager::new();
86 let shutdown_signal = shutdown.wait_shutdown_triggered();
87 (2..=threads).for_each(|i| {
88 thread::Builder::new()
89 .name(format!("spawns-executor-{i}/{threads}"))
90 .spawn({
91 let executor = executor.clone();
92 let shutdown_signal = shutdown_signal.clone();
93 move || Self::run_until(&executor, shutdown_signal)
94 })
95 .unwrap();
96 });
97 let _shutdown_on_drop = shutdown.trigger_shutdown_token(());
98 Self::run_until(&executor, future)
99 }
100}
101
102pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(future: F) -> F::Output {
108 let mut pool = LocalPool::new();
109 let spawner = Spawner {
110 spawner: pool.spawner(),
111 };
112 let _scope = enter(&spawner);
113 pool.run_until(spawn(future)).unwrap()
114}
115
116#[cfg(test)]
117mod tests {
118 use super::{block_current_thread_on, block_on, Blocking};
119 use spawns_core as spawns;
120
121 mod echo {
122 use async_net::*;
124 use futures_lite::io;
125 use futures_lite::prelude::*;
126 use spawns_core::{spawn, TaskHandle};
127
128 async fn echo_stream(stream: TcpStream) {
129 let (reader, writer) = io::split(stream);
130 let _ = io::copy(reader, writer).await;
131 }
132
133 async fn echo_server(listener: TcpListener) {
134 let mut echos = vec![];
135 loop {
136 let (conn, _addr) = listener.accept().await.unwrap();
137 echos.push(spawn(echo_stream(conn)).attach());
138 }
139 }
140
141 async fn start_echo_server() -> (u16, TaskHandle<()>) {
142 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
143 let port = listener.local_addr().unwrap().port();
144 let handle = spawn(echo_server(listener));
145 (port, handle.attach())
146 }
147
148 pub async fn echo_one(data: &[u8]) -> Vec<u8> {
149 let (port, _server_handle) = start_echo_server().await;
150 let mut stream = TcpStream::connect(format!("127.0.0.1:{port}"))
151 .await
152 .unwrap();
153 stream.write_all(data).await.unwrap();
154 stream.close().await.unwrap();
155 let mut buf = vec![];
156 stream.read_to_end(&mut buf).await.unwrap();
157 buf
158 }
159 }
160
161 #[test]
162 fn block_on_current_thread() {
163 let msg = b"Hello! Current Thread Executor!";
164 let result = block_on(echo::echo_one(msg));
165 assert_eq!(&result[..], msg);
166 }
167
168 #[test]
169 fn block_on_multi_thread() {
170 let msg = b"Hello! Multi-Thread Executor!";
171 let result = Blocking::new(4).block_on(echo::echo_one(msg));
172 assert_eq!(&result[..], msg);
173 }
174
175 #[test]
176 fn task_cancelled_after_main_return_current_thread() {
177 use async_io::Timer;
178 use std::time::Duration;
179 #[allow(clippy::async_yields_async)]
180 let handle = block_on(async {
181 spawns::spawn(async { Timer::after(Duration::from_secs(30)).await })
182 });
183 let err = block_current_thread_on(handle).unwrap_err();
184 assert!(err.is_cancelled());
185 }
186
187 #[test]
188 fn task_cancelled_after_main_return_multi_thread() {
189 use async_io::Timer;
190 use std::time::Duration;
191 #[allow(clippy::async_yields_async)]
192 let handle = Blocking::new(4).block_on(async {
193 spawns::spawn(async { Timer::after(Duration::from_secs(30)).await })
194 });
195 let err = block_current_thread_on(handle).unwrap_err();
196 assert!(err.is_cancelled());
197 }
198}