1use std::{future, future::Future, io, sync, task, thread};
31
32pub fn run<F, T>(blocking_fn: F) -> impl Future<Output = T>
40where
41 F: FnOnce() -> T + Send + 'static,
42 T: Send + 'static,
43{
44 run_with_builder(thread::Builder::new(), blocking_fn)
45 .expect("failed to spawn thread")
46 .0
47}
48
49pub fn run_with_builder<F, T>(
64 builder: thread::Builder,
65 blocking_fn: F,
66) -> io::Result<(impl Future<Output = T>, thread::JoinHandle<()>)>
67where
68 F: FnOnce() -> T + Send + 'static,
69 T: Send + 'static,
70{
71 let state: (Option<T>, Option<task::Waker>) = (None, None);
72 let state_in_future = sync::Arc::new(sync::Mutex::new(state));
73 let state_in_thread = sync::Arc::clone(&state_in_future);
74
75 Ok((
76 future::poll_fn(move |cx| {
77 let mut state = state_in_future.lock().unwrap();
78 match state.0.take() {
79 Some(output) => task::Poll::Ready(output),
80 None => {
81 match state.1.as_mut() {
82 Some(waker) => waker.clone_from(cx.waker()),
83 None => state.1 = Some(cx.waker().clone()),
84 }
85 task::Poll::Pending
86 }
87 }
88 }),
89 builder.spawn(move || {
90 let output = blocking_fn();
91 let mut state = state_in_thread.lock().unwrap();
92 state.0 = Some(output);
93 if let Some(waker) = state.1.take() {
94 waker.wake();
95 }
96 })?,
97 ))
98}
99
100#[cfg(test)]
101mod tests {
102 use super::{run, run_with_builder};
103 use std::{thread, time};
104
105 const DUR: time::Duration = time::Duration::from_millis(250);
106 const OUT: i32 = 42;
107
108 fn blocking_task() -> i32 {
109 thread::sleep(DUR);
110 OUT
111 }
112
113 #[tokio::test]
114 async fn single() {
115 let start = time::Instant::now();
116 let output = run(blocking_task).await;
117 let elapsed = time::Instant::now().duration_since(start);
118 assert!(DUR <= elapsed && elapsed < DUR * 2);
119 assert_eq!(output, OUT);
120 }
121
122 #[tokio::test]
123 async fn parallel() {
124 let start = time::Instant::now();
125 #[rustfmt::skip]
126 tokio::join!(
127 run(blocking_task), run(blocking_task), run(blocking_task), run(blocking_task),
128 run(blocking_task), run(blocking_task), run(blocking_task), run(blocking_task),
129 run(blocking_task), run(blocking_task), run(blocking_task), run(blocking_task),
130 run(blocking_task), run(blocking_task), run(blocking_task), run(blocking_task),
131 run(blocking_task), run(blocking_task), run(blocking_task), run(blocking_task),
132 run(blocking_task), run(blocking_task), run(blocking_task), run(blocking_task),
133 run(blocking_task), run(blocking_task), run(blocking_task), run(blocking_task),
134 run(blocking_task), run(blocking_task), run(blocking_task), run(blocking_task),
135 );
136 let elapsed = time::Instant::now().duration_since(start);
137 assert!(DUR <= elapsed && elapsed < DUR * 2);
138 }
139
140 #[tokio::test]
141 async fn mix_with_tokio() {
142 let start = time::Instant::now();
143 #[rustfmt::skip]
144 tokio::join!(
145 run(blocking_task), tokio::time::sleep(DUR), run(blocking_task), tokio::time::sleep(DUR),
146 run(blocking_task), tokio::time::sleep(DUR), run(blocking_task), tokio::time::sleep(DUR),
147 run(blocking_task), tokio::time::sleep(DUR), run(blocking_task), tokio::time::sleep(DUR),
148 run(blocking_task), tokio::time::sleep(DUR), run(blocking_task), tokio::time::sleep(DUR),
149 run(blocking_task), tokio::time::sleep(DUR), run(blocking_task), tokio::time::sleep(DUR),
150 run(blocking_task), tokio::time::sleep(DUR), run(blocking_task), tokio::time::sleep(DUR),
151 run(blocking_task), tokio::time::sleep(DUR), run(blocking_task), tokio::time::sleep(DUR),
152 run(blocking_task), tokio::time::sleep(DUR), run(blocking_task), tokio::time::sleep(DUR),
153 );
154 let elapsed = time::Instant::now().duration_since(start);
155 assert!(DUR <= elapsed && elapsed < DUR * 2);
156 }
157
158 #[tokio::test]
159 async fn delayed_await() {
160 let start = time::Instant::now();
161 let ft = run(blocking_task);
162 thread::sleep(DUR * 125 / 100);
163 let output = ft.await;
164 let elapsed = time::Instant::now().duration_since(start);
165 assert!(DUR <= elapsed && elapsed < DUR * 2);
166 assert_eq!(output, OUT);
167 }
168
169 #[tokio::test]
170 async fn builder() {
171 let name = "test run_with_builder()";
172 let builder = thread::Builder::new().name(name.into());
173 let start = time::Instant::now();
174 let (ft, jh) = run_with_builder(builder, blocking_task).unwrap();
175 assert_eq!(jh.thread().name(), Some(name));
176 assert!(!jh.is_finished());
177 let output = ft.await;
178 assert!(jh.is_finished());
179 let elapsed = time::Instant::now().duration_since(start);
180 assert!(DUR <= elapsed && elapsed < DUR * 2);
181 assert_eq!(output, OUT);
182 }
183
184 #[test]
185 fn sync_wait() {
186 use std::{future::Future as _, pin, sync, task};
187
188 struct MockWaker(sync::Mutex<u8>);
189 impl task::Wake for MockWaker {
190 fn wake(self: sync::Arc<Self>) {
191 *self.0.lock().unwrap() += 1;
192 }
193 }
194 let waker_inner = sync::Arc::new(MockWaker(Default::default()));
195 let waker = sync::Arc::clone(&waker_inner).into();
196 let mut context = task::Context::from_waker(&waker);
197
198 let builder = thread::Builder::new();
199 let start = time::Instant::now();
200 let (ft, jh) = run_with_builder(builder, blocking_task).unwrap();
201 let mut ft = pin::pin!(ft);
202
203 let poll_result = ft.as_mut().poll(&mut context);
204 assert!(!jh.is_finished());
205 assert_eq!(poll_result, task::Poll::Pending);
206 let poll_result = ft.as_mut().poll(&mut context);
207 assert!(!jh.is_finished());
208 assert_eq!(poll_result, task::Poll::Pending);
209
210 jh.join().unwrap();
211 let poll_result = ft.as_mut().poll(&mut context);
212 assert_eq!(poll_result, task::Poll::Ready(OUT));
213
214 let elapsed = time::Instant::now().duration_since(start);
215 assert!(DUR <= elapsed && elapsed < DUR * 2);
216 assert_eq!(*waker_inner.0.lock().unwrap(), 1);
217 }
218}