1mod select_all;
2
3use crate::select_all::select_all;
4use anyhow::Result;
5use futures::{future::LocalBoxFuture, Future, FutureExt, StreamExt};
6use std::{
7 pin::{pin, Pin},
8 task::{Context, Poll},
9};
10use tokio::signal;
11use tokio_util::sync::CancellationToken;
12
13fn root_shutdown() -> Result<LocalBoxFuture<'static, ()>> {
14 let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
15 Ok(Box::pin(
16 futures::future::select(
17 Box::pin(async move { sigterm.recv().await }),
18 Box::pin(signal::ctrl_c()),
19 )
20 .map(|_| ()),
21 ))
22}
23
24pub trait ManagedProc {
25 fn start_proc(
26 self: Box<Self>,
27 shutdown: CancellationToken,
28 ) -> LocalBoxFuture<'static, Result<()>>;
29}
30
31pub struct Supervisor {
32 procs: Vec<Box<dyn ManagedProc>>,
33}
34
35impl ManagedProc for Supervisor {
36 fn start_proc(
37 self: Box<Self>,
38 shutdown: CancellationToken,
39 ) -> LocalBoxFuture<'static, Result<()>> {
40 let cancel_listener = shutdown.cancelled_owned();
41 Box::pin(self.do_start(Box::pin(cancel_listener)))
42 }
43}
44
45pub struct SupervisorBuilder {
46 procs: Vec<Box<dyn ManagedProc>>,
47}
48
49struct CancelableLocalFuture {
50 cancel_token: CancellationToken,
51 future: LocalBoxFuture<'static, Result<()>>,
52}
53
54impl Future for CancelableLocalFuture {
55 type Output = Result<()>;
56
57 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
58 pin!(&mut self.future).poll(ctx)
59 }
60}
61
62impl<F, O> ManagedProc for F
63where
64 O: Future<Output = Result<()>> + 'static,
65 F: FnOnce(CancellationToken) -> O,
66{
67 fn start_proc(
68 self: Box<Self>,
69 shutdown: CancellationToken,
70 ) -> LocalBoxFuture<'static, Result<()>> {
71 Box::pin(self(shutdown))
72 }
73}
74
75impl Default for Supervisor {
76 fn default() -> Self {
77 Self::new()
78 }
79}
80
81impl Supervisor {
82 pub fn new() -> Self {
83 Self { procs: Vec::new() }
84 }
85
86 pub fn builder() -> SupervisorBuilder {
87 SupervisorBuilder { procs: Vec::new() }
88 }
89
90 pub fn add(&mut self, proc: impl ManagedProc + 'static) {
91 self.procs.push(Box::new(proc));
92 }
93
94 pub async fn start(self) -> Result<()> {
95 self.do_start(root_shutdown()?).await
96 }
97
98 async fn do_start(self, mut shutdown: LocalBoxFuture<'static, ()>) -> Result<()> {
99 let mut futures = start_futures(self.procs);
100
101 loop {
102 if futures.is_empty() {
103 break;
104 }
105
106 let mut select = select_all(futures);
107
108 tokio::select! {
109 biased;
110 _ = &mut shutdown => return stop_all(select.into_inner()).await,
111 (result, _index, remaining) = &mut select => match result {
112 Ok(_) => futures = remaining,
113 Err(err) => {
114 let _ = stop_all(remaining).await;
115 return Err(err);
116 }
117 }
118 }
119 }
120
121 Ok(())
122 }
123}
124
125impl SupervisorBuilder {
126 pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
127 self.procs.push(Box::new(proc));
128 self
129 }
130
131 pub fn build(self) -> Supervisor {
132 Supervisor { procs: self.procs }
133 }
134}
135
136fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
137 procs
138 .into_iter()
139 .map(|proc| {
140 let cancel_token = CancellationToken::new();
141 let child_token = cancel_token.child_token();
142 CancelableLocalFuture {
143 cancel_token,
144 future: proc.start_proc(child_token),
145 }
146 })
147 .collect()
148}
149
150async fn stop_all(procs: Vec<CancelableLocalFuture>) -> Result<()> {
151 futures::stream::iter(procs.into_iter().rev())
152 .then(|proc| async move {
153 proc.cancel_token.cancel();
154 proc.future.await
155 })
156 .collect::<Vec<_>>()
157 .await
158 .into_iter()
159 .collect()
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use anyhow::anyhow;
166 use futures::TryFutureExt;
167 use tokio::sync::mpsc;
168
169 struct TestProc {
170 name: &'static str,
171 delay: u64,
172 result: Result<()>,
173 sender: mpsc::Sender<&'static str>,
174 }
175
176 impl ManagedProc for TestProc {
177 fn start_proc(
178 self: Box<Self>,
179 shutdown: CancellationToken,
180 ) -> LocalBoxFuture<'static, Result<()>> {
181 let handle = tokio::spawn(async move {
182 tokio::select! {
183 _ = shutdown.cancelled() => (),
184 _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
185 }
186 self.sender.send(self.name).await.expect("unable to send");
187 self.result
188 });
189
190 Box::pin(
191 handle
192 .map_err(|err| err.into())
193 .and_then(|result| async move { result }),
194 )
195 }
196 }
197
198 #[tokio::test]
199 async fn stop_when_all_tasks_have_completed() {
200 let (sender, mut receiver) = mpsc::channel(5);
201
202 let result = Supervisor::builder()
203 .add_proc(TestProc {
204 name: "1",
205 delay: 50,
206 result: Ok(()),
207 sender: sender.clone(),
208 })
209 .add_proc(TestProc {
210 name: "2",
211 delay: 100,
212 result: Ok(()),
213 sender: sender.clone(),
214 })
215 .build()
216 .start()
217 .await;
218
219 assert_eq!(Some("1"), receiver.recv().await);
220 assert_eq!(Some("2"), receiver.recv().await);
221 assert!(result.is_ok());
222 }
223
224 #[tokio::test]
225 async fn will_stop_all_in_reverse_order_after_error() {
226 let (sender, mut receiver) = mpsc::channel(5);
227
228 let result = Supervisor::builder()
229 .add_proc(TestProc {
230 name: "1",
231 delay: 1000,
232 result: Ok(()),
233 sender: sender.clone(),
234 })
235 .add_proc(TestProc {
236 name: "2",
237 delay: 50,
238 result: Err(anyhow!("error")),
239 sender: sender.clone(),
240 })
241 .add_proc(TestProc {
242 name: "3",
243 delay: 1000,
244 result: Ok(()),
245 sender: sender.clone(),
246 })
247 .build()
248 .start()
249 .await;
250
251 assert_eq!(Some("2"), receiver.recv().await);
252 assert_eq!(Some("3"), receiver.recv().await);
253 assert_eq!(Some("1"), receiver.recv().await);
254 assert_eq!("error", result.unwrap_err().to_string());
255 }
256
257 #[tokio::test]
258 async fn will_return_first_error_returned() {
259 let (sender, mut receiver) = mpsc::channel(5);
260
261 let result = Supervisor::builder()
262 .add_proc(TestProc {
263 name: "1",
264 delay: 1000,
265 result: Ok(()),
266 sender: sender.clone(),
267 })
268 .add_proc(TestProc {
269 name: "2",
270 delay: 50,
271 result: Err(anyhow!("error")),
272 sender: sender.clone(),
273 })
274 .add_proc(TestProc {
275 name: "3",
276 delay: 200,
277 result: Err(anyhow!("second error")),
278 sender: sender.clone(),
279 })
280 .build()
281 .start()
282 .await;
283
284 assert_eq!(Some("2"), receiver.recv().await);
285 assert_eq!(Some("3"), receiver.recv().await);
286 assert_eq!(Some("1"), receiver.recv().await);
287 assert_eq!("error", result.unwrap_err().to_string());
288 }
289
290 #[tokio::test]
291 async fn nested_procs_will_stop_parent_then_move_up() {
292 let (sender, mut receiver) = mpsc::channel(10);
293
294 let result = Supervisor::builder()
295 .add_proc(TestProc {
296 name: "proc-1",
297 delay: 500,
298 result: Ok(()),
299 sender: sender.clone(),
300 })
301 .add_proc(
302 Supervisor::builder()
303 .add_proc(TestProc {
304 name: "proc-2-1",
305 delay: 500,
306 result: Ok(()),
307 sender: sender.clone(),
308 })
309 .add_proc(TestProc {
310 name: "proc-2-2",
311 delay: 100,
312 result: Err(anyhow!("error")),
313 sender: sender.clone(),
314 })
315 .add_proc(TestProc {
316 name: "proc-2-3",
317 delay: 500,
318 result: Ok(()),
319 sender: sender.clone(),
320 })
321 .add_proc(TestProc {
322 name: "proc-2-4",
323 delay: 500,
324 result: Ok(()),
325 sender: sender.clone(),
326 })
327 .build(),
328 )
329 .add_proc(
330 Supervisor::builder()
331 .add_proc(TestProc {
332 name: "proc-3-1",
333 delay: 1000,
334 result: Ok(()),
335 sender: sender.clone(),
336 })
337 .add_proc(TestProc {
338 name: "proc-3-2",
339 delay: 1000,
340 result: Ok(()),
341 sender: sender.clone(),
342 })
343 .add_proc(TestProc {
344 name: "proc-3-3",
345 delay: 1000,
346 result: Ok(()),
347 sender: sender.clone(),
348 })
349 .build(),
350 )
351 .build()
352 .start()
353 .await;
354
355 assert_eq!(Some("proc-2-2"), receiver.recv().await);
356 assert_eq!(Some("proc-2-4"), receiver.recv().await);
357 assert_eq!(Some("proc-2-3"), receiver.recv().await);
358 assert_eq!(Some("proc-2-1"), receiver.recv().await);
359 assert_eq!(Some("proc-3-3"), receiver.recv().await);
360 assert_eq!(Some("proc-3-2"), receiver.recv().await);
361 assert_eq!(Some("proc-3-1"), receiver.recv().await);
362 assert_eq!(Some("proc-1"), receiver.recv().await);
363 assert!(result.is_err());
364 }
365}