tokio_task_tracker/
lib.rs

1//! tokio-task-tracker is a simple graceful shutdown solution for tokio.
2//!
3//! The basic idea is to use a `TaskSpawner` to create `TaskTracker` object, and hold
4//! on to them in spawned tasks. Inside the task, you can check `tracker.cancelled().await`
5//! to wait for the task to be cancelled.
6//!
7//! The `TaskWaiter` can be used to wait for an interrupt and then wait for all
8//! `TaskTracker`s to be dropped.
9//!
10//! # Examples
11//!
12//! ```no_run
13//! # use std::time::Duration;
14//! #
15//! #[tokio::main(flavor = "current_thread")]
16//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
17//!     let (spawner, waiter) = tokio_task_tracker::new();
18//!
19//!     // Start a task
20//!     spawner.spawn(|tracker| async move {
21//!         tokio::select! {
22//!             _ = tracker.cancelled() => {
23//!                 // The token was cancelled, task should shut down.
24//!             }
25//!             _ = tokio::time::sleep(Duration::from_secs(9999)) => {
26//!                 // Long work has completed
27//!             }
28//!         }
29//!     });
30//!
31//!     // Wait for all tasks to complete, or for someone to hit ctrl-c.
32//!     // If tasks down't complete within 5 seconds, we'll quit anyways.
33//!     waiter.wait_for_shutdown(Duration::from_secs(5)).await?;
34//!
35//!     Ok(())
36//! }
37//! ```
38//!
39//! If you do not wish to allow a task to be aborted, you still need to make sure
40//! the task captures the tracker, because TaskWaiter will wait for all trackers to be dropped:
41//!
42//! ```no_run
43//! # use std::time::Duration;
44//! #
45//! # #[tokio::main(flavor = "current_thread")]
46//! # async fn main() {
47//! #     let (spawner, waiter) = tokio_task_tracker::new();
48//! #
49//!     // Start a task
50//!     spawner.spawn(|tracker| async move {
51//!         // Move the tracker into the task.
52//!         let _tracker = tracker;
53//!
54//!         // Do some work that we don't want to abort.
55//!         tokio::time::sleep(Duration::from_secs(9999)).await;
56//!     });
57//!
58//! # }
59//! ```
60//!
61//! You can also create a tracker via the `task` method:
62//!
63//! ```no_run
64//! # use std::time::Duration;
65//! #
66//! # #[tokio::main(flavor = "current_thread")]
67//! # async fn main() {
68//! #     let (spawner, waiter) = tokio_task_tracker::new();
69//! #
70//!     // Start a task
71//!     let tracker = spawner.task();
72//!     tokio::task::spawn(async move {
73//!         // Move the tracker into the task.
74//!         let _tracker = tracker;
75//!
76//!         // ...
77//!     });
78//!
79//! # }
80//! ```
81//!
82//! Trackers can be used to spawn subtasks via `tracker.subtask()` or
83//! `tracker.spawn()`.
84
85use std::{
86    future::Future,
87    sync::{Arc, Mutex},
88    time::Duration,
89};
90
91use shutdown::wait_for_shutdown_signal;
92use tokio::{select, sync::mpsc, task::JoinHandle};
93use tokio_util::sync::CancellationToken;
94
95mod shutdown;
96
97/// Builder is used to create a TaskSpawner and TaskWaiter.
98pub struct Builder {
99    token: Option<CancellationToken>,
100}
101
102/// TaskSpawner is used to spawn new task trackers.
103#[derive(Clone)]
104pub struct TaskSpawner {
105    token: CancellationToken,
106    stop_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
107}
108
109/// TaskWaiter is used to wait until all task trackers have been dropped.
110pub struct TaskWaiter {
111    token: CancellationToken,
112    /// Shared stop_tx is shared between all TaskSpawners and the TaskWaiter, so that
113    /// when we call TaskWaiter::wait() we can drop the tx from all spawners.
114    stop_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
115    stop_rx: mpsc::Receiver<()>,
116}
117
118/// A TaskTracker is used both as a token to keep track of active tasks, and
119/// as a cancellation token to check to see if the current task should quit.
120#[derive(Clone)]
121pub struct TaskTracker {
122    token: CancellationToken,
123    // Hang on to an instance of tx. We do this so we can know when all tasks
124    // have been completed.
125    _stop_tx: Option<mpsc::Sender<()>>,
126}
127
128#[derive(Debug, PartialEq)]
129pub enum Error {
130    /// Returned when we timeout waiting for all tasks to shut down.
131    Timeout,
132    /// Returned when we cannot bind to the interrupt/terminate signals.
133    CouldNotBindInterrupt,
134    /// Returned when we were waiting for graceful shutdown, but received a
135    /// second interrupt signal.
136    ShutdownEarly,
137}
138
139impl std::error::Error for Error {}
140
141impl std::fmt::Display for Error {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        match self {
144            Error::Timeout => write!(f, "Not all tasks finished before timeout"),
145            Error::CouldNotBindInterrupt => write!(f, "Could not bind interrupt handler"),
146            Error::ShutdownEarly => write!(f, "Skipping graceful shutdown due to second interrupt"),
147        }
148    }
149}
150
151/// Create a new TaskSpawner and TaskWaiter.
152pub fn new() -> (TaskSpawner, TaskWaiter) {
153    Builder::default().build()
154}
155
156impl Builder {
157    /// Create a new Builder.
158    pub fn new() -> Self {
159        Builder { token: None }
160    }
161
162    /// Use an existing CancellationToken for the returned TaskWaiter and TaskSpawner.
163    /// If the given token is cancelled, all associated TaskTrackers will be cancelled
164    /// as well.
165    pub fn set_cancellation_token(mut self, token: CancellationToken) -> Self {
166        self.token = Some(token);
167        self
168    }
169
170    /// Create a new TaskSpawner and TaskWaiter.
171    pub fn build(self) -> (TaskSpawner, TaskWaiter) {
172        let (stop_tx, stop_rx) = mpsc::channel(1);
173        let stop_tx = Arc::new(Mutex::new(Some(stop_tx)));
174        let token = self.token.unwrap_or(CancellationToken::new());
175
176        (
177            TaskSpawner {
178                token: token.clone(),
179                stop_tx: stop_tx.clone(),
180            },
181            TaskWaiter {
182                token,
183                stop_tx,
184                stop_rx,
185            },
186        )
187    }
188}
189
190impl Default for Builder {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196impl TaskSpawner {
197    /// Create a new TaskTracker.
198    pub fn task(&self) -> TaskTracker {
199        TaskTracker {
200            token: self.token.clone(),
201            _stop_tx: self.stop_tx.lock().unwrap().as_ref().cloned(),
202        }
203    }
204
205    /// Spawn a task.
206    ///
207    /// The given closure will be called, passing in a task tracker.
208    pub fn spawn<T, F: FnOnce(TaskTracker) -> T>(&self, f: F) -> JoinHandle<T::Output>
209    where
210        T: Future + Send + 'static,
211        T::Output: Send + 'static,
212    {
213        let tracker = self.task();
214        tokio::task::spawn(f(tracker))
215    }
216
217    /// Notify all tasks created by this TaskSpawner that they should abort.
218    pub fn cancel(&self) {
219        self.token.cancel();
220    }
221}
222
223impl TaskWaiter {
224    /// Notify all tasks this TaskWaiter is waiting on that they should abort.
225    pub fn cancel(&self) {
226        self.token.cancel();
227    }
228
229    /// Wait for the application to be interrupted, and then gracefully shutdown
230    /// allowing a timeout for all tasks to quit.  A second interrupt will cause
231    /// an immediate shutdown.
232    ///
233    /// On Unix systems, "interrupt" means a SIGINT or SIGTERM. On all other
234    /// platforms the current implementation uses `tokio::signal::ctrl_c()`
235    /// to wait for an interrupt.
236    pub async fn wait_for_shutdown(self, timeout: Duration) -> Result<(), Error> {
237        // Wait for the ctrl-c.
238        match wait_for_shutdown_signal().await {
239            Ok(()) => {
240                // time to shut down...
241            }
242            Err(_) => return Err(Error::CouldNotBindInterrupt),
243        }
244
245        // Let tasks know they should shut down.
246        self.token.cancel();
247
248        // Wait for everything to finish.
249        select! {
250            res = self.wait_with_timeout(timeout) => res,
251            _ = wait_for_shutdown_signal() => Err(Error::ShutdownEarly),
252        }
253    }
254
255    /// Wait for all tasks to finish.  If tasks do not finish before the timeout,
256    /// `Error::Timeout` will be returned.
257    pub async fn wait_with_timeout(self, timeout: Duration) -> Result<(), Error> {
258        // Wait for all tasks to be dropped.
259        tokio::time::timeout(timeout, self.wait())
260            .await
261            .map_err(|_| Error::Timeout {})?;
262
263        Ok(())
264    }
265
266    /// Wait for all tasks to finish.
267    pub async fn wait(mut self) {
268        // Drop the tx half of the channel.
269        drop(self.stop_tx.lock().unwrap().take());
270
271        // Wait for all tasks to be dropped.
272        let _ = self.stop_rx.recv().await;
273    }
274}
275
276impl TaskTracker {
277    /// Create a new subtask from this TaskTracker.
278    pub fn subtask(&self) -> Self {
279        self.clone()
280    }
281
282    /// Spawn a subtask.
283    ///
284    /// The given closure will be called, passing in a task tracker.
285    pub fn spawn<T, F: FnOnce(TaskTracker) -> T>(&self, f: F) -> JoinHandle<T::Output>
286    where
287        T: Future + Send + 'static,
288        T::Output: Send + 'static,
289    {
290        let tracker = self.subtask();
291        tokio::task::spawn(f(tracker))
292    }
293
294    /// Check to see if this task has been cancelled.
295    pub async fn cancelled(&self) {
296        self.token.cancelled().await;
297    }
298
299    /// Returns true if this token has been cancelled.
300    pub fn is_cancelled(&self) -> bool {
301        self.token.is_cancelled()
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use std::{
309        sync::atomic::{AtomicBool, Ordering},
310        time::Duration,
311    };
312
313    #[tokio::test]
314    async fn tracker_should_be_cancelled() {
315        let (spawner, waiter) = super::new();
316
317        let task = spawner.task();
318        waiter.cancel();
319        assert!(task.is_cancelled());
320    }
321
322    #[tokio::test]
323    async fn should_work_with_existing_cancellation_token() {
324        let token = CancellationToken::new();
325        let (spawner, _) = super::Builder::new()
326            .set_cancellation_token(token.clone())
327            .build();
328        let task = spawner.task();
329
330        // Cancelling the token should cancel the task.
331        token.cancel();
332        assert!(task.is_cancelled());
333    }
334
335    #[tokio::test]
336    async fn should_wait_for_tasks_to_complete() -> Result<(), Box<dyn std::error::Error>> {
337        let (spawner, waiter) = super::new();
338
339        let done = Arc::new(AtomicBool::new(false));
340
341        // Start a task
342        {
343            let done = done.clone();
344            spawner.spawn(|tracker| async move {
345                tokio::select! {
346                    _ = tracker.cancelled() => {
347                        // The token was cancelled, task should shut down.
348                    }
349                    _ = tokio::time::sleep(Duration::from_millis(100)) => {
350                        // Short task has completed.
351                        done.store(true, Ordering::SeqCst);
352                    }
353                }
354            });
355        }
356
357        // Wait for all tasks to complete.
358        waiter.wait().await;
359
360        // Should have completed.
361        assert!(done.load(Ordering::SeqCst));
362
363        Ok(())
364    }
365
366    #[tokio::test]
367    async fn should_cancel_tasks() -> Result<(), Box<dyn std::error::Error>> {
368        let (spawner, waiter) = super::new();
369
370        let done = Arc::new(AtomicBool::new(false));
371
372        // Start a task
373        {
374            let done = done.clone();
375            spawner.spawn(|tracker| async move {
376                tokio::select! {
377                    _ = tracker.cancelled() => {
378                        // The token was cancelled, task should shut down.
379                    }
380                    _ = tokio::time::sleep(Duration::from_secs(9999)) => {
381                        // Long work has completed
382                        done.store(true, Ordering::SeqCst);
383                    }
384                }
385            });
386        }
387
388        // Cancel the task after a short while.
389        tokio::time::sleep(Duration::from_millis(100)).await;
390        waiter.cancel();
391
392        // Wait for all tasks to complete.
393        waiter.wait().await;
394
395        // Should have timed out.
396        assert!(!done.load(Ordering::SeqCst));
397
398        Ok(())
399    }
400
401    #[tokio::test]
402    async fn interrupt_tests() -> Result<(), Box<dyn std::error::Error>> {
403        // Interrupt tests rely on global state in shutdown.rs to simulate
404        // SIGINT.  Need to run these serially.
405        should_wait_for_tasks_on_interrupt().await?;
406        should_stop_immediately_on_second_interrupt().await?;
407
408        Ok(())
409    }
410
411    async fn should_wait_for_tasks_on_interrupt() -> Result<(), Box<dyn std::error::Error>> {
412        shutdown::reset_before_test();
413
414        let (spawner, waiter) = super::new();
415
416        let done = Arc::new(AtomicBool::new(false));
417
418        // Start a task
419        {
420            let done = done.clone();
421            spawner.spawn(|tracker| async move {
422                tokio::select! {
423                    _ = tracker.cancelled() => {
424                        // The token was cancelled, task should shut down.
425                    }
426                    _ = tokio::time::sleep(Duration::from_secs(9999)) => {
427                        // Long running task...
428                        done.store(true, Ordering::SeqCst);
429                    }
430                }
431            });
432        }
433
434        // Send a fake shutdown signal.
435        tokio::spawn(async {
436            shutdown::send_shutdown().await;
437        });
438
439        // Wait for all tasks to complete.
440        waiter.wait_for_shutdown(Duration::from_secs(10)).await?;
441
442        // Task should have been aborted.
443        assert!(!done.load(Ordering::SeqCst));
444
445        Ok(())
446    }
447
448    async fn should_stop_immediately_on_second_interrupt() -> Result<(), Box<dyn std::error::Error>>
449    {
450        shutdown::reset_before_test();
451
452        let (spawner, waiter) = super::new();
453
454        let done = Arc::new(AtomicBool::new(false));
455
456        // Start a task
457        {
458            let done = done.clone();
459            spawner.spawn(|tracker| async move {
460                let _tracker = tracker;
461
462                // Long running task that can't be cancelled.
463                tokio::time::sleep(Duration::from_secs(99)).await;
464                done.store(true, Ordering::SeqCst);
465            });
466        }
467
468        // Send two shutdown signals. The second should cause us to die immediately.
469        tokio::spawn(async move {
470            shutdown::send_shutdown().await;
471            shutdown::send_shutdown().await;
472        });
473
474        // We shouldn't wait here, because of the second interrupt.
475        let err = waiter
476            .wait_for_shutdown(Duration::from_secs(99))
477            .await
478            .unwrap_err();
479        assert_eq!(err, Error::ShutdownEarly);
480
481        // Task should have been aborted.
482        assert!(!done.load(Ordering::SeqCst));
483
484        Ok(())
485    }
486}