1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
//
// Copyright (c) 2024 ZettaScale Technology
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License 2.0 which is available at
// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
//
// Contributors:
//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
//

//! ⚠️ WARNING ⚠️
//!
//! This module is intended for Zenoh's internal use.
//!
//! [Click here for Zenoh's documentation](../zenoh/index.html)

use std::{future::Future, time::Duration};

use futures::future::FutureExt;
use tokio::task::JoinHandle;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use zenoh_core::{ResolveFuture, Wait};
use zenoh_runtime::ZRuntime;

#[derive(Clone)]
pub struct TaskController {
    tracker: TaskTracker,
    token: CancellationToken,
}

impl Default for TaskController {
    fn default() -> Self {
        TaskController {
            tracker: TaskTracker::new(),
            token: CancellationToken::new(),
        }
    }
}

impl TaskController {
    /// Spawns a task that can be later terminated by call to [`TaskController::terminate_all()`].
    /// Task output is ignored.
    pub fn spawn_abortable<F, T>(&self, future: F) -> JoinHandle<()>
    where
        F: Future<Output = T> + Send + 'static,
        T: Send + 'static,
    {
        let token = self.token.child_token();
        let task = async move {
            tokio::select! {
                _ = token.cancelled() => {},
                _ = future => {}
            }
        };
        self.tracker.spawn(task)
    }

    /// Spawns a task using a specified runtime that can be later terminated by call to [`TaskController::terminate_all()`].
    /// Task output is ignored.
    pub fn spawn_abortable_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<()>
    where
        F: Future<Output = T> + Send + 'static,
        T: Send + 'static,
    {
        let token = self.token.child_token();
        let task = async move {
            tokio::select! {
                _ = token.cancelled() => {},
                _ = future => {}
            }
        };
        self.tracker.spawn_on(task, &rt)
    }

    pub fn get_cancellation_token(&self) -> CancellationToken {
        self.token.child_token()
    }

    /// Spawns a task that can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`],
    /// or that can run to completion in finite amount of time.
    /// It can be later terminated by call to [`TaskController::terminate_all()`].
    pub fn spawn<F, T>(&self, future: F) -> JoinHandle<()>
    where
        F: Future<Output = T> + Send + 'static,
        T: Send + 'static,
    {
        self.tracker.spawn(future.map(|_f| ()))
    }

    /// Spawns a task that can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`],
    /// or that can run to completion in finite amount of time, using a specified runtime.
    /// It can be later aborted by call to [`TaskController::terminate_all()`].
    pub fn spawn_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<()>
    where
        F: Future<Output = T> + Send + 'static,
        T: Send + 'static,
    {
        self.tracker.spawn_on(future.map(|_f| ()), &rt)
    }

    /// Attempts tp terminate all previously spawned tasks
    /// The caller must ensure that all tasks spawned with [`TaskController::spawn()`]
    /// or [`TaskController::spawn_with_rt()`] can yield in finite amount of time either because they will run to completion
    /// or due to cancellation of token acquired via [`TaskController::get_cancellation_token()`].
    /// Tasks spawned with [`TaskController::spawn_abortable()`] or [`TaskController::spawn_abortable_with_rt()`] will be aborted (i.e. terminated upon next await call).
    /// The call blocks until all tasks yield or timeout duration expires.
    /// Returns 0 in case of success, number of non terminated tasks otherwise.
    pub fn terminate_all(&self, timeout: Duration) -> usize {
        ResolveFuture::new(async move { self.terminate_all_async(timeout).await }).wait()
    }

    /// Async version of [`TaskController::terminate_all()`].
    pub async fn terminate_all_async(&self, timeout: Duration) -> usize {
        self.tracker.close();
        self.token.cancel();
        if tokio::time::timeout(timeout, self.tracker.wait())
            .await
            .is_err()
        {
            tracing::error!("Failed to terminate {} tasks", self.tracker.len());
            return self.tracker.len();
        }
        0
    }
}

pub struct TerminatableTask {
    handle: JoinHandle<()>,
    token: CancellationToken,
}

impl TerminatableTask {
    pub fn create_cancellation_token() -> CancellationToken {
        CancellationToken::new()
    }

    /// Spawns a task that can be later terminated by [`TerminatableTask::terminate()`].
    /// Prior to termination attempt the specified cancellation token will be cancelled.
    pub fn spawn<F, T>(rt: ZRuntime, future: F, token: CancellationToken) -> TerminatableTask
    where
        F: Future<Output = T> + Send + 'static,
        T: Send + 'static,
    {
        TerminatableTask {
            handle: rt.spawn(future.map(|_f| ())),
            token,
        }
    }

    /// Spawns a task that can be later aborted by [`TerminatableTask::terminate()`].
    pub fn spawn_abortable<F, T>(rt: ZRuntime, future: F) -> TerminatableTask
    where
        F: Future<Output = T> + Send + 'static,
        T: Send + 'static,
    {
        let token = CancellationToken::new();
        let token2 = token.clone();
        let task = async move {
            tokio::select! {
                _ = token2.cancelled() => {},
                _ = future => {}
            }
        };

        TerminatableTask {
            handle: rt.spawn(task),
            token,
        }
    }

    /// Attempts to terminate the task.
    /// Returns true if task completed / aborted within timeout duration, false otherwise.
    pub fn terminate(self, timeout: Duration) -> bool {
        ResolveFuture::new(async move { self.terminate_async(timeout).await }).wait()
    }

    /// Async version of [`TerminatableTask::terminate()`].
    pub async fn terminate_async(self, timeout: Duration) -> bool {
        self.token.cancel();
        if tokio::time::timeout(timeout, self.handle).await.is_err() {
            tracing::error!("Failed to terminate the task");
            return false;
        };
        true
    }
}