tokio_util/task/
join_queue.rs

1use super::AbortOnDropHandle;
2use std::{
3    collections::VecDeque,
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8use tokio::{
9    runtime::Handle,
10    task::{AbortHandle, Id, JoinError, JoinHandle},
11};
12
13/// A FIFO queue for of tasks spawned on a Tokio runtime.
14///
15/// A [`JoinQueue`] can be used to await the completion of the tasks in FIFO
16/// order. That is, if tasks are spawned in the order A, B, C, then
17/// awaiting the next completed task will always return A first, then B,
18/// then C, regardless of the order in which the tasks actually complete.
19///
20/// All of the tasks must have the same return type `T`.
21///
22/// When the [`JoinQueue`] is dropped, all tasks in the [`JoinQueue`] are
23/// immediately aborted.
24#[derive(Debug)]
25pub struct JoinQueue<T>(VecDeque<AbortOnDropHandle<T>>);
26
27impl<T> JoinQueue<T> {
28    /// Create a new empty [`JoinQueue`].
29    pub const fn new() -> Self {
30        Self(VecDeque::new())
31    }
32
33    /// Creates an empty [`JoinQueue`] with space for at least `capacity` tasks.
34    pub fn with_capacity(capacity: usize) -> Self {
35        Self(VecDeque::with_capacity(capacity))
36    }
37
38    /// Returns the number of tasks currently in the [`JoinQueue`].
39    ///
40    /// This includes both tasks that are currently running and tasks that have
41    /// completed but not yet been removed from the queue because outputting of
42    /// them waits for FIFO order.
43    pub fn len(&self) -> usize {
44        self.0.len()
45    }
46
47    /// Returns whether the [`JoinQueue`] is empty.
48    pub fn is_empty(&self) -> bool {
49        self.0.is_empty()
50    }
51
52    /// Spawn the provided task on the [`JoinQueue`], returning an [`AbortHandle`]
53    /// that can be used to remotely cancel the task.
54    ///
55    /// The provided future will start running in the background immediately
56    /// when this method is called, even if you don't await anything on this
57    /// [`JoinQueue`].
58    ///
59    /// # Panics
60    ///
61    /// This method panics if called outside of a Tokio runtime.
62    ///
63    /// [`AbortHandle`]: tokio::task::AbortHandle
64    #[track_caller]
65    pub fn spawn<F>(&mut self, task: F) -> AbortHandle
66    where
67        F: Future<Output = T> + Send + 'static,
68        T: Send + 'static,
69    {
70        self.push_back(tokio::spawn(task))
71    }
72
73    /// Spawn the provided task on the provided runtime and store it in this
74    /// [`JoinQueue`] returning an [`AbortHandle`] that can be used to remotely
75    /// cancel the task.
76    ///
77    /// The provided future will start running in the background immediately
78    /// when this method is called, even if you don't await anything on this
79    /// [`JoinQueue`].
80    ///
81    /// [`AbortHandle`]: tokio::task::AbortHandle
82    #[track_caller]
83    pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
84    where
85        F: Future<Output = T> + Send + 'static,
86        T: Send + 'static,
87    {
88        self.push_back(handle.spawn(task))
89    }
90
91    /// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`]
92    /// and store it in this [`JoinQueue`], returning an [`AbortHandle`] that
93    /// can be used to remotely cancel the task.
94    ///
95    /// The provided future will start running in the background immediately
96    /// when this method is called, even if you don't await anything on this
97    /// [`JoinQueue`].
98    ///
99    /// # Panics
100    ///
101    /// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
102    ///
103    /// [`LocalSet`]: tokio::task::LocalSet
104    /// [`LocalRuntime`]: tokio::runtime::LocalRuntime
105    /// [`AbortHandle`]: tokio::task::AbortHandle
106    #[track_caller]
107    pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
108    where
109        F: Future<Output = T> + 'static,
110        T: 'static,
111    {
112        self.push_back(tokio::task::spawn_local(task))
113    }
114
115    /// Spawn the blocking code on the blocking threadpool and store
116    /// it in this [`JoinQueue`], returning an [`AbortHandle`] that can be
117    /// used to remotely cancel the task.
118    ///
119    /// # Panics
120    ///
121    /// This method panics if called outside of a Tokio runtime.
122    ///
123    /// [`AbortHandle`]: tokio::task::AbortHandle
124    #[track_caller]
125    pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
126    where
127        F: FnOnce() -> T + Send + 'static,
128        T: Send + 'static,
129    {
130        self.push_back(tokio::task::spawn_blocking(f))
131    }
132
133    /// Spawn the blocking code on the blocking threadpool of the
134    /// provided runtime and store it in this [`JoinQueue`], returning an
135    /// [`AbortHandle`] that can be used to remotely cancel the task.
136    ///
137    /// [`AbortHandle`]: tokio::task::AbortHandle
138    #[track_caller]
139    pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
140    where
141        F: FnOnce() -> T + Send + 'static,
142        T: Send + 'static,
143    {
144        self.push_back(handle.spawn_blocking(f))
145    }
146
147    fn push_back(&mut self, jh: JoinHandle<T>) -> AbortHandle {
148        let jh = AbortOnDropHandle::new(jh);
149        let abort_handle = jh.abort_handle();
150        self.0.push_back(jh);
151        abort_handle
152    }
153
154    /// Waits until the next task in FIFO order completes and returns its output.
155    ///
156    /// Returns `None` if the queue is empty.
157    ///
158    /// # Cancel Safety
159    ///
160    /// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
161    /// statement and some other branch completes first, it is guaranteed that no tasks were
162    /// removed from this [`JoinQueue`].
163    pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
164        std::future::poll_fn(|cx| self.poll_join_next(cx)).await
165    }
166
167    /// Waits until the next task in FIFO order completes and returns its output,
168    /// along with the [task ID] of the completed task.
169    ///
170    /// Returns `None` if the queue is empty.
171    ///
172    /// When this method returns an error, then the id of the task that failed can be accessed
173    /// using the [`JoinError::id`] method.
174    ///
175    /// # Cancel Safety
176    ///
177    /// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
178    /// statement and some other branch completes first, it is guaranteed that no tasks were
179    /// removed from this [`JoinQueue`].
180    ///
181    /// [task ID]: tokio::task::Id
182    /// [`JoinError::id`]: fn@tokio::task::JoinError::id
183    pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
184        std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
185    }
186
187    /// Tries to poll an `AbortOnDropHandle` without blocking or yielding.
188    ///
189    /// Note that on success the handle will panic on subsequent polls
190    /// since it becomes consumed.
191    fn try_poll_handle(jh: &mut AbortOnDropHandle<T>) -> Option<Result<T, JoinError>> {
192        let waker = futures_util::task::noop_waker();
193        let mut cx = Context::from_waker(&waker);
194
195        // Since this function is not async and cannot be forced to yield, we should
196        // disable budgeting when we want to check for the `JoinHandle` readiness.
197        let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh));
198        if let Poll::Ready(res) = jh.poll(&mut cx) {
199            Some(res)
200        } else {
201            None
202        }
203    }
204
205    /// Tries to join the next task in FIFO order if it has completed.
206    ///
207    /// Returns `None` if the queue is empty or if the next task is not yet ready.
208    pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
209        let jh = self.0.front_mut()?;
210        let res = Self::try_poll_handle(jh)?;
211        // Use `detach` to avoid calling `abort` on a task that has already completed.
212        // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
213        // we only need to drop the `JoinHandle` for cleanup.
214        drop(self.0.pop_front().unwrap().detach());
215        Some(res)
216    }
217
218    /// Tries to join the next task in FIFO order if it has completed and return its output,
219    /// along with its [task ID].
220    ///
221    /// Returns `None` if the queue is empty or if the next task is not yet ready.
222    ///
223    /// When this method returns an error, then the id of the task that failed can be accessed
224    /// using the [`JoinError::id`] method.
225    ///
226    /// [task ID]: tokio::task::Id
227    /// [`JoinError::id`]: fn@tokio::task::JoinError::id
228    pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
229        let jh = self.0.front_mut()?;
230        let res = Self::try_poll_handle(jh)?;
231        // Use `detach` to avoid calling `abort` on a task that has already completed.
232        // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
233        // we only need to drop the `JoinHandle` for cleanup.
234        let jh = self.0.pop_front().unwrap().detach();
235        let id = jh.id();
236        drop(jh);
237        Some(res.map(|output| (id, output)))
238    }
239
240    /// Aborts all tasks and waits for them to finish shutting down.
241    ///
242    /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
243    /// a loop until it returns `None`.
244    ///
245    /// This method ignores any panics in the tasks shutting down. When this call returns, the
246    /// [`JoinQueue`] will be empty.
247    ///
248    /// [`abort_all`]: fn@Self::abort_all
249    /// [`join_next`]: fn@Self::join_next
250    pub async fn shutdown(&mut self) {
251        self.abort_all();
252        while self.join_next().await.is_some() {}
253    }
254
255    /// Awaits the completion of all tasks in this [`JoinQueue`], returning a vector of their results.
256    ///
257    /// The results will be stored in the order they were spawned, not the order they completed.
258    /// This is a convenience method that is equivalent to calling [`join_next`] in
259    /// a loop. If any tasks on the [`JoinQueue`] fail with an [`JoinError`], then this call
260    /// to `join_all` will panic and all remaining tasks on the [`JoinQueue`] are
261    /// cancelled. To handle errors in any other way, manually call [`join_next`]
262    /// in a loop.
263    ///
264    /// # Cancel Safety
265    ///
266    /// This method is not cancel safe as it calls `join_next` in a loop. If you need
267    /// cancel safety, manually call `join_next` in a loop with `Vec` accumulator.
268    ///
269    /// [`join_next`]: fn@Self::join_next
270    /// [`JoinError::id`]: fn@tokio::task::JoinError::id
271    pub async fn join_all(mut self) -> Vec<T> {
272        let mut output = Vec::with_capacity(self.len());
273
274        while let Some(res) = self.join_next().await {
275            match res {
276                Ok(t) => output.push(t),
277                Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
278                Err(err) => panic!("{err}"),
279            }
280        }
281        output
282    }
283
284    /// Aborts all tasks on this [`JoinQueue`].
285    ///
286    /// This does not remove the tasks from the [`JoinQueue`]. To wait for the tasks to complete
287    /// cancellation, you should call `join_next` in a loop until the [`JoinQueue`] is empty.
288    pub fn abort_all(&mut self) {
289        self.0.iter().for_each(|jh| jh.abort());
290    }
291
292    /// Removes all tasks from this [`JoinQueue`] without aborting them.
293    ///
294    /// The tasks removed by this call will continue to run in the background even if the [`JoinQueue`]
295    /// is dropped.
296    pub fn detach_all(&mut self) {
297        self.0.drain(..).for_each(|jh| drop(jh.detach()));
298    }
299
300    /// Polls for the next task in [`JoinQueue`] to complete.
301    ///
302    /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
303    ///
304    /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
305    /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
306    /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
307    /// scheduled to receive a wakeup.
308    ///
309    /// # Returns
310    ///
311    /// This function returns:
312    ///
313    ///  * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
314    ///    available right now.
315    ///  * `Poll::Ready(Some(Ok(value)))` if the next task in this [`JoinQueue`] has completed.
316    ///    The `value` is the return value that task.
317    ///  * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
318    ///    aborted. The `err` is the `JoinError` from the panicked/aborted task.
319    ///  * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
320    pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
321        let jh = match self.0.front_mut() {
322            None => return Poll::Ready(None),
323            Some(jh) => jh,
324        };
325        if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
326            // Use `detach` to avoid calling `abort` on a task that has already completed.
327            // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
328            // we only need to drop the `JoinHandle` for cleanup.
329            drop(self.0.pop_front().unwrap().detach());
330            Poll::Ready(Some(res))
331        } else {
332            Poll::Pending
333        }
334    }
335
336    /// Polls for the next task in [`JoinQueue`] to complete.
337    ///
338    /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
339    ///
340    /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
341    /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
342    /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
343    /// scheduled to receive a wakeup.
344    ///
345    /// # Returns
346    ///
347    /// This function returns:
348    ///
349    ///  * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
350    ///    available right now.
351    ///  * `Poll::Ready(Some(Ok((id, value))))` if the next task in this [`JoinQueue`] has completed.
352    ///    The `value` is the return value that task, and `id` is its [task ID].
353    ///  * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
354    ///    aborted. The `err` is the `JoinError` from the panicked/aborted task.
355    ///  * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
356    ///
357    /// [task ID]: tokio::task::Id
358    pub fn poll_join_next_with_id(
359        &mut self,
360        cx: &mut Context<'_>,
361    ) -> Poll<Option<Result<(Id, T), JoinError>>> {
362        let jh = match self.0.front_mut() {
363            None => return Poll::Ready(None),
364            Some(jh) => jh,
365        };
366        if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
367            // Use `detach` to avoid calling `abort` on a task that has already completed.
368            // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
369            // we only need to drop the `JoinHandle` for cleanup.
370            let jh = self.0.pop_front().unwrap().detach();
371            let id = jh.id();
372            drop(jh);
373            // If the task succeeded, add the task ID to the output. Otherwise, the
374            // `JoinError` will already have the task's ID.
375            Poll::Ready(Some(res.map(|output| (id, output))))
376        } else {
377            Poll::Pending
378        }
379    }
380}
381
382impl<T> Default for JoinQueue<T> {
383    fn default() -> Self {
384        Self::new()
385    }
386}
387
388/// Collect an iterator of futures into a [`JoinQueue`].
389///
390/// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator.
391impl<T, F> std::iter::FromIterator<F> for JoinQueue<T>
392where
393    F: Future<Output = T> + Send + 'static,
394    T: Send + 'static,
395{
396    fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
397        let mut set = Self::new();
398        iter.into_iter().for_each(|task| {
399            set.spawn(task);
400        });
401        set
402    }
403}