1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
//
// Copyright (c) 2024 ZettaScale Technology
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License 2.0 which is available at
// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
//
// Contributors:
// ZettaScale Zenoh Team, <zenoh@zettascale.tech>
//
//! ⚠️ WARNING ⚠️
//!
//! This module is intended for Zenoh's internal use.
//!
//! [Click here for Zenoh's documentation](../zenoh/index.html)
use std::{future::Future, time::Duration};
use futures::future::FutureExt;
use tokio::task::JoinHandle;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use zenoh_core::{ResolveFuture, Wait};
use zenoh_runtime::ZRuntime;
#[derive(Clone)]
pub struct TaskController {
tracker: TaskTracker,
token: CancellationToken,
}
impl Default for TaskController {
fn default() -> Self {
TaskController {
tracker: TaskTracker::new(),
token: CancellationToken::new(),
}
}
}
impl TaskController {
/// Spawns a task that can be later terminated by call to [`TaskController::terminate_all()`].
/// Task output is ignored.
pub fn spawn_abortable<F, T>(&self, future: F) -> JoinHandle<()>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let token = self.token.child_token();
let task = async move {
tokio::select! {
_ = token.cancelled() => {},
_ = future => {}
}
};
self.tracker.spawn(task)
}
/// Spawns a task using a specified runtime that can be later terminated by call to [`TaskController::terminate_all()`].
/// Task output is ignored.
pub fn spawn_abortable_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<()>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let token = self.token.child_token();
let task = async move {
tokio::select! {
_ = token.cancelled() => {},
_ = future => {}
}
};
self.tracker.spawn_on(task, &rt)
}
pub fn get_cancellation_token(&self) -> CancellationToken {
self.token.child_token()
}
/// Spawns a task that can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`],
/// or that can run to completion in finite amount of time.
/// It can be later terminated by call to [`TaskController::terminate_all()`].
pub fn spawn<F, T>(&self, future: F) -> JoinHandle<()>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.tracker.spawn(future.map(|_f| ()))
}
/// Spawns a task that can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`],
/// or that can run to completion in finite amount of time, using a specified runtime.
/// It can be later aborted by call to [`TaskController::terminate_all()`].
pub fn spawn_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<()>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.tracker.spawn_on(future.map(|_f| ()), &rt)
}
/// Attempts tp terminate all previously spawned tasks
/// The caller must ensure that all tasks spawned with [`TaskController::spawn()`]
/// or [`TaskController::spawn_with_rt()`] can yield in finite amount of time either because they will run to completion
/// or due to cancellation of token acquired via [`TaskController::get_cancellation_token()`].
/// Tasks spawned with [`TaskController::spawn_abortable()`] or [`TaskController::spawn_abortable_with_rt()`] will be aborted (i.e. terminated upon next await call).
/// The call blocks until all tasks yield or timeout duration expires.
/// Returns 0 in case of success, number of non terminated tasks otherwise.
pub fn terminate_all(&self, timeout: Duration) -> usize {
ResolveFuture::new(async move { self.terminate_all_async(timeout).await }).wait()
}
/// Async version of [`TaskController::terminate_all()`].
pub async fn terminate_all_async(&self, timeout: Duration) -> usize {
self.tracker.close();
self.token.cancel();
if tokio::time::timeout(timeout, self.tracker.wait())
.await
.is_err()
{
tracing::error!("Failed to terminate {} tasks", self.tracker.len());
return self.tracker.len();
}
0
}
}
pub struct TerminatableTask {
handle: JoinHandle<()>,
token: CancellationToken,
}
impl TerminatableTask {
pub fn create_cancellation_token() -> CancellationToken {
CancellationToken::new()
}
/// Spawns a task that can be later terminated by [`TerminatableTask::terminate()`].
/// Prior to termination attempt the specified cancellation token will be cancelled.
pub fn spawn<F, T>(rt: ZRuntime, future: F, token: CancellationToken) -> TerminatableTask
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
TerminatableTask {
handle: rt.spawn(future.map(|_f| ())),
token,
}
}
/// Spawns a task that can be later aborted by [`TerminatableTask::terminate()`].
pub fn spawn_abortable<F, T>(rt: ZRuntime, future: F) -> TerminatableTask
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let token = CancellationToken::new();
let token2 = token.clone();
let task = async move {
tokio::select! {
_ = token2.cancelled() => {},
_ = future => {}
}
};
TerminatableTask {
handle: rt.spawn(task),
token,
}
}
/// Attempts to terminate the task.
/// Returns true if task completed / aborted within timeout duration, false otherwise.
pub fn terminate(self, timeout: Duration) -> bool {
ResolveFuture::new(async move { self.terminate_async(timeout).await }).wait()
}
/// Async version of [`TerminatableTask::terminate()`].
pub async fn terminate_async(self, timeout: Duration) -> bool {
self.token.cancel();
if tokio::time::timeout(timeout, self.handle).await.is_err() {
tracing::error!("Failed to terminate the task");
return false;
};
true
}
}