tokio_task_supervisor/
lib.rs

1use std::{future::Future, ops::Deref, time::Duration};
2
3use tokio::runtime::Handle;
4use tokio::task::{JoinHandle, LocalSet};
5use tokio_util::sync::{CancellationToken, DropGuardRef};
6
7pub use tokio_util::task::task_tracker::{
8    TaskTracker, TaskTrackerToken, TaskTrackerWaitFuture, TrackedFuture,
9};
10
11/// The outcome of a task that races against cancellation.
12///
13/// This enum is returned by [`spawn_with_cancel`](TaskManager::spawn_with_cancel) variants
14/// to indicate whether the task completed normally or was cancelled.
15#[derive(Debug, PartialEq, Eq)]
16pub enum CancelOutcome<T> {
17    /// The task future completed before cancellation was requested.
18    Completed(T),
19    /// Cancellation won the race; the task future was dropped.
20    Cancelled,
21}
22
23impl<T> CancelOutcome<T> {
24    /// Creates a `CancelOutcome` from an optional result.
25    ///
26    /// # Arguments
27    ///
28    /// * `result` - `Some(value)` indicates completion, `None` indicates cancellation.
29    #[inline]
30    pub fn outcome(result: Option<T>) -> CancelOutcome<T> {
31        match result {
32            Some(value) => CancelOutcome::Completed(value),
33            None => CancelOutcome::Cancelled,
34        }
35    }
36}
37
38/// Manages a collection of asynchronous tasks and coordinates their shutdown.
39///
40/// `TaskSupervisor` wraps [`TaskTracker`] to keep count of outstanding tasks while also exposing a
41/// process-wide [`CancellationToken`] that can be used to request cooperative shutdown.
42///
43/// # Examples
44///
45/// ```rust
46/// use tokio_task_supervisor::TaskSupervisor;
47/// use tokio::time::{sleep, Duration};
48///
49/// #[tokio::main]
50/// async fn main() {
51///     let supervisor = TaskSupervisor::new();
52///     
53///     // Spawn a task that cooperatively handles cancellation
54///     let handle = supervisor.spawn_with_token(|token| async move {
55///         loop {
56///             if token.is_cancelled() {
57///                 break;
58///             }
59///             // Do work...
60///             sleep(Duration::from_millis(100)).await;
61///         }
62///     });
63///     
64///     // Later, request shutdown
65///     supervisor.shutdown().await;
66/// }
67/// ```
68///
69/// # Deref Implementation
70///
71/// `TaskSupervisor` implements [`Deref`] targeting [`TaskTracker`], allowing you to call
72/// `TaskTracker` methods directly on a `TaskSupervisor` instance:
73///
74/// ```rust
75/// use tokio_task_supervisor::TaskSupervisor;
76///
77/// #[tokio::main]
78/// async fn main() {
79///     let supervisor = TaskSupervisor::new();
80///     
81///     // These calls work through deref coercion:
82///     let handle = supervisor.spawn(async { 42 });
83///     let count = supervisor.len();
84///     let is_closed = supervisor.is_closed();
85///     
86///     // Equivalent to:
87///     let handle = supervisor.tracker().spawn(async { 42 });
88///     let count = supervisor.tracker().len();
89///     let is_closed = supervisor.tracker().is_closed();
90/// }
91/// ```
92#[derive(Clone)]
93pub struct TaskSupervisor {
94    tracker: TaskTracker,
95    shutdown: CancellationToken,
96}
97
98impl TaskSupervisor {
99    // === Construction ===
100
101    /// Creates a new task manager.
102    #[must_use]
103    pub fn new() -> Self {
104        Self {
105            tracker: TaskTracker::new(),
106            shutdown: CancellationToken::new(),
107        }
108    }
109
110    // === Accessors ===
111
112    /// Returns a reference to the underlying [`TaskTracker`].
113    #[inline]
114    pub fn tracker(&self) -> &TaskTracker {
115        &self.tracker
116    }
117
118    /// Returns a clone of the shared cancellation token.
119    #[inline]
120    pub fn token(&self) -> CancellationToken {
121        self.shutdown.clone()
122    }
123
124    /// Returns a guard that cancels the shutdown token when dropped.
125    #[must_use]
126    #[inline]
127    pub fn cancel_on_drop(&self) -> DropGuardRef<'_> {
128        self.shutdown.drop_guard_ref()
129    }
130
131    // === State Queries ===
132
133    /// Returns `true` if the shutdown token has been cancelled.
134    #[inline]
135    pub fn is_cancelled(&self) -> bool {
136        self.shutdown.is_cancelled()
137    }
138
139    /// Returns `true` if the task tracker is closed.
140    #[inline]
141    pub fn is_closed(&self) -> bool {
142        self.tracker.is_closed()
143    }
144
145    /// Returns the number of outstanding tasks.
146    #[inline]
147    pub fn len(&self) -> usize {
148        self.tracker.len()
149    }
150
151    /// Returns a future that completes when all tasks finish.
152    #[inline]
153    pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
154        self.tracker.wait()
155    }
156
157    // === Control Operations ===
158
159    /// Cancels the shared shutdown token.
160    ///
161    /// Tasks spawned through the managed API can observe this and exit cooperatively.
162    #[inline]
163    pub fn cancel(&self) {
164        self.shutdown.cancel();
165    }
166
167    /// Initiates graceful shutdown by closing the tracker and cancelling all tasks.
168    ///
169    /// This method will:
170    /// 1. Close the task tracker to prevent new tasks from being spawned
171    /// 2. Cancel the shutdown token to signal all existing tasks
172    /// 3. Wait for all tasks to complete
173    pub async fn shutdown(&self) {
174        self.tracker.close();
175        self.shutdown.cancel();
176        self.tracker.wait().await;
177    }
178
179    /// Initiates graceful shutdown with a timeout.
180    ///
181    /// # Arguments
182    ///
183    /// * `timeout` - Maximum time to wait for shutdown to complete
184    ///
185    /// # Returns
186    ///
187    /// * `Ok(())` if shutdown completed within the timeout
188    /// * `Err(Elapsed)` if the timeout was exceeded
189    #[inline]
190    pub async fn shutdown_with_timeout(
191        &self,
192        timeout: Duration,
193    ) -> Result<(), tokio::time::error::Elapsed> {
194        tokio::time::timeout(timeout, self.shutdown()).await
195    }
196
197    // === Spawn Methods with Cancellation Race ===
198
199    /// Spawns a task that races against the shared cancellation token.
200    ///
201    /// The returned future resolves with [`CancelOutcome`], indicating whether the task finished
202    /// normally or was cancelled. When cancellation wins the race, the task future is dropped, so it
203    /// should not rely on `Drop` for cleanup.
204    ///
205    /// # Arguments
206    ///
207    /// * `task` - A closure that returns the future to execute
208    ///
209    /// # Returns
210    ///
211    /// A `JoinHandle` that resolves to the task's outcome
212    #[must_use]
213    pub fn spawn_with_cancel<F, Fut>(&self, task: F) -> JoinHandle<CancelOutcome<Fut::Output>>
214    where
215        F: FnOnce() -> Fut + Send + 'static,
216        Fut: Future + Send + 'static,
217        Fut::Output: Send + 'static,
218    {
219        let token = self.token();
220        self.tracker
221            .spawn(async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) })
222    }
223
224    /// Spawns a task with cancellation handling on a specific runtime handle.
225    ///
226    /// # Arguments
227    ///
228    /// * `task` - A closure that returns the future to execute
229    /// * `handle` - The runtime handle to spawn the task on
230    ///
231    /// # Returns
232    ///
233    /// A `JoinHandle` that resolves to the task's outcome
234    #[must_use]
235    pub fn spawn_on_with_cancel<F, Fut>(
236        &self,
237        task: F,
238        handle: &Handle,
239    ) -> JoinHandle<CancelOutcome<Fut::Output>>
240    where
241        F: FnOnce() -> Fut + Send + 'static,
242        Fut: Future + Send + 'static,
243        Fut::Output: Send + 'static,
244    {
245        let token = self.token();
246        self.tracker.spawn_on(
247            async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) },
248            handle,
249        )
250    }
251
252    /// Spawns a !Send task that races against the shared cancellation token.
253    ///
254    /// # Arguments
255    ///
256    /// * `task` - A closure that returns the future to execute
257    ///
258    /// # Returns
259    ///
260    /// A `JoinHandle` that resolves to the task's outcome
261    #[must_use]
262    pub fn spawn_local_with_cancel<F, Fut>(&self, task: F) -> JoinHandle<CancelOutcome<Fut::Output>>
263    where
264        F: FnOnce() -> Fut + 'static,
265        Fut: Future + 'static,
266        Fut::Output: 'static,
267    {
268        let token = self.token();
269        self.tracker.spawn_local(async move {
270            CancelOutcome::outcome(token.run_until_cancelled(task()).await)
271        })
272    }
273
274    /// Spawns a !Send task on a [`LocalSet`] with cancellation handling.
275    ///
276    /// # Arguments
277    ///
278    /// * `task` - A closure that returns the future to execute
279    /// * `local_set` - The local set to spawn the task on
280    ///
281    /// # Returns
282    ///
283    /// A `JoinHandle` that resolves to the task's outcome
284    #[must_use]
285    pub fn spawn_local_on_with_cancel<F, Fut>(
286        &self,
287        task: F,
288        local_set: &LocalSet,
289    ) -> JoinHandle<CancelOutcome<Fut::Output>>
290    where
291        F: FnOnce() -> Fut + 'static,
292        Fut: Future + 'static,
293        Fut::Output: 'static,
294    {
295        let token = self.token();
296        self.tracker.spawn_local_on(
297            async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) },
298            local_set,
299        )
300    }
301
302    // === Spawn Methods with Token ===
303
304    /// Spawns a task that receives the shared cancellation token.
305    ///
306    /// # Arguments
307    ///
308    /// * `task` - A closure that takes a cancellation token and returns a future
309    ///
310    /// # Returns
311    ///
312    /// A `JoinHandle` that resolves to the task's output
313    #[must_use]
314    pub fn spawn_with_token<F, Fut>(&self, task: F) -> JoinHandle<Fut::Output>
315    where
316        F: FnOnce(CancellationToken) -> Fut + Send + 'static,
317        Fut: Future + Send + 'static,
318        Fut::Output: Send + 'static,
319    {
320        let token = self.shutdown.child_token();
321        self.tracker.spawn(async move { task(token).await })
322    }
323
324    /// Spawns a task with the shared cancellation token on a specific runtime handle.
325    ///
326    /// # Arguments
327    ///
328    /// * `task` - A closure that takes a cancellation token and returns a future
329    /// * `handle` - The runtime handle to spawn the task on
330    ///
331    /// # Returns
332    ///
333    /// A `JoinHandle` that resolves to the task's output
334    #[must_use]
335    pub fn spawn_on_with_token<F, Fut>(&self, task: F, handle: &Handle) -> JoinHandle<Fut::Output>
336    where
337        F: FnOnce(CancellationToken) -> Fut + Send + 'static,
338        Fut: Future + Send + 'static,
339        Fut::Output: Send + 'static,
340    {
341        let token = self.shutdown.child_token();
342        self.tracker
343            .spawn_on(async move { task(token).await }, handle)
344    }
345
346    /// Spawns a local task that receives the shared cancellation token.
347    ///
348    /// # Arguments
349    ///
350    /// * `task` - A closure that takes a cancellation token and returns a future
351    ///
352    /// # Returns
353    ///
354    /// A `JoinHandle` that resolves to the task's output
355    #[must_use]
356    pub fn spawn_local_with_token<F, Fut>(&self, task: F) -> JoinHandle<Fut::Output>
357    where
358        F: FnOnce(CancellationToken) -> Fut + 'static,
359        Fut: Future + 'static,
360        Fut::Output: 'static,
361    {
362        let token = self.shutdown.child_token();
363        self.tracker.spawn_local(async move { task(token).await })
364    }
365
366    /// Spawns a local task with a cancellation token on a specific local set.
367    ///
368    /// # Arguments
369    ///
370    /// * `task` - A closure that takes a cancellation token and returns a future
371    /// * `local_set` - The local set to spawn the task on
372    ///
373    /// # Returns
374    ///
375    /// A `JoinHandle` that resolves to the task's output
376    #[must_use]
377    pub fn spawn_local_on_with_token<F, Fut>(
378        &self,
379        task: F,
380        local_set: &LocalSet,
381    ) -> JoinHandle<Fut::Output>
382    where
383        F: FnOnce(CancellationToken) -> Fut + 'static,
384        Fut: Future + 'static,
385        Fut::Output: 'static,
386    {
387        let token = self.shutdown.child_token();
388        self.tracker
389            .spawn_local_on(async move { task(token).await }, local_set)
390    }
391
392    /// Spawns a blocking task that receives a cancellation context.
393    ///
394    /// # Arguments
395    ///
396    /// * `task` - A closure that takes a cancellation token and returns a value
397    ///
398    /// # Returns
399    ///
400    /// A `JoinHandle` that resolves to the task's output
401    #[cfg(not(target_family = "wasm"))]
402    #[must_use]
403    pub fn spawn_blocking_with_token<F, T>(&self, task: F) -> JoinHandle<T>
404    where
405        F: FnOnce(CancellationToken) -> T + Send + 'static,
406        T: Send + 'static,
407    {
408        let token = self.shutdown.child_token();
409        self.tracker.spawn_blocking(move || task(token))
410    }
411
412    /// Spawns a blocking task with context on a specific runtime handle.
413    ///
414    /// # Arguments
415    ///
416    /// * `task` - A closure that takes a cancellation token and returns a value
417    /// * `handle` - The runtime handle to spawn the task on
418    ///
419    /// # Returns
420    ///
421    /// A `JoinHandle` that resolves to the task's output
422    #[cfg(not(target_family = "wasm"))]
423    #[must_use]
424    pub fn spawn_blocking_on_with_token<F, T>(&self, task: F, handle: &Handle) -> JoinHandle<T>
425    where
426        F: FnOnce(CancellationToken) -> T + Send + 'static,
427        T: Send + 'static,
428    {
429        let token = self.shutdown.child_token();
430        self.tracker.spawn_blocking_on(move || task(token), handle)
431    }
432}
433
434impl Default for TaskSupervisor {
435    fn default() -> Self {
436        Self::new()
437    }
438}
439
440impl Deref for TaskSupervisor {
441    type Target = TaskTracker;
442
443    fn deref(&self) -> &Self::Target {
444        &self.tracker
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
452    use std::sync::Arc;
453    use tokio::time::{sleep, Duration};
454
455    #[tokio::test]
456    async fn test_spawn_with_token_provides_token() {
457        let supervisor = TaskSupervisor::new();
458
459        let handle = supervisor.spawn_with_token(|token| async move { token.is_cancelled() });
460
461        let result = handle.await.unwrap();
462        assert!(!result);
463    }
464
465    #[tokio::test]
466    #[cfg(feature = "rt")]
467    async fn test_spawn_on_with_token_provides_token() {
468        let supervisor = TaskSupervisor::new();
469        let handle = tokio::runtime::Handle::current();
470
471        let result = supervisor
472            .spawn_on_with_token(|token| async move { token.is_cancelled() }, &handle)
473            .await
474            .unwrap();
475
476        assert!(!result);
477    }
478
479    #[tokio::test]
480    #[cfg(all(feature = "rt", not(target_family = "wasm")))]
481    async fn test_spawn_blocking_with_token_provides_token() {
482        let supervisor = TaskSupervisor::new();
483
484        let handle = supervisor.spawn_blocking_with_token(move |token| token.is_cancelled());
485
486        let result = handle.await.unwrap();
487        assert!(!result);
488    }
489
490    #[tokio::test]
491    async fn test_cancel_sets_cancelled_state() {
492        let supervisor = TaskSupervisor::new();
493        assert!(!supervisor.is_cancelled());
494
495        supervisor.cancel();
496        assert!(supervisor.is_cancelled());
497    }
498
499    #[tokio::test]
500    async fn test_cancel_propagates_to_all_tasks() {
501        let supervisor = TaskSupervisor::new();
502        let count = Arc::new(AtomicUsize::new(0));
503
504        for _ in 0..3 {
505            let count_clone = count.clone();
506            let _ = supervisor.spawn_with_token(|token| async move {
507                token.cancelled().await;
508                count_clone.fetch_add(1, Ordering::SeqCst);
509            });
510        }
511
512        sleep(Duration::from_millis(50)).await;
513        supervisor.cancel();
514        sleep(Duration::from_millis(100)).await;
515
516        assert_eq!(count.load(Ordering::SeqCst), 3);
517    }
518
519    #[tokio::test]
520    async fn test_shutdown_cancels_and_waits() {
521        let supervisor = TaskSupervisor::new();
522        let task_finished = Arc::new(AtomicBool::new(false));
523        let task_finished_clone = task_finished.clone();
524
525        let _ = supervisor.spawn_with_token(|_token| async move {
526            sleep(Duration::from_millis(100)).await;
527            task_finished_clone.store(true, Ordering::SeqCst);
528        });
529
530        assert!(!supervisor.is_cancelled());
531        assert!(!supervisor.is_closed());
532
533        supervisor.shutdown().await;
534
535        assert!(supervisor.is_cancelled());
536        assert!(supervisor.is_closed());
537        assert!(task_finished.load(Ordering::SeqCst));
538        assert_eq!(supervisor.len(), 0);
539    }
540
541    #[tokio::test]
542    async fn test_shutdown_with_timeout_completes_in_time() {
543        let supervisor = TaskSupervisor::new();
544
545        let _ = supervisor.spawn_with_token(|_token| async move {
546            sleep(Duration::from_millis(50)).await;
547        });
548
549        let result = supervisor
550            .shutdown_with_timeout(Duration::from_secs(1))
551            .await;
552        assert!(result.is_ok());
553        assert!(supervisor.is_cancelled());
554        assert!(supervisor.is_closed());
555    }
556
557    #[tokio::test]
558    async fn test_shutdown_with_timeout_times_out() {
559        let supervisor = TaskSupervisor::new();
560
561        let _ = supervisor.tracker().spawn(async {
562            sleep(Duration::from_secs(10)).await;
563        });
564
565        let result = supervisor
566            .shutdown_with_timeout(Duration::from_millis(50))
567            .await;
568        assert!(result.is_err());
569    }
570
571    #[tokio::test]
572    async fn test_cooperative_cancellation_in_loop() {
573        let supervisor = TaskSupervisor::new();
574        let iterations = Arc::new(AtomicUsize::new(0));
575        let iterations_clone = iterations.clone();
576
577        let _ = supervisor.spawn_with_token(|token| async move {
578            loop {
579                if token.is_cancelled() {
580                    break;
581                }
582                iterations_clone.fetch_add(1, Ordering::SeqCst);
583                sleep(Duration::from_millis(10)).await;
584            }
585        });
586
587        sleep(Duration::from_millis(55)).await;
588        supervisor.cancel();
589        sleep(Duration::from_millis(50)).await;
590
591        let count = iterations.load(Ordering::SeqCst);
592        assert!(count >= 3 && count < 20);
593    }
594
595    #[tokio::test]
596    async fn test_spawn_with_cancel_completes() {
597        let supervisor = TaskSupervisor::new();
598
599        let handle = supervisor.spawn_with_cancel(|| async move {
600            sleep(Duration::from_millis(30)).await;
601            42
602        });
603
604        match handle.await.unwrap() {
605            CancelOutcome::Completed(value) => assert_eq!(value, 42),
606            CancelOutcome::Cancelled => panic!("task should have completed"),
607        }
608    }
609
610    #[tokio::test]
611    async fn test_spawn_with_cancel_reports_cancellation() {
612        let supervisor = TaskSupervisor::new();
613
614        let handle = supervisor.spawn_with_cancel(|| async move {
615            loop {
616                sleep(Duration::from_millis(10)).await;
617            }
618        });
619
620        sleep(Duration::from_millis(35)).await;
621        supervisor.cancel();
622
623        match handle.await.unwrap() {
624            CancelOutcome::Completed(_) => panic!("task should have been cancelled"),
625            CancelOutcome::Cancelled => {}
626        }
627    }
628
629    #[tokio::test]
630    async fn test_deref_allows_direct_tracker_access() {
631        let supervisor = TaskSupervisor::new();
632
633        // Test that we can call TaskTracker methods directly on TaskSupervisor
634        let handle = supervisor.spawn(async {
635            sleep(Duration::from_millis(50)).await;
636            42
637        });
638
639        // Verify we can access tracker methods through deref
640        assert_eq!(supervisor.len(), 1);
641        assert!(!supervisor.is_closed());
642
643        let result = handle.await.unwrap();
644        assert_eq!(result, 42);
645    }
646}