Skip to main content

zenoh_task/
lib.rs

1//
2// Copyright (c) 2024 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14
15//! ⚠️ WARNING ⚠️
16//!
17//! This module is intended for Zenoh's internal use.
18//!
19//! [Click here for Zenoh's documentation](https://docs.rs/zenoh/latest/zenoh)
20
21use std::{future::Future, time::Duration};
22
23use futures::future::FutureExt;
24use tokio::task::JoinHandle;
25use tokio_util::{sync::CancellationToken, task::TaskTracker};
26use zenoh_core::{ResolveFuture, Wait};
27use zenoh_runtime::ZRuntime;
28
29#[derive(Clone)]
30pub struct TaskController {
31    tracker: TaskTracker,
32    token: CancellationToken,
33}
34
35impl Default for TaskController {
36    fn default() -> Self {
37        TaskController {
38            tracker: TaskTracker::new(),
39            token: CancellationToken::new(),
40        }
41    }
42}
43
44impl TaskController {
45    /// Converts a task to abortable one, which can later be terminated by call to [`TaskController::terminate_all()`].
46    pub fn into_abortable<'a, F, T>(&self, future: F) -> impl Future<Output = Option<T>> + Send + 'a
47    where
48        F: Future<Output = T> + Send + 'a,
49        T: Send + 'static,
50    {
51        self.token.child_token().run_until_cancelled_owned(future)
52    }
53
54    /// Spawns a task that can be later terminated by call to [`TaskController::terminate_all()`].
55    /// Task output is ignored.
56    pub fn spawn_abortable<F, T>(&self, future: F) -> JoinHandle<Option<T>>
57    where
58        F: Future<Output = T> + Send + 'static,
59        T: Send + 'static,
60    {
61        #[cfg(feature = "tracing-instrument")]
62        let future = tracing::Instrument::instrument(future, tracing::Span::current());
63
64        self.tracker.spawn(self.into_abortable(future))
65    }
66
67    /// Spawns a task using a specified runtime that can be later terminated by call to [`TaskController::terminate_all()`].
68    pub fn spawn_abortable_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<Option<T>>
69    where
70        F: Future<Output = T> + Send + 'static,
71        T: Send + 'static,
72    {
73        #[cfg(feature = "tracing-instrument")]
74        let future = tracing::Instrument::instrument(future, tracing::Span::current());
75
76        self.tracker.spawn_on(self.into_abortable(future), &rt)
77    }
78
79    pub fn get_cancellation_token(&self) -> CancellationToken {
80        self.token.child_token()
81    }
82
83    /// Spawns a task that can be cancelled cancellation of a token obtained by [`TaskController::get_cancellation_token()`],
84    /// was created via [`TaskController::into_abortable()`],
85    /// or can run to completion in finite amount of time, using a specified runtime.
86    /// It can be later terminated by call to [`TaskController::terminate_all()`].
87    pub fn spawn<F, T>(&self, future: F) -> JoinHandle<T>
88    where
89        F: Future<Output = T> + Send + 'static,
90        T: Send + 'static,
91    {
92        #[cfg(feature = "tracing-instrument")]
93        let future = tracing::Instrument::instrument(future, tracing::Span::current());
94
95        self.tracker.spawn(future)
96    }
97
98    /// Spawns a task which can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`],
99    /// was created via [`TaskController::into_abortable()`],
100    /// or can run to completion in finite amount of time, using a specified runtime.
101    /// It can be later aborted by call to [`TaskController::terminate_all()`].
102    pub fn spawn_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<T>
103    where
104        F: Future<Output = T> + Send + 'static,
105        T: Send + 'static,
106    {
107        #[cfg(feature = "tracing-instrument")]
108        let future = tracing::Instrument::instrument(future, tracing::Span::current());
109
110        self.tracker.spawn_on(future, &rt)
111    }
112
113    /// Attempts tp terminate all previously spawned tasks
114    /// The caller must ensure that all tasks spawned with [`TaskController::spawn()`]
115    /// or [`TaskController::spawn_with_rt()`] can yield in finite amount of time either because they will run to completion
116    /// or due to cancellation of token acquired via [`TaskController::get_cancellation_token()`].
117    /// Tasks spawned with [`TaskController::spawn_abortable()`] or [`TaskController::spawn_abortable_with_rt()`] will be aborted (i.e. terminated upon next await call).
118    /// The call blocks until all tasks yield or timeout duration expires.
119    /// Returns 0 in case of success, number of non terminated tasks otherwise.
120    pub fn terminate_all(&self, timeout: Duration) -> usize {
121        ResolveFuture::new(async move {
122            if tokio::time::timeout(timeout, self.terminate_all_async())
123                .await
124                .is_err()
125            {
126                tracing::error!("Failed to terminate {} tasks", self.tracker.len());
127            }
128            self.tracker.len()
129        })
130        .wait()
131    }
132
133    /// Async version of [`TaskController::terminate_all()`].
134    pub async fn terminate_all_async(&self) {
135        self.tracker.close();
136        self.token.cancel();
137        self.tracker.wait().await
138    }
139}
140
141pub struct TerminatableTask {
142    handle: Option<JoinHandle<()>>,
143    token: CancellationToken,
144}
145
146impl Drop for TerminatableTask {
147    fn drop(&mut self) {
148        self.terminate(std::time::Duration::from_secs(10));
149    }
150}
151
152impl TerminatableTask {
153    pub fn create_cancellation_token() -> CancellationToken {
154        CancellationToken::new()
155    }
156
157    /// Spawns a task that can be later terminated by [`TerminatableTask::terminate()`].
158    /// Prior to termination attempt the specified cancellation token will be cancelled.
159    pub fn spawn<F, T>(rt: ZRuntime, future: F, token: CancellationToken) -> TerminatableTask
160    where
161        F: Future<Output = T> + Send + 'static,
162        T: Send + 'static,
163    {
164        TerminatableTask {
165            handle: Some(rt.spawn(future.map(|_f| ()))),
166            token,
167        }
168    }
169
170    /// Spawns a task that can be later aborted by [`TerminatableTask::terminate()`].
171    pub fn spawn_abortable<F, T>(rt: ZRuntime, future: F) -> TerminatableTask
172    where
173        F: Future<Output = T> + Send + 'static,
174        T: Send + 'static,
175    {
176        let token = CancellationToken::new();
177        let token2 = token.clone();
178        let task = async move {
179            tokio::select! {
180                _ = token2.cancelled() => {},
181                _ = future => {}
182            }
183        };
184
185        TerminatableTask {
186            handle: Some(rt.spawn(task)),
187            token,
188        }
189    }
190
191    /// Attempts to terminate the task.
192    /// Returns true if task completed / aborted within timeout duration, false otherwise.
193    pub fn terminate(&mut self, timeout: Duration) -> bool {
194        ResolveFuture::new(async move {
195            if tokio::time::timeout(timeout, self.terminate_async())
196                .await
197                .is_err()
198            {
199                tracing::error!("Failed to terminate the task");
200                return false;
201            };
202            true
203        })
204        .wait()
205    }
206
207    /// Async version of [`TerminatableTask::terminate()`].
208    pub async fn terminate_async(&mut self) {
209        self.token.cancel();
210        if let Some(handle) = self.handle.take() {
211            let _ = handle.await;
212        }
213    }
214}