1use async_executor::Executor;
2use async_shutdown::ShutdownManager;
3use futures::executor::{block_on as block_current_thread_on, LocalPool, LocalSpawner};
4use futures::task::{FutureObj, Spawn as _};
5use spawns_core::{enter, spawn, Spawn, Task};
6use std::boxed::Box;
7use std::future::Future;
8use std::num::NonZeroUsize;
9use std::sync::Arc;
10use std::thread;
11
12struct Spawner {
13 spawner: LocalSpawner,
14}
15
16impl Spawn for Spawner {
17 fn spawn(&self, task: Task) {
18 let Task { future, .. } = task;
19 self.spawner.spawn_obj(FutureObj::new(future)).unwrap()
20 }
21}
22
23struct ExecutorSpawner<'a> {
24 executor: &'a Executor<'static>,
25}
26
27impl<'a> ExecutorSpawner<'a> {
28 fn new(executor: &'a Executor<'static>) -> Self {
29 Self { executor }
30 }
31}
32
33impl Spawn for ExecutorSpawner<'_> {
34 fn spawn(&self, task: Task) {
35 let Task { future, .. } = task;
36 self.executor.spawn(Box::into_pin(future)).detach();
37 }
38}
39
40pub struct Blocking {
42 parallelism: usize,
43}
44
45impl Blocking {
46 pub fn new(parallelism: usize) -> Self {
52 Self { parallelism }
53 }
54
55 fn parallelism(&self) -> usize {
56 match self.parallelism {
57 0 => std::thread::available_parallelism().map_or(2, NonZeroUsize::get),
58 n => n,
59 }
60 }
61
62 fn run_until<T, F>(executor: &Executor<'static>, future: F) -> T
63 where
64 F: Future<Output = T> + Send + 'static,
65 {
66 let spawner = ExecutorSpawner::new(executor);
67 let _scope = enter(&spawner);
68 block_current_thread_on(executor.run(future))
69 }
70
71 pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(
75 self,
76 future: F,
77 ) -> F::Output {
78 let threads = self.parallelism();
79 if threads == 1 {
80 return block_on(future);
81 }
82 let executor = Arc::new(Executor::new());
83 let shutdown = ShutdownManager::new();
84 let shutdown_signal = shutdown.wait_shutdown_triggered();
85 (2..=threads).for_each(|i| {
86 thread::Builder::new()
87 .name(format!("spawns-executor-{}/{}", i, threads))
88 .spawn({
89 let executor = executor.clone();
90 let shutdown_signal = shutdown_signal.clone();
91 move || Self::run_until(&executor, shutdown_signal)
92 })
93 .unwrap();
94 });
95 let _shutdown_on_drop = shutdown.trigger_shutdown_token(());
96 Self::run_until(&executor, future)
97 }
98}
99
100pub fn block_on<T: Send + 'static, F: Future<Output = T> + Send + 'static>(future: F) -> F::Output {
104 let mut pool = LocalPool::new();
105 let spawner = Spawner {
106 spawner: pool.spawner(),
107 };
108 let _scope = enter(&spawner);
109 pool.run_until(spawn(future)).unwrap()
110}
111
112#[cfg(test)]
113mod tests {
114 use super::{block_on, Blocking};
115
116 mod echo {
117 use async_net::*;
119 use futures_lite::io;
120 use futures_lite::prelude::*;
121 use spawns_core::{spawn, TaskHandle};
122
123 async fn echo_stream(stream: TcpStream) {
124 let (reader, writer) = io::split(stream);
125 let _ = io::copy(reader, writer).await;
126 }
127
128 async fn echo_server(listener: TcpListener) {
129 let mut echos = vec![];
130 loop {
131 let (conn, _addr) = listener.accept().await.unwrap();
132 echos.push(spawn(echo_stream(conn)).attach());
133 }
134 }
135
136 async fn start_echo_server() -> (u16, TaskHandle<()>) {
137 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
138 let port = listener.local_addr().unwrap().port();
139 let handle = spawn(echo_server(listener));
140 (port, handle.attach())
141 }
142
143 pub async fn echo_one(data: &[u8]) -> Vec<u8> {
144 let (port, _server_handle) = start_echo_server().await;
145 let mut stream = TcpStream::connect(format!("127.0.0.1:{}", port))
146 .await
147 .unwrap();
148 stream.write_all(data).await.unwrap();
149 stream.close().await.unwrap();
150 let mut buf = vec![];
151 stream.read_to_end(&mut buf).await.unwrap();
152 buf
153 }
154 }
155
156 #[test]
157 fn block_on_current_thread() {
158 let msg = b"Hello! Current Thread Executor!";
159 let result = block_on(echo::echo_one(msg));
160 assert_eq!(&result[..], msg);
161 }
162
163 #[test]
164 fn block_on_multi_thread() {
165 let msg = b"Hello! Multi-Thread Executor!";
166 let result = Blocking::new(4).block_on(echo::echo_one(msg));
167 assert_eq!(&result[..], msg);
168 }
169}