Skip to main content

scion_sdk_utils/
task_handler.rs

1// Copyright 2025 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! Utilities for managing tasks and subprocesses.
15
16use tokio::{process::Child, task::JoinSet};
17use tokio_util::sync::CancellationToken;
18
19/// A in-process task set that is cancelled when dropped.
20pub struct InProcess {
21    /// Cancellable task set.
22    pub task_set: CancelTaskSet,
23}
24
25impl InProcess {
26    /// Creates a new in-process task set.
27    pub fn new(task_set: CancelTaskSet) -> Self {
28        Self { task_set }
29    }
30}
31
32impl Drop for InProcess {
33    fn drop(&mut self) {
34        self.task_set.cancellation_token().cancel();
35    }
36}
37
38/// A subprocess that is killed when dropped.
39pub struct Subprocess {
40    /// The child process.
41    pub child: Child,
42}
43
44impl Subprocess {
45    /// Creates a new subprocess.
46    pub fn new(child: Child) -> Self {
47        Self { child }
48    }
49}
50
51impl Drop for Subprocess {
52    fn drop(&mut self) {
53        let _ = self.child.start_kill();
54        let _ = self.child.try_wait();
55    }
56}
57
58/// A combination of a [tokio::task::JoinSet] and
59/// [tokio_util::sync::CancellationToken].
60///
61/// Provides methods that are commonly used in conjunction with those two data
62/// structures.
63pub struct CancelTaskSet {
64    /// Task set join set.
65    pub join_set: JoinSet<Result<(), std::io::Error>>,
66    cancellation_token: CancellationToken,
67}
68
69impl CancelTaskSet {
70    /// Creates a new task set.
71    #[allow(clippy::new_without_default)]
72    pub fn new() -> Self {
73        let cancellation_token = CancellationToken::new();
74        Self::from_cancel_token(false, cancellation_token)
75    }
76
77    /// Creates a task set and registers a signal handler that calls `cancel()`
78    /// on the cancellation token upon receiving `SIGINT` and `SIGTERM`.
79    pub fn new_with_signal_handler() -> Self {
80        let cancellation_token = CancellationToken::new();
81        Self::from_cancel_token(true, cancellation_token)
82    }
83
84    /// Creates a task set from an existing cancellation token.
85    ///
86    /// # Arguments
87    /// * `register_signal_handler`: If true, a signal handler is registered that calls `cancel()`
88    ///   on the cancellation token upon receiving `SIGINT` and `SIGTERM`.
89    /// * `cancellation_token`: The cancellation token to use.
90    pub fn from_cancel_token(
91        register_signal_handler: bool,
92        cancellation_token: CancellationToken,
93    ) -> Self {
94        let mut join_set = JoinSet::new();
95        if register_signal_handler {
96            Self::spawn_shutdown_handler(&mut join_set, cancellation_token.clone());
97        }
98        CancelTaskSet {
99            join_set,
100            cancellation_token,
101        }
102    }
103
104    /// Returns a clone of the cancellation token.
105    pub fn cancellation_token(&self) -> CancellationToken {
106        self.cancellation_token.clone()
107    }
108
109    fn spawn_shutdown_handler(
110        join_set: &mut JoinSet<Result<(), std::io::Error>>,
111        cancellation_token: CancellationToken,
112    ) {
113        join_set.spawn(async move {
114            #[cfg(target_family = "unix")]
115            {
116                use tokio::signal::unix::{SignalKind, signal};
117
118                let mut sigint =
119                    signal(SignalKind::interrupt()).expect("failed to register SIGINT handler");
120                let mut sigterm =
121                    signal(SignalKind::terminate()).expect("failed to register SIGTERM handler");
122                tokio::select! {
123                    _ = sigint.recv() => {
124                        tracing::debug!("Received SIGINT, cancelling token");
125                        cancellation_token.cancel();
126                    },
127                    _ = sigterm.recv() => {
128                        tracing::debug!("Received SIGTERM, cancelling token");
129                        cancellation_token.cancel();
130                    },
131                    _ = cancellation_token.cancelled() => {
132                        tracing::debug!("Cancellation token cancelled, exiting shutdown handler");
133                    },
134                }
135            }
136
137            #[cfg(target_family = "windows")]
138            {
139                use tokio::signal::windows;
140
141                let mut ctrl_c = windows::ctrl_c().expect("failed to register CTRL-C handler");
142                let mut ctrl_break =
143                    windows::ctrl_break().expect("failed to register CTRL-BREAK handler");
144
145                tokio::select! {
146                    _ = ctrl_c.recv() => {
147                        tracing::debug!("Received CTRL-C, cancelling token");
148                        cancellation_token.cancel();
149                    },
150                    _ = ctrl_break.recv() => {
151                        tracing::debug!("Received CTRL-BREAK, cancelling token");
152                        cancellation_token.cancel();
153                    },
154                    _ = cancellation_token.cancelled() => {
155                        tracing::debug!("Cancellation token cancelled, exiting shutdown handler");
156                    },
157                }
158            }
159
160            Ok(())
161        });
162    }
163
164    /// Spawns a task that will run until it is cancelled or completes.
165    pub fn spawn_cancellable_task<Fut>(&mut self, task: Fut)
166    where
167        Fut: Future<Output = Result<(), std::io::Error>> + Send + 'static,
168    {
169        let token = self.cancellation_token();
170        self.join_set.spawn(async move {
171            match token.run_until_cancelled(task).await {
172                Some(Ok(_)) => Ok(()),  // task completed successfully
173                Some(Err(e)) => Err(e), // task failed
174                None => Ok(()),         // task was successfully cancelled
175            }
176        });
177    }
178
179    /// Joins all tasks in the set. If any task fails to join or returns an error, cancel the token
180    /// to signal a graceful shutdown to the remaining tasks.
181    pub async fn join_all(&mut self) {
182        while let Some(result) = self.join_set.join_next().await {
183            match result {
184                Ok(Ok(())) => {} // Task completed successfully
185                Ok(Err(e)) => {
186                    tracing::error!(error=%e, "Task failed");
187                    self.cancellation_token.cancel();
188                }
189                Err(e) => {
190                    tracing::error!(error=%e, "Task join failed");
191                    self.cancellation_token.cancel();
192                }
193            }
194        }
195    }
196}
197
198impl Drop for CancelTaskSet {
199    fn drop(&mut self) {
200        self.cancellation_token.cancel();
201        self.join_set.abort_all();
202    }
203}