tokio_task_supervisor/
lib.rs

1use std::{future::Future, time::Duration};
2
3use tokio::runtime::Handle;
4use tokio::task::{JoinHandle, LocalSet};
5use tokio_util::sync::{CancellationToken, DropGuardRef};
6
7pub use tokio_util::task::task_tracker::{
8    TaskTracker, TaskTrackerToken, TaskTrackerWaitFuture, TrackedFuture,
9};
10
11/// The outcome of a task that races against cancellation.
12///
13/// This enum is returned by [`spawn_with_cancel`](TaskManager::spawn_with_cancel) variants
14/// to indicate whether the task completed normally or was cancelled.
15#[derive(Debug, PartialEq, Eq)]
16pub enum CancelOutcome<T> {
17    /// The task future completed before cancellation was requested.
18    Completed(T),
19    /// Cancellation won the race; the task future was dropped.
20    Cancelled,
21}
22
23impl<T> CancelOutcome<T> {
24    /// Creates a `CancelOutcome` from an optional result.
25    ///
26    /// # Arguments
27    ///
28    /// * `result` - `Some(value)` indicates completion, `None` indicates cancellation.
29    #[inline]
30    pub fn outcome(result: Option<T>) -> CancelOutcome<T> {
31        match result {
32            Some(value) => CancelOutcome::Completed(value),
33            None => CancelOutcome::Cancelled,
34        }
35    }
36}
37
38/// Manages a collection of asynchronous tasks and coordinates their shutdown.
39///
40/// `TaskSupervisor` wraps [`TaskTracker`] to keep count of outstanding tasks while also exposing a
41/// process-wide [`CancellationToken`] that can be used to request cooperative shutdown.
42///
43/// # Examples
44///
45/// ```rust
46/// use tokio_task_supervisor::TaskSupervisor;
47/// use tokio::time::{sleep, Duration};
48///
49/// #[tokio::main]
50/// async fn main() {
51///     let supervisor = TaskSupervisor::new();
52///     
53///     // Spawn a task that cooperatively handles cancellation
54///     let handle = supervisor.spawn_with_token(|token| async move {
55///         loop {
56///             if token.is_cancelled() {
57///                 break;
58///             }
59///             // Do work...
60///             sleep(Duration::from_millis(100)).await;
61///         }
62///     });
63///     
64///     // Later, request shutdown
65///     supervisor.shutdown().await;
66/// }
67/// ```
68#[derive(Clone)]
69pub struct TaskSupervisor {
70    tracker: TaskTracker,
71    shutdown: CancellationToken,
72}
73
74impl TaskSupervisor {
75    // === Construction ===
76
77    /// Creates a new task manager.
78    #[must_use]
79    pub fn new() -> Self {
80        Self {
81            tracker: TaskTracker::new(),
82            shutdown: CancellationToken::new(),
83        }
84    }
85
86    // === Accessors ===
87
88    /// Returns a reference to the underlying [`TaskTracker`].
89    #[inline]
90    pub fn tracker(&self) -> &TaskTracker {
91        &self.tracker
92    }
93
94    /// Returns a clone of the shared cancellation token.
95    #[inline]
96    pub fn token(&self) -> CancellationToken {
97        self.shutdown.clone()
98    }
99
100    /// Returns a guard that cancels the shutdown token when dropped.
101    #[must_use]
102    #[inline]
103    pub fn cancel_on_drop(&self) -> DropGuardRef<'_> {
104        self.shutdown.drop_guard_ref()
105    }
106
107    // === State Queries ===
108
109    /// Returns `true` if the shutdown token has been cancelled.
110    #[inline]
111    pub fn is_cancelled(&self) -> bool {
112        self.shutdown.is_cancelled()
113    }
114
115    /// Returns `true` if the task tracker is closed.
116    #[inline]
117    pub fn is_closed(&self) -> bool {
118        self.tracker.is_closed()
119    }
120
121    /// Returns the number of outstanding tasks.
122    #[inline]
123    pub fn len(&self) -> usize {
124        self.tracker.len()
125    }
126
127    /// Returns a future that completes when all tasks finish.
128    #[inline]
129    pub fn wait(&self) -> TaskTrackerWaitFuture<'_> {
130        self.tracker.wait()
131    }
132
133    // === Control Operations ===
134
135    /// Cancels the shared shutdown token.
136    ///
137    /// Tasks spawned through the managed API can observe this and exit cooperatively.
138    #[inline]
139    pub fn cancel(&self) {
140        self.shutdown.cancel();
141    }
142
143    /// Initiates graceful shutdown by closing the tracker and cancelling all tasks.
144    ///
145    /// This method will:
146    /// 1. Close the task tracker to prevent new tasks from being spawned
147    /// 2. Cancel the shutdown token to signal all existing tasks
148    /// 3. Wait for all tasks to complete
149    pub async fn shutdown(&self) {
150        self.tracker.close();
151        self.shutdown.cancel();
152        self.tracker.wait().await;
153    }
154
155    /// Initiates graceful shutdown with a timeout.
156    ///
157    /// # Arguments
158    ///
159    /// * `timeout` - Maximum time to wait for shutdown to complete
160    ///
161    /// # Returns
162    ///
163    /// * `Ok(())` if shutdown completed within the timeout
164    /// * `Err(Elapsed)` if the timeout was exceeded
165    #[inline]
166    pub async fn shutdown_with_timeout(
167        &self,
168        timeout: Duration,
169    ) -> Result<(), tokio::time::error::Elapsed> {
170        tokio::time::timeout(timeout, self.shutdown()).await
171    }
172
173    // === Spawn Methods with Cancellation Race ===
174
175    /// Spawns a task that races against the shared cancellation token.
176    ///
177    /// The returned future resolves with [`CancelOutcome`], indicating whether the task finished
178    /// normally or was cancelled. When cancellation wins the race, the task future is dropped, so it
179    /// should not rely on `Drop` for cleanup.
180    ///
181    /// # Arguments
182    ///
183    /// * `task` - A closure that returns the future to execute
184    ///
185    /// # Returns
186    ///
187    /// A `JoinHandle` that resolves to the task's outcome
188    #[must_use]
189    pub fn spawn_with_cancel<F, Fut>(&self, task: F) -> JoinHandle<CancelOutcome<Fut::Output>>
190    where
191        F: FnOnce() -> Fut + Send + 'static,
192        Fut: Future + Send + 'static,
193        Fut::Output: Send + 'static,
194    {
195        let token = self.token();
196        self.tracker
197            .spawn(async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) })
198    }
199
200    /// Spawns a task with cancellation handling on a specific runtime handle.
201    ///
202    /// # Arguments
203    ///
204    /// * `task` - A closure that returns the future to execute
205    /// * `handle` - The runtime handle to spawn the task on
206    ///
207    /// # Returns
208    ///
209    /// A `JoinHandle` that resolves to the task's outcome
210    #[must_use]
211    pub fn spawn_on_with_cancel<F, Fut>(
212        &self,
213        task: F,
214        handle: &Handle,
215    ) -> JoinHandle<CancelOutcome<Fut::Output>>
216    where
217        F: FnOnce() -> Fut + Send + 'static,
218        Fut: Future + Send + 'static,
219        Fut::Output: Send + 'static,
220    {
221        let token = self.token();
222        self.tracker.spawn_on(
223            async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) },
224            handle,
225        )
226    }
227
228    /// Spawns a !Send task that races against the shared cancellation token.
229    ///
230    /// # Arguments
231    ///
232    /// * `task` - A closure that returns the future to execute
233    ///
234    /// # Returns
235    ///
236    /// A `JoinHandle` that resolves to the task's outcome
237    #[must_use]
238    pub fn spawn_local_with_cancel<F, Fut>(&self, task: F) -> JoinHandle<CancelOutcome<Fut::Output>>
239    where
240        F: FnOnce() -> Fut + 'static,
241        Fut: Future + 'static,
242        Fut::Output: 'static,
243    {
244        let token = self.token();
245        self.tracker.spawn_local(async move {
246            CancelOutcome::outcome(token.run_until_cancelled(task()).await)
247        })
248    }
249
250    /// Spawns a !Send task on a [`LocalSet`] with cancellation handling.
251    ///
252    /// # Arguments
253    ///
254    /// * `task` - A closure that returns the future to execute
255    /// * `local_set` - The local set to spawn the task on
256    ///
257    /// # Returns
258    ///
259    /// A `JoinHandle` that resolves to the task's outcome
260    #[must_use]
261    pub fn spawn_local_on_with_cancel<F, Fut>(
262        &self,
263        task: F,
264        local_set: &LocalSet,
265    ) -> JoinHandle<CancelOutcome<Fut::Output>>
266    where
267        F: FnOnce() -> Fut + 'static,
268        Fut: Future + 'static,
269        Fut::Output: 'static,
270    {
271        let token = self.token();
272        self.tracker.spawn_local_on(
273            async move { CancelOutcome::outcome(token.run_until_cancelled(task()).await) },
274            local_set,
275        )
276    }
277
278    // === Spawn Methods with Token ===
279
280    /// Spawns a task that receives the shared cancellation token.
281    ///
282    /// # Arguments
283    ///
284    /// * `task` - A closure that takes a cancellation token and returns a future
285    ///
286    /// # Returns
287    ///
288    /// A `JoinHandle` that resolves to the task's output
289    #[must_use]
290    pub fn spawn_with_token<F, Fut>(&self, task: F) -> JoinHandle<Fut::Output>
291    where
292        F: FnOnce(CancellationToken) -> Fut + Send + 'static,
293        Fut: Future + Send + 'static,
294        Fut::Output: Send + 'static,
295    {
296        let token = self.shutdown.child_token();
297        self.tracker.spawn(async move { task(token).await })
298    }
299
300    /// Spawns a task with the shared cancellation token on a specific runtime handle.
301    ///
302    /// # Arguments
303    ///
304    /// * `task` - A closure that takes a cancellation token and returns a future
305    /// * `handle` - The runtime handle to spawn the task on
306    ///
307    /// # Returns
308    ///
309    /// A `JoinHandle` that resolves to the task's output
310    #[must_use]
311    pub fn spawn_on_with_token<F, Fut>(&self, task: F, handle: &Handle) -> JoinHandle<Fut::Output>
312    where
313        F: FnOnce(CancellationToken) -> Fut + Send + 'static,
314        Fut: Future + Send + 'static,
315        Fut::Output: Send + 'static,
316    {
317        let token = self.shutdown.child_token();
318        self.tracker
319            .spawn_on(async move { task(token).await }, handle)
320    }
321
322    /// Spawns a local task that receives the shared cancellation token.
323    ///
324    /// # Arguments
325    ///
326    /// * `task` - A closure that takes a cancellation token and returns a future
327    ///
328    /// # Returns
329    ///
330    /// A `JoinHandle` that resolves to the task's output
331    #[must_use]
332    pub fn spawn_local_with_token<F, Fut>(&self, task: F) -> JoinHandle<Fut::Output>
333    where
334        F: FnOnce(CancellationToken) -> Fut + 'static,
335        Fut: Future + 'static,
336        Fut::Output: 'static,
337    {
338        let token = self.shutdown.child_token();
339        self.tracker.spawn_local(async move { task(token).await })
340    }
341
342    /// Spawns a local task with a cancellation token on a specific local set.
343    ///
344    /// # Arguments
345    ///
346    /// * `task` - A closure that takes a cancellation token and returns a future
347    /// * `local_set` - The local set to spawn the task on
348    ///
349    /// # Returns
350    ///
351    /// A `JoinHandle` that resolves to the task's output
352    #[must_use]
353    pub fn spawn_local_on_with_token<F, Fut>(
354        &self,
355        task: F,
356        local_set: &LocalSet,
357    ) -> JoinHandle<Fut::Output>
358    where
359        F: FnOnce(CancellationToken) -> Fut + 'static,
360        Fut: Future + 'static,
361        Fut::Output: 'static,
362    {
363        let token = self.shutdown.child_token();
364        self.tracker
365            .spawn_local_on(async move { task(token).await }, local_set)
366    }
367
368    /// Spawns a blocking task that receives a cancellation context.
369    ///
370    /// # Arguments
371    ///
372    /// * `task` - A closure that takes a cancellation token and returns a value
373    ///
374    /// # Returns
375    ///
376    /// A `JoinHandle` that resolves to the task's output
377    #[cfg(not(target_family = "wasm"))]
378    #[must_use]
379    pub fn spawn_blocking_with_token<F, T>(&self, task: F) -> JoinHandle<T>
380    where
381        F: FnOnce(CancellationToken) -> T + Send + 'static,
382        T: Send + 'static,
383    {
384        let token = self.shutdown.child_token();
385        self.tracker.spawn_blocking(move || task(token))
386    }
387
388    /// Spawns a blocking task with context on a specific runtime handle.
389    ///
390    /// # Arguments
391    ///
392    /// * `task` - A closure that takes a cancellation token and returns a value
393    /// * `handle` - The runtime handle to spawn the task on
394    ///
395    /// # Returns
396    ///
397    /// A `JoinHandle` that resolves to the task's output
398    #[cfg(not(target_family = "wasm"))]
399    #[must_use]
400    pub fn spawn_blocking_on_with_token<F, T>(&self, task: F, handle: &Handle) -> JoinHandle<T>
401    where
402        F: FnOnce(CancellationToken) -> T + Send + 'static,
403        T: Send + 'static,
404    {
405        let token = self.shutdown.child_token();
406        self.tracker.spawn_blocking_on(move || task(token), handle)
407    }
408}
409
410impl Default for TaskSupervisor {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
420    use std::sync::Arc;
421    use tokio::time::{sleep, Duration};
422
423    #[tokio::test]
424    async fn test_spawn_with_token_provides_token() {
425        let supervisor = TaskSupervisor::new();
426
427        let handle = supervisor.spawn_with_token(|token| async move { token.is_cancelled() });
428
429        let result = handle.await.unwrap();
430        assert!(!result);
431    }
432
433    #[tokio::test]
434    #[cfg(feature = "rt")]
435    async fn test_spawn_on_with_token_provides_token() {
436        let supervisor = TaskSupervisor::new();
437        let handle = tokio::runtime::Handle::current();
438
439        let result = supervisor
440            .spawn_on_with_token(|token| async move { token.is_cancelled() }, &handle)
441            .await
442            .unwrap();
443
444        assert!(!result);
445    }
446
447    #[tokio::test]
448    #[cfg(all(feature = "rt", not(target_family = "wasm")))]
449    async fn test_spawn_blocking_with_token_provides_token() {
450        let supervisor = TaskSupervisor::new();
451
452        let handle = supervisor.spawn_blocking_with_token(move |token| token.is_cancelled());
453
454        let result = handle.await.unwrap();
455        assert!(!result);
456    }
457
458    #[tokio::test]
459    async fn test_cancel_sets_cancelled_state() {
460        let supervisor = TaskSupervisor::new();
461        assert!(!supervisor.is_cancelled());
462
463        supervisor.cancel();
464        assert!(supervisor.is_cancelled());
465    }
466
467    #[tokio::test]
468    async fn test_cancel_propagates_to_all_tasks() {
469        let supervisor = TaskSupervisor::new();
470        let count = Arc::new(AtomicUsize::new(0));
471
472        for _ in 0..3 {
473            let count_clone = count.clone();
474            let _ = supervisor.spawn_with_token(|token| async move {
475                token.cancelled().await;
476                count_clone.fetch_add(1, Ordering::SeqCst);
477            });
478        }
479
480        sleep(Duration::from_millis(50)).await;
481        supervisor.cancel();
482        sleep(Duration::from_millis(100)).await;
483
484        assert_eq!(count.load(Ordering::SeqCst), 3);
485    }
486
487    #[tokio::test]
488    async fn test_shutdown_cancels_and_waits() {
489        let supervisor = TaskSupervisor::new();
490        let task_finished = Arc::new(AtomicBool::new(false));
491        let task_finished_clone = task_finished.clone();
492
493        let _ = supervisor.spawn_with_token(|_token| async move {
494            sleep(Duration::from_millis(100)).await;
495            task_finished_clone.store(true, Ordering::SeqCst);
496        });
497
498        assert!(!supervisor.is_cancelled());
499        assert!(!supervisor.is_closed());
500
501        supervisor.shutdown().await;
502
503        assert!(supervisor.is_cancelled());
504        assert!(supervisor.is_closed());
505        assert!(task_finished.load(Ordering::SeqCst));
506        assert_eq!(supervisor.len(), 0);
507    }
508
509    #[tokio::test]
510    async fn test_shutdown_with_timeout_completes_in_time() {
511        let supervisor = TaskSupervisor::new();
512
513        let _ = supervisor.spawn_with_token(|_token| async move {
514            sleep(Duration::from_millis(50)).await;
515        });
516
517        let result = supervisor
518            .shutdown_with_timeout(Duration::from_secs(1))
519            .await;
520        assert!(result.is_ok());
521        assert!(supervisor.is_cancelled());
522        assert!(supervisor.is_closed());
523    }
524
525    #[tokio::test]
526    async fn test_shutdown_with_timeout_times_out() {
527        let supervisor = TaskSupervisor::new();
528
529        let _ = supervisor.tracker().spawn(async {
530            sleep(Duration::from_secs(10)).await;
531        });
532
533        let result = supervisor
534            .shutdown_with_timeout(Duration::from_millis(50))
535            .await;
536        assert!(result.is_err());
537    }
538
539    #[tokio::test]
540    async fn test_cooperative_cancellation_in_loop() {
541        let supervisor = TaskSupervisor::new();
542        let iterations = Arc::new(AtomicUsize::new(0));
543        let iterations_clone = iterations.clone();
544
545        let _ = supervisor.spawn_with_token(|token| async move {
546            loop {
547                if token.is_cancelled() {
548                    break;
549                }
550                iterations_clone.fetch_add(1, Ordering::SeqCst);
551                sleep(Duration::from_millis(10)).await;
552            }
553        });
554
555        sleep(Duration::from_millis(55)).await;
556        supervisor.cancel();
557        sleep(Duration::from_millis(50)).await;
558
559        let count = iterations.load(Ordering::SeqCst);
560        assert!(count >= 3 && count < 20);
561    }
562
563    #[tokio::test]
564    async fn test_spawn_with_cancel_completes() {
565        let supervisor = TaskSupervisor::new();
566
567        let handle = supervisor.spawn_with_cancel(|| async move {
568            sleep(Duration::from_millis(30)).await;
569            42
570        });
571
572        match handle.await.unwrap() {
573            CancelOutcome::Completed(value) => assert_eq!(value, 42),
574            CancelOutcome::Cancelled => panic!("task should have completed"),
575        }
576    }
577
578    #[tokio::test]
579    async fn test_spawn_with_cancel_reports_cancellation() {
580        let supervisor = TaskSupervisor::new();
581
582        let handle = supervisor.spawn_with_cancel(|| async move {
583            loop {
584                sleep(Duration::from_millis(10)).await;
585            }
586        });
587
588        sleep(Duration::from_millis(35)).await;
589        supervisor.cancel();
590
591        match handle.await.unwrap() {
592            CancelOutcome::Completed(_) => panic!("task should have been cancelled"),
593            CancelOutcome::Cancelled => {}
594        }
595    }
596}