Skip to main content

thread_future/
thread_future.rs

1use futures::task::AtomicWaker;
2use futures_core::FusedFuture;
3use pin_project::{pin_project, pinned_drop};
4use std::{
5    any::Any,
6    pin::Pin,
7    sync::{Arc, atomic::AtomicBool},
8    task::{Context, Poll},
9    thread::JoinHandle,
10};
11
12pub trait CancellationToken: Clone {
13    fn cancel(&self);
14}
15
16/// An ultra-simple cancellation token that can be cloned and shared across threads.
17#[derive(Debug, Clone)]
18pub struct SimpleCancellationToken {
19    cancelled: Arc<AtomicBool>,
20}
21
22impl SimpleCancellationToken {
23    pub fn new() -> Self {
24        Self {
25            cancelled: Arc::new(false.into()),
26        }
27    }
28
29    pub fn cancel(&self) {
30        self.cancelled
31            .store(true, std::sync::atomic::Ordering::SeqCst);
32    }
33
34    pub fn is_cancelled(&self) -> bool {
35        self.cancelled.load(std::sync::atomic::Ordering::SeqCst)
36    }
37}
38
39impl CancellationToken for SimpleCancellationToken {
40    fn cancel(&self) {
41        Self::cancel(self);
42    }
43}
44
45#[pin_project(project = ThreadFutureStateProj)]
46pub enum ThreadFutureState<T, F> {
47    /// The thread has not been started yet, the
48    /// work function is waiting
49    NotStarted(#[pin] F),
50
51    /// The thread is running.
52    Running(JoinHandle<T>),
53
54    /// The thread completed or failed, the value was
55    /// returned already. We have nothing to do.
56    Completed,
57
58    /// Internal state where the poll state machine
59    /// is being computed.
60    Polling,
61}
62
63/// Create a future that wraps a thread. The thread is lazily created.
64/// The thread is *not* (and cannot be) terminated when this structure
65/// is dropped. The thread must behave nicely and check the
66/// cancellation token given to see if it should terminate.
67///
68/// **Cancellation safety:** If the thread is still running when this future is dropped,
69/// the thread will not be forcefully terminated, but the cancellation token will be
70/// set so that the thread can check it and terminate early if it wants to. If you
71/// want to allow the thread to continue running without setting the cancellation token,
72/// you can call [Self::detach_on_drop] or [Self::detach_on_drop_ref] to prevent the
73/// cancellation token from being set on drop.
74#[pin_project(PinnedDrop)]
75pub struct ThreadFuture<T, F, C>
76where
77    C: CancellationToken + Send + 'static,
78{
79    /// The inner polling state of the wrapper
80    state: ThreadFutureState<T, F>,
81    /// `true` if we should cancel the thread when we are
82    /// dropped instead of letting it live.
83    cancel_on_drop: bool,
84    /// A cancellation token shared with the thread
85    /// that the thread can check to see if it should stop early.
86    cancellation_token: C,
87    /// Atomic waker used to communicate to the future when the thread has completed.
88    waker: Arc<AtomicWaker>,
89}
90
91impl<T, F> ThreadFuture<T, F, SimpleCancellationToken> {
92    /// Create a new future-tracked thread using the work function given.
93    ///
94    /// The thread will be lazily spawned on the first poll of this future.
95    ///
96    /// The default [SimpleCancellationToken] will be provided to the
97    /// thread work function. Check this token to see if the thread
98    /// should exit.
99    ///
100    /// See [Self::new_eager] for eagerly spawning the thread with the default [SimpleCancellationToken].
101    /// See [Self::new_with_cancellation] for providing a custom cancellation token.
102    pub fn new(work: F) -> Self
103    where
104        F: (FnOnce(SimpleCancellationToken) -> T) + Send + 'static,
105        T: Send + 'static,
106    {
107        Self {
108            state: ThreadFutureState::NotStarted(work),
109            cancel_on_drop: true,
110            cancellation_token: SimpleCancellationToken::new(),
111            waker: Arc::new(AtomicWaker::new()),
112        }
113    }
114
115    /// Create a new future-tracked thread using the work function given.
116    ///
117    /// The thread will be eagerly spawned during the call to this function.
118    ///
119    /// See [Self::new] for lazily spawning the thread with the default [SimpleCancellationToken].
120    /// See [Self::new_eager_with_cancellation] for providing a custom cancellation token.
121    pub fn new_eager(work: F) -> Self
122    where
123        F: (FnOnce(SimpleCancellationToken) -> T) + Send + 'static,
124        T: Send + 'static,
125    {
126        let cancellation_token = SimpleCancellationToken::new();
127        let waker = Arc::new(AtomicWaker::new());
128
129        let join_handle = Self::spawn_thread(work, cancellation_token.clone(), waker.clone());
130
131        let state = ThreadFutureState::Running(join_handle);
132
133        Self {
134            state,
135            cancel_on_drop: true,
136            cancellation_token,
137            waker,
138        }
139    }
140}
141
142impl<T, F, C> ThreadFuture<T, F, C>
143where
144    F: (FnOnce(C) -> T) + Send + 'static,
145    T: Send + 'static,
146    C: CancellationToken + Send + 'static,
147{
148    /// Create a new future-tracked thread using the work function given.
149    ///
150    /// The thread will be lazily spawned on the first poll of this future.
151    ///
152    /// Provide a custom cancellation token that implements [CancellationToken]
153    /// to share with the thread. The thread can check this token to see if it should
154    /// exit.
155    ///
156    /// See [Self::new_eager] for eagerly spawning the thread with the default [SimpleCancellationToken].
157    /// See [Self::new_eager_with_cancellation] for providing a custom cancellation token.
158    pub fn new_with_cancellation(work: F, cancellation_token: C) -> Self {
159        let waker = Arc::new(AtomicWaker::new());
160
161        Self {
162            state: ThreadFutureState::NotStarted(work),
163            cancel_on_drop: true,
164            cancellation_token,
165            waker,
166        }
167    }
168
169    /// Create a new future-tracked thread using the work function given.
170    ///
171    /// The thread will be eagerly spawned during the call to this function.
172    ///
173    /// See [Self::new] for lazily spawning the thread with the default [SimpleCancellationToken].
174    /// See [Self::new_with_cancellation] for providing a custom cancellation token.
175    pub fn new_eager_with_cancellation(work: F, cancellation_token: C) -> Self {
176        let waker = Arc::new(AtomicWaker::new());
177
178        let join_handle = Self::spawn_thread(work, cancellation_token.clone(), waker.clone());
179
180        let state = ThreadFutureState::Running(join_handle);
181
182        Self {
183            state,
184            cancel_on_drop: true,
185            cancellation_token,
186            waker,
187        }
188    }
189
190    /// When called, will instruct the wrapper to not to
191    /// activate the cancellation token when dropped.
192    pub fn detach_on_drop(mut self) -> Self {
193        self.cancel_on_drop = false;
194        self
195    }
196
197    /// Same as [Self::detach_on_drop], but can be called on a mutable
198    /// reference to the future instead of consuming it.
199    pub fn detach_on_drop_ref(&mut self) {
200        self.cancel_on_drop = false;
201    }
202
203    /// Check if the cancellation token will be activated when this future is dropped.
204    pub fn is_cancel_on_drop(&self) -> bool {
205        self.cancel_on_drop
206    }
207
208    /// Get a reference to the cancellation token internally stored.
209    pub fn cancellation_token(&self) -> &C {
210        &self.cancellation_token
211    }
212
213    /// Activate the internal cancellation token.
214    pub fn cancel(&self) {
215        self.cancellation_token.cancel();
216    }
217
218    /// Internal helper to spawn a thread with the given work function, cancellation token, and waker.
219    fn spawn_thread(work: F, cancel_token: C, waker: Arc<AtomicWaker>) -> JoinHandle<T>
220    where
221        F: (FnOnce(C) -> T) + Send + 'static,
222        T: Send + 'static,
223    {
224        std::thread::spawn(move || {
225            let result = work(cancel_token);
226            waker.wake();
227            result
228        })
229    }
230}
231
232/// If a thread fails to join, this is the error
233/// it may return. This can be any value from
234/// within the thread panic, so we don't know
235/// what it will be.
236///
237/// See [JoinHandle::join](std::thread::JoinHandle::join) for more details.
238type JoinError = Box<dyn Any + Send + 'static>;
239
240impl<T, F, C> Future for ThreadFuture<T, F, C>
241where
242    F: (FnOnce(C) -> T) + Send + 'static,
243    T: Send + 'static,
244    C: CancellationToken + Send + 'static,
245{
246    type Output = Result<T, JoinError>;
247
248    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
249        let this = self.project();
250
251        // Steal the current state, and make sure we replace it
252        // in the match statement below.
253        let current_state = std::mem::replace(this.state, ThreadFutureState::Polling);
254
255        match current_state {
256            ThreadFutureState::NotStarted(work) => {
257                // Create a new reference to our internal atomic waker, and
258                // register the executor's waker to be notified when the thread
259                // completes.
260                let waker = this.waker.clone();
261                waker.register(cx.waker());
262                let cancellation_token = this.cancellation_token.clone();
263                let join_handle = Self::spawn_thread(work, cancellation_token, waker);
264                *this.state = ThreadFutureState::Running(join_handle);
265                Poll::Pending
266            }
267            ThreadFutureState::Running(join_handle) => {
268                // In our implementation, we shouldn't be polled again until
269                // the thread wakes via the waker we copied to it. However,
270                // we can't assume all async runtimes will be nice and wait
271                // to poll us again, so if we are polled again before
272                // we've used the waker, make sure we don't try to join
273                // too soon.
274                if !join_handle.is_finished() {
275                    // If we haven't finished yet, register the latest waker
276                    this.waker.register(cx.waker());
277                }
278
279                // After potentially loading that waker, we must check once
280                // more if the thread has finished, since it could have finished
281                // between the last check and the new waker registration.
282                if join_handle.is_finished() {
283                    *this.state = ThreadFutureState::Completed;
284                    return Poll::Ready(join_handle.join());
285                } else {
286                    // Move the state back in for the next poll
287                    *this.state = ThreadFutureState::Running(join_handle);
288                    return Poll::Pending;
289                }
290            }
291            // If we get polled after we completed, we will forever be
292            // pending.
293            ThreadFutureState::Completed => {
294                *this.state = ThreadFutureState::Completed;
295                Poll::Pending
296            }
297            ThreadFutureState::Polling => {
298                unreachable!(
299                    "Intermediate polling state reached, this should not be possible unless the poll function was interrupted during processing!"
300                )
301            }
302        }
303    }
304}
305
306#[pinned_drop]
307impl<T, F, C> PinnedDrop for ThreadFuture<T, F, C>
308where
309    C: CancellationToken + Send + 'static,
310{
311    fn drop(self: Pin<&mut Self>) {
312        let this = self.project();
313
314        // If we are supposed to cancel the thread on drop, then
315        // set the cancellation token so the thread can check it and
316        // terminate early if it wants to.
317        if *this.cancel_on_drop {
318            this.cancellation_token.cancel();
319        }
320    }
321}
322
323/// We know when the future is terminated when the thread
324/// has completed, since we will never poll again after that.
325///
326/// The future is *not* terminated while the thread is still running.
327/// This means, even if you activate the cancellation token but the
328/// thread has not exited yet, the future is still not terminated.
329impl<T, F, C> FusedFuture for ThreadFuture<T, F, C>
330where
331    F: (FnOnce(C) -> T) + Send + 'static,
332    T: Send + 'static,
333    C: CancellationToken + Send + 'static,
334{
335    fn is_terminated(&self) -> bool {
336        matches!(self.state, ThreadFutureState::Completed)
337    }
338}