tokio_util/task/join_map.rs
1use hashbrown::hash_table::Entry;
2use hashbrown::{HashMap, HashTable};
3use std::borrow::Borrow;
4use std::collections::hash_map::RandomState;
5use std::fmt;
6use std::future::Future;
7use std::hash::{BuildHasher, Hash};
8use std::marker::PhantomData;
9use tokio::runtime::Handle;
10use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet};
11
12/// A collection of tasks spawned on a Tokio runtime, associated with hash map
13/// keys.
14///
15/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the
16/// addition of a set of keys associated with each task. These keys allow
17/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the
18/// `JoinMap` based on their keys, or [test whether a task corresponding to a
19/// given key exists][contains] in the `JoinMap`.
20///
21/// In addition, when tasks in the `JoinMap` complete, they will return the
22/// associated key along with the value returned by the task, if any.
23///
24/// A `JoinMap` can be used to await the completion of some or all of the tasks
25/// in the map. The map is not ordered, and the tasks will be returned in the
26/// order they complete.
27///
28/// All of the tasks must have the same return type `V`.
29///
30/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted.
31///
32/// # Examples
33///
34/// Spawn multiple tasks and wait for them:
35///
36/// ```
37/// use tokio_util::task::JoinMap;
38///
39/// # #[tokio::main(flavor = "current_thread")]
40/// # async fn main() {
41/// let mut map = JoinMap::new();
42///
43/// for i in 0..10 {
44/// // Spawn a task on the `JoinMap` with `i` as its key.
45/// map.spawn(i, async move { /* ... */ });
46/// }
47///
48/// let mut seen = [false; 10];
49///
50/// // When a task completes, `join_next` returns the task's key along
51/// // with its output.
52/// while let Some((key, res)) = map.join_next().await {
53/// seen[key] = true;
54/// assert!(res.is_ok(), "task {} completed successfully!", key);
55/// }
56///
57/// for i in 0..10 {
58/// assert!(seen[i]);
59/// }
60/// # }
61/// ```
62///
63/// Cancel tasks based on their keys:
64///
65/// ```
66/// use tokio_util::task::JoinMap;
67///
68/// # #[tokio::main(flavor = "current_thread")]
69/// # async fn main() {
70/// let mut map = JoinMap::new();
71///
72/// map.spawn("hello world", std::future::ready(1));
73/// map.spawn("goodbye world", std::future::pending());
74///
75/// // Look up the "goodbye world" task in the map and abort it.
76/// let aborted = map.abort("goodbye world");
77///
78/// // `JoinMap::abort` returns `true` if a task existed for the
79/// // provided key.
80/// assert!(aborted);
81///
82/// while let Some((key, res)) = map.join_next().await {
83/// if key == "goodbye world" {
84/// // The aborted task should complete with a cancelled `JoinError`.
85/// assert!(res.unwrap_err().is_cancelled());
86/// } else {
87/// // Other tasks should complete normally.
88/// assert_eq!(res.unwrap(), 1);
89/// }
90/// }
91/// # }
92/// ```
93///
94/// [`JoinSet`]: tokio::task::JoinSet
95/// [abort]: fn@Self::abort
96/// [abort_matching]: fn@Self::abort_matching
97/// [contains]: fn@Self::contains_key
98pub struct JoinMap<K, V, S = RandomState> {
99 /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`,
100 /// indexed by their keys.
101 tasks_by_key: HashTable<(K, AbortHandle)>,
102
103 /// A map from task IDs to the hash of the key associated with that task.
104 ///
105 /// This map is used to perform reverse lookups of tasks in the
106 /// `tasks_by_key` map based on their task IDs. When a task terminates, the
107 /// ID is provided to us by the `JoinSet`, so we can look up the hash value
108 /// of that task's key, and then remove it from the `tasks_by_key` map using
109 /// the raw hash code, resolving collisions by comparing task IDs.
110 hashes_by_task: HashMap<Id, u64, S>,
111
112 /// The [`JoinSet`] that awaits the completion of tasks spawned on this
113 /// `JoinMap`.
114 tasks: JoinSet<V>,
115}
116
117impl<K, V> JoinMap<K, V> {
118 /// Creates a new empty `JoinMap`.
119 ///
120 /// The `JoinMap` is initially created with a capacity of 0, so it will not
121 /// allocate until a task is first spawned on it.
122 ///
123 /// # Examples
124 ///
125 /// ```
126 /// use tokio_util::task::JoinMap;
127 /// let map: JoinMap<&str, i32> = JoinMap::new();
128 /// ```
129 #[inline]
130 #[must_use]
131 pub fn new() -> Self {
132 Self::with_hasher(RandomState::new())
133 }
134
135 /// Creates an empty `JoinMap` with the specified capacity.
136 ///
137 /// The `JoinMap` will be able to hold at least `capacity` tasks without
138 /// reallocating.
139 ///
140 /// # Examples
141 ///
142 /// ```
143 /// use tokio_util::task::JoinMap;
144 /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10);
145 /// ```
146 #[inline]
147 #[must_use]
148 pub fn with_capacity(capacity: usize) -> Self {
149 JoinMap::with_capacity_and_hasher(capacity, Default::default())
150 }
151}
152
153impl<K, V, S> JoinMap<K, V, S> {
154 /// Creates an empty `JoinMap` which will use the given hash builder to hash
155 /// keys.
156 ///
157 /// The created map has the default initial capacity.
158 ///
159 /// Warning: `hash_builder` is normally randomly generated, and
160 /// is designed to allow `JoinMap` to be resistant to attacks that
161 /// cause many collisions and very poor performance. Setting it
162 /// manually using this function can expose a DoS attack vector.
163 ///
164 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
165 /// the `JoinMap` to be useful, see its documentation for details.
166 #[inline]
167 #[must_use]
168 pub fn with_hasher(hash_builder: S) -> Self {
169 Self::with_capacity_and_hasher(0, hash_builder)
170 }
171
172 /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder`
173 /// to hash the keys.
174 ///
175 /// The `JoinMap` will be able to hold at least `capacity` elements without
176 /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate.
177 ///
178 /// Warning: `hash_builder` is normally randomly generated, and
179 /// is designed to allow HashMaps to be resistant to attacks that
180 /// cause many collisions and very poor performance. Setting it
181 /// manually using this function can expose a DoS attack vector.
182 ///
183 /// The `hash_builder` passed should implement the [`BuildHasher`] trait for
184 /// the `JoinMap`to be useful, see its documentation for details.
185 ///
186 /// # Examples
187 ///
188 /// ```
189 /// # #[tokio::main(flavor = "current_thread")]
190 /// # async fn main() {
191 /// use tokio_util::task::JoinMap;
192 /// use std::collections::hash_map::RandomState;
193 ///
194 /// let s = RandomState::new();
195 /// let mut map = JoinMap::with_capacity_and_hasher(10, s);
196 /// map.spawn(1, async move { "hello world!" });
197 /// # }
198 /// ```
199 #[inline]
200 #[must_use]
201 pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
202 Self {
203 tasks_by_key: HashTable::with_capacity(capacity),
204 hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder),
205 tasks: JoinSet::new(),
206 }
207 }
208
209 /// Returns the number of tasks currently in the `JoinMap`.
210 pub fn len(&self) -> usize {
211 let len = self.tasks_by_key.len();
212 debug_assert_eq!(len, self.hashes_by_task.len());
213 len
214 }
215
216 /// Returns whether the `JoinMap` is empty.
217 pub fn is_empty(&self) -> bool {
218 let empty = self.tasks_by_key.is_empty();
219 debug_assert_eq!(empty, self.hashes_by_task.is_empty());
220 empty
221 }
222
223 /// Returns the number of tasks the map can hold without reallocating.
224 ///
225 /// This number is a lower bound; the `JoinMap` might be able to hold
226 /// more, but is guaranteed to be able to hold at least this many.
227 ///
228 /// # Examples
229 ///
230 /// ```
231 /// use tokio_util::task::JoinMap;
232 ///
233 /// let map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
234 /// assert!(map.capacity() >= 100);
235 /// ```
236 #[inline]
237 pub fn capacity(&self) -> usize {
238 let capacity = self.tasks_by_key.capacity();
239 debug_assert_eq!(capacity, self.hashes_by_task.capacity());
240 capacity
241 }
242}
243
244impl<K, V, S> JoinMap<K, V, S>
245where
246 K: Hash + Eq,
247 V: 'static,
248 S: BuildHasher,
249{
250 /// Spawn the provided task and store it in this `JoinMap` with the provided
251 /// key.
252 ///
253 /// If a task previously existed in the `JoinMap` for this key, that task
254 /// will be cancelled and replaced with the new one. The previous task will
255 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
256 /// *not* return a cancelled [`JoinError`] for that task.
257 ///
258 /// # Panics
259 ///
260 /// This method panics if called outside of a Tokio runtime.
261 ///
262 /// [`join_next`]: Self::join_next
263 #[track_caller]
264 pub fn spawn<F>(&mut self, key: K, task: F)
265 where
266 F: Future<Output = V>,
267 F: Send + 'static,
268 V: Send,
269 {
270 let task = self.tasks.spawn(task);
271 self.insert(key, task)
272 }
273
274 /// Spawn the provided task on the provided runtime and store it in this
275 /// `JoinMap` with the provided key.
276 ///
277 /// If a task previously existed in the `JoinMap` for this key, that task
278 /// will be cancelled and replaced with the new one. The previous task will
279 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
280 /// *not* return a cancelled [`JoinError`] for that task.
281 ///
282 /// [`join_next`]: Self::join_next
283 #[track_caller]
284 pub fn spawn_on<F>(&mut self, key: K, task: F, handle: &Handle)
285 where
286 F: Future<Output = V>,
287 F: Send + 'static,
288 V: Send,
289 {
290 let task = self.tasks.spawn_on(task, handle);
291 self.insert(key, task);
292 }
293
294 /// Spawn the blocking code on the blocking threadpool and store it in this `JoinMap` with the provided
295 /// key.
296 ///
297 /// If a task previously existed in the `JoinMap` for this key, that task
298 /// will be cancelled and replaced with the new one. The previous task will
299 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
300 /// *not* return a cancelled [`JoinError`] for that task.
301 ///
302 /// Note that blocking tasks cannot be cancelled after execution starts.
303 /// Replaced blocking tasks will still run to completion if the task has begun
304 /// to execute when it is replaced. A blocking task which is replaced before
305 /// it has been scheduled on a blocking worker thread will be cancelled.
306 ///
307 /// # Panics
308 ///
309 /// This method panics if called outside of a Tokio runtime.
310 ///
311 /// [`join_next`]: Self::join_next
312 #[track_caller]
313 pub fn spawn_blocking<F>(&mut self, key: K, f: F)
314 where
315 F: FnOnce() -> V,
316 F: Send + 'static,
317 V: Send,
318 {
319 let task = self.tasks.spawn_blocking(f);
320 self.insert(key, task)
321 }
322
323 /// Spawn the blocking code on the blocking threadpool of the provided runtime and store it in this
324 /// `JoinMap` with the provided key.
325 ///
326 /// If a task previously existed in the `JoinMap` for this key, that task
327 /// will be cancelled and replaced with the new one. The previous task will
328 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
329 /// *not* return a cancelled [`JoinError`] for that task.
330 ///
331 /// Note that blocking tasks cannot be cancelled after execution starts.
332 /// Replaced blocking tasks will still run to completion if the task has begun
333 /// to execute when it is replaced. A blocking task which is replaced before
334 /// it has been scheduled on a blocking worker thread will be cancelled.
335 ///
336 /// [`join_next`]: Self::join_next
337 #[track_caller]
338 pub fn spawn_blocking_on<F>(&mut self, key: K, f: F, handle: &Handle)
339 where
340 F: FnOnce() -> V,
341 F: Send + 'static,
342 V: Send,
343 {
344 let task = self.tasks.spawn_blocking_on(f, handle);
345 self.insert(key, task);
346 }
347
348 /// Spawn the provided task on the current [`LocalSet`] or [`LocalRuntime`]
349 /// and store it in this `JoinMap` with the provided key.
350 ///
351 /// If a task previously existed in the `JoinMap` for this key, that task
352 /// will be cancelled and replaced with the new one. The previous task will
353 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
354 /// *not* return a cancelled [`JoinError`] for that task.
355 ///
356 /// # Panics
357 ///
358 /// This method panics if it is called outside of a `LocalSet` or `LocalRuntime`.
359 ///
360 /// [`LocalSet`]: tokio::task::LocalSet
361 /// [`LocalRuntime`]: tokio::runtime::LocalRuntime
362 /// [`join_next`]: Self::join_next
363 #[track_caller]
364 pub fn spawn_local<F>(&mut self, key: K, task: F)
365 where
366 F: Future<Output = V>,
367 F: 'static,
368 {
369 let task = self.tasks.spawn_local(task);
370 self.insert(key, task);
371 }
372
373 /// Spawn the provided task on the provided [`LocalSet`] and store it in
374 /// this `JoinMap` with the provided key.
375 ///
376 /// If a task previously existed in the `JoinMap` for this key, that task
377 /// will be cancelled and replaced with the new one. The previous task will
378 /// be removed from the `JoinMap`; a subsequent call to [`join_next`] will
379 /// *not* return a cancelled [`JoinError`] for that task.
380 ///
381 /// [`LocalSet`]: tokio::task::LocalSet
382 /// [`join_next`]: Self::join_next
383 #[track_caller]
384 pub fn spawn_local_on<F>(&mut self, key: K, task: F, local_set: &LocalSet)
385 where
386 F: Future<Output = V>,
387 F: 'static,
388 {
389 let task = self.tasks.spawn_local_on(task, local_set);
390 self.insert(key, task)
391 }
392
393 fn insert(&mut self, mut key: K, mut abort: AbortHandle) {
394 let hash_builder = self.hashes_by_task.hasher();
395 let hash = hash_builder.hash_one(&key);
396 let id = abort.id();
397
398 // Insert the new key into the map of tasks by keys.
399 let entry =
400 self.tasks_by_key
401 .entry(hash, |(k, _)| *k == key, |(k, _)| hash_builder.hash_one(k));
402 match entry {
403 Entry::Occupied(occ) => {
404 // There was a previous task spawned with the same key! Cancel
405 // that task, and remove its ID from the map of hashes by task IDs.
406 (key, abort) = std::mem::replace(occ.into_mut(), (key, abort));
407
408 // Remove the old task ID.
409 let _prev_hash = self.hashes_by_task.remove(&abort.id());
410 debug_assert_eq!(Some(hash), _prev_hash);
411
412 // Associate the key's hash with the new task's ID, for looking up tasks by ID.
413 let _prev = self.hashes_by_task.insert(id, hash);
414 debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
415
416 // Note: it's important to drop `key` and abort the task here.
417 // This defends against any panics during drop handling for causing inconsistent state.
418 abort.abort();
419 drop(key);
420 }
421 Entry::Vacant(vac) => {
422 vac.insert((key, abort));
423
424 // Associate the key's hash with this task's ID, for looking up tasks by ID.
425 let _prev = self.hashes_by_task.insert(id, hash);
426 debug_assert!(_prev.is_none(), "no prior task should have had the same ID");
427 }
428 };
429 }
430
431 /// Waits until one of the tasks in the map completes and returns its
432 /// output, along with the key corresponding to that task.
433 ///
434 /// Returns `None` if the map is empty.
435 ///
436 /// # Cancel Safety
437 ///
438 /// This method is cancel safe. If `join_next` is used as the event in a [`tokio::select!`]
439 /// statement and some other branch completes first, it is guaranteed that no tasks were
440 /// removed from this `JoinMap`.
441 ///
442 /// # Returns
443 ///
444 /// This function returns:
445 ///
446 /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has
447 /// completed. The `value` is the return value of that ask, and `key` is
448 /// the key associated with the task.
449 /// * `Some((key, Err(err))` if one of the tasks in this `JoinMap` has
450 /// panicked or been aborted. `key` is the key associated with the task
451 /// that panicked or was aborted.
452 /// * `None` if the `JoinMap` is empty.
453 ///
454 /// [`tokio::select!`]: tokio::select
455 pub async fn join_next(&mut self) -> Option<(K, Result<V, JoinError>)> {
456 loop {
457 let (res, id) = match self.tasks.join_next_with_id().await {
458 Some(Ok((id, output))) => (Ok(output), id),
459 Some(Err(e)) => {
460 let id = e.id();
461 (Err(e), id)
462 }
463 None => return None,
464 };
465 if let Some(key) = self.remove_by_id(id) {
466 break Some((key, res));
467 }
468 }
469 }
470
471 /// Aborts all tasks and waits for them to finish shutting down.
472 ///
473 /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_next`] in
474 /// a loop until it returns `None`.
475 ///
476 /// This method ignores any panics in the tasks shutting down. When this call returns, the
477 /// `JoinMap` will be empty.
478 ///
479 /// [`abort_all`]: fn@Self::abort_all
480 /// [`join_next`]: fn@Self::join_next
481 pub async fn shutdown(&mut self) {
482 self.abort_all();
483 while self.join_next().await.is_some() {}
484 }
485
486 /// Abort the task corresponding to the provided `key`.
487 ///
488 /// If this `JoinMap` contains a task corresponding to `key`, this method
489 /// will abort that task and return `true`. Otherwise, if no task exists for
490 /// `key`, this method returns `false`.
491 ///
492 /// # Examples
493 ///
494 /// Aborting a task by key:
495 ///
496 /// ```
497 /// use tokio_util::task::JoinMap;
498 ///
499 /// # #[tokio::main(flavor = "current_thread")]
500 /// # async fn main() {
501 /// let mut map = JoinMap::new();
502 ///
503 /// map.spawn("hello world", std::future::ready(1));
504 /// map.spawn("goodbye world", std::future::pending());
505 ///
506 /// // Look up the "goodbye world" task in the map and abort it.
507 /// map.abort("goodbye world");
508 ///
509 /// while let Some((key, res)) = map.join_next().await {
510 /// if key == "goodbye world" {
511 /// // The aborted task should complete with a cancelled `JoinError`.
512 /// assert!(res.unwrap_err().is_cancelled());
513 /// } else {
514 /// // Other tasks should complete normally.
515 /// assert_eq!(res.unwrap(), 1);
516 /// }
517 /// }
518 /// # }
519 /// ```
520 ///
521 /// `abort` returns `true` if a task was aborted:
522 /// ```
523 /// use tokio_util::task::JoinMap;
524 ///
525 /// # #[tokio::main(flavor = "current_thread")]
526 /// # async fn main() {
527 /// let mut map = JoinMap::new();
528 ///
529 /// map.spawn("hello world", async move { /* ... */ });
530 /// map.spawn("goodbye world", async move { /* ... */});
531 ///
532 /// // A task for the key "goodbye world" should exist in the map:
533 /// assert!(map.abort("goodbye world"));
534 ///
535 /// // Aborting a key that does not exist will return `false`:
536 /// assert!(!map.abort("goodbye universe"));
537 /// # }
538 /// ```
539 pub fn abort<Q>(&mut self, key: &Q) -> bool
540 where
541 Q: ?Sized + Hash + Eq,
542 K: Borrow<Q>,
543 {
544 match self.get_by_key(key) {
545 Some((_, handle)) => {
546 handle.abort();
547 true
548 }
549 None => false,
550 }
551 }
552
553 /// Aborts all tasks with keys matching `predicate`.
554 ///
555 /// `predicate` is a function called with a reference to each key in the
556 /// map. If it returns `true` for a given key, the corresponding task will
557 /// be cancelled.
558 ///
559 /// # Examples
560 /// ```
561 /// use tokio_util::task::JoinMap;
562 ///
563 /// # // use the current thread rt so that spawned tasks don't
564 /// # // complete in the background before they can be aborted.
565 /// # #[tokio::main(flavor = "current_thread")]
566 /// # async fn main() {
567 /// let mut map = JoinMap::new();
568 ///
569 /// map.spawn("hello world", async move {
570 /// // ...
571 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
572 /// });
573 /// map.spawn("goodbye world", async move {
574 /// // ...
575 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
576 /// });
577 /// map.spawn("hello san francisco", async move {
578 /// // ...
579 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
580 /// });
581 /// map.spawn("goodbye universe", async move {
582 /// // ...
583 /// # tokio::task::yield_now().await; // don't complete immediately, get aborted!
584 /// });
585 ///
586 /// // Abort all tasks whose keys begin with "goodbye"
587 /// map.abort_matching(|key| key.starts_with("goodbye"));
588 ///
589 /// let mut seen = 0;
590 /// while let Some((key, res)) = map.join_next().await {
591 /// seen += 1;
592 /// if key.starts_with("goodbye") {
593 /// // The aborted task should complete with a cancelled `JoinError`.
594 /// assert!(res.unwrap_err().is_cancelled());
595 /// } else {
596 /// // Other tasks should complete normally.
597 /// assert!(key.starts_with("hello"));
598 /// assert!(res.is_ok());
599 /// }
600 /// }
601 ///
602 /// // All spawned tasks should have completed.
603 /// assert_eq!(seen, 4);
604 /// # }
605 /// ```
606 pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) {
607 // Note: this method iterates over the tasks and keys *without* removing
608 // any entries, so that the keys from aborted tasks can still be
609 // returned when calling `join_next` in the future.
610 for (key, task) in &self.tasks_by_key {
611 if predicate(key) {
612 task.abort();
613 }
614 }
615 }
616
617 /// Returns an iterator visiting all keys in this `JoinMap` in arbitrary order.
618 ///
619 /// If a task has completed, but its output hasn't yet been consumed by a
620 /// call to [`join_next`], this method will still return its key.
621 ///
622 /// [`join_next`]: fn@Self::join_next
623 pub fn keys(&self) -> JoinMapKeys<'_, K, V> {
624 JoinMapKeys {
625 iter: self.tasks_by_key.iter(),
626 _value: PhantomData,
627 }
628 }
629
630 /// Returns `true` if this `JoinMap` contains a task for the provided key.
631 ///
632 /// If the task has completed, but its output hasn't yet been consumed by a
633 /// call to [`join_next`], this method will still return `true`.
634 ///
635 /// [`join_next`]: fn@Self::join_next
636 pub fn contains_key<Q>(&self, key: &Q) -> bool
637 where
638 Q: ?Sized + Hash + Eq,
639 K: Borrow<Q>,
640 {
641 self.get_by_key(key).is_some()
642 }
643
644 /// Returns `true` if this `JoinMap` contains a task with the provided
645 /// [task ID].
646 ///
647 /// If the task has completed, but its output hasn't yet been consumed by a
648 /// call to [`join_next`], this method will still return `true`.
649 ///
650 /// [`join_next`]: fn@Self::join_next
651 /// [task ID]: tokio::task::Id
652 pub fn contains_task(&self, task: &Id) -> bool {
653 self.hashes_by_task.contains_key(task)
654 }
655
656 /// Reserves capacity for at least `additional` more tasks to be spawned
657 /// on this `JoinMap` without reallocating for the map of task keys. The
658 /// collection may reserve more space to avoid frequent reallocations.
659 ///
660 /// Note that spawning a task will still cause an allocation for the task
661 /// itself.
662 ///
663 /// # Panics
664 ///
665 /// Panics if the new allocation size overflows [`usize`].
666 ///
667 /// # Examples
668 ///
669 /// ```
670 /// use tokio_util::task::JoinMap;
671 ///
672 /// let mut map: JoinMap<&str, i32> = JoinMap::new();
673 /// map.reserve(10);
674 /// ```
675 #[inline]
676 pub fn reserve(&mut self, additional: usize) {
677 self.tasks_by_key.reserve(additional, |(k, _)| {
678 self.hashes_by_task.hasher().hash_one(k)
679 });
680 self.hashes_by_task.reserve(additional);
681 }
682
683 /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop
684 /// down as much as possible while maintaining the internal rules
685 /// and possibly leaving some space in accordance with the resize policy.
686 ///
687 /// # Examples
688 ///
689 /// ```
690 /// # #[tokio::main(flavor = "current_thread")]
691 /// # async fn main() {
692 /// use tokio_util::task::JoinMap;
693 ///
694 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
695 /// map.spawn(1, async move { 2 });
696 /// map.spawn(3, async move { 4 });
697 /// assert!(map.capacity() >= 100);
698 /// map.shrink_to_fit();
699 /// assert!(map.capacity() >= 2);
700 /// # }
701 /// ```
702 #[inline]
703 pub fn shrink_to_fit(&mut self) {
704 self.hashes_by_task.shrink_to_fit();
705 self.tasks_by_key
706 .shrink_to_fit(|(k, _)| self.hashes_by_task.hasher().hash_one(k));
707 }
708
709 /// Shrinks the capacity of the map with a lower limit. It will drop
710 /// down no lower than the supplied limit while maintaining the internal rules
711 /// and possibly leaving some space in accordance with the resize policy.
712 ///
713 /// If the current capacity is less than the lower limit, this is a no-op.
714 ///
715 /// # Examples
716 ///
717 /// ```
718 /// # #[tokio::main(flavor = "current_thread")]
719 /// # async fn main() {
720 /// use tokio_util::task::JoinMap;
721 ///
722 /// let mut map: JoinMap<i32, i32> = JoinMap::with_capacity(100);
723 /// map.spawn(1, async move { 2 });
724 /// map.spawn(3, async move { 4 });
725 /// assert!(map.capacity() >= 100);
726 /// map.shrink_to(10);
727 /// assert!(map.capacity() >= 10);
728 /// map.shrink_to(0);
729 /// assert!(map.capacity() >= 2);
730 /// # }
731 /// ```
732 #[inline]
733 pub fn shrink_to(&mut self, min_capacity: usize) {
734 self.hashes_by_task.shrink_to(min_capacity);
735 self.tasks_by_key.shrink_to(min_capacity, |(k, _)| {
736 self.hashes_by_task.hasher().hash_one(k)
737 })
738 }
739
740 /// Look up a task in the map by its key, returning the key and abort handle.
741 fn get_by_key<'map, Q>(&'map self, key: &Q) -> Option<&'map (K, AbortHandle)>
742 where
743 Q: ?Sized + Hash + Eq,
744 K: Borrow<Q>,
745 {
746 let hash = self.hashes_by_task.hasher().hash_one(key);
747 self.tasks_by_key.find(hash, |(k, _)| k.borrow() == key)
748 }
749
750 /// Remove a task from the map by ID, returning the key for that task.
751 fn remove_by_id(&mut self, id: Id) -> Option<K> {
752 // Get the hash for the given ID.
753 let hash = self.hashes_by_task.remove(&id)?;
754
755 // Remove the entry for that hash.
756 let entry = self
757 .tasks_by_key
758 .find_entry(hash, |(_, abort)| abort.id() == id);
759 let (key, _) = match entry {
760 Ok(entry) => entry.remove().0,
761 _ => return None,
762 };
763 Some(key)
764 }
765}
766
767impl<K, V, S> JoinMap<K, V, S>
768where
769 V: 'static,
770{
771 /// Aborts all tasks on this `JoinMap`.
772 ///
773 /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete
774 /// cancellation, you should call `join_next` in a loop until the `JoinMap` is empty.
775 pub fn abort_all(&mut self) {
776 self.tasks.abort_all()
777 }
778
779 /// Removes all tasks from this `JoinMap` without aborting them.
780 ///
781 /// The tasks removed by this call will continue to run in the background even if the `JoinMap`
782 /// is dropped. They may still be aborted by key.
783 pub fn detach_all(&mut self) {
784 self.tasks.detach_all();
785 self.tasks_by_key.clear();
786 self.hashes_by_task.clear();
787 }
788}
789
790// Hand-written `fmt::Debug` implementation in order to avoid requiring `V:
791// Debug`, since no value is ever actually stored in the map.
792impl<K: fmt::Debug, V, S> fmt::Debug for JoinMap<K, V, S> {
793 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
794 // format the task keys and abort handles a little nicer by just
795 // printing the key and task ID pairs, without format the `Key` struct
796 // itself or the `AbortHandle`, which would just format the task's ID
797 // again.
798 struct KeySet<'a, K: fmt::Debug>(&'a HashTable<(K, AbortHandle)>);
799 impl<K: fmt::Debug> fmt::Debug for KeySet<'_, K> {
800 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
801 f.debug_map()
802 .entries(self.0.iter().map(|(key, abort)| (key, abort.id())))
803 .finish()
804 }
805 }
806
807 f.debug_struct("JoinMap")
808 // The `tasks_by_key` map is the only one that contains information
809 // that's really worth formatting for the user, since it contains
810 // the tasks' keys and IDs. The other fields are basically
811 // implementation details.
812 .field("tasks", &KeySet(&self.tasks_by_key))
813 .finish()
814 }
815}
816
817impl<K, V> Default for JoinMap<K, V> {
818 fn default() -> Self {
819 Self::new()
820 }
821}
822
823/// An iterator over the keys of a [`JoinMap`].
824#[derive(Debug, Clone)]
825pub struct JoinMapKeys<'a, K, V> {
826 iter: hashbrown::hash_table::Iter<'a, (K, AbortHandle)>,
827 /// To make it easier to change `JoinMap` in the future, keep V as a generic
828 /// parameter.
829 _value: PhantomData<&'a V>,
830}
831
832impl<'a, K, V> Iterator for JoinMapKeys<'a, K, V> {
833 type Item = &'a K;
834
835 fn next(&mut self) -> Option<&'a K> {
836 self.iter.next().map(|(key, _)| key)
837 }
838
839 fn size_hint(&self) -> (usize, Option<usize>) {
840 self.iter.size_hint()
841 }
842}
843
844impl<'a, K, V> ExactSizeIterator for JoinMapKeys<'a, K, V> {
845 fn len(&self) -> usize {
846 self.iter.len()
847 }
848}
849
850impl<'a, K, V> std::iter::FusedIterator for JoinMapKeys<'a, K, V> {}