tokio_util/task/join_map.rs
1use hashbrown::hash_table::Entry;
2use hashbrown::{HashMap, HashTable};
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash, Hasher};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// # Examples
33///
34/// Spawn multiple tasks and wait for them:
35///
36/// ```
37/// use tokio_util::task::JoinMap;
38///
39/// #[tokio::main]
40/// async fn main() {
41/// let mut map = JoinMap::new();
42///
43/// for i in 0..10 {
44/// // Spawn a task on the `JoinMap` with `i` as its key.
45/// map.spawn(i, async move { /* ... */ });
46/// }
47///
48/// let mut seen = [false; 10];
49///
50/// // When a task completes, `join_next` returns the task's key along
51/// // with its output.
52/// while let Some((key, res)) = map.join_next().await {
53/// seen[key] = true;
54/// assert!(res.is_ok(), "task {} completed successfully!", key);
55/// }
56///
57/// for i in 0..10 {
58/// assert!(seen[i]);
59/// }
60/// }
61/// ```
62///
63/// Cancel tasks based on their keys:
64///
65/// ```
66/// use tokio_util::task::JoinMap;
67///
68/// #[tokio::main]
69/// async fn main() {
70/// let mut map = JoinMap::new();
71///
72/// map.spawn("hello world", async move { /* ... */ });
73/// map.spawn("goodbye world", async move { /* ... */});
74///
75/// // Look up the "goodbye world" task in the map and abort it.
76/// let aborted = map.abort("goodbye world");
77///
78/// // `JoinMap::abort` returns `true` if a task existed for the
79/// // provided key.
80/// assert!(aborted);
81///
82/// while let Some((key, res)) = map.join_next().await {
83/// if key == "goodbye world" {
84/// // The aborted task should complete with a cancelled `JoinError`.
85/// assert!(res.unwrap_err().is_cancelled());
86/// } else {
87/// // Other tasks should complete normally.
88/// assert!(res.is_ok());
89/// }
90/// }
91/// }
92/// ```
93///
94/// [`JoinSet`]: tokio::task::JoinSet
95/// [abort]: fn@Self::abort
96/// [abort_matching]: fn@Self::abort_matching
97/// [contains]: fn@Self::contains_key
98pub struct JoinMap<K, V, S = RandomState> {
99 /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
100 /// indexed by their keys.
101 tasks_by_key: HashTable<(K, AbortHandle)>,
102
103 /// A map from task IDs to the hash of the key associated with that task.
104 ///
105 /// This map is used to perform reverse lookups of tasks in the
106 /// `tasks_by_key` map based on their task IDs. When a task terminates, the
107 /// ID is provided to us by the `JoinSet`, so we can look up the hash value
108 /// of that task's key, and then remove it from the `tasks_by_key` map using
109 /// the raw hash code, resolving collisions by comparing task IDs.
110 hashes_by_task: HashMap<Id, u64, S>,
111
112 /// The [`JoinSet`] that awaits the completion of tasks spawned on this
113 /// `JoinMap`.
114 tasks: JoinSet<V>,
115}
116
117impl<K, V> JoinMap<K, V> {
118 /// Creates a new empty `JoinMap`.
119 ///
120 /// The `JoinMap` is initially created with a capacity of 0, so it will not
121 /// allocate until a task is first spawned on it.
122 ///
123 /// # Examples
124 ///
125 /// ```
126 /// use tokio_util::task::JoinMap;
127 /// let map: JoinMap<&str, i32> = JoinMap::new();
128 /// ```
129 #[inline]
130 #[must_use]
131 pub fn new() -> Self {
132 Self::with_hasher(RandomState::new())
133 }
134
135 /// Creates an empty `JoinMap` with the specified capacity.
136 ///
137 /// The `JoinMap` will be able to hold at least `capacity` tasks without
138 /// reallocating.
139 ///
140 /// # Examples
141 ///
142 /// ```
143 /// use tokio_util::task::JoinMap;
144 /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
145 /// ```
146 #[inline]
147 #[must_use]
148 pub fn with_capacity(capacity: usize) -> Self {
149 JoinMap::with_capacity_and_hasher(capacity, Default::default())
150 }
151}
152
153impl<K, V, S> JoinMap<K, V, S> {
154 /// Creates an empty `JoinMap` which will use the given hash builder to hash
155 /// keys.
156 ///
157 /// The created map has the default initial capacity.
158 ///
159 /// Warning: `hash_builder` is normally randomly generated, and
160 /// is designed to allow `JoinMap` to be resistant to attacks that
161 /// cause many collisions and very poor performance. Setting it
162 /// manually using this function can expose a DoS attack vector.
163 ///
164 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
165 /// the `JoinMap` to be useful, see its documentation for details.
166 #[inline]
167 #[must_use]
168 pub fn with_hasher(hash_builder: S) -> Self {
169 Self::with_capacity_and_hasher(0, hash_builder)
170 }
171
172 /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
173 /// to hash the keys.
174 ///
175 /// The `JoinMap` will be able to hold at least `capacity` elements without
176 /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
177 ///
178 /// Warning: `hash_builder` is normally randomly generated, and
179 /// is designed to allow HashMaps to be resistant to attacks that
180 /// cause many collisions and very poor performance. Setting it
181 /// manually using this function can expose a DoS attack vector.
182 ///
183 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
184 /// the `JoinMap`to be useful, see its documentation for details.
185 ///
186 /// # Examples
187 ///
188 /// ```
189 /// # #[tokio::main]
190 /// # async fn main() {
191 /// use tokio_util::task::JoinMap;
192 /// use std::collections::hash_map::RandomState;
193 ///
194 /// let s = RandomState::new();
195 /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
196 /// map.spawn(1, async move { "hello world!" });
197 /// # }
198 /// ```
199 #[inline]
200 #[must_use]
201 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
202 Self {
203 tasks_by_key: HashTable::with_capacity(capacity),
204 hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
205 tasks: JoinSet::new(),
206 }
207 }
208
209 /// Returns the number of tasks currently in the `JoinMap`.
210 pub fn len(&self) -> usize {
211 let len = self.tasks_by_key.len();
212 debug_assert_eq!(len, self.hashes_by_task.len());
213 len
214 }
215
216 /// Returns whether the `JoinMap` is empty.
217 pub fn is_empty(&self) -> bool {
218 let empty = self.tasks_by_key.is_empty();
219 debug_assert_eq!(empty, self.hashes_by_task.is_empty());
220 empty
221 }
222
223 /// Returns the number of tasks the map can hold without reallocating.
224 ///
225 /// This number is a lower bound; the `JoinMap` might be able to hold
226 /// more, but is guaranteed to be able to hold at least this many.
227 ///
228 /// # Examples
229 ///
230 /// ```
231 /// use tokio_util::task::JoinMap;
232 ///
233 /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
234 /// assert!(map.capacity() >= 100);
235 /// ```
236 #[inline]
237 pub fn capacity(&self) -> usize {
238 let capacity = self.tasks_by_key.capacity();
239 debug_assert_eq!(capacity, self.hashes_by_task.capacity());
240 capacity
241 }
242}
243
244impl<K, V, S> JoinMap<K, V, S>
245where
246 K: Hash + Eq,
247 V: 'static,
248 S: BuildHasher,
249{
250 /// Spawn the provided task and store it in this `JoinMap` with the provided
251 /// key.
252 ///
253 /// If a task previously existed in the `JoinMap` for this key, that task
254 /// will be cancelled and replaced with the new one. The previous task will
255 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
256 /// *not* return a cancelled [`JoinError`] for that task.
257 ///
258 /// # Panics
259 ///
260 /// This method panics if called outside of a Tokio runtime.
261 ///
262 /// [`join_next`]: Self::join_next
263 #[track_caller]
264 pub fn spawn<F>(&mut self, key: K, task: F)
265 where
266 F: Future<Output = V>,
267 F: Send + 'static,
268 V: Send,
269 {
270 let task = self.tasks.spawn(task);
271 self.insert(key, task)
272 }
273
274 /// Spawn the provided task on the provided runtime and store it in this
275 /// `JoinMap` with the provided key.
276 ///
277 /// If a task previously existed in the `JoinMap` for this key, that task
278 /// will be cancelled and replaced with the new one. The previous task will
279 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
280 /// *not* return a cancelled [`JoinError`] for that task.
281 ///
282 /// [`join_next`]: Self::join_next
283 #[track_caller]
284 pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
285 where
286 F: Future<Output = V>,
287 F: Send + 'static,
288 V: Send,
289 {
290 let task = self.tasks.spawn_on(task, handle);
291 self.insert(key, task);
292 }
293
294 /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
295 /// key.
296 ///
297 /// If a task previously existed in the `JoinMap` for this key, that task
298 /// will be cancelled and replaced with the new one. The previous task will
299 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
300 /// *not* return a cancelled [`JoinError`] for that task.
301 ///
302 /// Note that blocking tasks cannot be cancelled after execution starts.
303 /// Replaced blocking tasks will still run to completion if the task has begun
304 /// to execute when it is replaced. A blocking task which is replaced before
305 /// it has been scheduled on a blocking worker thread will be cancelled.
306 ///
307 /// # Panics
308 ///
309 /// This method panics if called outside of a Tokio runtime.
310 ///
311 /// [`join_next`]: Self::join_next
312 #[track_caller]
313 pub fn spawn_blocking<F>(&mut self, key: K, f: F)
314 where
315 F: FnOnce() -> V,
316 F: Send + 'static,
317 V: Send,
318 {
319 let task = self.tasks.spawn_blocking(f);
320 self.insert(key, task)
321 }
322
323 /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
324 /// `JoinMap` with the provided key.
325 ///
326 /// If a task previously existed in the `JoinMap` for this key, that task
327 /// will be cancelled and replaced with the new one. The previous task will
328 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
329 /// *not* return a cancelled [`JoinError`] for that task.
330 ///
331 /// Note that blocking tasks cannot be cancelled after execution starts.
332 /// Replaced blocking tasks will still run to completion if the task has begun
333 /// to execute when it is replaced. A blocking task which is replaced before
334 /// it has been scheduled on a blocking worker thread will be cancelled.
335 ///
336 /// [`join_next`]: Self::join_next
337 #[track_caller]
338 pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
339 where
340 F: FnOnce() -> V,
341 F: Send + 'static,
342 V: Send,
343 {
344 let task = self.tasks.spawn_blocking_on(f, handle);
345 self.insert(key, task);
346 }
347
348 /// Spawn the provided task on the current [`LocalSet`] and store it in this
349 /// `JoinMap` with the provided key.
350 ///
351 /// If a task previously existed in the `JoinMap` for this key, that task
352 /// will be cancelled and replaced with the new one. The previous task will
353 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
354 /// *not* return a cancelled [`JoinError`] for that task.
355 ///
356 /// # Panics
357 ///
358 /// This method panics if it is called outside of a `LocalSet`.
359 ///
360 /// [`LocalSet`]: tokio::task::LocalSet
361 /// [`join_next`]: Self::join_next
362 #[track_caller]
363 pub fn spawn_local<F>(&mut self, key: K, task: F)
364 where
365 F: Future<Output = V>,
366 F: 'static,
367 {
368 let task = self.tasks.spawn_local(task);
369 self.insert(key, task);
370 }
371
372 /// Spawn the provided task on the provided [`LocalSet`] and store it in
373 /// this `JoinMap` with the provided key.
374 ///
375 /// If a task previously existed in the `JoinMap` for this key, that task
376 /// will be cancelled and replaced with the new one. The previous task will
377 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
378 /// *not* return a cancelled [`JoinError`] for that task.
379 ///
380 /// [`LocalSet`]: tokio::task::LocalSet
381 /// [`join_next`]: Self::join_next
382 #[track_caller]
383 pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
384 where
385 F: Future<Output = V>,
386 F: 'static,
387 {
388 let task = self.tasks.spawn_local_on(task, local_set);
389 self.insert(key, task)
390 }
391
392 fn insert(&mut self, mut key: K, mut abort: AbortHandle) {
393 let hash_builder = self.hashes_by_task.hasher();
394 let hash = hash_one(hash_builder, &key);
395 let id = abort.id();
396
397 // Insert the new key into the map of tasks by keys.
398 let entry =
399 self.tasks_by_key
400 .entry(hash, |(k, _)| *k == key, |(k, _)| hash_one(hash_builder, k));
401 match entry {
402 Entry::Occupied(occ) => {
403 // There was a previous task spawned with the same key! Cancel
404 // that task, and remove its ID from the map of hashes by task IDs.
405 (key, abort) = std::mem::replace(occ.into_mut(), (key, abort));
406
407 // Remove the old task ID.
408 let _prev_hash = self.hashes_by_task.remove(&abort.id());
409 debug_assert_eq!(Some(hash), _prev_hash);
410
411 // Associate the key's hash with the new task's ID, for looking up tasks by ID.
412 let _prev = self.hashes_by_task.insert(id, hash);
413 debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
414
415 // Note: it's important to drop `key` and abort the task here.
416 // This defends against any panics during drop handling for causing inconsistent state.
417 abort.abort();
418 drop(key);
419 }
420 Entry::Vacant(vac) => {
421 vac.insert((key, abort));
422
423 // Associate the key's hash with this task's ID, for looking up tasks by ID.
424 let _prev = self.hashes_by_task.insert(id, hash);
425 debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
426 }
427 };
428 }
429
430 /// Waits until one of the tasks in the map completes and returns its
431 /// output, along with the key corresponding to that task.
432 ///
433 /// Returns `None` if the map is empty.
434 ///
435 /// # Cancel Safety
436 ///
437 /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
438 /// statement and some other branch completes first, it is guaranteed that no tasks were
439 /// removed from this `JoinMap`.
440 ///
441 /// # Returns
442 ///
443 /// This function returns:
444 ///
445 /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
446 /// completed. The `value` is the return value of that ask, and `key` is
447 /// the key associated with the task.
448 /// * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
449 /// panicked or been aborted. `key` is the key associated with the task
450 /// that panicked or was aborted.
451 /// * `None` if the `JoinMap` is empty.
452 ///
453 /// [`tokio::select!`]: tokio::select
454 pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
455 loop {
456 let (res, id) = match self.tasks.join_next_with_id().await {
457 Some(Ok((id, output))) => (Ok(output), id),
458 Some(Err(e)) => {
459 let id = e.id();
460 (Err(e), id)
461 }
462 None => return None,
463 };
464 if let Some(key) = self.remove_by_id(id) {
465 break Some((key, res));
466 }
467 }
468 }
469
470 /// Aborts all tasks and waits for them to finish shutting down.
471 ///
472 /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
473 /// a loop until it returns `None`.
474 ///
475 /// This method ignores any panics in the tasks shutting down. When this call returns, the
476 /// `JoinMap` will be empty.
477 ///
478 /// [`abort_all`]: fn@Self::abort_all
479 /// [`join_next`]: fn@Self::join_next
480 pub async fn shutdown(&mut self) {
481 self.abort_all();
482 while self.join_next().await.is_some() {}
483 }
484
485 /// Abort the task corresponding to the provided `key`.
486 ///
487 /// If this `JoinMap` contains a task corresponding to `key`, this method
488 /// will abort that task and return `true`. Otherwise, if no task exists for
489 /// `key`, this method returns `false`.
490 ///
491 /// # Examples
492 ///
493 /// Aborting a task by key:
494 ///
495 /// ```
496 /// use tokio_util::task::JoinMap;
497 ///
498 /// # #[tokio::main]
499 /// # async fn main() {
500 /// let mut map = JoinMap::new();
501 ///
502 /// map.spawn("hello world", async move { /* ... */ });
503 /// map.spawn("goodbye world", async move { /* ... */});
504 ///
505 /// // Look up the "goodbye world" task in the map and abort it.
506 /// map.abort("goodbye world");
507 ///
508 /// while let Some((key, res)) = map.join_next().await {
509 /// if key == "goodbye world" {
510 /// // The aborted task should complete with a cancelled `JoinError`.
511 /// assert!(res.unwrap_err().is_cancelled());
512 /// } else {
513 /// // Other tasks should complete normally.
514 /// assert!(res.is_ok());
515 /// }
516 /// }
517 /// # }
518 /// ```
519 ///
520 /// `abort` returns `true` if a task was aborted:
521 /// ```
522 /// use tokio_util::task::JoinMap;
523 ///
524 /// # #[tokio::main]
525 /// # async fn main() {
526 /// let mut map = JoinMap::new();
527 ///
528 /// map.spawn("hello world", async move { /* ... */ });
529 /// map.spawn("goodbye world", async move { /* ... */});
530 ///
531 /// // A task for the key "goodbye world" should exist in the map:
532 /// assert!(map.abort("goodbye world"));
533 ///
534 /// // Aborting a key that does not exist will return `false`:
535 /// assert!(!map.abort("goodbye universe"));
536 /// # }
537 /// ```
538 pub fn abort<Q>(&mut self, key: &Q) -> bool
539 where
540 Q: ?Sized + Hash + Eq,
541 K: Borrow<Q>,
542 {
543 match self.get_by_key(key) {
544 Some((_, handle)) => {
545 handle.abort();
546 true
547 }
548 None => false,
549 }
550 }
551
552 /// Aborts all tasks with keys matching `predicate`.
553 ///
554 /// `predicate` is a function called with a reference to each key in the
555 /// map. If it returns `true` for a given key, the corresponding task will
556 /// be cancelled.
557 ///
558 /// # Examples
559 /// ```
560 /// use tokio_util::task::JoinMap;
561 ///
562 /// # // use the current thread rt so that spawned tasks don't
563 /// # // complete in the background before they can be aborted.
564 /// # #[tokio::main(flavor = "current_thread")]
565 /// # async fn main() {
566 /// let mut map = JoinMap::new();
567 ///
568 /// map.spawn("hello world", async move {
569 /// // ...
570 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
571 /// });
572 /// map.spawn("goodbye world", async move {
573 /// // ...
574 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
575 /// });
576 /// map.spawn("hello san francisco", async move {
577 /// // ...
578 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
579 /// });
580 /// map.spawn("goodbye universe", async move {
581 /// // ...
582 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
583 /// });
584 ///
585 /// // Abort all tasks whose keys begin with "goodbye"
586 /// map.abort_matching(|key| key.starts_with("goodbye"));
587 ///
588 /// let mut seen = 0;
589 /// while let Some((key, res)) = map.join_next().await {
590 /// seen += 1;
591 /// if key.starts_with("goodbye") {
592 /// // The aborted task should complete with a cancelled `JoinError`.
593 /// assert!(res.unwrap_err().is_cancelled());
594 /// } else {
595 /// // Other tasks should complete normally.
596 /// assert!(key.starts_with("hello"));
597 /// assert!(res.is_ok());
598 /// }
599 /// }
600 ///
601 /// // All spawned tasks should have completed.
602 /// assert_eq!(seen, 4);
603 /// # }
604 /// ```
605 pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
606 // Note: this method iterates over the tasks and keys *without* removing
607 // any entries, so that the keys from aborted tasks can still be
608 // returned when calling `join_next` in the future.
609 for (key, task) in &self.tasks_by_key {
610 if predicate(key) {
611 task.abort();
612 }
613 }
614 }
615
616 /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
617 ///
618 /// If a task has completed, but its output hasn't yet been consumed by a
619 /// call to [`join_next`], this method will still return its key.
620 ///
621 /// [`join_next`]: fn@Self::join_next
622 pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
623 JoinMapKeys {
624 iter: self.tasks_by_key.iter(),
625 _value: PhantomData,
626 }
627 }
628
629 /// Returns `true` if this `JoinMap` contains a task for the provided key.
630 ///
631 /// If the task has completed, but its output hasn't yet been consumed by a
632 /// call to [`join_next`], this method will still return `true`.
633 ///
634 /// [`join_next`]: fn@Self::join_next
635 pub fn contains_key<Q>(&self, key: &Q) -> bool
636 where
637 Q: ?Sized + Hash + Eq,
638 K: Borrow<Q>,
639 {
640 self.get_by_key(key).is_some()
641 }
642
643 /// Returns `true` if this `JoinMap` contains a task with the provided
644 /// [task ID].
645 ///
646 /// If the task has completed, but its output hasn't yet been consumed by a
647 /// call to [`join_next`], this method will still return `true`.
648 ///
649 /// [`join_next`]: fn@Self::join_next
650 /// [task ID]: tokio::task::Id
651 pub fn contains_task(&self, task: &Id) -> bool {
652 self.hashes_by_task.contains_key(task)
653 }
654
655 /// Reserves capacity for at least `additional` more tasks to be spawned
656 /// on this `JoinMap` without reallocating for the map of task keys. The
657 /// collection may reserve more space to avoid frequent reallocations.
658 ///
659 /// Note that spawning a task will still cause an allocation for the task
660 /// itself.
661 ///
662 /// # Panics
663 ///
664 /// Panics if the new allocation size overflows [`usize`].
665 ///
666 /// # Examples
667 ///
668 /// ```
669 /// use tokio_util::task::JoinMap;
670 ///
671 /// let mut map: JoinMap<&str, i32> = JoinMap::new();
672 /// map.reserve(10);
673 /// ```
674 #[inline]
675 pub fn reserve(&mut self, additional: usize) {
676 let hash_builder = self.hashes_by_task.hasher();
677 self.tasks_by_key
678 .reserve(additional, |(k, _)| hash_one(hash_builder, k));
679 self.hashes_by_task.reserve(additional);
680 }
681
682 /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
683 /// down as much as possible while maintaining the internal rules
684 /// and possibly leaving some space in accordance with the resize policy.
685 ///
686 /// # Examples
687 ///
688 /// ```
689 /// # #[tokio::main]
690 /// # async fn main() {
691 /// use tokio_util::task::JoinMap;
692 ///
693 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
694 /// map.spawn(1, async move { 2 });
695 /// map.spawn(3, async move { 4 });
696 /// assert!(map.capacity() >= 100);
697 /// map.shrink_to_fit();
698 /// assert!(map.capacity() >= 2);
699 /// # }
700 /// ```
701 #[inline]
702 pub fn shrink_to_fit(&mut self) {
703 self.hashes_by_task.shrink_to_fit();
704 let hash_builder = self.hashes_by_task.hasher();
705 self.tasks_by_key
706 .shrink_to_fit(|(k, _)| hash_one(hash_builder, k));
707 }
708
709 /// Shrinks the capacity of the map with a lower limit. It will drop
710 /// down no lower than the supplied limit while maintaining the internal rules
711 /// and possibly leaving some space in accordance with the resize policy.
712 ///
713 /// If the current capacity is less than the lower limit, this is a no-op.
714 ///
715 /// # Examples
716 ///
717 /// ```
718 /// # #[tokio::main]
719 /// # async fn main() {
720 /// use tokio_util::task::JoinMap;
721 ///
722 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
723 /// map.spawn(1, async move { 2 });
724 /// map.spawn(3, async move { 4 });
725 /// assert!(map.capacity() >= 100);
726 /// map.shrink_to(10);
727 /// assert!(map.capacity() >= 10);
728 /// map.shrink_to(0);
729 /// assert!(map.capacity() >= 2);
730 /// # }
731 /// ```
732 #[inline]
733 pub fn shrink_to(&mut self, min_capacity: usize) {
734 self.hashes_by_task.shrink_to(min_capacity);
735 let hash_builder = self.hashes_by_task.hasher();
736 self.tasks_by_key
737 .shrink_to(min_capacity, |(k, _)| hash_one(hash_builder, k))
738 }
739
740 /// Look up a task in the map by its key, returning the key and abort handle.
741 fn get_by_key<'map, Q>(&'map self, key: &Q) -> Option<&'map (K, AbortHandle)>
742 where
743 Q: ?Sized + Hash + Eq,
744 K: Borrow<Q>,
745 {
746 let hash_builder = self.hashes_by_task.hasher();
747 let hash = hash_one(hash_builder, key);
748 self.tasks_by_key.find(hash, |(k, _)| k.borrow() == key)
749 }
750
751 /// Remove a task from the map by ID, returning the key for that task.
752 fn remove_by_id(&mut self, id: Id) -> Option<K> {
753 // Get the hash for the given ID.
754 let hash = self.hashes_by_task.remove(&id)?;
755
756 // Remove the entry for that hash.
757 let entry = self
758 .tasks_by_key
759 .find_entry(hash, |(_, abort)| abort.id() == id);
760 let (key, _) = match entry {
761 Ok(entry) => entry.remove().0,
762 _ => return None,
763 };
764 self.hashes_by_task.remove(&id);
765 Some(key)
766 }
767}
768
769/// Returns the hash for a given key.
770#[inline]
771fn hash_one<S, Q>(hash_builder: &S, key: &Q) -> u64
772where
773 Q: ?Sized + Hash,
774 S: BuildHasher,
775{
776 let mut hasher = hash_builder.build_hasher();
777 key.hash(&mut hasher);
778 hasher.finish()
779}
780
781impl<K, V, S> JoinMap<K, V, S>
782where
783 V: 'static,
784{
785 /// Aborts all tasks on this `JoinMap`.
786 ///
787 /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
788 /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
789 pub fn abort_all(&mut self) {
790 self.tasks.abort_all()
791 }
792
793 /// Removes all tasks from this `JoinMap` without aborting them.
794 ///
795 /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
796 /// is dropped. They may still be aborted by key.
797 pub fn detach_all(&mut self) {
798 self.tasks.detach_all();
799 self.tasks_by_key.clear();
800 self.hashes_by_task.clear();
801 }
802}
803
804// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
805// Debug`, since no value is ever actually stored in the map.
806impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
807 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
808 // format the task keys and abort handles a little nicer by just
809 // printing the key and task ID pairs, without format the `Key` struct
810 // itself or the `AbortHandle`, which would just format the task's ID
811 // again.
812 struct KeySet<'a, K: fmt::Debug>(&'a HashTable<(K, AbortHandle)>);
813 impl<K: fmt::Debug> fmt::Debug for KeySet<'_, K> {
814 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
815 f.debug_map()
816 .entries(self.0.iter().map(|(key, abort)| (key, abort.id())))
817 .finish()
818 }
819 }
820
821 f.debug_struct("JoinMap")
822 // The `tasks_by_key` map is the only one that contains information
823 // that's really worth formatting for the user, since it contains
824 // the tasks' keys and IDs. The other fields are basically
825 // implementation details.
826 .field("tasks", &KeySet(&self.tasks_by_key))
827 .finish()
828 }
829}
830
831impl<K, V> Default for JoinMap<K, V> {
832 fn default() -> Self {
833 Self::new()
834 }
835}
836
837/// An iterator over the keys of a [`JoinMap`].
838#[derive(Debug, Clone)]
839pub struct JoinMapKeys<'a, K, V> {
840 iter: hashbrown::hash_table::Iter<'a, (K, AbortHandle)>,
841 /// To make it easier to change `JoinMap` in the future, keep V as a generic
842 /// parameter.
843 _value: PhantomData<&'a V>,
844}
845
846impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
847 type Item = &'a K;
848
849 fn next(&mut self) -> Option<&'a K> {
850 self.iter.next().map(|(key, _)| key)
851 }
852
853 fn size_hint(&self) -> (usize, Option<usize>) {
854 self.iter.size_hint()
855 }
856}
857
858impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
859 fn len(&self) -> usize {
860 self.iter.len()
861 }
862}
863
864impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}