super_visor/
lib.rs

1mod ordered_select_all;
2
3use crate::ordered_select_all::ordered_select_all;
4use futures::{future::BoxFuture, Future, FutureExt, StreamExt, TryFutureExt};
5use std::{
6    pin::{pin, Pin},
7    task::{Context, Poll},
8};
9use tokio::signal;
10use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
11
12/// A boxed error type for errors raised in calling code
13pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
14
15/// Error type returned from task operations
16#[derive(Debug, thiserror::Error)]
17pub enum SupervisorError {
18    /// Error setting up signal listeners
19    #[error("signal listener setup failed: {0}")]
20    Signal(#[from] std::io::Error),
21    /// Error from within the managed process
22    #[error("task failed: {0}")]
23    Process(#[source] BoxError),
24}
25
26impl SupervisorError {
27    /// Creates a `SupervisorError::Process` from any error type that can be converted to [`BoxError`].
28    ///
29    /// This is useful in `map_err` calls to convert user process errors:
30    ///
31    /// ```ignore
32    /// my_future.map_err(SupervisorError::from_err)
33    /// ```
34    pub fn from_err<E: Into<BoxError>>(err: E) -> Self {
35        Self::Process(err.into())
36    }
37}
38
39impl From<BoxError> for SupervisorError {
40    fn from(err: BoxError) -> Self {
41        Self::Process(err)
42    }
43}
44
45/// Result type for supervised process operations
46pub type ProcResult = Result<(), SupervisorError>;
47
48/// The return type for managed processes
49///
50/// A boxed future that returns [`super_visor::Result`] when complete.
51/// Processes should return `Ok(())` on successful shutdown or an error if something went wrong
52pub type ManagedFuture = futures::future::BoxFuture<'static, ProcResult>;
53
54/// An awaitable construct for signalling shutdown to supervised processes
55/// to complete work and exit on the next available await point
56/// Lazily instantiates the future when the future is first polled to allow cloning the underlying
57/// token and propagating shutdown signals from a single source in a  one-to-many relationship
58pub struct ShutdownSignal {
59    token: CancellationToken,
60    future: Option<Pin<Box<WaitForCancellationFutureOwned>>>,
61}
62
63impl ShutdownSignal {
64    pub fn new(token: CancellationToken) -> Self {
65        Self {
66            token,
67            future: None,
68        }
69    }
70
71    pub fn is_cancelled(&self) -> bool {
72        self.token.is_cancelled()
73    }
74
75    pub fn token(&self) -> &CancellationToken {
76        &self.token
77    }
78}
79
80impl Future for ShutdownSignal {
81    type Output = ();
82
83    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84        // Lazy init the future on first poll
85        if self.future.is_none() {
86            self.future = Some(Box::pin(self.token.clone().cancelled_owned()));
87        }
88
89        // Poll the cached future
90        self.future.as_mut().unwrap().as_mut().poll(cx)
91    }
92}
93
94impl Clone for ShutdownSignal {
95    fn clone(&self) -> Self {
96        Self {
97            token: self.token.clone(),
98            future: None,
99        }
100    }
101}
102
103impl Unpin for ShutdownSignal {}
104
105/// Spawns a future ito its own Tokio task.
106///
107/// Use this in [`ManagedProc::start_task`] implementations when your future
108/// is `Send + 'static`. This is the preferred approach as it allows the process
109/// to run independently on the Tokio runtime.
110///
111/// # Example
112///
113/// ```ignore
114/// impl ManagedProc for MyDaemon {
115///     fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture {
116///         super_visor::spawn(self.run(shutdown))
117///     }
118/// }
119/// ```
120pub fn spawn<F, E>(fut: F) -> ManagedFuture
121where
122    F: Future<Output = Result<(), E>> + Send + 'static,
123    E: Into<BoxError> + Send + 'static,
124{
125    // tokio::spawn returns Result<Result<(), E>, JoinError>
126    Box::pin(tokio::spawn(fut).map(|result| match result {
127        Ok(Ok(())) => Ok(()),
128        Ok(Err(e)) => Err(SupervisorError::from_err(e)),
129        Err(e) => Err(SupervisorError::from_err(e)),
130    }))
131}
132
133/// Boxes a future without spawning a separate managed task
134///
135/// Use this in [`ManagedProc::start_task`] impls when you want to run the future
136/// directly rather than spawning it. Prefer [`spawn`] when possible, as spawned
137/// tasks can run more efficiently on the Tokio runtime
138///
139/// # Example
140///
141/// ```ignore
142/// impl ManagedProc for MyLocalDaemon {
143///     fn start_task(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture {
144///         super_visor::run(self.run(shutdown))
145///     }
146/// }
147/// ```
148pub fn run<F, E>(fut: F) -> ManagedFuture
149where
150    F: Future<Output = Result<(), E>> + Send + 'static,
151    E: Into<BoxError> + 'static,
152{
153    Box::pin(fut.map_err(SupervisorError::from_err))
154}
155
156/// A trait for types that can be managed as long-running async tasks.
157///
158/// Implement this trait to make your type usable with [`Supervisor`].
159/// The trait is also automatically implemented for closures of the form
160/// `FnOnce(ShutdownSignal) -> Future<Output = ProcResult>`.
161///
162/// # Example
163///
164/// ```ignore
165/// use super_visor::{ManagedProc, ManagedFuture};
166///
167/// struct MyDaemon { /* ... */ }
168///
169/// impl ManagedProc for MyDaemon {
170///     fn start_task(self: Box<Self>, shutdown: ShutdownSignal) -> ManaagedFuture {
171///         super_visor::spawn(self.run_task_logic_in_some_loop(shutdown))
172///     }
173/// }
174/// ```
175pub trait ManagedProc: Send + Sync {
176    /// Starts the process and returns a future that completes when the work is complete
177    /// or runs indefinitely in a continual loop.
178    ///
179    /// The `shutdown` listener will be triggered when the supervisor wants to shut down
180    /// the process. Implementations should listen for this signal and clean up gracefully.
181    /// Listening typically involves awaiting the shutdown signal alongside the primary operational
182    /// logic of the managed task in a select function or macro, or checking the signal has completed
183    /// or been cancelled at await points in the control loop
184    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture;
185}
186
187/// Manages the lifecycle of multiple async workers with coordinated, ordered shutdown.
188///
189/// `Supervisor` runs all registered processes concurrently and handles graceful shutdown
190/// when receiving SIGTERM or Ctrl+C. Tasks are shutdown in reverse order of their initial
191/// registration (LIFO), allowing dependent tasks to stop their dependencies. This mimics the
192/// Patterns of process management hierarchy and dependency familiar in the OTP Erlang runtime.
193///
194/// # Example
195///
196/// ```ignore
197/// use super_visor::Supervisor;
198///
199/// // Using the builder pattern
200/// Supervisor::builder()
201///     .add_proc(server)
202///     .add_proc(worker)
203///     .build()
204///     .start()
205///     .await?;
206///
207/// // Or using direct construction
208/// let mut supervisor = Supervisor::new();
209/// supervisor.add(server);
210/// supervisor.add(worker);
211/// supervisor.start().await?;
212/// ```
213pub struct Supervisor {
214    procs: Vec<Box<dyn ManagedProc>>,
215}
216
217impl ManagedProc for Supervisor {
218    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture {
219        crate::run(self.do_run(Box::pin(shutdown)))
220    }
221}
222
223/// Builder for constructing a [`Supervisor`].
224///
225/// # Example
226///
227/// ```ignore
228/// use super_visor::Supervisor;
229///
230/// let supervisor = Supervisor::builder()
231///     .add_proc(server)
232///     .add_proc(worker)
233///     .add_proc(sink)
234///     .build();
235/// ```
236pub struct SupervisorBuilder {
237    procs: Vec<Box<dyn ManagedProc>>,
238}
239
240struct CancelableLocalFuture {
241    cancel_token: CancellationToken,
242    future: ManagedFuture,
243}
244
245impl Future for CancelableLocalFuture {
246    type Output = ProcResult;
247
248    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
249        pin!(&mut self.future).poll(ctx)
250    }
251}
252
253impl<O, P> ManagedProc for P
254where
255    O: Future<Output = ProcResult> + Send + 'static,
256    P: FnOnce(ShutdownSignal) -> O + Send + Sync,
257{
258    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture {
259        Box::pin(self(shutdown))
260    }
261}
262
263impl Default for Supervisor {
264    fn default() -> Self {
265        Self::new()
266    }
267}
268
269impl Supervisor {
270    /// Creates a new empty supervisor.
271    pub fn new() -> Self {
272        Self { procs: Vec::new() }
273    }
274
275    /// Creates a new [`SupervisorBuilder`] for fluent task registration.
276    pub fn builder() -> SupervisorBuilder {
277        SupervisorBuilder { procs: Vec::new() }
278    }
279
280    /// Adds a task to the supervisor
281    ///
282    /// Tasks are started in the order they are added and shutdown in the reverse order.
283    pub fn add(&mut self, proc: impl ManagedProc + 'static) {
284        self.procs.push(Box::new(proc));
285    }
286
287    /// Starts all registered processes and waits for completion or shutdown.
288    ///
289    /// This method:
290    /// 1. Starts all processes concurrently
291    /// 2. Listens for SIGTERM or Ctrl+C signals
292    /// 3. On signal or error, shuts down all running processes in reverse order (LIFO)
293    /// 4. Returns the first error encountered or `Ok(())` if all tasks complete successfully
294    pub async fn start(self) -> ProcResult {
295        let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
296        let shutdown = Box::pin(
297            futures::future::select(
298                Box::pin(async move { sigterm.recv().await }),
299                Box::pin(signal::ctrl_c()),
300            )
301            .map(|_| ()),
302        );
303        self.do_run(shutdown).await
304    }
305
306    async fn do_run(self, mut shutdown: BoxFuture<'static, ()>) -> ProcResult {
307        let mut futures = start_futures(self.procs);
308
309        loop {
310            if futures.is_empty() {
311                break;
312            }
313
314            let mut select = ordered_select_all(futures);
315
316            tokio::select! {
317                biased;
318                _ = &mut shutdown => return stop_all(select.into_inner()).await,
319                (result, _index, remaining) = &mut select => match result {
320                    Ok(_) => futures = remaining,
321                    Err(err) => {
322                        let _ = stop_all(remaining).await;
323                        return Err(err);
324                    }
325                }
326            }
327        }
328
329        Ok(())
330    }
331}
332
333impl SupervisorBuilder {
334    /// Adds a process to the builder
335    ///
336    /// Processes are started in the order they are registered and shut down in reverse order.
337    pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
338        self.procs.push(Box::new(proc));
339        self
340    }
341
342    /// Consumes the builder and returns a configured [`Supervisor`].
343    pub fn build(self) -> Supervisor {
344        Supervisor { procs: self.procs }
345    }
346}
347
348fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
349    procs
350        .into_iter()
351        .map(|proc| {
352            let cancel_token = CancellationToken::new();
353            let child_token = cancel_token.child_token();
354            CancelableLocalFuture {
355                cancel_token,
356                future: proc.run_proc(ShutdownSignal::new(child_token)),
357            }
358        })
359        .collect()
360}
361
362async fn stop_all(procs: Vec<CancelableLocalFuture>) -> ProcResult {
363    futures::stream::iter(procs.into_iter().rev())
364        .then(|proc| async move {
365            proc.cancel_token.cancel();
366            proc.future.await
367        })
368        .collect::<Vec<_>>()
369        .await
370        .into_iter()
371        .collect()
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use tokio::sync::mpsc;
378
379    #[allow(dead_code)]
380    fn assert_send_sync() {
381        fn is_send<T: Send>() {}
382        fn is_sync<T: Sync>() {}
383        is_send::<Supervisor>();
384        is_sync::<Supervisor>();
385    }
386
387    #[derive(Debug)]
388    struct TestError(&'static str);
389
390    impl std::fmt::Display for TestError {
391        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
392            write!(f, "{}", self.0)
393        }
394    }
395
396    impl std::error::Error for TestError {}
397
398    fn test_err(msg: &'static str) -> SupervisorError {
399        SupervisorError::from_err(TestError(msg))
400    }
401
402    struct TestProc {
403        name: &'static str,
404        delay: u64,
405        result: ProcResult,
406        sender: mpsc::Sender<&'static str>,
407    }
408
409    impl ManagedProc for TestProc {
410        fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> ManagedFuture {
411            let handle = tokio::spawn(async move {
412                tokio::select! {
413                    _ = shutdown => (),
414                    _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
415                }
416                self.sender.send(self.name).await.expect("unable to send");
417                self.result
418            });
419
420            Box::pin(handle.map(|result| match result {
421                Ok(inner) => inner,
422                Err(e) => Err(SupervisorError::from_err(e)),
423            }))
424        }
425    }
426
427    #[tokio::test]
428    async fn stop_when_all_tasks_have_completed() {
429        let (sender, mut receiver) = mpsc::channel(5);
430
431        let result = Supervisor::builder()
432            .add_proc(TestProc {
433                name: "1",
434                delay: 50,
435                result: Ok(()),
436                sender: sender.clone(),
437            })
438            .add_proc(TestProc {
439                name: "2",
440                delay: 100,
441                result: Ok(()),
442                sender: sender.clone(),
443            })
444            .build()
445            .start()
446            .await;
447
448        assert_eq!(Some("1"), receiver.recv().await);
449        assert_eq!(Some("2"), receiver.recv().await);
450        assert!(result.is_ok());
451    }
452
453    #[tokio::test]
454    async fn will_stop_all_in_reverse_order_after_error() {
455        let (sender, mut receiver) = mpsc::channel(5);
456
457        let result = Supervisor::builder()
458            .add_proc(TestProc {
459                name: "1",
460                delay: 1000,
461                result: Ok(()),
462                sender: sender.clone(),
463            })
464            .add_proc(TestProc {
465                name: "2",
466                delay: 50,
467                result: Err(test_err("error")),
468                sender: sender.clone(),
469            })
470            .add_proc(TestProc {
471                name: "3",
472                delay: 1000,
473                result: Ok(()),
474                sender: sender.clone(),
475            })
476            .build()
477            .start()
478            .await;
479
480        assert_eq!(Some("2"), receiver.recv().await);
481        assert_eq!(Some("3"), receiver.recv().await);
482        assert_eq!(Some("1"), receiver.recv().await);
483        assert_eq!("task failed: error", result.unwrap_err().to_string());
484    }
485
486    #[tokio::test]
487    async fn will_return_first_error_returned() {
488        let (sender, mut receiver) = mpsc::channel(5);
489
490        let result = Supervisor::builder()
491            .add_proc(TestProc {
492                name: "1",
493                delay: 1000,
494                result: Ok(()),
495                sender: sender.clone(),
496            })
497            .add_proc(TestProc {
498                name: "2",
499                delay: 50,
500                result: Err(test_err("error")),
501                sender: sender.clone(),
502            })
503            .add_proc(TestProc {
504                name: "3",
505                delay: 200,
506                result: Err(test_err("second error")),
507                sender: sender.clone(),
508            })
509            .build()
510            .start()
511            .await;
512
513        assert_eq!(Some("2"), receiver.recv().await);
514        assert_eq!(Some("3"), receiver.recv().await);
515        assert_eq!(Some("1"), receiver.recv().await);
516        assert_eq!("task failed: error", result.unwrap_err().to_string());
517    }
518
519    #[tokio::test]
520    async fn nested_procs_will_stop_parent_then_move_up() {
521        let (sender, mut receiver) = mpsc::channel(10);
522
523        let result = Supervisor::builder()
524            .add_proc(TestProc {
525                name: "proc-1",
526                delay: 500,
527                result: Ok(()),
528                sender: sender.clone(),
529            })
530            .add_proc(
531                Supervisor::builder()
532                    .add_proc(TestProc {
533                        name: "proc-2-1",
534                        delay: 500,
535                        result: Ok(()),
536                        sender: sender.clone(),
537                    })
538                    .add_proc(TestProc {
539                        name: "proc-2-2",
540                        delay: 100,
541                        result: Err(test_err("error")),
542                        sender: sender.clone(),
543                    })
544                    .add_proc(TestProc {
545                        name: "proc-2-3",
546                        delay: 500,
547                        result: Ok(()),
548                        sender: sender.clone(),
549                    })
550                    .add_proc(TestProc {
551                        name: "proc-2-4",
552                        delay: 500,
553                        result: Ok(()),
554                        sender: sender.clone(),
555                    })
556                    .build(),
557            )
558            .add_proc(
559                Supervisor::builder()
560                    .add_proc(TestProc {
561                        name: "proc-3-1",
562                        delay: 1000,
563                        result: Ok(()),
564                        sender: sender.clone(),
565                    })
566                    .add_proc(TestProc {
567                        name: "proc-3-2",
568                        delay: 1000,
569                        result: Ok(()),
570                        sender: sender.clone(),
571                    })
572                    .add_proc(TestProc {
573                        name: "proc-3-3",
574                        delay: 1000,
575                        result: Ok(()),
576                        sender: sender.clone(),
577                    })
578                    .build(),
579            )
580            .build()
581            .start()
582            .await;
583
584        assert_eq!(Some("proc-2-2"), receiver.recv().await);
585        assert_eq!(Some("proc-2-4"), receiver.recv().await);
586        assert_eq!(Some("proc-2-3"), receiver.recv().await);
587        assert_eq!(Some("proc-2-1"), receiver.recv().await);
588        assert_eq!(Some("proc-3-3"), receiver.recv().await);
589        assert_eq!(Some("proc-3-2"), receiver.recv().await);
590        assert_eq!(Some("proc-3-1"), receiver.recv().await);
591        assert_eq!(Some("proc-1"), receiver.recv().await);
592        assert!(result.is_err());
593    }
594}