1use std::{future::Future, time::Duration};
2
3use tokio::runtime::Handle;
4use tokio::task::{JoinHandle, LocalSet};
5use tokio_util::sync::{CancellationToken, DropGuardRef};
6
7pub use tokio_util::task::task_tracker::{
8 TaskTracker, TaskTrackerToken, TaskTrackerWaitFuture, TrackedFuture,
9};
10
11#[derive(Debug, PartialEq, Eq)]
16pub enum CancelOutcome<T> {
17 Completed(T),
19 Cancelled,
21}
22
23impl<T> CancelOutcome<T> {
24 #[inline]
30 pub fn outcome(result: Option<T>) -> CancelOutcome<T> {
31 match result {
32 Some(value) => CancelOutcome::Completed(value),
33 None => CancelOutcome::Cancelled,
34 }
35 }
36}
37
38#[derive(Clone)]
69pub struct TaskSupervisor {
70 tracker: TaskTracker,
71 shutdown: CancellationToken,
72}
73
74impl TaskSupervisor {
75 #[must_use]
79 pub fn new() -> Self {
80 Self {
81 tracker: TaskTracker::new(),
82 shutdown: CancellationToken::new(),
83 }
84 }
85
86 #[inline]
90 pub fn tracker(&self) -> &TaskTracker {
91 &self.tracker
92 }
93
94 #[inline]
96 pub fn token(&self) -> CancellationToken {
97 self.shutdown.clone()
98 }
99
100 #[must_use]
102 #[inline]
103 pub fn cancel_on_drop(&self) -> DropGuardRef<'_> {
104 self.shutdown.drop_guard_ref()
105 }
106
107 #[inline]
111 pub fn is_cancelled(&self) -> bool {
112 self.shutdown.is_cancelled()
113 }
114
115 #[inline]
117 pub fn is_closed(&self) -> bool {
118 self.tracker.is_closed()
119 }
120
121 #[inline]
123 pub fn len(&self) -> usize {
124 self.tracker.len()
125 }
126
127 #[inline]
129 pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
130 self.tracker.wait()
131 }
132
133 #[inline]
139 pub fn cancel(&self) {
140 self.shutdown.cancel();
141 }
142
143 pub async fn shutdown(&self) {
150 self.tracker.close();
151 self.shutdown.cancel();
152 self.tracker.wait().await;
153 }
154
155 #[inline]
166 pub async fn shutdown_with_timeout(
167 &self,
168 timeout: Duration,
169 ) -> Result<(), tokio::time::error::Elapsed> {
170 tokio::time::timeout(timeout, self.shutdown()).await
171 }
172
173 #[must_use]
189 pub fn spawn_with_cancel<F, Fut>(&self, task: F) -> JoinHandle<CancelOutcome<Fut::Output>>
190 where
191 F: FnOnce() -> Fut + Send + 'static,
192 Fut: Future + Send + 'static,
193 Fut::Output: Send + 'static,
194 {
195 let token = self.token();
196 self.tracker
197 .spawn(async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) })
198 }
199
200 #[must_use]
211 pub fn spawn_on_with_cancel<F, Fut>(
212 &self,
213 task: F,
214 handle: &Handle,
215 ) -> JoinHandle<CancelOutcome<Fut::Output>>
216 where
217 F: FnOnce() -> Fut + Send + 'static,
218 Fut: Future + Send + 'static,
219 Fut::Output: Send + 'static,
220 {
221 let token = self.token();
222 self.tracker.spawn_on(
223 async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) },
224 handle,
225 )
226 }
227
228 #[must_use]
238 pub fn spawn_local_with_cancel<F, Fut>(&self, task: F) -> JoinHandle<CancelOutcome<Fut::Output>>
239 where
240 F: FnOnce() -> Fut + 'static,
241 Fut: Future + 'static,
242 Fut::Output: 'static,
243 {
244 let token = self.token();
245 self.tracker.spawn_local(async move {
246 CancelOutcome::outcome(token.run_until_cancelled(task()).await)
247 })
248 }
249
250 #[must_use]
261 pub fn spawn_local_on_with_cancel<F, Fut>(
262 &self,
263 task: F,
264 local_set: &LocalSet,
265 ) -> JoinHandle<CancelOutcome<Fut::Output>>
266 where
267 F: FnOnce() -> Fut + 'static,
268 Fut: Future + 'static,
269 Fut::Output: 'static,
270 {
271 let token = self.token();
272 self.tracker.spawn_local_on(
273 async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) },
274 local_set,
275 )
276 }
277
278 #[must_use]
290 pub fn spawn_with_token<F, Fut>(&self, task: F) -> JoinHandle<Fut::Output>
291 where
292 F: FnOnce(CancellationToken) -> Fut + Send + 'static,
293 Fut: Future + Send + 'static,
294 Fut::Output: Send + 'static,
295 {
296 let token = self.shutdown.child_token();
297 self.tracker.spawn(async move { task(token).await })
298 }
299
300 #[must_use]
311 pub fn spawn_on_with_token<F, Fut>(&self, task: F, handle: &Handle) -> JoinHandle<Fut::Output>
312 where
313 F: FnOnce(CancellationToken) -> Fut + Send + 'static,
314 Fut: Future + Send + 'static,
315 Fut::Output: Send + 'static,
316 {
317 let token = self.shutdown.child_token();
318 self.tracker
319 .spawn_on(async move { task(token).await }, handle)
320 }
321
322 #[must_use]
332 pub fn spawn_local_with_token<F, Fut>(&self, task: F) -> JoinHandle<Fut::Output>
333 where
334 F: FnOnce(CancellationToken) -> Fut + 'static,
335 Fut: Future + 'static,
336 Fut::Output: 'static,
337 {
338 let token = self.shutdown.child_token();
339 self.tracker.spawn_local(async move { task(token).await })
340 }
341
342 #[must_use]
353 pub fn spawn_local_on_with_token<F, Fut>(
354 &self,
355 task: F,
356 local_set: &LocalSet,
357 ) -> JoinHandle<Fut::Output>
358 where
359 F: FnOnce(CancellationToken) -> Fut + 'static,
360 Fut: Future + 'static,
361 Fut::Output: 'static,
362 {
363 let token = self.shutdown.child_token();
364 self.tracker
365 .spawn_local_on(async move { task(token).await }, local_set)
366 }
367
368 #[cfg(not(target_family = "wasm"))]
378 #[must_use]
379 pub fn spawn_blocking_with_token<F, T>(&self, task: F) -> JoinHandle<T>
380 where
381 F: FnOnce(CancellationToken) -> T + Send + 'static,
382 T: Send + 'static,
383 {
384 let token = self.shutdown.child_token();
385 self.tracker.spawn_blocking(move || task(token))
386 }
387
388 #[cfg(not(target_family = "wasm"))]
399 #[must_use]
400 pub fn spawn_blocking_on_with_token<F, T>(&self, task: F, handle: &Handle) -> JoinHandle<T>
401 where
402 F: FnOnce(CancellationToken) -> T + Send + 'static,
403 T: Send + 'static,
404 {
405 let token = self.shutdown.child_token();
406 self.tracker.spawn_blocking_on(move || task(token), handle)
407 }
408}
409
410impl Default for TaskSupervisor {
411 fn default() -> Self {
412 Self::new()
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
420 use std::sync::Arc;
421 use tokio::time::{sleep, Duration};
422
423 #[tokio::test]
424 async fn test_spawn_with_token_provides_token() {
425 let supervisor = TaskSupervisor::new();
426
427 let handle = supervisor.spawn_with_token(|token| async move { token.is_cancelled() });
428
429 let result = handle.await.unwrap();
430 assert!(!result);
431 }
432
433 #[tokio::test]
434 #[cfg(feature = "rt")]
435 async fn test_spawn_on_with_token_provides_token() {
436 let supervisor = TaskSupervisor::new();
437 let handle = tokio::runtime::Handle::current();
438
439 let result = supervisor
440 .spawn_on_with_token(|token| async move { token.is_cancelled() }, &handle)
441 .await
442 .unwrap();
443
444 assert!(!result);
445 }
446
447 #[tokio::test]
448 #[cfg(all(feature = "rt", not(target_family = "wasm")))]
449 async fn test_spawn_blocking_with_token_provides_token() {
450 let supervisor = TaskSupervisor::new();
451
452 let handle = supervisor.spawn_blocking_with_token(move |token| token.is_cancelled());
453
454 let result = handle.await.unwrap();
455 assert!(!result);
456 }
457
458 #[tokio::test]
459 async fn test_cancel_sets_cancelled_state() {
460 let supervisor = TaskSupervisor::new();
461 assert!(!supervisor.is_cancelled());
462
463 supervisor.cancel();
464 assert!(supervisor.is_cancelled());
465 }
466
467 #[tokio::test]
468 async fn test_cancel_propagates_to_all_tasks() {
469 let supervisor = TaskSupervisor::new();
470 let count = Arc::new(AtomicUsize::new(0));
471
472 for _ in 0..3 {
473 let count_clone = count.clone();
474 let _ = supervisor.spawn_with_token(|token| async move {
475 token.cancelled().await;
476 count_clone.fetch_add(1, Ordering::SeqCst);
477 });
478 }
479
480 sleep(Duration::from_millis(50)).await;
481 supervisor.cancel();
482 sleep(Duration::from_millis(100)).await;
483
484 assert_eq!(count.load(Ordering::SeqCst), 3);
485 }
486
487 #[tokio::test]
488 async fn test_shutdown_cancels_and_waits() {
489 let supervisor = TaskSupervisor::new();
490 let task_finished = Arc::new(AtomicBool::new(false));
491 let task_finished_clone = task_finished.clone();
492
493 let _ = supervisor.spawn_with_token(|_token| async move {
494 sleep(Duration::from_millis(100)).await;
495 task_finished_clone.store(true, Ordering::SeqCst);
496 });
497
498 assert!(!supervisor.is_cancelled());
499 assert!(!supervisor.is_closed());
500
501 supervisor.shutdown().await;
502
503 assert!(supervisor.is_cancelled());
504 assert!(supervisor.is_closed());
505 assert!(task_finished.load(Ordering::SeqCst));
506 assert_eq!(supervisor.len(), 0);
507 }
508
509 #[tokio::test]
510 async fn test_shutdown_with_timeout_completes_in_time() {
511 let supervisor = TaskSupervisor::new();
512
513 let _ = supervisor.spawn_with_token(|_token| async move {
514 sleep(Duration::from_millis(50)).await;
515 });
516
517 let result = supervisor
518 .shutdown_with_timeout(Duration::from_secs(1))
519 .await;
520 assert!(result.is_ok());
521 assert!(supervisor.is_cancelled());
522 assert!(supervisor.is_closed());
523 }
524
525 #[tokio::test]
526 async fn test_shutdown_with_timeout_times_out() {
527 let supervisor = TaskSupervisor::new();
528
529 let _ = supervisor.tracker().spawn(async {
530 sleep(Duration::from_secs(10)).await;
531 });
532
533 let result = supervisor
534 .shutdown_with_timeout(Duration::from_millis(50))
535 .await;
536 assert!(result.is_err());
537 }
538
539 #[tokio::test]
540 async fn test_cooperative_cancellation_in_loop() {
541 let supervisor = TaskSupervisor::new();
542 let iterations = Arc::new(AtomicUsize::new(0));
543 let iterations_clone = iterations.clone();
544
545 let _ = supervisor.spawn_with_token(|token| async move {
546 loop {
547 if token.is_cancelled() {
548 break;
549 }
550 iterations_clone.fetch_add(1, Ordering::SeqCst);
551 sleep(Duration::from_millis(10)).await;
552 }
553 });
554
555 sleep(Duration::from_millis(55)).await;
556 supervisor.cancel();
557 sleep(Duration::from_millis(50)).await;
558
559 let count = iterations.load(Ordering::SeqCst);
560 assert!(count >= 3 && count < 20);
561 }
562
563 #[tokio::test]
564 async fn test_spawn_with_cancel_completes() {
565 let supervisor = TaskSupervisor::new();
566
567 let handle = supervisor.spawn_with_cancel(|| async move {
568 sleep(Duration::from_millis(30)).await;
569 42
570 });
571
572 match handle.await.unwrap() {
573 CancelOutcome::Completed(value) => assert_eq!(value, 42),
574 CancelOutcome::Cancelled => panic!("task should have completed"),
575 }
576 }
577
578 #[tokio::test]
579 async fn test_spawn_with_cancel_reports_cancellation() {
580 let supervisor = TaskSupervisor::new();
581
582 let handle = supervisor.spawn_with_cancel(|| async move {
583 loop {
584 sleep(Duration::from_millis(10)).await;
585 }
586 });
587
588 sleep(Duration::from_millis(35)).await;
589 supervisor.cancel();
590
591 match handle.await.unwrap() {
592 CancelOutcome::Completed(_) => panic!("task should have been cancelled"),
593 CancelOutcome::Cancelled => {}
594 }
595 }
596}