thread_future/thread_future.rs
1use futures::task::AtomicWaker;
2use futures_core::FusedFuture;
3use pin_project::{pin_project, pinned_drop};
4use std::{
5 any::Any,
6 pin::Pin,
7 sync::{Arc, atomic::AtomicBool},
8 task::{Context, Poll},
9 thread::JoinHandle,
10};
11
12pub trait CancellationToken: Clone {
13 fn cancel(&self);
14}
15
16/// An ultra-simple cancellation token that can be cloned and shared across threads.
17#[derive(Debug, Clone)]
18pub struct SimpleCancellationToken {
19 cancelled: Arc<AtomicBool>,
20}
21
22impl SimpleCancellationToken {
23 pub fn new() -> Self {
24 Self {
25 cancelled: Arc::new(false.into()),
26 }
27 }
28
29 pub fn cancel(&self) {
30 self.cancelled
31 .store(true, std::sync::atomic::Ordering::SeqCst);
32 }
33
34 pub fn is_cancelled(&self) -> bool {
35 self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
36 }
37}
38
39impl CancellationToken for SimpleCancellationToken {
40 fn cancel(&self) {
41 Self::cancel(self);
42 }
43}
44
45#[pin_project(project = ThreadFutureStateProj)]
46pub enum ThreadFutureState<T, F> {
47 /// The thread has not been started yet, the
48 /// work function is waiting
49 NotStarted(#[pin] F),
50
51 /// The thread is running.
52 Running(JoinHandle<T>),
53
54 /// The thread completed or failed, the value was
55 /// returned already. We have nothing to do.
56 Completed,
57
58 /// Internal state where the poll state machine
59 /// is being computed.
60 Polling,
61}
62
63/// Create a future that wraps a thread. The thread is lazily created.
64/// The thread is *not* (and cannot be) terminated when this structure
65/// is dropped. The thread must behave nicely and check the
66/// cancellation token given to see if it should terminate.
67///
68/// **Cancellation safety:** If the thread is still running when this future is dropped,
69/// the thread will not be forcefully terminated, but the cancellation token will be
70/// set so that the thread can check it and terminate early if it wants to. If you
71/// want to allow the thread to continue running without setting the cancellation token,
72/// you can call [Self::detach_on_drop] or [Self::detach_on_drop_ref] to prevent the
73/// cancellation token from being set on drop.
74#[pin_project(PinnedDrop)]
75pub struct ThreadFuture<T, F, C>
76where
77 C: CancellationToken + Send + 'static,
78{
79 /// The inner polling state of the wrapper
80 state: ThreadFutureState<T, F>,
81 /// `true` if we should cancel the thread when we are
82 /// dropped instead of letting it live.
83 cancel_on_drop: bool,
84 /// A cancellation token shared with the thread
85 /// that the thread can check to see if it should stop early.
86 cancellation_token: C,
87 /// Atomic waker used to communicate to the future when the thread has completed.
88 waker: Arc<AtomicWaker>,
89}
90
91impl<T, F> ThreadFuture<T, F, SimpleCancellationToken> {
92 /// Create a new future-tracked thread using the work function given.
93 ///
94 /// The thread will be lazily spawned on the first poll of this future.
95 ///
96 /// The default [SimpleCancellationToken] will be provided to the
97 /// thread work function. Check this token to see if the thread
98 /// should exit.
99 ///
100 /// See [Self::new_eager] for eagerly spawning the thread with the default [SimpleCancellationToken].
101 /// See [Self::new_with_cancellation] for providing a custom cancellation token.
102 pub fn new(work: F) -> Self
103 where
104 F: (FnOnce(SimpleCancellationToken) -> T) + Send + 'static,
105 T: Send + 'static,
106 {
107 Self {
108 state: ThreadFutureState::NotStarted(work),
109 cancel_on_drop: true,
110 cancellation_token: SimpleCancellationToken::new(),
111 waker: Arc::new(AtomicWaker::new()),
112 }
113 }
114
115 /// Create a new future-tracked thread using the work function given.
116 ///
117 /// The thread will be eagerly spawned during the call to this function.
118 ///
119 /// See [Self::new] for lazily spawning the thread with the default [SimpleCancellationToken].
120 /// See [Self::new_eager_with_cancellation] for providing a custom cancellation token.
121 pub fn new_eager(work: F) -> Self
122 where
123 F: (FnOnce(SimpleCancellationToken) -> T) + Send + 'static,
124 T: Send + 'static,
125 {
126 let cancellation_token = SimpleCancellationToken::new();
127 let waker = Arc::new(AtomicWaker::new());
128
129 let join_handle = Self::spawn_thread(work, cancellation_token.clone(), waker.clone());
130
131 let state = ThreadFutureState::Running(join_handle);
132
133 Self {
134 state,
135 cancel_on_drop: true,
136 cancellation_token,
137 waker,
138 }
139 }
140}
141
142impl<T, F, C> ThreadFuture<T, F, C>
143where
144 F: (FnOnce(C) -> T) + Send + 'static,
145 T: Send + 'static,
146 C: CancellationToken + Send + 'static,
147{
148 /// Create a new future-tracked thread using the work function given.
149 ///
150 /// The thread will be lazily spawned on the first poll of this future.
151 ///
152 /// Provide a custom cancellation token that implements [CancellationToken]
153 /// to share with the thread. The thread can check this token to see if it should
154 /// exit.
155 ///
156 /// See [Self::new_eager] for eagerly spawning the thread with the default [SimpleCancellationToken].
157 /// See [Self::new_eager_with_cancellation] for providing a custom cancellation token.
158 pub fn new_with_cancellation(work: F, cancellation_token: C) -> Self {
159 let waker = Arc::new(AtomicWaker::new());
160
161 Self {
162 state: ThreadFutureState::NotStarted(work),
163 cancel_on_drop: true,
164 cancellation_token,
165 waker,
166 }
167 }
168
169 /// Create a new future-tracked thread using the work function given.
170 ///
171 /// The thread will be eagerly spawned during the call to this function.
172 ///
173 /// See [Self::new] for lazily spawning the thread with the default [SimpleCancellationToken].
174 /// See [Self::new_with_cancellation] for providing a custom cancellation token.
175 pub fn new_eager_with_cancellation(work: F, cancellation_token: C) -> Self {
176 let waker = Arc::new(AtomicWaker::new());
177
178 let join_handle = Self::spawn_thread(work, cancellation_token.clone(), waker.clone());
179
180 let state = ThreadFutureState::Running(join_handle);
181
182 Self {
183 state,
184 cancel_on_drop: true,
185 cancellation_token,
186 waker,
187 }
188 }
189
190 /// When called, will instruct the wrapper to not to
191 /// activate the cancellation token when dropped.
192 pub fn detach_on_drop(mut self) -> Self {
193 self.cancel_on_drop = false;
194 self
195 }
196
197 /// Same as [Self::detach_on_drop], but can be called on a mutable
198 /// reference to the future instead of consuming it.
199 pub fn detach_on_drop_ref(&mut self) {
200 self.cancel_on_drop = false;
201 }
202
203 /// Check if the cancellation token will be activated when this future is dropped.
204 pub fn is_cancel_on_drop(&self) -> bool {
205 self.cancel_on_drop
206 }
207
208 /// Get a reference to the cancellation token internally stored.
209 pub fn cancellation_token(&self) -> &C {
210 &self.cancellation_token
211 }
212
213 /// Activate the internal cancellation token.
214 pub fn cancel(&self) {
215 self.cancellation_token.cancel();
216 }
217
218 /// Internal helper to spawn a thread with the given work function, cancellation token, and waker.
219 fn spawn_thread(work: F, cancel_token: C, waker: Arc<AtomicWaker>) -> JoinHandle<T>
220 where
221 F: (FnOnce(C) -> T) + Send + 'static,
222 T: Send + 'static,
223 {
224 std::thread::spawn(move || {
225 let result = work(cancel_token);
226 waker.wake();
227 result
228 })
229 }
230}
231
232/// If a thread fails to join, this is the error
233/// it may return. This can be any value from
234/// within the thread panic, so we don't know
235/// what it will be.
236///
237/// See [JoinHandle::join](std::thread::JoinHandle::join) for more details.
238type JoinError = Box<dyn Any + Send + 'static>;
239
240impl<T, F, C> Future for ThreadFuture<T, F, C>
241where
242 F: (FnOnce(C) -> T) + Send + 'static,
243 T: Send + 'static,
244 C: CancellationToken + Send + 'static,
245{
246 type Output = Result<T, JoinError>;
247
248 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
249 let this = self.project();
250
251 // Steal the current state, and make sure we replace it
252 // in the match statement below.
253 let current_state = std::mem::replace(this.state, ThreadFutureState::Polling);
254
255 match current_state {
256 ThreadFutureState::NotStarted(work) => {
257 // Create a new reference to our internal atomic waker, and
258 // register the executor's waker to be notified when the thread
259 // completes.
260 let waker = this.waker.clone();
261 waker.register(cx.waker());
262 let cancellation_token = this.cancellation_token.clone();
263 let join_handle = Self::spawn_thread(work, cancellation_token, waker);
264 *this.state = ThreadFutureState::Running(join_handle);
265 Poll::Pending
266 }
267 ThreadFutureState::Running(join_handle) => {
268 // In our implementation, we shouldn't be polled again until
269 // the thread wakes via the waker we copied to it. However,
270 // we can't assume all async runtimes will be nice and wait
271 // to poll us again, so if we are polled again before
272 // we've used the waker, make sure we don't try to join
273 // too soon.
274 if !join_handle.is_finished() {
275 // If we haven't finished yet, register the latest waker
276 this.waker.register(cx.waker());
277 }
278
279 // After potentially loading that waker, we must check once
280 // more if the thread has finished, since it could have finished
281 // between the last check and the new waker registration.
282 if join_handle.is_finished() {
283 *this.state = ThreadFutureState::Completed;
284 return Poll::Ready(join_handle.join());
285 } else {
286 // Move the state back in for the next poll
287 *this.state = ThreadFutureState::Running(join_handle);
288 return Poll::Pending;
289 }
290 }
291 // If we get polled after we completed, we will forever be
292 // pending.
293 ThreadFutureState::Completed => {
294 *this.state = ThreadFutureState::Completed;
295 Poll::Pending
296 }
297 ThreadFutureState::Polling => {
298 unreachable!(
299 "Intermediate polling state reached, this should not be possible unless the poll function was interrupted during processing!"
300 )
301 }
302 }
303 }
304}
305
306#[pinned_drop]
307impl<T, F, C> PinnedDrop for ThreadFuture<T, F, C>
308where
309 C: CancellationToken + Send + 'static,
310{
311 fn drop(self: Pin<&mut Self>) {
312 let this = self.project();
313
314 // If we are supposed to cancel the thread on drop, then
315 // set the cancellation token so the thread can check it and
316 // terminate early if it wants to.
317 if *this.cancel_on_drop {
318 this.cancellation_token.cancel();
319 }
320 }
321}
322
323/// We know when the future is terminated when the thread
324/// has completed, since we will never poll again after that.
325///
326/// The future is *not* terminated while the thread is still running.
327/// This means, even if you activate the cancellation token but the
328/// thread has not exited yet, the future is still not terminated.
329impl<T, F, C> FusedFuture for ThreadFuture<T, F, C>
330where
331 F: (FnOnce(C) -> T) + Send + 'static,
332 T: Send + 'static,
333 C: CancellationToken + Send + 'static,
334{
335 fn is_terminated(&self) -> bool {
336 matches!(self.state, ThreadFutureState::Completed)
337 }
338}