tokio_context/task.rs
1use std::{future::Future, time::Duration};
2use tokio::sync::broadcast::Sender;
3use tokio::{sync::broadcast, time::Instant};
4
5/// Handles spawning tasks which can also be cancelled by calling `cancel` on the task controller.
6/// If a [`std::time::Duration`] is supplied using the
7/// [`with_timeout`](fn@TaskController::with_timeout) constructor, then any tasks spawned by the
8/// TaskController will automatically be cancelled after the supplied duration has elapsed.
9///
10/// This provides a different API from Context for the same end result. It's nicer to use when you
11/// don't need child futures to gracefully shutdown. In cases that you do require graceful shutdown
12/// of child futures, you will need to pass a Context down, and incorporate the context into normal
13/// program flow for the child function so that they can react to it as needed and perform custom
14/// asynchronous cleanup logic.
15///
16/// # Examples
17///
18/// ```rust
19/// use std::time::Duration;
20/// use tokio::time;
21/// use tokio_context::task::TaskController;
22///
23/// async fn task_that_takes_too_long() {
24/// time::sleep(time::Duration::from_secs(60)).await;
25/// println!("done");
26/// }
27///
28/// #[tokio::main]
29/// async fn main() {
30/// let mut controller = TaskController::new();
31///
32/// let mut join_handles = vec![];
33///
34/// for i in 0..10 {
35/// let handle = controller.spawn(async { task_that_takes_too_long().await });
36/// join_handles.push(handle);
37/// }
38///
39/// // Will cancel all spawned contexts.
40/// controller.cancel();
41///
42/// // Now all join handles should gracefully close.
43/// for join in join_handles {
44/// join.await.unwrap();
45/// }
46/// }
47/// ```
48pub struct TaskController {
49 timeout: Option<Instant>,
50 cancel_sender: Sender<()>,
51}
52
53impl TaskController {
54 /// Call cancel() to cancel any tasks spawned by this TaskController. You can also simply drop
55 /// the TaskController to achieve the same result.
56 pub fn cancel(self) {}
57
58 /// Constructs a new TaskController, which can be used to spawn tasks. Tasks spawned from the
59 /// task controller will be cancelled if `cancel()` gets called.
60 pub fn new() -> TaskController {
61 let (tx, _) = broadcast::channel(1);
62 TaskController {
63 timeout: None,
64 cancel_sender: tx,
65 }
66 }
67
68 /// Constructs a new TaskController, which can be used to spawn tasks. Tasks spawned from the
69 /// task controller will be cancelled if `cancel()` gets called. They will also be cancelled if
70 /// a supplied timeout elapses.
71 pub fn with_timeout(timeout: Duration) -> TaskController {
72 let (tx, _) = broadcast::channel(1);
73 TaskController {
74 timeout: Some(Instant::now() + timeout),
75 cancel_sender: tx,
76 }
77 }
78
79 /// Spawns tasks using an identical API to tokio::task::spawn. Tasks spawned from this
80 /// TaskController will obey the optional timeout that may have been supplied during
81 /// construction of the TaskController. They will also be cancelled if `cancel()` is ever
82 /// called. Returns a JoinHandle from the internally generated task.
83 pub fn spawn<T>(&mut self, future: T) -> tokio::task::JoinHandle<Option<T::Output>>
84 where
85 T: Future + Send + 'static,
86 T::Output: Send + 'static,
87 {
88 let mut rx = self.cancel_sender.subscribe();
89 if let Some(instant) = self.timeout {
90 tokio::task::spawn(async move {
91 tokio::select! {
92 res = future => Some(res),
93 _ = rx.recv() => None,
94 _ = tokio::time::sleep_until(instant) => None,
95 }
96 })
97 } else {
98 tokio::task::spawn(async move {
99 tokio::select! {
100 res = future => Some(res),
101 _ = rx.recv() => None,
102 }
103 })
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use std::time::Duration;
112
113 #[tokio::test]
114 async fn cancel_handle_cancels_task() {
115 let mut controller = TaskController::new();
116 let join = controller.spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
117 controller.cancel();
118
119 tokio::select! {
120 _ = join => assert!(true),
121 _ = tokio::time::sleep(Duration::from_millis(1)) => assert!(false),
122 }
123 }
124
125 #[tokio::test]
126 async fn duration_cancels_task() {
127 let mut controller = TaskController::with_timeout(Duration::from_millis(10));
128 let join = controller.spawn(async { tokio::time::sleep(Duration::from_secs(60)).await });
129
130 tokio::select! {
131 _ = join => assert!(true),
132 _ = tokio::time::sleep(Duration::from_millis(15)) => assert!(false),
133 }
134 }
135}