tokio_context/
task.rs

1use std::{future::Future, time::Duration};
2use tokio::sync::broadcast::Sender;
3use tokio::{sync::broadcast, time::Instant};
4
5/// Handles spawning tasks which can also be cancelled by calling `cancel` on the task controller.
6/// If a [`std::time::Duration`] is supplied using the
7/// [`with_timeout`](fn@TaskController::with_timeout) constructor, then any tasks spawned by the
8/// TaskController will automatically be cancelled after the supplied duration has elapsed.
9///
10/// This provides a different API from Context for the same end result. It's nicer to use when you
11/// don't need child futures to gracefully shutdown. In cases that you do require graceful shutdown
12/// of child futures, you will need to pass a Context down, and incorporate the context into normal
13/// program flow for the child function so that they can react to it as needed and perform custom
14/// asynchronous cleanup logic.
15///
16/// # Examples
17///
18/// ```rust
19/// use std::time::Duration;
20/// use tokio::time;
21/// use tokio_context::task::TaskController;
22///
23/// async fn task_that_takes_too_long() {
24///     time::sleep(time::Duration::from_secs(60)).await;
25///     println!("done");
26/// }
27///
28/// #[tokio::main]
29/// async fn main() {
30///     let mut controller = TaskController::new();
31///
32///     let mut join_handles = vec![];
33///
34///     for i in 0..10 {
35///         let handle = controller.spawn(async { task_that_takes_too_long().await });
36///         join_handles.push(handle);
37///     }
38///
39///     // Will cancel all spawned contexts.
40///     controller.cancel();
41///
42///     // Now all join handles should gracefully close.
43///     for join in join_handles {
44///         join.await.unwrap();
45///     }
46/// }
47/// ```
48pub struct TaskController {
49    timeout: Option<Instant>,
50    cancel_sender: Sender<()>,
51}
52
53impl TaskController {
54    /// Call cancel() to cancel any tasks spawned by this TaskController. You can also simply drop
55    /// the TaskController to achieve the same result.
56    pub fn cancel(self) {}
57
58    /// Constructs a new TaskController, which can be used to spawn tasks. Tasks spawned from the
59    /// task controller will be cancelled if `cancel()` gets called.
60    pub fn new() -> TaskController {
61        let (tx, _) = broadcast::channel(1);
62        TaskController {
63            timeout: None,
64            cancel_sender: tx,
65        }
66    }
67
68    /// Constructs a new TaskController, which can be used to spawn tasks. Tasks spawned from the
69    /// task controller will be cancelled if `cancel()` gets called. They will also be cancelled if
70    /// a supplied timeout elapses.
71    pub fn with_timeout(timeout: Duration) -> TaskController {
72        let (tx, _) = broadcast::channel(1);
73        TaskController {
74            timeout: Some(Instant::now() + timeout),
75            cancel_sender: tx,
76        }
77    }
78
79    /// Spawns tasks using an identical API to tokio::task::spawn. Tasks spawned from this
80    /// TaskController will obey the optional timeout that may have been supplied during
81    /// construction of the TaskController. They will also be cancelled if `cancel()` is ever
82    /// called. Returns a JoinHandle from the internally generated task.
83    pub fn spawn<T>(&mut self, future: T) -> tokio::task::JoinHandle<Option<T::Output>>
84    where
85        T: Future + Send + 'static,
86        T::Output: Send + 'static,
87    {
88        let mut rx = self.cancel_sender.subscribe();
89        if let Some(instant) = self.timeout {
90            tokio::task::spawn(async move {
91                tokio::select! {
92                    res = future => Some(res),
93                    _ = rx.recv() => None,
94                    _ = tokio::time::sleep_until(instant) => None,
95                }
96            })
97        } else {
98            tokio::task::spawn(async move {
99                tokio::select! {
100                    res = future => Some(res),
101                    _ = rx.recv() => None,
102                }
103            })
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use std::time::Duration;
112
113    #[tokio::test]
114    async fn cancel_handle_cancels_task() {
115        let mut controller = TaskController::new();
116        let join = controller.spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
117        controller.cancel();
118
119        tokio::select! {
120            _ = join => assert!(true),
121            _ = tokio::time::sleep(Duration::from_millis(1)) => assert!(false),
122        }
123    }
124
125    #[tokio::test]
126    async fn duration_cancels_task() {
127        let mut controller = TaskController::with_timeout(Duration::from_millis(10));
128        let join = controller.spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
129
130        tokio::select! {
131            _ = join => assert!(true),
132            _ = tokio::time::sleep(Duration::from_millis(15)) => assert!(false),
133        }
134    }
135}