1use miette::Diagnostic;
22use std::{
23 collections::HashMap,
24 fmt::{self, Debug, Display},
25 mem, process,
26 sync::{Arc, Mutex},
27 time::Duration,
28};
29use thiserror::Error;
30use tokio::{
31 select, spawn,
32 sync::{oneshot, watch},
33 task::JoinError,
34 time::timeout,
35};
36use tokio_util::sync::CancellationToken;
37use tracing::{error, info, warn};
38
39type ChildMap = Arc<Mutex<HashMap<String, (watch::Sender<bool>, watch::Receiver<bool>)>>>;
40
41pub struct RootBuilder {
42 name: String,
43 catch_signals: bool,
44 shutdown_timeout: Option<std::time::Duration>,
45}
46
47struct CrashHolder {
48 crash: Arc<Mutex<SubsystemResult>>,
49 cancel: CancellationToken,
50}
51
52impl Clone for CrashHolder {
53 fn clone(&self) -> Self {
54 CrashHolder {
55 crash: self.crash.clone(),
56 cancel: self.cancel.clone(),
57 }
58 }
59}
60
61impl CrashHolder {
62 fn set_crash(&self, err: SubsystemError) {
63 let mut guard = self.crash.lock().expect("mutex is poisoned");
64 if guard.is_ok() {
65 *guard = Err(err);
66 self.cancel.cancel();
67 }
68 }
69
70 fn take_crash(&self) -> SubsystemResult {
71 let mut guard = self.crash.lock().expect("mutex is poisoned");
72 mem::replace(&mut *guard, Ok(()))
73 }
74}
75
76impl RootBuilder {
77 pub async fn start<E, F>(
78 self,
79 subsys: impl FnOnce(SubsystemHandle) -> F + Send + 'static,
80 ) -> SubsystemResult
81 where
82 F: std::future::Future<Output = Result<(), E>> + Send + 'static,
83 E: IntoGenericError + Display,
84 {
85 let global = CancellationToken::new();
86 let local = global.child_token();
87
88 if self.catch_signals {
89 self.register_signal_handlers(&global);
90 }
91
92 let crash = CrashHolder {
93 crash: Arc::new(Mutex::new(Ok(()))),
94 cancel: global.clone(),
95 };
96
97 let (res_tx, res_rx) = oneshot::channel();
98 let (join_tx, join_rx) = watch::channel(false);
99
100 let cancel_clean_shutdown = CancellationToken::new();
101
102 let children = Arc::new(Mutex::new(HashMap::new()));
103
104 let handle = SubsystemHandle {
105 name: self.name.clone(),
106 global: global.clone(),
107 local: local.clone(),
108 cancel_clean_shutdown: cancel_clean_shutdown.clone(),
109 children: children.clone(),
110 crash: crash.clone(),
111 join_handle: (join_tx.clone(), join_rx),
112 };
113
114 let glob = global.clone();
115 if let Some(to) = self.shutdown_timeout {
116 spawn(async move {
117 match subsys(handle).await {
118 Ok(_) => info!("Root system '{}' terminated normally.", self.name),
119 Err(e) => {
120 error!("Root system '{}' terminated with error: {e}", self.name);
121 crash.set_crash(SubsystemError::Error(
122 self.name.clone(),
123 e.into_generic_error(),
124 ));
125 }
126 }
127
128 glob.cancel();
129 info!(
130 "Shutdown initiated, waiting up to {:?} for clean shutdown.",
131 to
132 );
133
134 let children = {
135 let children = children.lock().expect("mutex is poisoned");
136 children.clone()
137 };
138 let children_shutdown = wait_for_children_shutdown(&children);
139
140 match timeout(to, children_shutdown).await {
141 Ok(_) => {
142 info!("All subsystems have shut down in time.");
143 }
144 Err(_) => {
145 error!("Shutdown timeout reached, forcing shutdown …");
146 cancel_clean_shutdown.cancel();
147 crash.set_crash(SubsystemError::ForcedShutdown);
148 }
149 }
150
151 res_tx.send(crash.take_crash()).ok();
152 join_tx.send(true).ok();
153 });
154 } else {
155 spawn(async move {
156 match subsys(handle).await {
157 Ok(_) => info!("Root system '{}' terminated normally.", self.name),
158 Err(e) => error!("Root system '{}' terminated with error: {e}", self.name),
159 }
160
161 if !global.is_cancelled() {
162 glob.cancel();
163 }
164 info!("Shutdown initiated, waiting for clean shutdown.");
165
166 let children = {
167 let children = children.lock().expect("mutex is poisoned");
168 children.clone()
169 };
170 let children_shutdown = wait_for_children_shutdown(&children);
171 children_shutdown.await;
172 info!("All subsystems have shut down.");
173
174 res_tx.send(crash.take_crash()).ok();
175 join_tx.send(true).ok();
176 });
177 }
178
179 res_rx.await.unwrap_or(Err(SubsystemError::ForcedShutdown))
180 }
181
182 pub fn catch_signals(mut self) -> Self {
183 self.catch_signals = true;
184 self
185 }
186
187 pub fn with_timeout(mut self, shutdown_timeout: Duration) -> Self {
188 self.shutdown_timeout = Some(shutdown_timeout);
189 self
190 }
191
192 #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "freebsd")))]
193 fn register_signal_handlers(&self, global: &CancellationToken) {
194 use tokio::signal::ctrl_c;
195
196 let global = global.clone();
197 spawn(async move {
198 let mut counter = 0;
199 loop {
200 ctrl_c().await.expect("Ctrl+C handler not supported");
201 counter += 1;
202 if counter > 1 {
203 break;
204 }
205 info!("Received Ctrl+C, initiating shutdown.");
206 global.cancel();
207 }
208 process::exit(1);
209 });
210 }
211
212 #[cfg(any(target_os = "linux", target_os = "macos", target_os = "freebsd"))]
213 fn register_signal_handlers(&self, global: &CancellationToken) {
214 use tokio::signal::unix::{SignalKind, signal};
215
216 if let Ok(signal) = signal(SignalKind::hangup()) {
217 handle_unix_signal(
218 global,
219 signal,
220 "SIGHUP",
221 SignalKind::hangup().as_raw_value(),
222 );
223 } else {
224 error!("Failed to register SIGHUP handler");
225 }
226
227 if let Ok(signal) = signal(SignalKind::interrupt()) {
228 handle_unix_signal(
229 global,
230 signal,
231 "SIGINT",
232 SignalKind::interrupt().as_raw_value(),
233 );
234 } else {
235 error!("Failed to register SIGINT handler");
236 }
237
238 if let Ok(signal) = signal(SignalKind::quit()) {
239 handle_unix_signal(global, signal, "SIGQUIT", SignalKind::quit().as_raw_value());
240 } else {
241 error!("Failed to register SIGQUIT handler");
242 }
243
244 if let Ok(signal) = signal(SignalKind::terminate()) {
245 handle_unix_signal(
246 global,
247 signal,
248 "SIGTERM",
249 SignalKind::terminate().as_raw_value(),
250 );
251 } else {
252 error!("Failed to register SIGTERM handler");
253 }
254 }
255}
256
257#[cfg(any(target_os = "linux", target_os = "macos", target_os = "freebsd"))]
258fn handle_unix_signal(
259 global: &CancellationToken,
260 mut signal: tokio::signal::unix::Signal,
261 signal_name: &'static str,
262 code: i32,
263) {
264 let global = global.clone();
265 spawn(async move {
266 let mut counter = 0;
267 loop {
268 signal.recv().await;
269 counter += 1;
270 if counter > 1 {
271 break;
272 }
273 info!("Received {signal_name} signal, initiating shutdown.");
274 global.cancel();
275 }
276 process::exit(128 + code);
277 });
278}
279
280#[derive(Debug, Error, Diagnostic)]
281pub enum SubsystemError {
282 #[error("Subsystem '{0}' terminated with error: {1}")]
283 Error(String, GenericError),
284 #[error("Subsystem '{0}' panicked: {1}")]
285 Panic(String, String),
286 #[error("Subsystem shutdown timed out")]
287 ForcedShutdown,
288}
289
290pub trait GenErr: Debug + Display + Send + Sync + 'static {}
291
292impl<E> GenErr for E where E: Debug + Display + Send + Sync + 'static {}
293
294pub struct GenericError(Box<dyn GenErr>);
295
296pub trait IntoGenericError {
297 fn into_generic_error(self) -> GenericError;
298}
299
300impl<E: GenErr> IntoGenericError for E {
301 fn into_generic_error(self) -> GenericError {
302 GenericError(Box::new(self))
303 }
304}
305
306impl fmt::Debug for GenericError {
307 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308 write!(f, "{:?}", self.0)
309 }
310}
311
312impl fmt::Display for GenericError {
313 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314 write!(f, "{}", self.0)
315 }
316}
317
318pub type SubsystemResult = Result<(), SubsystemError>;
319
320async fn wait_for_children_shutdown(
321 children: &HashMap<String, (watch::Sender<bool>, watch::Receiver<bool>)>,
322) {
323 for child in children.values() {
324 let mut rx = child.1.clone();
325 rx.wait_for(|it| *it).await.ok();
326 }
327}
328
329pub struct SubsystemHandle {
330 name: String,
331 local: CancellationToken,
332 global: CancellationToken,
333 cancel_clean_shutdown: CancellationToken,
334 children: ChildMap,
335 crash: CrashHolder,
336 join_handle: (watch::Sender<bool>, watch::Receiver<bool>),
337}
338
339impl Clone for SubsystemHandle {
340 fn clone(&self) -> Self {
341 SubsystemHandle {
342 name: self.name.clone(),
343 local: self.local.clone(),
344 global: self.global.clone(),
345 cancel_clean_shutdown: self.cancel_clean_shutdown.clone(),
346 children: self.children.clone(),
347 crash: self.crash.clone(),
348 join_handle: (self.join_handle.0.clone(), self.join_handle.1.clone()),
349 }
350 }
351}
352
353impl fmt::Debug for SubsystemHandle {
354 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
355 f.debug_struct("SubsystemHandle")
356 .field("name", &self.name)
357 .finish()
358 }
359}
360
361fn convert_result<Err>(res: Result<(), Err>) -> Result<(), GenericError>
362where
363 Err: IntoGenericError,
364{
365 match res {
366 Ok(_) => Ok(()),
367 Err(e) => Err(e.into_generic_error()),
368 }
369}
370
371impl SubsystemHandle {
372 pub fn name(&self) -> &str {
373 &self.name
374 }
375
376 pub fn spawn<Err, F>(
377 &self,
378 name: impl AsRef<str>,
379 subsys: impl FnOnce(SubsystemHandle) -> F + Send + 'static,
380 ) -> SubsystemHandle
381 where
382 F: Future<Output = Result<(), Err>> + Send + 'static,
383 Err: IntoGenericError,
384 {
385 let cancel_clean_shutdown = self.cancel_clean_shutdown.clone();
386
387 let handle = self.create_child(name, cancel_clean_shutdown.clone());
388 let full_name = handle.name().to_owned();
389
390 let fname = full_name.clone();
391 let children = self.children.clone();
392 let mut crash = self.crash.clone();
393 let glob = self.global.clone();
394 let h = handle.clone();
395 info!("Spawning subsystem '{}' …", fname);
396 tokio::spawn(async move {
397 let name = fname.clone();
398 let mut join_handle = tokio::spawn(async move {
399 info!("Subsystem '{}' started.", name);
400 let res = subsys(h).await;
401 convert_result(res)
402 });
403 select! {
404 res = &mut join_handle => Self::child_joined(res, children, &fname, &mut crash).await,
405 _ = cancel_clean_shutdown.cancelled() => Self::shutdown_timed_out(join_handle, &fname, &glob, &mut crash).await,
406 };
407 });
408
409 handle
410 }
411
412 pub fn request_global_shutdown(&self) {
413 self.global.cancel();
414 }
415
416 pub fn request_local_shutdown(&self) {
417 self.local.cancel();
418 }
419
420 pub async fn shutdown_requested(&self) {
421 self.local.cancelled().await
422 }
423
424 pub fn is_shut_down(&self) -> bool {
425 self.local.is_cancelled()
426 }
427
428 pub async fn join(&self) {
429 let mut join_handle = self.join_handle.clone();
430 join_handle.1.wait_for(|it| *it).await.ok();
431 }
432
433 fn create_child(
434 &self,
435 name: impl AsRef<str>,
436 cancel_clean_shutdown: CancellationToken,
437 ) -> SubsystemHandle {
438 let (res_tx, res_rx) = watch::channel(false);
439 let name = format!("{}/{}", self.name, name.as_ref());
440 let global = self.global.clone();
441 let local = self.local.child_token();
442 let children = self.children.clone();
443 let crash = self.crash.clone();
444
445 let mut gc = self.children.lock().expect("mutex is poisoned");
446 gc.insert(name.clone(), (res_tx.clone(), res_rx.clone()));
447
448 SubsystemHandle {
449 name,
450 global,
451 local,
452 cancel_clean_shutdown,
453 children,
454 crash,
455 join_handle: (res_tx, res_rx),
456 }
457 }
458
459 async fn child_joined(
460 res: Result<Result<(), GenericError>, JoinError>,
461 children: ChildMap,
462 child_name: &str,
463 crash: &mut CrashHolder,
464 ) {
465 let mut children = children.lock().expect("mutex is poisoned");
466 let Some(child) = children.remove(child_name) else {
467 warn!("Subsystem '{}' already removed from tracking.", child_name);
468 return;
469 };
470
471 match res {
472 Ok(Ok(_)) => {
473 info!("Subsystem '{}' terminated normally.", child_name);
474 child.0.send(true).ok();
475 }
476 Ok(Err(e)) => {
477 error!("Subsystem '{}' terminated with error: {}", child_name, e);
478 let err = SubsystemError::Error(child_name.to_owned(), e);
479 crash.set_crash(err);
480 child.0.send(true).ok();
481 }
482 Err(e) => {
483 if e.is_panic() {
484 error!("Subsystem '{}' panicked: {}", child_name, e);
485 let err = SubsystemError::Panic(child_name.to_owned(), e.to_string());
486 crash.set_crash(err);
487 child.0.send(true).ok();
488 } else {
489 warn!("Subsystem '{}' was shut down forcefully.", child_name);
490 let err = SubsystemError::ForcedShutdown;
491 crash.set_crash(err);
492 child.0.send(true).ok();
493 }
494 }
495 };
496 }
497
498 async fn shutdown_timed_out<Err>(
499 join_handle: tokio::task::JoinHandle<Result<(), Err>>,
500 child_name: &str,
501 global: &CancellationToken,
502 crash: &mut CrashHolder,
503 ) where
504 Err: Debug + Display + Send + Sync + 'static,
505 {
506 warn!("Subsystem '{}' is being shut down forcefully.", child_name);
507 join_handle.abort();
508 global.cancel();
509 crash.set_crash(SubsystemError::ForcedShutdown);
510 }
511}
512
513pub fn build_root(name: impl Into<String>) -> RootBuilder {
514 RootBuilder {
515 name: name.into(),
516 catch_signals: false,
517 shutdown_timeout: None,
518 }
519}