1mod ordered_select_all;
2
3use crate::ordered_select_all::ordered_select_all;
4use futures::{future::BoxFuture, Future, FutureExt, StreamExt, TryFutureExt};
5use std::{
6 pin::{pin, Pin},
7 task::{Context, Poll},
8};
9use tokio::signal;
10use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
11
12pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
14
15#[derive(Debug, thiserror::Error)]
17pub enum SupervisorError {
18 #[error("signal listener setup failed: {0}")]
20 Signal(#[from] std::io::Error),
21 #[error("task failed: {0}")]
23 Process(#[source] BoxError),
24}
25
26impl SupervisorError {
27 pub fn from_err<E: Into<BoxError>>(err: E) -> Self {
35 Self::Process(err.into())
36 }
37}
38
39impl From<BoxError> for SupervisorError {
40 fn from(err: BoxError) -> Self {
41 Self::Process(err)
42 }
43}
44
45pub type ProcResult = Result<(), SupervisorError>;
47
48pub type ManagedFuture = futures::future::BoxFuture<'static, ProcResult>;
53
54pub struct ShutdownSignal {
59 token: CancellationToken,
60 future: Option<Pin<Box<WaitForCancellationFutureOwned>>>,
61}
62
63impl ShutdownSignal {
64 pub fn new(token: CancellationToken) -> Self {
65 Self {
66 token,
67 future: None,
68 }
69 }
70
71 pub fn is_cancelled(&self) -> bool {
72 self.token.is_cancelled()
73 }
74
75 pub fn token(&self) -> &CancellationToken {
76 &self.token
77 }
78}
79
80impl Future for ShutdownSignal {
81 type Output = ();
82
83 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84 if self.future.is_none() {
86 self.future = Some(Box::pin(self.token.clone().cancelled_owned()));
87 }
88
89 self.future.as_mut().unwrap().as_mut().poll(cx)
91 }
92}
93
94impl Clone for ShutdownSignal {
95 fn clone(&self) -> Self {
96 Self {
97 token: self.token.clone(),
98 future: None,
99 }
100 }
101}
102
103impl Unpin for ShutdownSignal {}
104
105pub fn spawn<F, E>(fut: F) -> ManagedFuture
121where
122 F: Future<Output = Result<(), E>> + Send + 'static,
123 E: Into<BoxError> + Send + 'static,
124{
125 Box::pin(tokio::spawn(fut).map(|result| match result {
127 Ok(Ok(())) => Ok(()),
128 Ok(Err(e)) => Err(SupervisorError::from_err(e)),
129 Err(e) => Err(SupervisorError::from_err(e)),
130 }))
131}
132
133pub fn run<F, E>(fut: F) -> ManagedFuture
149where
150 F: Future<Output = Result<(), E>> + Send + 'static,
151 E: Into<BoxError> + 'static,
152{
153 Box::pin(fut.map_err(SupervisorError::from_err))
154}
155
156pub trait ManagedProc: Send + Sync {
176 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture;
185}
186
187pub struct Supervisor {
214 procs: Vec<Box<dyn ManagedProc>>,
215}
216
217impl ManagedProc for Supervisor {
218 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture {
219 crate::run(self.do_run(Box::pin(shutdown)))
220 }
221}
222
223pub struct SupervisorBuilder {
237 procs: Vec<Box<dyn ManagedProc>>,
238}
239
240struct CancelableLocalFuture {
241 cancel_token: CancellationToken,
242 future: ManagedFuture,
243}
244
245impl Future for CancelableLocalFuture {
246 type Output = ProcResult;
247
248 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
249 pin!(&mut self.future).poll(ctx)
250 }
251}
252
253impl<O, P> ManagedProc for P
254where
255 O: Future<Output = ProcResult> + Send + 'static,
256 P: FnOnce(ShutdownSignal) -> O + Send + Sync,
257{
258 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture {
259 Box::pin(self(shutdown))
260 }
261}
262
263impl Default for Supervisor {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl Supervisor {
270 pub fn new() -> Self {
272 Self { procs: Vec::new() }
273 }
274
275 pub fn builder() -> SupervisorBuilder {
277 SupervisorBuilder { procs: Vec::new() }
278 }
279
280 pub fn add(&mut self, proc: impl ManagedProc + 'static) {
284 self.procs.push(Box::new(proc));
285 }
286
287 pub async fn start(self) -> ProcResult {
295 let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
296 let shutdown = Box::pin(
297 futures::future::select(
298 Box::pin(async move { sigterm.recv().await }),
299 Box::pin(signal::ctrl_c()),
300 )
301 .map(|_| ()),
302 );
303 self.do_run(shutdown).await
304 }
305
306 async fn do_run(self, mut shutdown: BoxFuture<'static, ()>) -> ProcResult {
307 let mut futures = start_futures(self.procs);
308
309 loop {
310 if futures.is_empty() {
311 break;
312 }
313
314 let mut select = ordered_select_all(futures);
315
316 tokio::select! {
317 biased;
318 _ = &mut shutdown => return stop_all(select.into_inner()).await,
319 (result, _index, remaining) = &mut select => match result {
320 Ok(_) => futures = remaining,
321 Err(err) => {
322 let _ = stop_all(remaining).await;
323 return Err(err);
324 }
325 }
326 }
327 }
328
329 Ok(())
330 }
331}
332
333impl SupervisorBuilder {
334 pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
338 self.procs.push(Box::new(proc));
339 self
340 }
341
342 pub fn build(self) -> Supervisor {
344 Supervisor { procs: self.procs }
345 }
346}
347
348fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
349 procs
350 .into_iter()
351 .map(|proc| {
352 let cancel_token = CancellationToken::new();
353 let child_token = cancel_token.child_token();
354 CancelableLocalFuture {
355 cancel_token,
356 future: proc.run_proc(ShutdownSignal::new(child_token)),
357 }
358 })
359 .collect()
360}
361
362async fn stop_all(procs: Vec<CancelableLocalFuture>) -> ProcResult {
363 futures::stream::iter(procs.into_iter().rev())
364 .then(|proc| async move {
365 proc.cancel_token.cancel();
366 proc.future.await
367 })
368 .collect::<Vec<_>>()
369 .await
370 .into_iter()
371 .collect()
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use tokio::sync::mpsc;
378
379 #[allow(dead_code)]
380 fn assert_send_sync() {
381 fn is_send<T: Send>() {}
382 fn is_sync<T: Sync>() {}
383 is_send::<Supervisor>();
384 is_sync::<Supervisor>();
385 }
386
387 #[derive(Debug)]
388 struct TestError(&'static str);
389
390 impl std::fmt::Display for TestError {
391 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392 write!(f, "{}", self.0)
393 }
394 }
395
396 impl std::error::Error for TestError {}
397
398 fn test_err(msg: &'static str) -> SupervisorError {
399 SupervisorError::from_err(TestError(msg))
400 }
401
402 struct TestProc {
403 name: &'static str,
404 delay: u64,
405 result: ProcResult,
406 sender: mpsc::Sender<&'static str>,
407 }
408
409 impl ManagedProc for TestProc {
410 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture {
411 let handle = tokio::spawn(async move {
412 tokio::select! {
413 _ = shutdown => (),
414 _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
415 }
416 self.sender.send(self.name).await.expect("unable to send");
417 self.result
418 });
419
420 Box::pin(handle.map(|result| match result {
421 Ok(inner) => inner,
422 Err(e) => Err(SupervisorError::from_err(e)),
423 }))
424 }
425 }
426
427 #[tokio::test]
428 async fn stop_when_all_tasks_have_completed() {
429 let (sender, mut receiver) = mpsc::channel(5);
430
431 let result = Supervisor::builder()
432 .add_proc(TestProc {
433 name: "1",
434 delay: 50,
435 result: Ok(()),
436 sender: sender.clone(),
437 })
438 .add_proc(TestProc {
439 name: "2",
440 delay: 100,
441 result: Ok(()),
442 sender: sender.clone(),
443 })
444 .build()
445 .start()
446 .await;
447
448 assert_eq!(Some("1"), receiver.recv().await);
449 assert_eq!(Some("2"), receiver.recv().await);
450 assert!(result.is_ok());
451 }
452
453 #[tokio::test]
454 async fn will_stop_all_in_reverse_order_after_error() {
455 let (sender, mut receiver) = mpsc::channel(5);
456
457 let result = Supervisor::builder()
458 .add_proc(TestProc {
459 name: "1",
460 delay: 1000,
461 result: Ok(()),
462 sender: sender.clone(),
463 })
464 .add_proc(TestProc {
465 name: "2",
466 delay: 50,
467 result: Err(test_err("error")),
468 sender: sender.clone(),
469 })
470 .add_proc(TestProc {
471 name: "3",
472 delay: 1000,
473 result: Ok(()),
474 sender: sender.clone(),
475 })
476 .build()
477 .start()
478 .await;
479
480 assert_eq!(Some("2"), receiver.recv().await);
481 assert_eq!(Some("3"), receiver.recv().await);
482 assert_eq!(Some("1"), receiver.recv().await);
483 assert_eq!("task failed: error", result.unwrap_err().to_string());
484 }
485
486 #[tokio::test]
487 async fn will_return_first_error_returned() {
488 let (sender, mut receiver) = mpsc::channel(5);
489
490 let result = Supervisor::builder()
491 .add_proc(TestProc {
492 name: "1",
493 delay: 1000,
494 result: Ok(()),
495 sender: sender.clone(),
496 })
497 .add_proc(TestProc {
498 name: "2",
499 delay: 50,
500 result: Err(test_err("error")),
501 sender: sender.clone(),
502 })
503 .add_proc(TestProc {
504 name: "3",
505 delay: 200,
506 result: Err(test_err("second error")),
507 sender: sender.clone(),
508 })
509 .build()
510 .start()
511 .await;
512
513 assert_eq!(Some("2"), receiver.recv().await);
514 assert_eq!(Some("3"), receiver.recv().await);
515 assert_eq!(Some("1"), receiver.recv().await);
516 assert_eq!("task failed: error", result.unwrap_err().to_string());
517 }
518
519 #[tokio::test]
520 async fn nested_procs_will_stop_parent_then_move_up() {
521 let (sender, mut receiver) = mpsc::channel(10);
522
523 let result = Supervisor::builder()
524 .add_proc(TestProc {
525 name: "proc-1",
526 delay: 500,
527 result: Ok(()),
528 sender: sender.clone(),
529 })
530 .add_proc(
531 Supervisor::builder()
532 .add_proc(TestProc {
533 name: "proc-2-1",
534 delay: 500,
535 result: Ok(()),
536 sender: sender.clone(),
537 })
538 .add_proc(TestProc {
539 name: "proc-2-2",
540 delay: 100,
541 result: Err(test_err("error")),
542 sender: sender.clone(),
543 })
544 .add_proc(TestProc {
545 name: "proc-2-3",
546 delay: 500,
547 result: Ok(()),
548 sender: sender.clone(),
549 })
550 .add_proc(TestProc {
551 name: "proc-2-4",
552 delay: 500,
553 result: Ok(()),
554 sender: sender.clone(),
555 })
556 .build(),
557 )
558 .add_proc(
559 Supervisor::builder()
560 .add_proc(TestProc {
561 name: "proc-3-1",
562 delay: 1000,
563 result: Ok(()),
564 sender: sender.clone(),
565 })
566 .add_proc(TestProc {
567 name: "proc-3-2",
568 delay: 1000,
569 result: Ok(()),
570 sender: sender.clone(),
571 })
572 .add_proc(TestProc {
573 name: "proc-3-3",
574 delay: 1000,
575 result: Ok(()),
576 sender: sender.clone(),
577 })
578 .build(),
579 )
580 .build()
581 .start()
582 .await;
583
584 assert_eq!(Some("proc-2-2"), receiver.recv().await);
585 assert_eq!(Some("proc-2-4"), receiver.recv().await);
586 assert_eq!(Some("proc-2-3"), receiver.recv().await);
587 assert_eq!(Some("proc-2-1"), receiver.recv().await);
588 assert_eq!(Some("proc-3-3"), receiver.recv().await);
589 assert_eq!(Some("proc-3-2"), receiver.recv().await);
590 assert_eq!(Some("proc-3-1"), receiver.recv().await);
591 assert_eq!(Some("proc-1"), receiver.recv().await);
592 assert!(result.is_err());
593 }
594}