tokio_tasker/
tasker.rs

1use std::fmt;
2use std::future::Future;
3use std::mem;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll, Waker};
8use std::vec;
9
10use futures_util::stream::FuturesUnordered;
11use futures_util::task::noop_waker_ref;
12use futures_util::{FutureExt as _, StreamExt};
13use parking_lot::Mutex;
14use tokio::sync::Notify;
15use tokio::task::{self, JoinError, JoinHandle};
16
17use crate::{JoinStream, Signaller, Stopper};
18
19#[derive(Default)]
20pub(crate) struct Handles {
21    handles: FuturesUnordered<JoinHandle<()>>,
22    waker: Option<Waker>,
23}
24
25impl Handles {
26    fn new() -> Self {
27        Self {
28            handles: FuturesUnordered::new(),
29            waker: None,
30        }
31    }
32
33    fn wake(&mut self) {
34        self.waker.take().map(Waker::wake);
35        // ^ notifies runtime to repoll JoinStream, if any
36    }
37
38    fn push(&mut self, handle: JoinHandle<()>) {
39        self.handles.push(handle);
40        self.wake();
41    }
42
43    pub(crate) fn set_waker(&mut self, cx: &Context<'_>) {
44        self.waker = Some(cx.waker().clone());
45    }
46
47    pub(crate) fn poll_next(
48        &mut self,
49        cx: &mut Context<'_>,
50    ) -> Poll<Option<Result<(), JoinError>>> {
51        self.handles.poll_next_unpin(cx)
52    }
53}
54
55/// Internal data shared by [`Tasker`] as well as [`Stopper`] clones.
56pub(crate) struct Shared {
57    pub(crate) handles: Mutex<Handles>,
58    /// Number of outstanding `Tasker` clones.
59    // NB. We can't use Arc's strong count as `Stopper` also needs a (strong) clone.
60    num_clones: AtomicU32,
61    finished_clones: AtomicU32,
62    pub(crate) stopped: AtomicBool,
63    pub(crate) notify_stop: Notify,
64}
65
66impl Shared {
67    pub(crate) fn new() -> Self {
68        Self {
69            handles: Mutex::new(Handles::new()),
70            num_clones: AtomicU32::new(1),
71            finished_clones: AtomicU32::new(0),
72            stopped: AtomicBool::new(false),
73            notify_stop: Notify::new(),
74        }
75    }
76
77    /// Returns `true` if this call signalled stopping or `false`
78    /// if the [`Tasker`] was already stopped.
79    pub(crate) fn stop(&self) -> bool {
80        let stop = !self.stopped.swap(true, Ordering::SeqCst);
81        if stop {
82            self.notify_stop.notify_waiters();
83        }
84
85        stop
86    }
87
88    pub(crate) fn all_finished(&self) -> bool {
89        self.finished_clones.load(Ordering::SeqCst) == self.num_clones.load(Ordering::SeqCst)
90    }
91
92    pub(crate) fn ptr(&self) -> *const Self {
93        self as _
94    }
95}
96
97impl Default for Shared {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103/// Manages a group of tasks.
104///
105/// _See [library-level][crate] documentation._
106pub struct Tasker {
107    shared: Pin<Arc<Shared>>,
108}
109
110impl Tasker {
111    /// Constructs and empty task set.
112    pub fn new() -> Self {
113        Self {
114            shared: Arc::pin(Shared::new()),
115        }
116    }
117
118    /// Add a tokio task handle to the task group.
119    ///
120    /// It is your responsibility to make sure the task
121    /// is stopped by this group's [`Stopper`] future.
122    pub fn add_handle(&self, handle: JoinHandle<()>) {
123        self.shared.handles.lock().push(handle);
124    }
125
126    /// Spawn a `!Send` future the local task set and add its handle to the task group.
127    ///
128    /// It is your responsibility to make sure the task
129    /// is stopped by this group's [`Stopper`] future.
130    pub fn spawn<F>(&self, future: F)
131    where
132        F: Future + Send + 'static,
133        F::Output: Send + 'static,
134    {
135        let handle = task::spawn(future.map(|_| ()));
136        self.add_handle(handle);
137    }
138
139    /// Spawn a tokio task and add its handle to the task group.
140    /// See [`tokio::task::spawn_local()`].
141    ///
142    /// It is your responsibility to make sure the task
143    /// is stopped by this group's [`Stopper`] future.
144    pub fn spawn_local<F>(&self, future: F)
145    where
146        F: Future + 'static,
147        F::Output: 'static,
148    {
149        let handle = task::spawn_local(future.map(|_| ()));
150        self.add_handle(handle);
151    }
152
153    /// Dispense a [`Stopper`], a future that will resolve once `.stop()`
154    /// is called on any `Tasker` (or signaller [`Signaller`]) clone.
155    ///
156    /// The `Stopper` future can be used with our [`.unless()`]
157    /// extension to stop `Future`s, with [`.take_until()`] to stop `Stream`s,
158    /// as part of `tokio::select!()`, and similar...
159    ///
160    /// [`.unless()`]: crate::FutureExt::unless()
161    /// [`.take_until()`]: futures_util::stream::StreamExt::take_until()
162    pub fn stopper(&self) -> Stopper {
163        Stopper::new(&self.shared)
164    }
165
166    /// Stop the tasks in the group.
167    ///
168    /// This will resolve all [`Stopper`] futures (including ones obtained after this call).
169    ///
170    /// Returns `true` if this was the first effective stop call
171    /// or `false` if the group was already signalled to stop.
172    pub fn stop(&self) -> bool {
173        self.shared.stop()
174    }
175
176    /// Dispense a [`Signaller`], a `Tasker` clone which can be used to signal stopping,
177    /// but doesn't require dropping or calling [`finish()`][Tasker::finish()].
178    pub fn signaller(&self) -> Signaller {
179        let arc = Pin::into_inner(self.shared.clone());
180        let weak = Arc::downgrade(&arc);
181        Signaller::new(weak)
182    }
183
184    /// `true` if stopping was already signalled.
185    pub fn is_stopped(&self) -> bool {
186        self.shared.stopped.load(Ordering::SeqCst)
187    }
188
189    /// Number of tasks currently belonging to this task group.
190    pub fn num_tasks(&self) -> usize {
191        self.shared.handles.lock().handles.len()
192    }
193
194    fn mark_finished(&self) {
195        self.shared.finished_clones.fetch_add(1, Ordering::SeqCst);
196        self.shared.handles.lock().wake();
197    }
198
199    /// Join all the tasks in the group.
200    ///
201    /// **This function will panic if any of the tasks panicked**.
202    /// Use [`try_join()`][Tasker::try_join()] if you need to handle task panics yourself.
203    ///
204    /// Note that `join()` will only return once all other `Tasker` clones are finished
205    /// (via [`finish()`][Tasker::finish()] or drop).
206    pub async fn join(self) {
207        let mut join_stream = self.join_stream();
208        while let Some(handle) = join_stream.next().await {
209            handle.expect("Join error");
210        }
211    }
212
213    /// Join all the tasks in the group, returning a vector of join results
214    /// (ie. results of tokio's [`JoinHandle`][tokio::task::JoinHandle]).
215    ///
216    /// Note that `try_join()` will only return once all other `Tasker` clones are finished
217    /// (via [`finish()`][Tasker::finish()] or drop).
218    pub async fn try_join(self) -> Vec<Result<(), JoinError>> {
219        self.join_stream().collect().await
220    }
221
222    /// Returns a `Stream` which yields join results as they become available,
223    /// ie. as the tasks terminate.
224    ///
225    /// If any of the tasks terminates earlier than others, such as due to a panic,
226    /// its result should be readily avialable through this stream.
227    ///
228    /// The join stream stops after all tasks are joined *and* all other `Tasker`
229    /// clones are finished (via [`finish()`][Tasker::finish()] or drop).
230    pub fn join_stream(self) -> JoinStream {
231        JoinStream::new(self.shared.clone())
232        // self is dropped and marked finished
233    }
234
235    /// Mark this `Tasker` clone as finished.
236    ///
237    /// This has the same effect as dropping the clone, but is more explicit.
238    /// This lets the task group know that no new tasks will be added through this clone.
239    ///
240    /// All `Tasker` clones need to be finished/dropped in order for [`.join()`][Tasker::join()]
241    /// to be able to join all the tasks.
242    ///
243    /// Use [`.signaller()`][Tasker::signaller()] to get a special `Tasker` clone that doesn't
244    /// need to be dropped and can be used to `.stop()` the task group.
245    pub fn finish(self) {
246        mem::drop(self); // Drop impl will call mark_finished()
247    }
248
249    /// Poll tasks once and join those that are finished.
250    ///
251    /// Handles to tasks that are finished executing will be
252    /// removed from the internal storage.
253    ///
254    /// Returns the number of tasks that were joined.
255    ///
256    /// **This function will panic if any of the tasks panicked**.
257    /// Use [`try_poll_join()`][Tasker::try_poll_join()] if you need to handle task panics yourself.
258    pub fn poll_join(&self) -> usize {
259        let mut handles = self.shared.handles.lock();
260        let mut cx_noop = Context::from_waker(noop_waker_ref());
261
262        let mut num_joined = 0;
263        while let Poll::Ready(Some(result)) = handles.handles.poll_next_unpin(&mut cx_noop) {
264            result.expect("Join error");
265            num_joined += 1;
266        }
267
268        num_joined
269    }
270
271    /// Poll tasks once and join those that are already done.
272    ///
273    /// Handles to tasks that are already finished executing will be joined
274    /// and removed from the internal storage.
275    ///
276    /// Returns vector of join results of tasks that were joined
277    /// (may be empty if no tasks could've been joined).
278    pub fn try_poll_join(&self) -> Vec<Result<(), JoinError>> {
279        let mut handles = self.shared.handles.lock();
280
281        let mut cx_noop = Context::from_waker(noop_waker_ref());
282        let mut ready_results = vec![];
283
284        while let Poll::Ready(Some(result)) = handles.handles.poll_next_unpin(&mut cx_noop) {
285            ready_results.push(result);
286        }
287
288        ready_results
289    }
290}
291
292impl Default for Tasker {
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298impl Clone for Tasker {
299    fn clone(&self) -> Self {
300        self.shared.num_clones.fetch_add(1, Ordering::SeqCst);
301
302        Self {
303            shared: self.shared.clone(),
304        }
305    }
306}
307
308impl Drop for Tasker {
309    fn drop(&mut self) {
310        self.mark_finished();
311    }
312}
313
314impl fmt::Debug for Tasker {
315    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
316        f.debug_tuple("Tasker").field(&self.shared.ptr()).finish()
317    }
318}