super_visor/
lib.rs

1mod select_all;
2
3use crate::select_all::select_all;
4use anyhow::Result;
5use futures::{future::LocalBoxFuture, Future, FutureExt, StreamExt};
6use std::{
7    pin::{pin, Pin},
8    task::{Context, Poll},
9};
10use tokio::signal;
11use tokio_util::sync::CancellationToken;
12
13fn root_shutdown() -> Result<LocalBoxFuture<'static, ()>> {
14    let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
15    Ok(Box::pin(
16        futures::future::select(
17            Box::pin(async move { sigterm.recv().await }),
18            Box::pin(signal::ctrl_c()),
19        )
20        .map(|_| ()),
21    ))
22}
23
24pub trait ManagedProc {
25    fn start_proc(
26        self: Box<Self>,
27        shutdown: CancellationToken,
28    ) -> LocalBoxFuture<'static, Result<()>>;
29}
30
31pub struct Supervisor {
32    procs: Vec<Box<dyn ManagedProc>>,
33}
34
35impl ManagedProc for Supervisor {
36    fn start_proc(
37        self: Box<Self>,
38        shutdown: CancellationToken,
39    ) -> LocalBoxFuture<'static, Result<()>> {
40        let cancel_listener = shutdown.cancelled_owned();
41        Box::pin(self.do_start(Box::pin(cancel_listener)))
42    }
43}
44
45pub struct SupervisorBuilder {
46    procs: Vec<Box<dyn ManagedProc>>,
47}
48
49struct CancelableLocalFuture {
50    cancel_token: CancellationToken,
51    future: LocalBoxFuture<'static, Result<()>>,
52}
53
54impl Future for CancelableLocalFuture {
55    type Output = Result<()>;
56
57    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
58        pin!(&mut self.future).poll(ctx)
59    }
60}
61
62impl<F, O> ManagedProc for F
63where
64    O: Future<Output = Result<()>> + 'static,
65    F: FnOnce(CancellationToken) -> O,
66{
67    fn start_proc(
68        self: Box<Self>,
69        shutdown: CancellationToken,
70    ) -> LocalBoxFuture<'static, Result<()>> {
71        Box::pin(self(shutdown))
72    }
73}
74
75impl Default for Supervisor {
76    fn default() -> Self {
77        Self::new()
78    }
79}
80
81impl Supervisor {
82    pub fn new() -> Self {
83        Self { procs: Vec::new() }
84    }
85
86    pub fn builder() -> SupervisorBuilder {
87        SupervisorBuilder { procs: Vec::new() }
88    }
89
90    pub fn add(&mut self, proc: impl ManagedProc + 'static) {
91        self.procs.push(Box::new(proc));
92    }
93
94    pub async fn start(self) -> Result<()> {
95        self.do_start(root_shutdown()?).await
96    }
97
98    async fn do_start(self, mut shutdown: LocalBoxFuture<'static, ()>) -> Result<()> {
99        let mut futures = start_futures(self.procs);
100
101        loop {
102            if futures.is_empty() {
103                break;
104            }
105
106            let mut select = select_all(futures);
107
108            tokio::select! {
109                biased;
110                _ = &mut shutdown => return stop_all(select.into_inner()).await,
111                (result, _index, remaining) = &mut select => match result {
112                    Ok(_) => futures = remaining,
113                    Err(err) => {
114                        let _ = stop_all(remaining).await;
115                        return Err(err);
116                    }
117                }
118            }
119        }
120
121        Ok(())
122    }
123}
124
125impl SupervisorBuilder {
126    pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
127        self.procs.push(Box::new(proc));
128        self
129    }
130
131    pub fn build(self) -> Supervisor {
132        Supervisor { procs: self.procs }
133    }
134}
135
136fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
137    procs
138        .into_iter()
139        .map(|proc| {
140            let cancel_token = CancellationToken::new();
141            let child_token = cancel_token.child_token();
142            CancelableLocalFuture {
143                cancel_token,
144                future: proc.start_proc(child_token),
145            }
146        })
147        .collect()
148}
149
150async fn stop_all(procs: Vec<CancelableLocalFuture>) -> Result<()> {
151    futures::stream::iter(procs.into_iter().rev())
152        .then(|proc| async move {
153            proc.cancel_token.cancel();
154            proc.future.await
155        })
156        .collect::<Vec<_>>()
157        .await
158        .into_iter()
159        .collect()
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use anyhow::anyhow;
166    use futures::TryFutureExt;
167    use tokio::sync::mpsc;
168
169    struct TestProc {
170        name: &'static str,
171        delay: u64,
172        result: Result<()>,
173        sender: mpsc::Sender<&'static str>,
174    }
175
176    impl ManagedProc for TestProc {
177        fn start_proc(
178            self: Box<Self>,
179            shutdown: CancellationToken,
180        ) -> LocalBoxFuture<'static, Result<()>> {
181            let handle = tokio::spawn(async move {
182                tokio::select! {
183                    _ = shutdown.cancelled() => (),
184                    _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
185                }
186                self.sender.send(self.name).await.expect("unable to send");
187                self.result
188            });
189
190            Box::pin(
191                handle
192                    .map_err(|err| err.into())
193                    .and_then(|result| async move { result }),
194            )
195        }
196    }
197
198    #[tokio::test]
199    async fn stop_when_all_tasks_have_completed() {
200        let (sender, mut receiver) = mpsc::channel(5);
201
202        let result = Supervisor::builder()
203            .add_proc(TestProc {
204                name: "1",
205                delay: 50,
206                result: Ok(()),
207                sender: sender.clone(),
208            })
209            .add_proc(TestProc {
210                name: "2",
211                delay: 100,
212                result: Ok(()),
213                sender: sender.clone(),
214            })
215            .build()
216            .start()
217            .await;
218
219        assert_eq!(Some("1"), receiver.recv().await);
220        assert_eq!(Some("2"), receiver.recv().await);
221        assert!(result.is_ok());
222    }
223
224    #[tokio::test]
225    async fn will_stop_all_in_reverse_order_after_error() {
226        let (sender, mut receiver) = mpsc::channel(5);
227
228        let result = Supervisor::builder()
229            .add_proc(TestProc {
230                name: "1",
231                delay: 1000,
232                result: Ok(()),
233                sender: sender.clone(),
234            })
235            .add_proc(TestProc {
236                name: "2",
237                delay: 50,
238                result: Err(anyhow!("error")),
239                sender: sender.clone(),
240            })
241            .add_proc(TestProc {
242                name: "3",
243                delay: 1000,
244                result: Ok(()),
245                sender: sender.clone(),
246            })
247            .build()
248            .start()
249            .await;
250
251        assert_eq!(Some("2"), receiver.recv().await);
252        assert_eq!(Some("3"), receiver.recv().await);
253        assert_eq!(Some("1"), receiver.recv().await);
254        assert_eq!("error", result.unwrap_err().to_string());
255    }
256
257    #[tokio::test]
258    async fn will_return_first_error_returned() {
259        let (sender, mut receiver) = mpsc::channel(5);
260
261        let result = Supervisor::builder()
262            .add_proc(TestProc {
263                name: "1",
264                delay: 1000,
265                result: Ok(()),
266                sender: sender.clone(),
267            })
268            .add_proc(TestProc {
269                name: "2",
270                delay: 50,
271                result: Err(anyhow!("error")),
272                sender: sender.clone(),
273            })
274            .add_proc(TestProc {
275                name: "3",
276                delay: 200,
277                result: Err(anyhow!("second error")),
278                sender: sender.clone(),
279            })
280            .build()
281            .start()
282            .await;
283
284        assert_eq!(Some("2"), receiver.recv().await);
285        assert_eq!(Some("3"), receiver.recv().await);
286        assert_eq!(Some("1"), receiver.recv().await);
287        assert_eq!("error", result.unwrap_err().to_string());
288    }
289
290    #[tokio::test]
291    async fn nested_procs_will_stop_parent_then_move_up() {
292        let (sender, mut receiver) = mpsc::channel(10);
293
294        let result = Supervisor::builder()
295            .add_proc(TestProc {
296                name: "proc-1",
297                delay: 500,
298                result: Ok(()),
299                sender: sender.clone(),
300            })
301            .add_proc(
302                Supervisor::builder()
303                    .add_proc(TestProc {
304                        name: "proc-2-1",
305                        delay: 500,
306                        result: Ok(()),
307                        sender: sender.clone(),
308                    })
309                    .add_proc(TestProc {
310                        name: "proc-2-2",
311                        delay: 100,
312                        result: Err(anyhow!("error")),
313                        sender: sender.clone(),
314                    })
315                    .add_proc(TestProc {
316                        name: "proc-2-3",
317                        delay: 500,
318                        result: Ok(()),
319                        sender: sender.clone(),
320                    })
321                    .add_proc(TestProc {
322                        name: "proc-2-4",
323                        delay: 500,
324                        result: Ok(()),
325                        sender: sender.clone(),
326                    })
327                    .build(),
328            )
329            .add_proc(
330                Supervisor::builder()
331                    .add_proc(TestProc {
332                        name: "proc-3-1",
333                        delay: 1000,
334                        result: Ok(()),
335                        sender: sender.clone(),
336                    })
337                    .add_proc(TestProc {
338                        name: "proc-3-2",
339                        delay: 1000,
340                        result: Ok(()),
341                        sender: sender.clone(),
342                    })
343                    .add_proc(TestProc {
344                        name: "proc-3-3",
345                        delay: 1000,
346                        result: Ok(()),
347                        sender: sender.clone(),
348                    })
349                    .build(),
350            )
351            .build()
352            .start()
353            .await;
354
355        assert_eq!(Some("proc-2-2"), receiver.recv().await);
356        assert_eq!(Some("proc-2-4"), receiver.recv().await);
357        assert_eq!(Some("proc-2-3"), receiver.recv().await);
358        assert_eq!(Some("proc-2-1"), receiver.recv().await);
359        assert_eq!(Some("proc-3-3"), receiver.recv().await);
360        assert_eq!(Some("proc-3-2"), receiver.recv().await);
361        assert_eq!(Some("proc-3-1"), receiver.recv().await);
362        assert_eq!(Some("proc-1"), receiver.recv().await);
363        assert!(result.is_err());
364    }
365}