super_visor/
lib.rs

1mod ordered_select_all;
2
3use crate::ordered_select_all::ordered_select_all;
4use anyhow::Result;
5use futures::{future::LocalBoxFuture, Future, FutureExt, StreamExt};
6use std::{
7    pin::{pin, Pin},
8    task::{Context, Poll},
9};
10use tokio::signal;
11use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
12
13pub struct ShutdownSignal {
14    token: CancellationToken,
15    future: Option<Pin<Box<WaitForCancellationFutureOwned>>>,
16}
17
18impl ShutdownSignal {
19    pub fn new(token: CancellationToken) -> Self {
20        Self {
21            token,
22            future: None,
23        }
24    }
25
26    pub fn is_cancelled(&self) -> bool {
27        self.token.is_cancelled()
28    }
29
30    pub fn token(&self) -> &CancellationToken {
31        &self.token
32    }
33}
34
35impl Future for ShutdownSignal {
36    type Output = ();
37
38    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
39        // Lazy init the future on first poll
40        if self.future.is_none() {
41            self.future = Some(Box::pin(self.token.clone().cancelled_owned()));
42        }
43
44        // Poll the cached future
45        self.future.as_mut().unwrap().as_mut().poll(cx)
46    }
47}
48
49impl Clone for ShutdownSignal {
50    fn clone(&self) -> Self {
51        Self {
52            token: self.token.clone(),
53            future: None,
54        }
55    }
56}
57
58impl Unpin for ShutdownSignal {}
59
60pub trait ManagedProc {
61    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>>;
62}
63
64pub struct Supervisor {
65    procs: Vec<Box<dyn ManagedProc>>,
66}
67
68impl ManagedProc for Supervisor {
69    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
70        Box::pin(self.do_run(Box::pin(shutdown)))
71    }
72}
73
74pub struct SupervisorBuilder {
75    procs: Vec<Box<dyn ManagedProc>>,
76}
77
78struct CancelableLocalFuture {
79    cancel_token: CancellationToken,
80    future: LocalBoxFuture<'static, Result<()>>,
81}
82
83impl Future for CancelableLocalFuture {
84    type Output = Result<()>;
85
86    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
87        pin!(&mut self.future).poll(ctx)
88    }
89}
90
91impl<O, P> ManagedProc for P
92where
93    O: Future<Output = Result<()>> + 'static,
94    P: FnOnce(ShutdownSignal) -> O,
95{
96    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
97        Box::pin(self(shutdown))
98    }
99}
100
101impl Default for Supervisor {
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl Supervisor {
108    pub fn new() -> Self {
109        Self { procs: Vec::new() }
110    }
111
112    pub fn builder() -> SupervisorBuilder {
113        SupervisorBuilder { procs: Vec::new() }
114    }
115
116    pub fn add(&mut self, proc: impl ManagedProc + 'static) {
117        self.procs.push(Box::new(proc));
118    }
119
120    pub async fn start(self) -> Result<()> {
121        let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
122        let shutdown = Box::pin(
123            futures::future::select(
124                Box::pin(async move { sigterm.recv().await }),
125                Box::pin(signal::ctrl_c()),
126            )
127            .map(|_| ()),
128        );
129        self.do_run(shutdown).await
130    }
131
132    async fn do_run(self, mut shutdown: LocalBoxFuture<'static, ()>) -> Result<()> {
133        let mut futures = start_futures(self.procs);
134
135        loop {
136            if futures.is_empty() {
137                break;
138            }
139
140            let mut select = ordered_select_all(futures);
141
142            tokio::select! {
143                biased;
144                _ = &mut shutdown => return stop_all(select.into_inner()).await,
145                (result, _index, remaining) = &mut select => match result {
146                    Ok(_) => futures = remaining,
147                    Err(err) => {
148                        let _ = stop_all(remaining).await;
149                        return Err(err);
150                    }
151                }
152            }
153        }
154
155        Ok(())
156    }
157}
158
159impl SupervisorBuilder {
160    pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
161        self.procs.push(Box::new(proc));
162        self
163    }
164
165    pub fn build(self) -> Supervisor {
166        Supervisor { procs: self.procs }
167    }
168}
169
170fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
171    procs
172        .into_iter()
173        .map(|proc| {
174            let cancel_token = CancellationToken::new();
175            let child_token = cancel_token.child_token();
176            CancelableLocalFuture {
177                cancel_token,
178                future: proc.run_proc(ShutdownSignal::new(child_token)),
179            }
180        })
181        .collect()
182}
183
184async fn stop_all(procs: Vec<CancelableLocalFuture>) -> Result<()> {
185    futures::stream::iter(procs.into_iter().rev())
186        .then(|proc| async move {
187            proc.cancel_token.cancel();
188            proc.future.await
189        })
190        .collect::<Vec<_>>()
191        .await
192        .into_iter()
193        .collect()
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use anyhow::anyhow;
200    use futures::TryFutureExt;
201    use tokio::sync::mpsc;
202
203    struct TestProc {
204        name: &'static str,
205        delay: u64,
206        result: Result<()>,
207        sender: mpsc::Sender<&'static str>,
208    }
209
210    impl ManagedProc for TestProc {
211        fn run_proc(
212            self: Box<Self>,
213            shutdown: ShutdownSignal,
214        ) -> LocalBoxFuture<'static, Result<()>> {
215            let handle = tokio::spawn(async move {
216                tokio::select! {
217                    _ = shutdown => (),
218                    _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
219                }
220                self.sender.send(self.name).await.expect("unable to send");
221                self.result
222            });
223
224            Box::pin(
225                handle
226                    .map_err(|err| err.into())
227                    .and_then(|result| async move { result }),
228            )
229        }
230    }
231
232    #[tokio::test]
233    async fn stop_when_all_tasks_have_completed() {
234        let (sender, mut receiver) = mpsc::channel(5);
235
236        let result = Supervisor::builder()
237            .add_proc(TestProc {
238                name: "1",
239                delay: 50,
240                result: Ok(()),
241                sender: sender.clone(),
242            })
243            .add_proc(TestProc {
244                name: "2",
245                delay: 100,
246                result: Ok(()),
247                sender: sender.clone(),
248            })
249            .build()
250            .start()
251            .await;
252
253        assert_eq!(Some("1"), receiver.recv().await);
254        assert_eq!(Some("2"), receiver.recv().await);
255        assert!(result.is_ok());
256    }
257
258    #[tokio::test]
259    async fn will_stop_all_in_reverse_order_after_error() {
260        let (sender, mut receiver) = mpsc::channel(5);
261
262        let result = Supervisor::builder()
263            .add_proc(TestProc {
264                name: "1",
265                delay: 1000,
266                result: Ok(()),
267                sender: sender.clone(),
268            })
269            .add_proc(TestProc {
270                name: "2",
271                delay: 50,
272                result: Err(anyhow!("error")),
273                sender: sender.clone(),
274            })
275            .add_proc(TestProc {
276                name: "3",
277                delay: 1000,
278                result: Ok(()),
279                sender: sender.clone(),
280            })
281            .build()
282            .start()
283            .await;
284
285        assert_eq!(Some("2"), receiver.recv().await);
286        assert_eq!(Some("3"), receiver.recv().await);
287        assert_eq!(Some("1"), receiver.recv().await);
288        assert_eq!("error", result.unwrap_err().to_string());
289    }
290
291    #[tokio::test]
292    async fn will_return_first_error_returned() {
293        let (sender, mut receiver) = mpsc::channel(5);
294
295        let result = Supervisor::builder()
296            .add_proc(TestProc {
297                name: "1",
298                delay: 1000,
299                result: Ok(()),
300                sender: sender.clone(),
301            })
302            .add_proc(TestProc {
303                name: "2",
304                delay: 50,
305                result: Err(anyhow!("error")),
306                sender: sender.clone(),
307            })
308            .add_proc(TestProc {
309                name: "3",
310                delay: 200,
311                result: Err(anyhow!("second error")),
312                sender: sender.clone(),
313            })
314            .build()
315            .start()
316            .await;
317
318        assert_eq!(Some("2"), receiver.recv().await);
319        assert_eq!(Some("3"), receiver.recv().await);
320        assert_eq!(Some("1"), receiver.recv().await);
321        assert_eq!("error", result.unwrap_err().to_string());
322    }
323
324    #[tokio::test]
325    async fn nested_procs_will_stop_parent_then_move_up() {
326        let (sender, mut receiver) = mpsc::channel(10);
327
328        let result = Supervisor::builder()
329            .add_proc(TestProc {
330                name: "proc-1",
331                delay: 500,
332                result: Ok(()),
333                sender: sender.clone(),
334            })
335            .add_proc(
336                Supervisor::builder()
337                    .add_proc(TestProc {
338                        name: "proc-2-1",
339                        delay: 500,
340                        result: Ok(()),
341                        sender: sender.clone(),
342                    })
343                    .add_proc(TestProc {
344                        name: "proc-2-2",
345                        delay: 100,
346                        result: Err(anyhow!("error")),
347                        sender: sender.clone(),
348                    })
349                    .add_proc(TestProc {
350                        name: "proc-2-3",
351                        delay: 500,
352                        result: Ok(()),
353                        sender: sender.clone(),
354                    })
355                    .add_proc(TestProc {
356                        name: "proc-2-4",
357                        delay: 500,
358                        result: Ok(()),
359                        sender: sender.clone(),
360                    })
361                    .build(),
362            )
363            .add_proc(
364                Supervisor::builder()
365                    .add_proc(TestProc {
366                        name: "proc-3-1",
367                        delay: 1000,
368                        result: Ok(()),
369                        sender: sender.clone(),
370                    })
371                    .add_proc(TestProc {
372                        name: "proc-3-2",
373                        delay: 1000,
374                        result: Ok(()),
375                        sender: sender.clone(),
376                    })
377                    .add_proc(TestProc {
378                        name: "proc-3-3",
379                        delay: 1000,
380                        result: Ok(()),
381                        sender: sender.clone(),
382                    })
383                    .build(),
384            )
385            .build()
386            .start()
387            .await;
388
389        assert_eq!(Some("proc-2-2"), receiver.recv().await);
390        assert_eq!(Some("proc-2-4"), receiver.recv().await);
391        assert_eq!(Some("proc-2-3"), receiver.recv().await);
392        assert_eq!(Some("proc-2-1"), receiver.recv().await);
393        assert_eq!(Some("proc-3-3"), receiver.recv().await);
394        assert_eq!(Some("proc-3-2"), receiver.recv().await);
395        assert_eq!(Some("proc-3-1"), receiver.recv().await);
396        assert_eq!(Some("proc-1"), receiver.recv().await);
397        assert!(result.is_err());
398    }
399}