super_visor/
lib.rs

1mod ordered_select_all;
2
3use crate::ordered_select_all::ordered_select_all;
4use anyhow::Result;
5use futures::{future::LocalBoxFuture, Future, FutureExt, StreamExt};
6use std::{
7    pin::{pin, Pin},
8    task::{Context, Poll},
9};
10use tokio::signal;
11use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
12
13pub struct ShutdownSignal {
14    future: Pin<Box<WaitForCancellationFutureOwned>>,
15}
16
17impl ShutdownSignal {
18    pub fn new(token: CancellationToken) -> Self {
19        Self {
20            future: Box::pin(token.cancelled_owned()),
21        }
22    }
23}
24
25impl Future for ShutdownSignal {
26    type Output = ();
27
28    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
29        self.future.as_mut().poll(cx)
30    }
31}
32
33pub trait ManagedProc {
34    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>>;
35}
36
37pub struct Supervisor {
38    procs: Vec<Box<dyn ManagedProc>>,
39}
40
41impl ManagedProc for Supervisor {
42    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
43        Box::pin(self.do_run(Box::pin(shutdown)))
44    }
45}
46
47pub struct SupervisorBuilder {
48    procs: Vec<Box<dyn ManagedProc>>,
49}
50
51struct CancelableLocalFuture {
52    cancel_token: CancellationToken,
53    future: LocalBoxFuture<'static, Result<()>>,
54}
55
56impl Future for CancelableLocalFuture {
57    type Output = Result<()>;
58
59    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
60        pin!(&mut self.future).poll(ctx)
61    }
62}
63
64impl<O, P> ManagedProc for P
65where
66    O: Future<Output = Result<()>> + 'static,
67    P: FnOnce(ShutdownSignal) -> O,
68{
69    fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
70        Box::pin(self(shutdown))
71    }
72}
73
74impl Default for Supervisor {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl Supervisor {
81    pub fn new() -> Self {
82        Self { procs: Vec::new() }
83    }
84
85    pub fn builder() -> SupervisorBuilder {
86        SupervisorBuilder { procs: Vec::new() }
87    }
88
89    pub fn add(&mut self, proc: impl ManagedProc + 'static) {
90        self.procs.push(Box::new(proc));
91    }
92
93    pub async fn start(self) -> Result<()> {
94        let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
95        let shutdown = Box::pin(
96            futures::future::select(
97                Box::pin(async move { sigterm.recv().await }),
98                Box::pin(signal::ctrl_c()),
99            )
100            .map(|_| ()),
101        );
102        self.do_run(shutdown).await
103    }
104
105    async fn do_run(self, mut shutdown: LocalBoxFuture<'static, ()>) -> Result<()> {
106        let mut futures = start_futures(self.procs);
107
108        loop {
109            if futures.is_empty() {
110                break;
111            }
112
113            let mut select = ordered_select_all(futures);
114
115            tokio::select! {
116                biased;
117                _ = &mut shutdown => return stop_all(select.into_inner()).await,
118                (result, _index, remaining) = &mut select => match result {
119                    Ok(_) => futures = remaining,
120                    Err(err) => {
121                        let _ = stop_all(remaining).await;
122                        return Err(err);
123                    }
124                }
125            }
126        }
127
128        Ok(())
129    }
130}
131
132impl SupervisorBuilder {
133    pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
134        self.procs.push(Box::new(proc));
135        self
136    }
137
138    pub fn build(self) -> Supervisor {
139        Supervisor { procs: self.procs }
140    }
141}
142
143fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
144    procs
145        .into_iter()
146        .map(|proc| {
147            let cancel_token = CancellationToken::new();
148            let child_token = cancel_token.child_token();
149            CancelableLocalFuture {
150                cancel_token,
151                future: proc.run_proc(ShutdownSignal::new(child_token)),
152            }
153        })
154        .collect()
155}
156
157async fn stop_all(procs: Vec<CancelableLocalFuture>) -> Result<()> {
158    futures::stream::iter(procs.into_iter().rev())
159        .then(|proc| async move {
160            proc.cancel_token.cancel();
161            proc.future.await
162        })
163        .collect::<Vec<_>>()
164        .await
165        .into_iter()
166        .collect()
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use anyhow::anyhow;
173    use futures::TryFutureExt;
174    use tokio::sync::mpsc;
175
176    struct TestProc {
177        name: &'static str,
178        delay: u64,
179        result: Result<()>,
180        sender: mpsc::Sender<&'static str>,
181    }
182
183    impl ManagedProc for TestProc {
184        fn run_proc(
185            self: Box<Self>,
186            shutdown: ShutdownSignal,
187        ) -> LocalBoxFuture<'static, Result<()>> {
188            let handle = tokio::spawn(async move {
189                tokio::select! {
190                    _ = shutdown => (),
191                    _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
192                }
193                self.sender.send(self.name).await.expect("unable to send");
194                self.result
195            });
196
197            Box::pin(
198                handle
199                    .map_err(|err| err.into())
200                    .and_then(|result| async move { result }),
201            )
202        }
203    }
204
205    #[tokio::test]
206    async fn stop_when_all_tasks_have_completed() {
207        let (sender, mut receiver) = mpsc::channel(5);
208
209        let result = Supervisor::builder()
210            .add_proc(TestProc {
211                name: "1",
212                delay: 50,
213                result: Ok(()),
214                sender: sender.clone(),
215            })
216            .add_proc(TestProc {
217                name: "2",
218                delay: 100,
219                result: Ok(()),
220                sender: sender.clone(),
221            })
222            .build()
223            .start()
224            .await;
225
226        assert_eq!(Some("1"), receiver.recv().await);
227        assert_eq!(Some("2"), receiver.recv().await);
228        assert!(result.is_ok());
229    }
230
231    #[tokio::test]
232    async fn will_stop_all_in_reverse_order_after_error() {
233        let (sender, mut receiver) = mpsc::channel(5);
234
235        let result = Supervisor::builder()
236            .add_proc(TestProc {
237                name: "1",
238                delay: 1000,
239                result: Ok(()),
240                sender: sender.clone(),
241            })
242            .add_proc(TestProc {
243                name: "2",
244                delay: 50,
245                result: Err(anyhow!("error")),
246                sender: sender.clone(),
247            })
248            .add_proc(TestProc {
249                name: "3",
250                delay: 1000,
251                result: Ok(()),
252                sender: sender.clone(),
253            })
254            .build()
255            .start()
256            .await;
257
258        assert_eq!(Some("2"), receiver.recv().await);
259        assert_eq!(Some("3"), receiver.recv().await);
260        assert_eq!(Some("1"), receiver.recv().await);
261        assert_eq!("error", result.unwrap_err().to_string());
262    }
263
264    #[tokio::test]
265    async fn will_return_first_error_returned() {
266        let (sender, mut receiver) = mpsc::channel(5);
267
268        let result = Supervisor::builder()
269            .add_proc(TestProc {
270                name: "1",
271                delay: 1000,
272                result: Ok(()),
273                sender: sender.clone(),
274            })
275            .add_proc(TestProc {
276                name: "2",
277                delay: 50,
278                result: Err(anyhow!("error")),
279                sender: sender.clone(),
280            })
281            .add_proc(TestProc {
282                name: "3",
283                delay: 200,
284                result: Err(anyhow!("second error")),
285                sender: sender.clone(),
286            })
287            .build()
288            .start()
289            .await;
290
291        assert_eq!(Some("2"), receiver.recv().await);
292        assert_eq!(Some("3"), receiver.recv().await);
293        assert_eq!(Some("1"), receiver.recv().await);
294        assert_eq!("error", result.unwrap_err().to_string());
295    }
296
297    #[tokio::test]
298    async fn nested_procs_will_stop_parent_then_move_up() {
299        let (sender, mut receiver) = mpsc::channel(10);
300
301        let result = Supervisor::builder()
302            .add_proc(TestProc {
303                name: "proc-1",
304                delay: 500,
305                result: Ok(()),
306                sender: sender.clone(),
307            })
308            .add_proc(
309                Supervisor::builder()
310                    .add_proc(TestProc {
311                        name: "proc-2-1",
312                        delay: 500,
313                        result: Ok(()),
314                        sender: sender.clone(),
315                    })
316                    .add_proc(TestProc {
317                        name: "proc-2-2",
318                        delay: 100,
319                        result: Err(anyhow!("error")),
320                        sender: sender.clone(),
321                    })
322                    .add_proc(TestProc {
323                        name: "proc-2-3",
324                        delay: 500,
325                        result: Ok(()),
326                        sender: sender.clone(),
327                    })
328                    .add_proc(TestProc {
329                        name: "proc-2-4",
330                        delay: 500,
331                        result: Ok(()),
332                        sender: sender.clone(),
333                    })
334                    .build(),
335            )
336            .add_proc(
337                Supervisor::builder()
338                    .add_proc(TestProc {
339                        name: "proc-3-1",
340                        delay: 1000,
341                        result: Ok(()),
342                        sender: sender.clone(),
343                    })
344                    .add_proc(TestProc {
345                        name: "proc-3-2",
346                        delay: 1000,
347                        result: Ok(()),
348                        sender: sender.clone(),
349                    })
350                    .add_proc(TestProc {
351                        name: "proc-3-3",
352                        delay: 1000,
353                        result: Ok(()),
354                        sender: sender.clone(),
355                    })
356                    .build(),
357            )
358            .build()
359            .start()
360            .await;
361
362        assert_eq!(Some("proc-2-2"), receiver.recv().await);
363        assert_eq!(Some("proc-2-4"), receiver.recv().await);
364        assert_eq!(Some("proc-2-3"), receiver.recv().await);
365        assert_eq!(Some("proc-2-1"), receiver.recv().await);
366        assert_eq!(Some("proc-3-3"), receiver.recv().await);
367        assert_eq!(Some("proc-3-2"), receiver.recv().await);
368        assert_eq!(Some("proc-3-1"), receiver.recv().await);
369        assert_eq!(Some("proc-1"), receiver.recv().await);
370        assert!(result.is_err());
371    }
372}