1use std::{future::Future, ops::Deref, time::Duration};
2
3use tokio::runtime::Handle;
4use tokio::task::{JoinHandle, LocalSet};
5use tokio_util::sync::{CancellationToken, DropGuardRef};
6
7pub use tokio_util::task::task_tracker::{
8 TaskTracker, TaskTrackerToken, TaskTrackerWaitFuture, TrackedFuture,
9};
10
11#[derive(Debug, PartialEq, Eq)]
16pub enum CancelOutcome<T> {
17 Completed(T),
19 Cancelled,
21}
22
23impl<T> CancelOutcome<T> {
24 #[inline]
30 pub fn outcome(result: Option<T>) -> CancelOutcome<T> {
31 match result {
32 Some(value) => CancelOutcome::Completed(value),
33 None => CancelOutcome::Cancelled,
34 }
35 }
36}
37
38#[derive(Clone)]
93pub struct TaskSupervisor {
94 tracker: TaskTracker,
95 shutdown: CancellationToken,
96}
97
98impl TaskSupervisor {
99 #[must_use]
103 pub fn new() -> Self {
104 Self {
105 tracker: TaskTracker::new(),
106 shutdown: CancellationToken::new(),
107 }
108 }
109
110 #[inline]
114 pub fn tracker(&self) -> &TaskTracker {
115 &self.tracker
116 }
117
118 #[inline]
120 pub fn token(&self) -> CancellationToken {
121 self.shutdown.clone()
122 }
123
124 #[must_use]
126 #[inline]
127 pub fn cancel_on_drop(&self) -> DropGuardRef<'_> {
128 self.shutdown.drop_guard_ref()
129 }
130
131 #[inline]
135 pub fn is_cancelled(&self) -> bool {
136 self.shutdown.is_cancelled()
137 }
138
139 #[inline]
141 pub fn is_closed(&self) -> bool {
142 self.tracker.is_closed()
143 }
144
145 #[inline]
147 pub fn len(&self) -> usize {
148 self.tracker.len()
149 }
150
151 #[inline]
153 pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
154 self.tracker.wait()
155 }
156
157 #[inline]
163 pub fn cancel(&self) {
164 self.shutdown.cancel();
165 }
166
167 pub async fn shutdown(&self) {
174 self.tracker.close();
175 self.shutdown.cancel();
176 self.tracker.wait().await;
177 }
178
179 #[inline]
190 pub async fn shutdown_with_timeout(
191 &self,
192 timeout: Duration,
193 ) -> Result<(), tokio::time::error::Elapsed> {
194 tokio::time::timeout(timeout, self.shutdown()).await
195 }
196
197 #[must_use]
213 pub fn spawn_with_cancel<F, Fut>(&self, task: F) -> JoinHandle<CancelOutcome<Fut::Output>>
214 where
215 F: FnOnce() -> Fut + Send + 'static,
216 Fut: Future + Send + 'static,
217 Fut::Output: Send + 'static,
218 {
219 let token = self.token();
220 self.tracker
221 .spawn(async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) })
222 }
223
224 #[must_use]
235 pub fn spawn_on_with_cancel<F, Fut>(
236 &self,
237 task: F,
238 handle: &Handle,
239 ) -> JoinHandle<CancelOutcome<Fut::Output>>
240 where
241 F: FnOnce() -> Fut + Send + 'static,
242 Fut: Future + Send + 'static,
243 Fut::Output: Send + 'static,
244 {
245 let token = self.token();
246 self.tracker.spawn_on(
247 async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) },
248 handle,
249 )
250 }
251
252 #[must_use]
262 pub fn spawn_local_with_cancel<F, Fut>(&self, task: F) -> JoinHandle<CancelOutcome<Fut::Output>>
263 where
264 F: FnOnce() -> Fut + 'static,
265 Fut: Future + 'static,
266 Fut::Output: 'static,
267 {
268 let token = self.token();
269 self.tracker.spawn_local(async move {
270 CancelOutcome::outcome(token.run_until_cancelled(task()).await)
271 })
272 }
273
274 #[must_use]
285 pub fn spawn_local_on_with_cancel<F, Fut>(
286 &self,
287 task: F,
288 local_set: &LocalSet,
289 ) -> JoinHandle<CancelOutcome<Fut::Output>>
290 where
291 F: FnOnce() -> Fut + 'static,
292 Fut: Future + 'static,
293 Fut::Output: 'static,
294 {
295 let token = self.token();
296 self.tracker.spawn_local_on(
297 async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) },
298 local_set,
299 )
300 }
301
302 #[must_use]
314 pub fn spawn_with_token<F, Fut>(&self, task: F) -> JoinHandle<Fut::Output>
315 where
316 F: FnOnce(CancellationToken) -> Fut + Send + 'static,
317 Fut: Future + Send + 'static,
318 Fut::Output: Send + 'static,
319 {
320 let token = self.shutdown.child_token();
321 self.tracker.spawn(async move { task(token).await })
322 }
323
324 #[must_use]
335 pub fn spawn_on_with_token<F, Fut>(&self, task: F, handle: &Handle) -> JoinHandle<Fut::Output>
336 where
337 F: FnOnce(CancellationToken) -> Fut + Send + 'static,
338 Fut: Future + Send + 'static,
339 Fut::Output: Send + 'static,
340 {
341 let token = self.shutdown.child_token();
342 self.tracker
343 .spawn_on(async move { task(token).await }, handle)
344 }
345
346 #[must_use]
356 pub fn spawn_local_with_token<F, Fut>(&self, task: F) -> JoinHandle<Fut::Output>
357 where
358 F: FnOnce(CancellationToken) -> Fut + 'static,
359 Fut: Future + 'static,
360 Fut::Output: 'static,
361 {
362 let token = self.shutdown.child_token();
363 self.tracker.spawn_local(async move { task(token).await })
364 }
365
366 #[must_use]
377 pub fn spawn_local_on_with_token<F, Fut>(
378 &self,
379 task: F,
380 local_set: &LocalSet,
381 ) -> JoinHandle<Fut::Output>
382 where
383 F: FnOnce(CancellationToken) -> Fut + 'static,
384 Fut: Future + 'static,
385 Fut::Output: 'static,
386 {
387 let token = self.shutdown.child_token();
388 self.tracker
389 .spawn_local_on(async move { task(token).await }, local_set)
390 }
391
392 #[cfg(not(target_family = "wasm"))]
402 #[must_use]
403 pub fn spawn_blocking_with_token<F, T>(&self, task: F) -> JoinHandle<T>
404 where
405 F: FnOnce(CancellationToken) -> T + Send + 'static,
406 T: Send + 'static,
407 {
408 let token = self.shutdown.child_token();
409 self.tracker.spawn_blocking(move || task(token))
410 }
411
412 #[cfg(not(target_family = "wasm"))]
423 #[must_use]
424 pub fn spawn_blocking_on_with_token<F, T>(&self, task: F, handle: &Handle) -> JoinHandle<T>
425 where
426 F: FnOnce(CancellationToken) -> T + Send + 'static,
427 T: Send + 'static,
428 {
429 let token = self.shutdown.child_token();
430 self.tracker.spawn_blocking_on(move || task(token), handle)
431 }
432}
433
434impl Default for TaskSupervisor {
435 fn default() -> Self {
436 Self::new()
437 }
438}
439
440impl Deref for TaskSupervisor {
441 type Target = TaskTracker;
442
443 fn deref(&self) -> &Self::Target {
444 &self.tracker
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
452 use std::sync::Arc;
453 use tokio::time::{sleep, Duration};
454
455 #[tokio::test]
456 async fn test_spawn_with_token_provides_token() {
457 let supervisor = TaskSupervisor::new();
458
459 let handle = supervisor.spawn_with_token(|token| async move { token.is_cancelled() });
460
461 let result = handle.await.unwrap();
462 assert!(!result);
463 }
464
465 #[tokio::test]
466 #[cfg(feature = "rt")]
467 async fn test_spawn_on_with_token_provides_token() {
468 let supervisor = TaskSupervisor::new();
469 let handle = tokio::runtime::Handle::current();
470
471 let result = supervisor
472 .spawn_on_with_token(|token| async move { token.is_cancelled() }, &handle)
473 .await
474 .unwrap();
475
476 assert!(!result);
477 }
478
479 #[tokio::test]
480 #[cfg(all(feature = "rt", not(target_family = "wasm")))]
481 async fn test_spawn_blocking_with_token_provides_token() {
482 let supervisor = TaskSupervisor::new();
483
484 let handle = supervisor.spawn_blocking_with_token(move |token| token.is_cancelled());
485
486 let result = handle.await.unwrap();
487 assert!(!result);
488 }
489
490 #[tokio::test]
491 async fn test_cancel_sets_cancelled_state() {
492 let supervisor = TaskSupervisor::new();
493 assert!(!supervisor.is_cancelled());
494
495 supervisor.cancel();
496 assert!(supervisor.is_cancelled());
497 }
498
499 #[tokio::test]
500 async fn test_cancel_propagates_to_all_tasks() {
501 let supervisor = TaskSupervisor::new();
502 let count = Arc::new(AtomicUsize::new(0));
503
504 for _ in 0..3 {
505 let count_clone = count.clone();
506 let _ = supervisor.spawn_with_token(|token| async move {
507 token.cancelled().await;
508 count_clone.fetch_add(1, Ordering::SeqCst);
509 });
510 }
511
512 sleep(Duration::from_millis(50)).await;
513 supervisor.cancel();
514 sleep(Duration::from_millis(100)).await;
515
516 assert_eq!(count.load(Ordering::SeqCst), 3);
517 }
518
519 #[tokio::test]
520 async fn test_shutdown_cancels_and_waits() {
521 let supervisor = TaskSupervisor::new();
522 let task_finished = Arc::new(AtomicBool::new(false));
523 let task_finished_clone = task_finished.clone();
524
525 let _ = supervisor.spawn_with_token(|_token| async move {
526 sleep(Duration::from_millis(100)).await;
527 task_finished_clone.store(true, Ordering::SeqCst);
528 });
529
530 assert!(!supervisor.is_cancelled());
531 assert!(!supervisor.is_closed());
532
533 supervisor.shutdown().await;
534
535 assert!(supervisor.is_cancelled());
536 assert!(supervisor.is_closed());
537 assert!(task_finished.load(Ordering::SeqCst));
538 assert_eq!(supervisor.len(), 0);
539 }
540
541 #[tokio::test]
542 async fn test_shutdown_with_timeout_completes_in_time() {
543 let supervisor = TaskSupervisor::new();
544
545 let _ = supervisor.spawn_with_token(|_token| async move {
546 sleep(Duration::from_millis(50)).await;
547 });
548
549 let result = supervisor
550 .shutdown_with_timeout(Duration::from_secs(1))
551 .await;
552 assert!(result.is_ok());
553 assert!(supervisor.is_cancelled());
554 assert!(supervisor.is_closed());
555 }
556
557 #[tokio::test]
558 async fn test_shutdown_with_timeout_times_out() {
559 let supervisor = TaskSupervisor::new();
560
561 let _ = supervisor.tracker().spawn(async {
562 sleep(Duration::from_secs(10)).await;
563 });
564
565 let result = supervisor
566 .shutdown_with_timeout(Duration::from_millis(50))
567 .await;
568 assert!(result.is_err());
569 }
570
571 #[tokio::test]
572 async fn test_cooperative_cancellation_in_loop() {
573 let supervisor = TaskSupervisor::new();
574 let iterations = Arc::new(AtomicUsize::new(0));
575 let iterations_clone = iterations.clone();
576
577 let _ = supervisor.spawn_with_token(|token| async move {
578 loop {
579 if token.is_cancelled() {
580 break;
581 }
582 iterations_clone.fetch_add(1, Ordering::SeqCst);
583 sleep(Duration::from_millis(10)).await;
584 }
585 });
586
587 sleep(Duration::from_millis(55)).await;
588 supervisor.cancel();
589 sleep(Duration::from_millis(50)).await;
590
591 let count = iterations.load(Ordering::SeqCst);
592 assert!(count >= 3 && count < 20);
593 }
594
595 #[tokio::test]
596 async fn test_spawn_with_cancel_completes() {
597 let supervisor = TaskSupervisor::new();
598
599 let handle = supervisor.spawn_with_cancel(|| async move {
600 sleep(Duration::from_millis(30)).await;
601 42
602 });
603
604 match handle.await.unwrap() {
605 CancelOutcome::Completed(value) => assert_eq!(value, 42),
606 CancelOutcome::Cancelled => panic!("task should have completed"),
607 }
608 }
609
610 #[tokio::test]
611 async fn test_spawn_with_cancel_reports_cancellation() {
612 let supervisor = TaskSupervisor::new();
613
614 let handle = supervisor.spawn_with_cancel(|| async move {
615 loop {
616 sleep(Duration::from_millis(10)).await;
617 }
618 });
619
620 sleep(Duration::from_millis(35)).await;
621 supervisor.cancel();
622
623 match handle.await.unwrap() {
624 CancelOutcome::Completed(_) => panic!("task should have been cancelled"),
625 CancelOutcome::Cancelled => {}
626 }
627 }
628
629 #[tokio::test]
630 async fn test_deref_allows_direct_tracker_access() {
631 let supervisor = TaskSupervisor::new();
632
633 let handle = supervisor.spawn(async {
635 sleep(Duration::from_millis(50)).await;
636 42
637 });
638
639 assert_eq!(supervisor.len(), 1);
641 assert!(!supervisor.is_closed());
642
643 let result = handle.await.unwrap();
644 assert_eq!(result, 42);
645 }
646}