1use std::{future::Future, time::Duration};
22
23use futures::future::FutureExt;
24use tokio::task::JoinHandle;
25use tokio_util::{sync::CancellationToken, task::TaskTracker};
26use zenoh_core::{ResolveFuture, Wait};
27use zenoh_runtime::ZRuntime;
28
29#[derive(Clone)]
30pub struct TaskController {
31 tracker: TaskTracker,
32 token: CancellationToken,
33}
34
35impl Default for TaskController {
36 fn default() -> Self {
37 TaskController {
38 tracker: TaskTracker::new(),
39 token: CancellationToken::new(),
40 }
41 }
42}
43
44impl TaskController {
45 pub fn into_abortable<'a, F, T>(&self, future: F) -> impl Future<Output = Option<T>> + Send + 'a
47 where
48 F: Future<Output = T> + Send + 'a,
49 T: Send + 'static,
50 {
51 self.token.child_token().run_until_cancelled_owned(future)
52 }
53
54 pub fn spawn_abortable<F, T>(&self, future: F) -> JoinHandle<Option<T>>
57 where
58 F: Future<Output = T> + Send + 'static,
59 T: Send + 'static,
60 {
61 #[cfg(feature = "tracing-instrument")]
62 let future = tracing::Instrument::instrument(future, tracing::Span::current());
63
64 self.tracker.spawn(self.into_abortable(future))
65 }
66
67 pub fn spawn_abortable_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<Option<T>>
69 where
70 F: Future<Output = T> + Send + 'static,
71 T: Send + 'static,
72 {
73 #[cfg(feature = "tracing-instrument")]
74 let future = tracing::Instrument::instrument(future, tracing::Span::current());
75
76 self.tracker.spawn_on(self.into_abortable(future), &rt)
77 }
78
79 pub fn get_cancellation_token(&self) -> CancellationToken {
80 self.token.child_token()
81 }
82
83 pub fn spawn<F, T>(&self, future: F) -> JoinHandle<T>
88 where
89 F: Future<Output = T> + Send + 'static,
90 T: Send + 'static,
91 {
92 #[cfg(feature = "tracing-instrument")]
93 let future = tracing::Instrument::instrument(future, tracing::Span::current());
94
95 self.tracker.spawn(future)
96 }
97
98 pub fn spawn_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<T>
103 where
104 F: Future<Output = T> + Send + 'static,
105 T: Send + 'static,
106 {
107 #[cfg(feature = "tracing-instrument")]
108 let future = tracing::Instrument::instrument(future, tracing::Span::current());
109
110 self.tracker.spawn_on(future, &rt)
111 }
112
113 pub fn terminate_all(&self, timeout: Duration) -> usize {
121 ResolveFuture::new(async move {
122 if tokio::time::timeout(timeout, self.terminate_all_async())
123 .await
124 .is_err()
125 {
126 tracing::error!("Failed to terminate {} tasks", self.tracker.len());
127 }
128 self.tracker.len()
129 })
130 .wait()
131 }
132
133 pub async fn terminate_all_async(&self) {
135 self.tracker.close();
136 self.token.cancel();
137 self.tracker.wait().await
138 }
139}
140
141pub struct TerminatableTask {
142 handle: Option<JoinHandle<()>>,
143 token: CancellationToken,
144}
145
146impl Drop for TerminatableTask {
147 fn drop(&mut self) {
148 self.terminate(std::time::Duration::from_secs(10));
149 }
150}
151
152impl TerminatableTask {
153 pub fn create_cancellation_token() -> CancellationToken {
154 CancellationToken::new()
155 }
156
157 pub fn spawn<F, T>(rt: ZRuntime, future: F, token: CancellationToken) -> TerminatableTask
160 where
161 F: Future<Output = T> + Send + 'static,
162 T: Send + 'static,
163 {
164 TerminatableTask {
165 handle: Some(rt.spawn(future.map(|_f| ()))),
166 token,
167 }
168 }
169
170 pub fn spawn_abortable<F, T>(rt: ZRuntime, future: F) -> TerminatableTask
172 where
173 F: Future<Output = T> + Send + 'static,
174 T: Send + 'static,
175 {
176 let token = CancellationToken::new();
177 let token2 = token.clone();
178 let task = async move {
179 tokio::select! {
180 _ = token2.cancelled() => {},
181 _ = future => {}
182 }
183 };
184
185 TerminatableTask {
186 handle: Some(rt.spawn(task)),
187 token,
188 }
189 }
190
191 pub fn terminate(&mut self, timeout: Duration) -> bool {
194 ResolveFuture::new(async move {
195 if tokio::time::timeout(timeout, self.terminate_async())
196 .await
197 .is_err()
198 {
199 tracing::error!("Failed to terminate the task");
200 return false;
201 };
202 true
203 })
204 .wait()
205 }
206
207 pub async fn terminate_async(&mut self) {
209 self.token.cancel();
210 if let Some(handle) = self.handle.take() {
211 let _ = handle.await;
212 }
213 }
214}