Skip to main content

tosub/
lib.rs

1/*
2 *  Copyright 2026 Michael Bachmann
3 *
4 * Licensed under either the MIT or the Apache License, Version 2.0,
5 * as per the user's preference.
6 * You may not use this file except in compliance with at least one
7 * of these two licenses.
8 * You may obtain a copy of the Licenses at
9 *
10 *     https://www.apache.org/licenses/LICENSE-2.0
11 *     and
12 *     https://opensource.org/license/MIT
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 */
20
21use miette::Diagnostic;
22use std::{
23    collections::HashMap,
24    fmt::{self, Debug, Display},
25    mem, process,
26    sync::{Arc, Mutex},
27    time::Duration,
28};
29use thiserror::Error;
30use tokio::{
31    select, spawn,
32    sync::{oneshot, watch},
33    task::JoinError,
34    time::timeout,
35};
36use tokio_util::sync::CancellationToken;
37use tracing::{error, info, warn};
38
39type ChildMap = Arc<Mutex<HashMap<String, (watch::Sender<bool>, watch::Receiver<bool>)>>>;
40
41pub struct RootBuilder {
42    name: String,
43    catch_signals: bool,
44    shutdown_timeout: Option<std::time::Duration>,
45}
46
47struct CrashHolder {
48    crash: Arc<Mutex<SubsystemResult>>,
49    cancel: CancellationToken,
50}
51
52impl Clone for CrashHolder {
53    fn clone(&self) -> Self {
54        CrashHolder {
55            crash: self.crash.clone(),
56            cancel: self.cancel.clone(),
57        }
58    }
59}
60
61impl CrashHolder {
62    fn set_crash(&self, err: SubsystemError) {
63        let mut guard = self.crash.lock().expect("mutex is poisoned");
64        if guard.is_ok() {
65            *guard = Err(err);
66            self.cancel.cancel();
67        }
68    }
69
70    fn take_crash(&self) -> SubsystemResult {
71        let mut guard = self.crash.lock().expect("mutex is poisoned");
72        mem::replace(&mut *guard, Ok(()))
73    }
74}
75
76impl RootBuilder {
77    pub async fn start<E, F>(
78        self,
79        subsys: impl FnOnce(SubsystemHandle) -> F + Send + 'static,
80    ) -> SubsystemResult
81    where
82        F: std::future::Future<Output = Result<(), E>> + Send + 'static,
83        E: IntoGenericError + Display,
84    {
85        let global = CancellationToken::new();
86        let local = global.child_token();
87
88        if self.catch_signals {
89            self.register_signal_handlers(&global);
90        }
91
92        let crash = CrashHolder {
93            crash: Arc::new(Mutex::new(Ok(()))),
94            cancel: global.clone(),
95        };
96
97        let (res_tx, res_rx) = oneshot::channel();
98        let (join_tx, join_rx) = watch::channel(false);
99
100        let cancel_clean_shutdown = CancellationToken::new();
101
102        let children = Arc::new(Mutex::new(HashMap::new()));
103
104        let handle = SubsystemHandle {
105            name: self.name.clone(),
106            global: global.clone(),
107            local: local.clone(),
108            cancel_clean_shutdown: cancel_clean_shutdown.clone(),
109            children: children.clone(),
110            crash: crash.clone(),
111            join_handle: (join_tx.clone(), join_rx),
112        };
113
114        let glob = global.clone();
115        if let Some(to) = self.shutdown_timeout {
116            spawn(async move {
117                match subsys(handle).await {
118                    Ok(_) => info!("Root system '{}' terminated normally.", self.name),
119                    Err(e) => {
120                        error!("Root system '{}' terminated with error: {e}", self.name);
121                        crash.set_crash(SubsystemError::Error(
122                            self.name.clone(),
123                            e.into_generic_error(),
124                        ));
125                    }
126                }
127
128                glob.cancel();
129                info!(
130                    "Shutdown initiated, waiting up to {:?} for clean shutdown.",
131                    to
132                );
133
134                let children = {
135                    let children = children.lock().expect("mutex is poisoned");
136                    children.clone()
137                };
138                let children_shutdown = wait_for_children_shutdown(&children);
139
140                match timeout(to, children_shutdown).await {
141                    Ok(_) => {
142                        info!("All subsystems have shut down in time.");
143                    }
144                    Err(_) => {
145                        error!("Shutdown timeout reached, forcing shutdown …");
146                        cancel_clean_shutdown.cancel();
147                        crash.set_crash(SubsystemError::ForcedShutdown);
148                    }
149                }
150
151                res_tx.send(crash.take_crash()).ok();
152                join_tx.send(true).ok();
153            });
154        } else {
155            spawn(async move {
156                match subsys(handle).await {
157                    Ok(_) => info!("Root system '{}' terminated normally.", self.name),
158                    Err(e) => error!("Root system '{}' terminated with error: {e}", self.name),
159                }
160
161                if !global.is_cancelled() {
162                    glob.cancel();
163                }
164                info!("Shutdown initiated, waiting for clean shutdown.");
165
166                let children = {
167                    let children = children.lock().expect("mutex is poisoned");
168                    children.clone()
169                };
170                let children_shutdown = wait_for_children_shutdown(&children);
171                children_shutdown.await;
172                info!("All subsystems have shut down.");
173
174                res_tx.send(crash.take_crash()).ok();
175                join_tx.send(true).ok();
176            });
177        }
178
179        res_rx.await.unwrap_or(Err(SubsystemError::ForcedShutdown))
180    }
181
182    pub fn catch_signals(mut self) -> Self {
183        self.catch_signals = true;
184        self
185    }
186
187    pub fn with_timeout(mut self, shutdown_timeout: Duration) -> Self {
188        self.shutdown_timeout = Some(shutdown_timeout);
189        self
190    }
191
192    #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "freebsd")))]
193    fn register_signal_handlers(&self, global: &CancellationToken) {
194        use tokio::signal::ctrl_c;
195
196        let global = global.clone();
197        spawn(async move {
198            let mut counter = 0;
199            loop {
200                ctrl_c().await.expect("Ctrl+C handler not supported");
201                counter += 1;
202                if counter > 1 {
203                    break;
204                }
205                info!("Received Ctrl+C, initiating shutdown.");
206                global.cancel();
207            }
208            process::exit(1);
209        });
210    }
211
212    #[cfg(any(target_os = "linux", target_os = "macos", target_os = "freebsd"))]
213    fn register_signal_handlers(&self, global: &CancellationToken) {
214        use tokio::signal::unix::{SignalKind, signal};
215
216        if let Ok(signal) = signal(SignalKind::hangup()) {
217            handle_unix_signal(
218                global,
219                signal,
220                "SIGHUP",
221                SignalKind::hangup().as_raw_value(),
222            );
223        } else {
224            error!("Failed to register SIGHUP handler");
225        }
226
227        if let Ok(signal) = signal(SignalKind::interrupt()) {
228            handle_unix_signal(
229                global,
230                signal,
231                "SIGINT",
232                SignalKind::interrupt().as_raw_value(),
233            );
234        } else {
235            error!("Failed to register SIGINT handler");
236        }
237
238        if let Ok(signal) = signal(SignalKind::quit()) {
239            handle_unix_signal(global, signal, "SIGQUIT", SignalKind::quit().as_raw_value());
240        } else {
241            error!("Failed to register SIGQUIT handler");
242        }
243
244        if let Ok(signal) = signal(SignalKind::terminate()) {
245            handle_unix_signal(
246                global,
247                signal,
248                "SIGTERM",
249                SignalKind::terminate().as_raw_value(),
250            );
251        } else {
252            error!("Failed to register SIGTERM handler");
253        }
254    }
255}
256
257#[cfg(any(target_os = "linux", target_os = "macos", target_os = "freebsd"))]
258fn handle_unix_signal(
259    global: &CancellationToken,
260    mut signal: tokio::signal::unix::Signal,
261    signal_name: &'static str,
262    code: i32,
263) {
264    let global = global.clone();
265    spawn(async move {
266        let mut counter = 0;
267        loop {
268            signal.recv().await;
269            counter += 1;
270            if counter > 1 {
271                break;
272            }
273            info!("Received {signal_name} signal, initiating shutdown.");
274            global.cancel();
275        }
276        process::exit(128 + code);
277    });
278}
279
280#[derive(Debug, Error, Diagnostic)]
281pub enum SubsystemError {
282    #[error("Subsystem '{0}' terminated with error: {1}")]
283    Error(String, GenericError),
284    #[error("Subsystem '{0}' panicked: {1}")]
285    Panic(String, String),
286    #[error("Subsystem shutdown timed out")]
287    ForcedShutdown,
288}
289
290pub trait GenErr: Debug + Display + Send + Sync + 'static {}
291
292impl<E> GenErr for E where E: Debug + Display + Send + Sync + 'static {}
293
294pub struct GenericError(Box<dyn GenErr>);
295
296pub trait IntoGenericError {
297    fn into_generic_error(self) -> GenericError;
298}
299
300impl<E: GenErr> IntoGenericError for E {
301    fn into_generic_error(self) -> GenericError {
302        GenericError(Box::new(self))
303    }
304}
305
306impl fmt::Debug for GenericError {
307    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308        write!(f, "{:?}", self.0)
309    }
310}
311
312impl fmt::Display for GenericError {
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        write!(f, "{}", self.0)
315    }
316}
317
318pub type SubsystemResult = Result<(), SubsystemError>;
319
320async fn wait_for_children_shutdown(
321    children: &HashMap<String, (watch::Sender<bool>, watch::Receiver<bool>)>,
322) {
323    for child in children.values() {
324        let mut rx = child.1.clone();
325        rx.wait_for(|it| *it).await.ok();
326    }
327}
328
329pub struct SubsystemHandle {
330    name: String,
331    local: CancellationToken,
332    global: CancellationToken,
333    cancel_clean_shutdown: CancellationToken,
334    children: ChildMap,
335    crash: CrashHolder,
336    join_handle: (watch::Sender<bool>, watch::Receiver<bool>),
337}
338
339impl Clone for SubsystemHandle {
340    fn clone(&self) -> Self {
341        SubsystemHandle {
342            name: self.name.clone(),
343            local: self.local.clone(),
344            global: self.global.clone(),
345            cancel_clean_shutdown: self.cancel_clean_shutdown.clone(),
346            children: self.children.clone(),
347            crash: self.crash.clone(),
348            join_handle: (self.join_handle.0.clone(), self.join_handle.1.clone()),
349        }
350    }
351}
352
353impl fmt::Debug for SubsystemHandle {
354    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355        f.debug_struct("SubsystemHandle")
356            .field("name", &self.name)
357            .finish()
358    }
359}
360
361fn convert_result<Err>(res: Result<(), Err>) -> Result<(), GenericError>
362where
363    Err: IntoGenericError,
364{
365    match res {
366        Ok(_) => Ok(()),
367        Err(e) => Err(e.into_generic_error()),
368    }
369}
370
371impl SubsystemHandle {
372    pub fn name(&self) -> &str {
373        &self.name
374    }
375
376    pub fn spawn<Err, F>(
377        &self,
378        name: impl AsRef<str>,
379        subsys: impl FnOnce(SubsystemHandle) -> F + Send + 'static,
380    ) -> SubsystemHandle
381    where
382        F: Future<Output = Result<(), Err>> + Send + 'static,
383        Err: IntoGenericError,
384    {
385        let cancel_clean_shutdown = self.cancel_clean_shutdown.clone();
386
387        let handle = self.create_child(name, cancel_clean_shutdown.clone());
388        let full_name = handle.name().to_owned();
389
390        let fname = full_name.clone();
391        let children = self.children.clone();
392        let mut crash = self.crash.clone();
393        let glob = self.global.clone();
394        let h = handle.clone();
395        info!("Spawning subsystem '{}' …", fname);
396        tokio::spawn(async move {
397            let name = fname.clone();
398            let mut join_handle = tokio::spawn(async move {
399                info!("Subsystem '{}' started.", name);
400                let res = subsys(h).await;
401                convert_result(res)
402            });
403            select! {
404                res = &mut join_handle => Self::child_joined(res, children, &fname, &mut crash).await,
405                _ = cancel_clean_shutdown.cancelled() => Self::shutdown_timed_out(join_handle, &fname, &glob, &mut crash).await,
406            };
407        });
408
409        handle
410    }
411
412    pub fn request_global_shutdown(&self) {
413        self.global.cancel();
414    }
415
416    pub fn request_local_shutdown(&self) {
417        self.local.cancel();
418    }
419
420    pub async fn shutdown_requested(&self) {
421        self.local.cancelled().await
422    }
423
424    pub fn is_shut_down(&self) -> bool {
425        self.local.is_cancelled()
426    }
427
428    pub async fn join(&self) {
429        let mut join_handle = self.join_handle.clone();
430        join_handle.1.wait_for(|it| *it).await.ok();
431    }
432
433    fn create_child(
434        &self,
435        name: impl AsRef<str>,
436        cancel_clean_shutdown: CancellationToken,
437    ) -> SubsystemHandle {
438        let (res_tx, res_rx) = watch::channel(false);
439        let name = format!("{}/{}", self.name, name.as_ref());
440        let global = self.global.clone();
441        let local = self.local.child_token();
442        let children = self.children.clone();
443        let crash = self.crash.clone();
444
445        let mut gc = self.children.lock().expect("mutex is poisoned");
446        gc.insert(name.clone(), (res_tx.clone(), res_rx.clone()));
447
448        SubsystemHandle {
449            name,
450            global,
451            local,
452            cancel_clean_shutdown,
453            children,
454            crash,
455            join_handle: (res_tx, res_rx),
456        }
457    }
458
459    async fn child_joined(
460        res: Result<Result<(), GenericError>, JoinError>,
461        children: ChildMap,
462        child_name: &str,
463        crash: &mut CrashHolder,
464    ) {
465        let mut children = children.lock().expect("mutex is poisoned");
466        let Some(child) = children.remove(child_name) else {
467            warn!("Subsystem '{}' already removed from tracking.", child_name);
468            return;
469        };
470
471        match res {
472            Ok(Ok(_)) => {
473                info!("Subsystem '{}' terminated normally.", child_name);
474                child.0.send(true).ok();
475            }
476            Ok(Err(e)) => {
477                error!("Subsystem '{}' terminated with error: {}", child_name, e);
478                let err = SubsystemError::Error(child_name.to_owned(), e);
479                crash.set_crash(err);
480                child.0.send(true).ok();
481            }
482            Err(e) => {
483                if e.is_panic() {
484                    error!("Subsystem '{}' panicked: {}", child_name, e);
485                    let err = SubsystemError::Panic(child_name.to_owned(), e.to_string());
486                    crash.set_crash(err);
487                    child.0.send(true).ok();
488                } else {
489                    warn!("Subsystem '{}' was shut down forcefully.", child_name);
490                    let err = SubsystemError::ForcedShutdown;
491                    crash.set_crash(err);
492                    child.0.send(true).ok();
493                }
494            }
495        };
496    }
497
498    async fn shutdown_timed_out<Err>(
499        join_handle: tokio::task::JoinHandle<Result<(), Err>>,
500        child_name: &str,
501        global: &CancellationToken,
502        crash: &mut CrashHolder,
503    ) where
504        Err: Debug + Display + Send + Sync + 'static,
505    {
506        warn!("Subsystem '{}' is being shut down forcefully.", child_name);
507        join_handle.abort();
508        global.cancel();
509        crash.set_crash(SubsystemError::ForcedShutdown);
510    }
511}
512
513pub fn build_root(name: impl Into<String>) -> RootBuilder {
514    RootBuilder {
515        name: name.into(),
516        catch_signals: false,
517        shutdown_timeout: None,
518    }
519}