tokio_util/task/join_queue.rs
1use super::AbortOnDropHandle;
2use std::{
3 collections::VecDeque,
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8use tokio::{
9 runtime::Handle,
10 task::{AbortHandle, Id, JoinError, JoinHandle},
11};
12
13/// A FIFO queue for of tasks spawned on a Tokio runtime.
14///
15/// A [`JoinQueue`] can be used to await the completion of the tasks in FIFO
16/// order. That is, if tasks are spawned in the order A, B, C, then
17/// awaiting the next completed task will always return A first, then B,
18/// then C, regardless of the order in which the tasks actually complete.
19///
20/// All of the tasks must have the same return type `T`.
21///
22/// When the [`JoinQueue`] is dropped, all tasks in the [`JoinQueue`] are
23/// immediately aborted.
24#[derive(Debug)]
25pub struct JoinQueue<T>(VecDeque<AbortOnDropHandle<T>>);
26
27impl<T> JoinQueue<T> {
28 /// Create a new empty [`JoinQueue`].
29 pub const fn new() -> Self {
30 Self(VecDeque::new())
31 }
32
33 /// Creates an empty [`JoinQueue`] with space for at least `capacity` tasks.
34 pub fn with_capacity(capacity: usize) -> Self {
35 Self(VecDeque::with_capacity(capacity))
36 }
37
38 /// Returns the number of tasks currently in the [`JoinQueue`].
39 ///
40 /// This includes both tasks that are currently running and tasks that have
41 /// completed but not yet been removed from the queue because outputting of
42 /// them waits for FIFO order.
43 pub fn len(&self) -> usize {
44 self.0.len()
45 }
46
47 /// Returns whether the [`JoinQueue`] is empty.
48 pub fn is_empty(&self) -> bool {
49 self.0.is_empty()
50 }
51
52 /// Spawn the provided task on the [`JoinQueue`], returning an [`AbortHandle`]
53 /// that can be used to remotely cancel the task.
54 ///
55 /// The provided future will start running in the background immediately
56 /// when this method is called, even if you don't await anything on this
57 /// [`JoinQueue`].
58 ///
59 /// # Panics
60 ///
61 /// This method panics if called outside of a Tokio runtime.
62 ///
63 /// [`AbortHandle`]: tokio::task::AbortHandle
64 #[track_caller]
65 pub fn spawn<F>(&mut self, task: F) -> AbortHandle
66 where
67 F: Future<Output = T> + Send + 'static,
68 T: Send + 'static,
69 {
70 self.push_back(tokio::spawn(task))
71 }
72
73 /// Spawn the provided task on the provided runtime and store it in this
74 /// [`JoinQueue`] returning an [`AbortHandle`] that can be used to remotely
75 /// cancel the task.
76 ///
77 /// The provided future will start running in the background immediately
78 /// when this method is called, even if you don't await anything on this
79 /// [`JoinQueue`].
80 ///
81 /// [`AbortHandle`]: tokio::task::AbortHandle
82 #[track_caller]
83 pub fn spawn_on<F>(&mut self, task: F, handle: &Handle) -> AbortHandle
84 where
85 F: Future<Output = T> + Send + 'static,
86 T: Send + 'static,
87 {
88 self.push_back(handle.spawn(task))
89 }
90
91 /// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`]
92 /// and store it in this [`JoinQueue`], returning an [`AbortHandle`] that
93 /// can be used to remotely cancel the task.
94 ///
95 /// The provided future will start running in the background immediately
96 /// when this method is called, even if you don't await anything on this
97 /// [`JoinQueue`].
98 ///
99 /// # Panics
100 ///
101 /// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
102 ///
103 /// [`LocalSet`]: tokio::task::LocalSet
104 /// [`LocalRuntime`]: tokio::runtime::LocalRuntime
105 /// [`AbortHandle`]: tokio::task::AbortHandle
106 #[track_caller]
107 pub fn spawn_local<F>(&mut self, task: F) -> AbortHandle
108 where
109 F: Future<Output = T> + 'static,
110 T: 'static,
111 {
112 self.push_back(tokio::task::spawn_local(task))
113 }
114
115 /// Spawn the blocking code on the blocking threadpool and store
116 /// it in this [`JoinQueue`], returning an [`AbortHandle`] that can be
117 /// used to remotely cancel the task.
118 ///
119 /// # Panics
120 ///
121 /// This method panics if called outside of a Tokio runtime.
122 ///
123 /// [`AbortHandle`]: tokio::task::AbortHandle
124 #[track_caller]
125 pub fn spawn_blocking<F>(&mut self, f: F) -> AbortHandle
126 where
127 F: FnOnce() -> T + Send + 'static,
128 T: Send + 'static,
129 {
130 self.push_back(tokio::task::spawn_blocking(f))
131 }
132
133 /// Spawn the blocking code on the blocking threadpool of the
134 /// provided runtime and store it in this [`JoinQueue`], returning an
135 /// [`AbortHandle`] that can be used to remotely cancel the task.
136 ///
137 /// [`AbortHandle`]: tokio::task::AbortHandle
138 #[track_caller]
139 pub fn spawn_blocking_on<F>(&mut self, f: F, handle: &Handle) -> AbortHandle
140 where
141 F: FnOnce() -> T + Send + 'static,
142 T: Send + 'static,
143 {
144 self.push_back(handle.spawn_blocking(f))
145 }
146
147 fn push_back(&mut self, jh: JoinHandle<T>) -> AbortHandle {
148 let jh = AbortOnDropHandle::new(jh);
149 let abort_handle = jh.abort_handle();
150 self.0.push_back(jh);
151 abort_handle
152 }
153
154 /// Waits until the next task in FIFO order completes and returns its output.
155 ///
156 /// Returns `None` if the queue is empty.
157 ///
158 /// # Cancel Safety
159 ///
160 /// This method is cancel safe. If `join_next` is used as the event in a `tokio::select!`
161 /// statement and some other branch completes first, it is guaranteed that no tasks were
162 /// removed from this [`JoinQueue`].
163 pub async fn join_next(&mut self) -> Option<Result<T, JoinError>> {
164 std::future::poll_fn(|cx| self.poll_join_next(cx)).await
165 }
166
167 /// Waits until the next task in FIFO order completes and returns its output,
168 /// along with the [task ID] of the completed task.
169 ///
170 /// Returns `None` if the queue is empty.
171 ///
172 /// When this method returns an error, then the id of the task that failed can be accessed
173 /// using the [`JoinError::id`] method.
174 ///
175 /// # Cancel Safety
176 ///
177 /// This method is cancel safe. If `join_next_with_id` is used as the event in a `tokio::select!`
178 /// statement and some other branch completes first, it is guaranteed that no tasks were
179 /// removed from this [`JoinQueue`].
180 ///
181 /// [task ID]: tokio::task::Id
182 /// [`JoinError::id`]: fn@tokio::task::JoinError::id
183 pub async fn join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
184 std::future::poll_fn(|cx| self.poll_join_next_with_id(cx)).await
185 }
186
187 /// Tries to poll an `AbortOnDropHandle` without blocking or yielding.
188 ///
189 /// Note that on success the handle will panic on subsequent polls
190 /// since it becomes consumed.
191 fn try_poll_handle(jh: &mut AbortOnDropHandle<T>) -> Option<Result<T, JoinError>> {
192 let waker = futures_util::task::noop_waker();
193 let mut cx = Context::from_waker(&waker);
194
195 // Since this function is not async and cannot be forced to yield, we should
196 // disable budgeting when we want to check for the `JoinHandle` readiness.
197 let jh = std::pin::pin!(tokio::task::coop::unconstrained(jh));
198 if let Poll::Ready(res) = jh.poll(&mut cx) {
199 Some(res)
200 } else {
201 None
202 }
203 }
204
205 /// Tries to join the next task in FIFO order if it has completed.
206 ///
207 /// Returns `None` if the queue is empty or if the next task is not yet ready.
208 pub fn try_join_next(&mut self) -> Option<Result<T, JoinError>> {
209 let jh = self.0.front_mut()?;
210 let res = Self::try_poll_handle(jh)?;
211 // Use `detach` to avoid calling `abort` on a task that has already completed.
212 // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
213 // we only need to drop the `JoinHandle` for cleanup.
214 drop(self.0.pop_front().unwrap().detach());
215 Some(res)
216 }
217
218 /// Tries to join the next task in FIFO order if it has completed and return its output,
219 /// along with its [task ID].
220 ///
221 /// Returns `None` if the queue is empty or if the next task is not yet ready.
222 ///
223 /// When this method returns an error, then the id of the task that failed can be accessed
224 /// using the [`JoinError::id`] method.
225 ///
226 /// [task ID]: tokio::task::Id
227 /// [`JoinError::id`]: fn@tokio::task::JoinError::id
228 pub fn try_join_next_with_id(&mut self) -> Option<Result<(Id, T), JoinError>> {
229 let jh = self.0.front_mut()?;
230 let res = Self::try_poll_handle(jh)?;
231 // Use `detach` to avoid calling `abort` on a task that has already completed.
232 // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
233 // we only need to drop the `JoinHandle` for cleanup.
234 let jh = self.0.pop_front().unwrap().detach();
235 let id = jh.id();
236 drop(jh);
237 Some(res.map(|output| (id, output)))
238 }
239
240 /// Aborts all tasks and waits for them to finish shutting down.
241 ///
242 /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
243 /// a loop until it returns `None`.
244 ///
245 /// This method ignores any panics in the tasks shutting down. When this call returns, the
246 /// [`JoinQueue`] will be empty.
247 ///
248 /// [`abort_all`]: fn@Self::abort_all
249 /// [`join_next`]: fn@Self::join_next
250 pub async fn shutdown(&mut self) {
251 self.abort_all();
252 while self.join_next().await.is_some() {}
253 }
254
255 /// Awaits the completion of all tasks in this [`JoinQueue`], returning a vector of their results.
256 ///
257 /// The results will be stored in the order they were spawned, not the order they completed.
258 /// This is a convenience method that is equivalent to calling [`join_next`] in
259 /// a loop. If any tasks on the [`JoinQueue`] fail with an [`JoinError`], then this call
260 /// to `join_all` will panic and all remaining tasks on the [`JoinQueue`] are
261 /// cancelled. To handle errors in any other way, manually call [`join_next`]
262 /// in a loop.
263 ///
264 /// # Cancel Safety
265 ///
266 /// This method is not cancel safe as it calls `join_next` in a loop. If you need
267 /// cancel safety, manually call `join_next` in a loop with `Vec` accumulator.
268 ///
269 /// [`join_next`]: fn@Self::join_next
270 /// [`JoinError::id`]: fn@tokio::task::JoinError::id
271 pub async fn join_all(mut self) -> Vec<T> {
272 let mut output = Vec::with_capacity(self.len());
273
274 while let Some(res) = self.join_next().await {
275 match res {
276 Ok(t) => output.push(t),
277 Err(err) if err.is_panic() => std::panic::resume_unwind(err.into_panic()),
278 Err(err) => panic!("{err}"),
279 }
280 }
281 output
282 }
283
284 /// Aborts all tasks on this [`JoinQueue`].
285 ///
286 /// This does not remove the tasks from the [`JoinQueue`]. To wait for the tasks to complete
287 /// cancellation, you should call `join_next` in a loop until the [`JoinQueue`] is empty.
288 pub fn abort_all(&mut self) {
289 self.0.iter().for_each(|jh| jh.abort());
290 }
291
292 /// Removes all tasks from this [`JoinQueue`] without aborting them.
293 ///
294 /// The tasks removed by this call will continue to run in the background even if the [`JoinQueue`]
295 /// is dropped.
296 pub fn detach_all(&mut self) {
297 self.0.drain(..).for_each(|jh| drop(jh.detach()));
298 }
299
300 /// Polls for the next task in [`JoinQueue`] to complete.
301 ///
302 /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
303 ///
304 /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
305 /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
306 /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
307 /// scheduled to receive a wakeup.
308 ///
309 /// # Returns
310 ///
311 /// This function returns:
312 ///
313 /// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
314 /// available right now.
315 /// * `Poll::Ready(Some(Ok(value)))` if the next task in this [`JoinQueue`] has completed.
316 /// The `value` is the return value that task.
317 /// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
318 /// aborted. The `err` is the `JoinError` from the panicked/aborted task.
319 /// * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
320 pub fn poll_join_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<T, JoinError>>> {
321 let jh = match self.0.front_mut() {
322 None => return Poll::Ready(None),
323 Some(jh) => jh,
324 };
325 if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
326 // Use `detach` to avoid calling `abort` on a task that has already completed.
327 // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
328 // we only need to drop the `JoinHandle` for cleanup.
329 drop(self.0.pop_front().unwrap().detach());
330 Poll::Ready(Some(res))
331 } else {
332 Poll::Pending
333 }
334 }
335
336 /// Polls for the next task in [`JoinQueue`] to complete.
337 ///
338 /// If this returns `Poll::Ready(Some(_))`, then the task that completed is removed from the queue.
339 ///
340 /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled
341 /// to receive a wakeup when a task in the [`JoinQueue`] completes. Note that on multiple calls to
342 /// `poll_join_next`, only the `Waker` from the `Context` passed to the most recent call is
343 /// scheduled to receive a wakeup.
344 ///
345 /// # Returns
346 ///
347 /// This function returns:
348 ///
349 /// * `Poll::Pending` if the [`JoinQueue`] is not empty but there is no task whose output is
350 /// available right now.
351 /// * `Poll::Ready(Some(Ok((id, value))))` if the next task in this [`JoinQueue`] has completed.
352 /// The `value` is the return value that task, and `id` is its [task ID].
353 /// * `Poll::Ready(Some(Err(err)))` if the next task in this [`JoinQueue`] has panicked or been
354 /// aborted. The `err` is the `JoinError` from the panicked/aborted task.
355 /// * `Poll::Ready(None)` if the [`JoinQueue`] is empty.
356 ///
357 /// [task ID]: tokio::task::Id
358 pub fn poll_join_next_with_id(
359 &mut self,
360 cx: &mut Context<'_>,
361 ) -> Poll<Option<Result<(Id, T), JoinError>>> {
362 let jh = match self.0.front_mut() {
363 None => return Poll::Ready(None),
364 Some(jh) => jh,
365 };
366 if let Poll::Ready(res) = Pin::new(jh).poll(cx) {
367 // Use `detach` to avoid calling `abort` on a task that has already completed.
368 // Dropping `AbortOnDropHandle` would abort the task, but since it is finished,
369 // we only need to drop the `JoinHandle` for cleanup.
370 let jh = self.0.pop_front().unwrap().detach();
371 let id = jh.id();
372 drop(jh);
373 // If the task succeeded, add the task ID to the output. Otherwise, the
374 // `JoinError` will already have the task's ID.
375 Poll::Ready(Some(res.map(|output| (id, output))))
376 } else {
377 Poll::Pending
378 }
379 }
380}
381
382impl<T> Default for JoinQueue<T> {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388/// Collect an iterator of futures into a [`JoinQueue`].
389///
390/// This is equivalent to calling [`JoinQueue::spawn`] on each element of the iterator.
391impl<T, F> std::iter::FromIterator<F> for JoinQueue<T>
392where
393 F: Future<Output = T> + Send + 'static,
394 T: Send + 'static,
395{
396 fn from_iter<I: IntoIterator<Item = F>>(iter: I) -> Self {
397 let mut set = Self::new();
398 iter.into_iter().for_each(|task| {
399 set.spawn(task);
400 });
401 set
402 }
403}