tokio_util/task/
join_map.rs

1use hashbrown::hash_table::Entry;
2use hashbrown::{HashMap, HashTable};
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a  set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on   their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// # Examples
33///
34/// Spawn multiple tasks and wait for them:
35///
36/// ```
37/// use tokio_util::task::JoinMap;
38///
39/// # #[tokio::main(flavor = "current_thread")]
40/// # async fn main() {
41/// let mut map = JoinMap::new();
42///
43/// for i in 0..10 {
44///     // Spawn a task on the `JoinMap` with `i` as its key.
45///     map.spawn(i, async move { /* ... */ });
46/// }
47///
48/// let mut seen = [false; 10];
49///
50/// // When a task completes, `join_next` returns the task's key along
51/// // with its output.
52/// while let Some((key, res)) = map.join_next().await {
53///     seen[key] = true;
54///     assert!(res.is_ok(), "task {} completed successfully!", key);
55/// }
56///
57/// for i in 0..10 {
58///     assert!(seen[i]);
59/// }
60/// # }
61/// ```
62///
63/// Cancel tasks based on their keys:
64///
65/// ```
66/// use tokio_util::task::JoinMap;
67///
68/// # #[tokio::main(flavor = "current_thread")]
69/// # async fn main() {
70/// let mut map = JoinMap::new();
71///
72/// map.spawn("hello world", std::future::ready(1));
73/// map.spawn("goodbye world", std::future::pending());
74///
75/// // Look up the "goodbye world" task in the map and abort it.
76/// let aborted = map.abort("goodbye world");
77///
78/// // `JoinMap::abort` returns `true` if a task existed for the
79/// // provided key.
80/// assert!(aborted);
81///
82/// while let Some((key, res)) = map.join_next().await {
83///     if key == "goodbye world" {
84///         // The aborted task should complete with a cancelled `JoinError`.
85///         assert!(res.unwrap_err().is_cancelled());
86///     } else {
87///         // Other tasks should complete normally.
88///         assert_eq!(res.unwrap(), 1);
89///     }
90/// }
91/// # }
92/// ```
93///
94/// [`JoinSet`]: tokio::task::JoinSet
95/// [abort]: fn@Self::abort
96/// [abort_matching]: fn@Self::abort_matching
97/// [contains]: fn@Self::contains_key
98pub struct JoinMap<K, V, S = RandomState> {
99    /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
100    /// indexed by their keys.
101    tasks_by_key: HashTable<(K, AbortHandle)>,
102
103    /// A map from task IDs to the hash of the key associated with that task.
104    ///
105    /// This map is used to perform reverse lookups of tasks in the
106    /// `tasks_by_key` map based on their task IDs. When a task terminates, the
107    /// ID is provided to us by the `JoinSet`, so we can look up the hash value
108    /// of that task's key, and then remove it from the `tasks_by_key` map using
109    /// the raw hash code, resolving collisions by comparing task IDs.
110    hashes_by_task: HashMap<Id, u64, S>,
111
112    /// The [`JoinSet`] that awaits the completion of tasks spawned on this
113    /// `JoinMap`.
114    tasks: JoinSet<V>,
115}
116
117impl<K, V> JoinMap<K, V> {
118    /// Creates a new empty `JoinMap`.
119    ///
120    /// The `JoinMap` is initially created with a capacity of 0, so it will not
121    /// allocate until a task is first spawned on it.
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// use tokio_util::task::JoinMap;
127    /// let map: JoinMap<&str, i32> = JoinMap::new();
128    /// ```
129    #[inline]
130    #[must_use]
131    pub fn new() -> Self {
132        Self::with_hasher(RandomState::new())
133    }
134
135    /// Creates an empty `JoinMap` with the specified capacity.
136    ///
137    /// The `JoinMap` will be able to hold at least `capacity` tasks without
138    /// reallocating.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// use tokio_util::task::JoinMap;
144    /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
145    /// ```
146    #[inline]
147    #[must_use]
148    pub fn with_capacity(capacity: usize) -> Self {
149        JoinMap::with_capacity_and_hasher(capacity, Default::default())
150    }
151}
152
153impl<K, V, S> JoinMap<K, V, S> {
154    /// Creates an empty `JoinMap` which will use the given hash builder to hash
155    /// keys.
156    ///
157    /// The created map has the default initial capacity.
158    ///
159    /// Warning: `hash_builder` is normally randomly generated, and
160    /// is designed to allow `JoinMap` to be resistant to attacks that
161    /// cause many collisions and very poor performance. Setting it
162    /// manually using this function can expose a DoS attack vector.
163    ///
164    /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
165    /// the `JoinMap` to be useful, see its documentation for details.
166    #[inline]
167    #[must_use]
168    pub fn with_hasher(hash_builder: S) -> Self {
169        Self::with_capacity_and_hasher(0, hash_builder)
170    }
171
172    /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
173    /// to hash the keys.
174    ///
175    /// The `JoinMap` will be able to hold at least `capacity` elements without
176    /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
177    ///
178    /// Warning: `hash_builder` is normally randomly generated, and
179    /// is designed to allow HashMaps to be resistant to attacks that
180    /// cause many collisions and very poor performance. Setting it
181    /// manually using this function can expose a DoS attack vector.
182    ///
183    /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
184    /// the `JoinMap`to be useful, see its documentation for details.
185    ///
186    /// # Examples
187    ///
188    /// ```
189    /// # #[tokio::main(flavor = "current_thread")]
190    /// # async fn main() {
191    /// use tokio_util::task::JoinMap;
192    /// use std::collections::hash_map::RandomState;
193    ///
194    /// let s = RandomState::new();
195    /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
196    /// map.spawn(1, async move { "hello world!" });
197    /// # }
198    /// ```
199    #[inline]
200    #[must_use]
201    pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
202        Self {
203            tasks_by_key: HashTable::with_capacity(capacity),
204            hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
205            tasks: JoinSet::new(),
206        }
207    }
208
209    /// Returns the number of tasks currently in the `JoinMap`.
210    pub fn len(&self) -> usize {
211        let len = self.tasks_by_key.len();
212        debug_assert_eq!(len, self.hashes_by_task.len());
213        len
214    }
215
216    /// Returns whether the `JoinMap` is empty.
217    pub fn is_empty(&self) -> bool {
218        let empty = self.tasks_by_key.is_empty();
219        debug_assert_eq!(empty, self.hashes_by_task.is_empty());
220        empty
221    }
222
223    /// Returns the number of tasks the map can hold without reallocating.
224    ///
225    /// This number is a lower bound; the `JoinMap` might be able to hold
226    /// more, but is guaranteed to be able to hold at least this many.
227    ///
228    /// # Examples
229    ///
230    /// ```
231    /// use tokio_util::task::JoinMap;
232    ///
233    /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
234    /// assert!(map.capacity() >= 100);
235    /// ```
236    #[inline]
237    pub fn capacity(&self) -> usize {
238        let capacity = self.tasks_by_key.capacity();
239        debug_assert_eq!(capacity, self.hashes_by_task.capacity());
240        capacity
241    }
242}
243
244impl<K, V, S> JoinMap<K, V, S>
245where
246    K: Hash + Eq,
247    V: 'static,
248    S: BuildHasher,
249{
250    /// Spawn the provided task and store it in this `JoinMap` with the provided
251    /// key.
252    ///
253    /// If a task previously existed in the `JoinMap` for this key, that task
254    /// will be cancelled and replaced with the new one. The previous task will
255    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
256    /// *not* return a cancelled [`JoinError`] for that task.
257    ///
258    /// # Panics
259    ///
260    /// This method panics if called outside of a Tokio runtime.
261    ///
262    /// [`join_next`]: Self::join_next
263    #[track_caller]
264    pub fn spawn<F>(&mut self, key: K, task: F)
265    where
266        F: Future<Output = V>,
267        F: Send + 'static,
268        V: Send,
269    {
270        let task = self.tasks.spawn(task);
271        self.insert(key, task)
272    }
273
274    /// Spawn the provided task on the provided runtime and store it in this
275    /// `JoinMap` with the provided key.
276    ///
277    /// If a task previously existed in the `JoinMap` for this key, that task
278    /// will be cancelled and replaced with the new one. The previous task will
279    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
280    /// *not* return a cancelled [`JoinError`] for that task.
281    ///
282    /// [`join_next`]: Self::join_next
283    #[track_caller]
284    pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
285    where
286        F: Future<Output = V>,
287        F: Send + 'static,
288        V: Send,
289    {
290        let task = self.tasks.spawn_on(task, handle);
291        self.insert(key, task);
292    }
293
294    /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
295    /// key.
296    ///
297    /// If a task previously existed in the `JoinMap` for this key, that task
298    /// will be cancelled and replaced with the new one. The previous task will
299    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
300    /// *not* return a cancelled [`JoinError`] for that task.
301    ///
302    /// Note that blocking tasks cannot be cancelled after execution starts.
303    /// Replaced blocking tasks will still run to completion if the task has begun
304    /// to execute when it is replaced. A blocking task which is replaced before
305    /// it has been scheduled on a blocking worker thread will be cancelled.
306    ///
307    /// # Panics
308    ///
309    /// This method panics if called outside of a Tokio runtime.
310    ///
311    /// [`join_next`]: Self::join_next
312    #[track_caller]
313    pub fn spawn_blocking<F>(&mut self, key: K, f: F)
314    where
315        F: FnOnce() -> V,
316        F: Send + 'static,
317        V: Send,
318    {
319        let task = self.tasks.spawn_blocking(f);
320        self.insert(key, task)
321    }
322
323    /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
324    /// `JoinMap` with the provided key.
325    ///
326    /// If a task previously existed in the `JoinMap` for this key, that task
327    /// will be cancelled and replaced with the new one. The previous task will
328    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
329    /// *not* return a cancelled [`JoinError`] for that task.
330    ///
331    /// Note that blocking tasks cannot be cancelled after execution starts.
332    /// Replaced blocking tasks will still run to completion if the task has begun
333    /// to execute when it is replaced. A blocking task which is replaced before
334    /// it has been scheduled on a blocking worker thread will be cancelled.
335    ///
336    /// [`join_next`]: Self::join_next
337    #[track_caller]
338    pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
339    where
340        F: FnOnce() -> V,
341        F: Send + 'static,
342        V: Send,
343    {
344        let task = self.tasks.spawn_blocking_on(f, handle);
345        self.insert(key, task);
346    }
347
348    /// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`]
349    /// and store it in this `JoinMap` with the provided key.
350    ///
351    /// If a task previously existed in the `JoinMap` for this key, that task
352    /// will be cancelled and replaced with the new one. The previous task will
353    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
354    /// *not* return a cancelled [`JoinError`] for that task.
355    ///
356    /// # Panics
357    ///
358    /// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
359    ///
360    /// [`LocalSet`]: tokio::task::LocalSet
361    /// [`LocalRuntime`]: tokio::runtime::LocalRuntime
362    /// [`join_next`]: Self::join_next
363    #[track_caller]
364    pub fn spawn_local<F>(&mut self, key: K, task: F)
365    where
366        F: Future<Output = V>,
367        F: 'static,
368    {
369        let task = self.tasks.spawn_local(task);
370        self.insert(key, task);
371    }
372
373    /// Spawn the provided task on the provided [`LocalSet`] and store it in
374    /// this `JoinMap` with the provided key.
375    ///
376    /// If a task previously existed in the `JoinMap` for this key, that task
377    /// will be cancelled and replaced with the new one. The previous task will
378    /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
379    /// *not* return a cancelled [`JoinError`] for that task.
380    ///
381    /// [`LocalSet`]: tokio::task::LocalSet
382    /// [`join_next`]: Self::join_next
383    #[track_caller]
384    pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
385    where
386        F: Future<Output = V>,
387        F: 'static,
388    {
389        let task = self.tasks.spawn_local_on(task, local_set);
390        self.insert(key, task)
391    }
392
393    fn insert(&mut self, mut key: K, mut abort: AbortHandle) {
394        let hash_builder = self.hashes_by_task.hasher();
395        let hash = hash_builder.hash_one(&key);
396        let id = abort.id();
397
398        // Insert the new key into the map of tasks by keys.
399        let entry =
400            self.tasks_by_key
401                .entry(hash, |(k, _)| *k == key, |(k, _)| hash_builder.hash_one(k));
402        match entry {
403            Entry::Occupied(occ) => {
404                // There was a previous task spawned with the same key! Cancel
405                // that task, and remove its ID from the map of hashes by task IDs.
406                (key, abort) = std::mem::replace(occ.into_mut(), (key, abort));
407
408                // Remove the old task ID.
409                let _prev_hash = self.hashes_by_task.remove(&abort.id());
410                debug_assert_eq!(Some(hash), _prev_hash);
411
412                // Associate the key's hash with the new task's ID, for looking up tasks by ID.
413                let _prev = self.hashes_by_task.insert(id, hash);
414                debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
415
416                // Note: it's important to drop `key` and abort the task here.
417                // This defends against any panics during drop handling for causing inconsistent state.
418                abort.abort();
419                drop(key);
420            }
421            Entry::Vacant(vac) => {
422                vac.insert((key, abort));
423
424                // Associate the key's hash with this task's ID, for looking up tasks by ID.
425                let _prev = self.hashes_by_task.insert(id, hash);
426                debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
427            }
428        };
429    }
430
431    /// Waits until one of the tasks in the map completes and returns its
432    /// output, along with the key corresponding to that task.
433    ///
434    /// Returns `None` if the map is empty.
435    ///
436    /// # Cancel Safety
437    ///
438    /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
439    /// statement and some other branch completes first, it is guaranteed that no tasks were
440    /// removed from this `JoinMap`.
441    ///
442    /// # Returns
443    ///
444    /// This function returns:
445    ///
446    ///  * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
447    ///    completed. The `value` is the return value of that ask, and `key` is
448    ///    the key associated with the task.
449    ///  * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
450    ///    panicked or been aborted. `key` is the key associated  with the task
451    ///    that panicked or was aborted.
452    ///  * `None` if the `JoinMap` is empty.
453    ///
454    /// [`tokio::select!`]: tokio::select
455    pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
456        loop {
457            let (res, id) = match self.tasks.join_next_with_id().await {
458                Some(Ok((id, output))) => (Ok(output), id),
459                Some(Err(e)) => {
460                    let id = e.id();
461                    (Err(e), id)
462                }
463                None => return None,
464            };
465            if let Some(key) = self.remove_by_id(id) {
466                break Some((key, res));
467            }
468        }
469    }
470
471    /// Aborts all tasks and waits for them to finish shutting down.
472    ///
473    /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
474    /// a loop until it returns `None`.
475    ///
476    /// This method ignores any panics in the tasks shutting down. When this call returns, the
477    /// `JoinMap` will be empty.
478    ///
479    /// [`abort_all`]: fn@Self::abort_all
480    /// [`join_next`]: fn@Self::join_next
481    pub async fn shutdown(&mut self) {
482        self.abort_all();
483        while self.join_next().await.is_some() {}
484    }
485
486    /// Abort the task corresponding to the provided `key`.
487    ///
488    /// If this `JoinMap` contains a task corresponding to `key`, this method
489    /// will abort that task and return `true`. Otherwise, if no task exists for
490    /// `key`, this method returns `false`.
491    ///
492    /// # Examples
493    ///
494    /// Aborting a task by key:
495    ///
496    /// ```
497    /// use tokio_util::task::JoinMap;
498    ///
499    /// # #[tokio::main(flavor = "current_thread")]
500    /// # async fn main() {
501    /// let mut map = JoinMap::new();
502    ///
503    /// map.spawn("hello world", std::future::ready(1));
504    /// map.spawn("goodbye world", std::future::pending());
505    ///
506    /// // Look up the "goodbye world" task in the map and abort it.
507    /// map.abort("goodbye world");
508    ///
509    /// while let Some((key, res)) = map.join_next().await {
510    ///     if key == "goodbye world" {
511    ///         // The aborted task should complete with a cancelled `JoinError`.
512    ///         assert!(res.unwrap_err().is_cancelled());
513    ///     } else {
514    ///         // Other tasks should complete normally.
515    ///         assert_eq!(res.unwrap(), 1);
516    ///     }
517    /// }
518    /// # }
519    /// ```
520    ///
521    /// `abort` returns `true` if a task was aborted:
522    /// ```
523    /// use tokio_util::task::JoinMap;
524    ///
525    /// # #[tokio::main(flavor = "current_thread")]
526    /// # async fn main() {
527    /// let mut map = JoinMap::new();
528    ///
529    /// map.spawn("hello world", async move { /* ... */ });
530    /// map.spawn("goodbye world", async move { /* ... */});
531    ///
532    /// // A task for the key "goodbye world" should exist in the map:
533    /// assert!(map.abort("goodbye world"));
534    ///
535    /// // Aborting a key that does not exist will return `false`:
536    /// assert!(!map.abort("goodbye universe"));
537    /// # }
538    /// ```
539    pub fn abort<Q>(&mut self, key: &Q) -> bool
540    where
541        Q: ?Sized + Hash + Eq,
542        K: Borrow<Q>,
543    {
544        match self.get_by_key(key) {
545            Some((_, handle)) => {
546                handle.abort();
547                true
548            }
549            None => false,
550        }
551    }
552
553    /// Aborts all tasks with keys matching `predicate`.
554    ///
555    /// `predicate` is a function called with a reference to each key in the
556    /// map. If it returns `true` for a given key, the corresponding task will
557    /// be cancelled.
558    ///
559    /// # Examples
560    /// ```
561    /// use tokio_util::task::JoinMap;
562    ///
563    /// # // use the current thread rt so that spawned tasks don't
564    /// # // complete in the background before they can be aborted.
565    /// # #[tokio::main(flavor = "current_thread")]
566    /// # async fn main() {
567    /// let mut map = JoinMap::new();
568    ///
569    /// map.spawn("hello world", async move {
570    ///     // ...
571    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
572    /// });
573    /// map.spawn("goodbye world", async move {
574    ///     // ...
575    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
576    /// });
577    /// map.spawn("hello san francisco", async move {
578    ///     // ...
579    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
580    /// });
581    /// map.spawn("goodbye universe", async move {
582    ///     // ...
583    ///     # tokio::task::yield_now().await; // don't complete immediately, get aborted!
584    /// });
585    ///
586    /// // Abort all tasks whose keys begin with "goodbye"
587    /// map.abort_matching(|key| key.starts_with("goodbye"));
588    ///
589    /// let mut seen = 0;
590    /// while let Some((key, res)) = map.join_next().await {
591    ///     seen += 1;
592    ///     if key.starts_with("goodbye") {
593    ///         // The aborted task should complete with a cancelled `JoinError`.
594    ///         assert!(res.unwrap_err().is_cancelled());
595    ///     } else {
596    ///         // Other tasks should complete normally.
597    ///         assert!(key.starts_with("hello"));
598    ///         assert!(res.is_ok());
599    ///     }
600    /// }
601    ///
602    /// // All spawned tasks should have completed.
603    /// assert_eq!(seen, 4);
604    /// # }
605    /// ```
606    pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
607        // Note: this method iterates over the tasks and keys *without* removing
608        // any entries, so that the keys from aborted tasks can still be
609        // returned when calling `join_next` in the future.
610        for (key, task) in &self.tasks_by_key {
611            if predicate(key) {
612                task.abort();
613            }
614        }
615    }
616
617    /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
618    ///
619    /// If a task has completed, but its output hasn't yet been consumed by a
620    /// call to [`join_next`], this method will still return its key.
621    ///
622    /// [`join_next`]: fn@Self::join_next
623    pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
624        JoinMapKeys {
625            iter: self.tasks_by_key.iter(),
626            _value: PhantomData,
627        }
628    }
629
630    /// Returns `true` if this `JoinMap` contains a task for the provided key.
631    ///
632    /// If the task has completed, but its output hasn't yet been consumed by a
633    /// call to [`join_next`], this method will still return `true`.
634    ///
635    /// [`join_next`]: fn@Self::join_next
636    pub fn contains_key<Q>(&self, key: &Q) -> bool
637    where
638        Q: ?Sized + Hash + Eq,
639        K: Borrow<Q>,
640    {
641        self.get_by_key(key).is_some()
642    }
643
644    /// Returns `true` if this `JoinMap` contains a task with the provided
645    /// [task ID].
646    ///
647    /// If the task has completed, but its output hasn't yet been consumed by a
648    /// call to [`join_next`], this method will still return `true`.
649    ///
650    /// [`join_next`]: fn@Self::join_next
651    /// [task ID]: tokio::task::Id
652    pub fn contains_task(&self, task: &Id) -> bool {
653        self.hashes_by_task.contains_key(task)
654    }
655
656    /// Reserves capacity for at least `additional` more tasks to be spawned
657    /// on this `JoinMap` without reallocating for the map of task keys. The
658    /// collection may reserve more space to avoid frequent reallocations.
659    ///
660    /// Note that spawning a task will still cause an allocation for the task
661    /// itself.
662    ///
663    /// # Panics
664    ///
665    /// Panics if the new allocation size overflows [`usize`].
666    ///
667    /// # Examples
668    ///
669    /// ```
670    /// use tokio_util::task::JoinMap;
671    ///
672    /// let mut map: JoinMap<&str, i32> = JoinMap::new();
673    /// map.reserve(10);
674    /// ```
675    #[inline]
676    pub fn reserve(&mut self, additional: usize) {
677        self.tasks_by_key.reserve(additional, |(k, _)| {
678            self.hashes_by_task.hasher().hash_one(k)
679        });
680        self.hashes_by_task.reserve(additional);
681    }
682
683    /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
684    /// down as much as possible while maintaining the internal rules
685    /// and possibly leaving some space in accordance with the resize policy.
686    ///
687    /// # Examples
688    ///
689    /// ```
690    /// # #[tokio::main(flavor = "current_thread")]
691    /// # async fn main() {
692    /// use tokio_util::task::JoinMap;
693    ///
694    /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
695    /// map.spawn(1, async move { 2 });
696    /// map.spawn(3, async move { 4 });
697    /// assert!(map.capacity() >= 100);
698    /// map.shrink_to_fit();
699    /// assert!(map.capacity() >= 2);
700    /// # }
701    /// ```
702    #[inline]
703    pub fn shrink_to_fit(&mut self) {
704        self.hashes_by_task.shrink_to_fit();
705        self.tasks_by_key
706            .shrink_to_fit(|(k, _)| self.hashes_by_task.hasher().hash_one(k));
707    }
708
709    /// Shrinks the capacity of the map with a lower limit. It will drop
710    /// down no lower than the supplied limit while maintaining the internal rules
711    /// and possibly leaving some space in accordance with the resize policy.
712    ///
713    /// If the current capacity is less than the lower limit, this is a no-op.
714    ///
715    /// # Examples
716    ///
717    /// ```
718    /// # #[tokio::main(flavor = "current_thread")]
719    /// # async fn main() {
720    /// use tokio_util::task::JoinMap;
721    ///
722    /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
723    /// map.spawn(1, async move { 2 });
724    /// map.spawn(3, async move { 4 });
725    /// assert!(map.capacity() >= 100);
726    /// map.shrink_to(10);
727    /// assert!(map.capacity() >= 10);
728    /// map.shrink_to(0);
729    /// assert!(map.capacity() >= 2);
730    /// # }
731    /// ```
732    #[inline]
733    pub fn shrink_to(&mut self, min_capacity: usize) {
734        self.hashes_by_task.shrink_to(min_capacity);
735        self.tasks_by_key.shrink_to(min_capacity, |(k, _)| {
736            self.hashes_by_task.hasher().hash_one(k)
737        })
738    }
739
740    /// Look up a task in the map by its key, returning the key and abort handle.
741    fn get_by_key<'map, Q>(&'map self, key: &Q) -> Option<&'map (K, AbortHandle)>
742    where
743        Q: ?Sized + Hash + Eq,
744        K: Borrow<Q>,
745    {
746        let hash = self.hashes_by_task.hasher().hash_one(key);
747        self.tasks_by_key.find(hash, |(k, _)| k.borrow() == key)
748    }
749
750    /// Remove a task from the map by ID, returning the key for that task.
751    fn remove_by_id(&mut self, id: Id) -> Option<K> {
752        // Get the hash for the given ID.
753        let hash = self.hashes_by_task.remove(&id)?;
754
755        // Remove the entry for that hash.
756        let entry = self
757            .tasks_by_key
758            .find_entry(hash, |(_, abort)| abort.id() == id);
759        let (key, _) = match entry {
760            Ok(entry) => entry.remove().0,
761            _ => return None,
762        };
763        Some(key)
764    }
765}
766
767impl<K, V, S> JoinMap<K, V, S>
768where
769    V: 'static,
770{
771    /// Aborts all tasks on this `JoinMap`.
772    ///
773    /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
774    /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
775    pub fn abort_all(&mut self) {
776        self.tasks.abort_all()
777    }
778
779    /// Removes all tasks from this `JoinMap` without aborting them.
780    ///
781    /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
782    /// is dropped. They may still be aborted by key.
783    pub fn detach_all(&mut self) {
784        self.tasks.detach_all();
785        self.tasks_by_key.clear();
786        self.hashes_by_task.clear();
787    }
788}
789
790// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
791// Debug`, since no value is ever actually stored in the map.
792impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
793    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
794        // format the task keys and abort handles a little nicer by just
795        // printing the key and task ID pairs, without format the `Key` struct
796        // itself or the `AbortHandle`, which would just format the task's ID
797        // again.
798        struct KeySet<'a, K: fmt::Debug>(&'a HashTable<(K, AbortHandle)>);
799        impl<K: fmt::Debug> fmt::Debug for KeySet<'_, K> {
800            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
801                f.debug_map()
802                    .entries(self.0.iter().map(|(key, abort)| (key, abort.id())))
803                    .finish()
804            }
805        }
806
807        f.debug_struct("JoinMap")
808            // The `tasks_by_key` map is the only one that contains information
809            // that's really worth formatting for the user, since it contains
810            // the tasks' keys and IDs. The other fields are basically
811            // implementation details.
812            .field("tasks", &KeySet(&self.tasks_by_key))
813            .finish()
814    }
815}
816
817impl<K, V> Default for JoinMap<K, V> {
818    fn default() -> Self {
819        Self::new()
820    }
821}
822
823/// An iterator over the keys of a [`JoinMap`].
824#[derive(Debug, Clone)]
825pub struct JoinMapKeys<'a, K, V> {
826    iter: hashbrown::hash_table::Iter<'a, (K, AbortHandle)>,
827    /// To make it easier to change `JoinMap` in the future, keep V as a generic
828    /// parameter.
829    _value: PhantomData<&'a V>,
830}
831
832impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
833    type Item = &'a K;
834
835    fn next(&mut self) -> Option<&'a K> {
836        self.iter.next().map(|(key, _)| key)
837    }
838
839    fn size_hint(&self) -> (usize, Option<usize>) {
840        self.iter.size_hint()
841    }
842}
843
844impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
845    fn len(&self) -> usize {
846        self.iter.len()
847    }
848}
849
850impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}