1mod ordered_select_all;
2
3use crate::ordered_select_all::ordered_select_all;
4use anyhow::Result;
5use futures::{future::LocalBoxFuture, Future, FutureExt, StreamExt};
6use std::{
7 pin::{pin, Pin},
8 task::{Context, Poll},
9};
10use tokio::signal;
11use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
12
13pub struct ShutdownSignal {
14 future: Pin<Box<WaitForCancellationFutureOwned>>,
15}
16
17impl ShutdownSignal {
18 pub fn new(token: CancellationToken) -> Self {
19 Self {
20 future: Box::pin(token.cancelled_owned()),
21 }
22 }
23}
24
25impl Future for ShutdownSignal {
26 type Output = ();
27
28 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
29 self.future.as_mut().poll(cx)
30 }
31}
32
33pub trait ManagedProc {
34 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>>;
35}
36
37pub struct Supervisor {
38 procs: Vec<Box<dyn ManagedProc>>,
39}
40
41impl ManagedProc for Supervisor {
42 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
43 Box::pin(self.do_run(Box::pin(shutdown)))
44 }
45}
46
47pub struct SupervisorBuilder {
48 procs: Vec<Box<dyn ManagedProc>>,
49}
50
51struct CancelableLocalFuture {
52 cancel_token: CancellationToken,
53 future: LocalBoxFuture<'static, Result<()>>,
54}
55
56impl Future for CancelableLocalFuture {
57 type Output = Result<()>;
58
59 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
60 pin!(&mut self.future).poll(ctx)
61 }
62}
63
64impl<O, P> ManagedProc for P
65where
66 O: Future<Output = Result<()>> + 'static,
67 P: FnOnce(ShutdownSignal) -> O,
68{
69 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
70 Box::pin(self(shutdown))
71 }
72}
73
74impl Default for Supervisor {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl Supervisor {
81 pub fn new() -> Self {
82 Self { procs: Vec::new() }
83 }
84
85 pub fn builder() -> SupervisorBuilder {
86 SupervisorBuilder { procs: Vec::new() }
87 }
88
89 pub fn add(&mut self, proc: impl ManagedProc + 'static) {
90 self.procs.push(Box::new(proc));
91 }
92
93 pub async fn start(self) -> Result<()> {
94 let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
95 let shutdown = Box::pin(
96 futures::future::select(
97 Box::pin(async move { sigterm.recv().await }),
98 Box::pin(signal::ctrl_c()),
99 )
100 .map(|_| ()),
101 );
102 self.do_run(shutdown).await
103 }
104
105 async fn do_run(self, mut shutdown: LocalBoxFuture<'static, ()>) -> Result<()> {
106 let mut futures = start_futures(self.procs);
107
108 loop {
109 if futures.is_empty() {
110 break;
111 }
112
113 let mut select = ordered_select_all(futures);
114
115 tokio::select! {
116 biased;
117 _ = &mut shutdown => return stop_all(select.into_inner()).await,
118 (result, _index, remaining) = &mut select => match result {
119 Ok(_) => futures = remaining,
120 Err(err) => {
121 let _ = stop_all(remaining).await;
122 return Err(err);
123 }
124 }
125 }
126 }
127
128 Ok(())
129 }
130}
131
132impl SupervisorBuilder {
133 pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
134 self.procs.push(Box::new(proc));
135 self
136 }
137
138 pub fn build(self) -> Supervisor {
139 Supervisor { procs: self.procs }
140 }
141}
142
143fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
144 procs
145 .into_iter()
146 .map(|proc| {
147 let cancel_token = CancellationToken::new();
148 let child_token = cancel_token.child_token();
149 CancelableLocalFuture {
150 cancel_token,
151 future: proc.run_proc(ShutdownSignal::new(child_token)),
152 }
153 })
154 .collect()
155}
156
157async fn stop_all(procs: Vec<CancelableLocalFuture>) -> Result<()> {
158 futures::stream::iter(procs.into_iter().rev())
159 .then(|proc| async move {
160 proc.cancel_token.cancel();
161 proc.future.await
162 })
163 .collect::<Vec<_>>()
164 .await
165 .into_iter()
166 .collect()
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use anyhow::anyhow;
173 use futures::TryFutureExt;
174 use tokio::sync::mpsc;
175
176 struct TestProc {
177 name: &'static str,
178 delay: u64,
179 result: Result<()>,
180 sender: mpsc::Sender<&'static str>,
181 }
182
183 impl ManagedProc for TestProc {
184 fn run_proc(
185 self: Box<Self>,
186 shutdown: ShutdownSignal,
187 ) -> LocalBoxFuture<'static, Result<()>> {
188 let handle = tokio::spawn(async move {
189 tokio::select! {
190 _ = shutdown => (),
191 _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
192 }
193 self.sender.send(self.name).await.expect("unable to send");
194 self.result
195 });
196
197 Box::pin(
198 handle
199 .map_err(|err| err.into())
200 .and_then(|result| async move { result }),
201 )
202 }
203 }
204
205 #[tokio::test]
206 async fn stop_when_all_tasks_have_completed() {
207 let (sender, mut receiver) = mpsc::channel(5);
208
209 let result = Supervisor::builder()
210 .add_proc(TestProc {
211 name: "1",
212 delay: 50,
213 result: Ok(()),
214 sender: sender.clone(),
215 })
216 .add_proc(TestProc {
217 name: "2",
218 delay: 100,
219 result: Ok(()),
220 sender: sender.clone(),
221 })
222 .build()
223 .start()
224 .await;
225
226 assert_eq!(Some("1"), receiver.recv().await);
227 assert_eq!(Some("2"), receiver.recv().await);
228 assert!(result.is_ok());
229 }
230
231 #[tokio::test]
232 async fn will_stop_all_in_reverse_order_after_error() {
233 let (sender, mut receiver) = mpsc::channel(5);
234
235 let result = Supervisor::builder()
236 .add_proc(TestProc {
237 name: "1",
238 delay: 1000,
239 result: Ok(()),
240 sender: sender.clone(),
241 })
242 .add_proc(TestProc {
243 name: "2",
244 delay: 50,
245 result: Err(anyhow!("error")),
246 sender: sender.clone(),
247 })
248 .add_proc(TestProc {
249 name: "3",
250 delay: 1000,
251 result: Ok(()),
252 sender: sender.clone(),
253 })
254 .build()
255 .start()
256 .await;
257
258 assert_eq!(Some("2"), receiver.recv().await);
259 assert_eq!(Some("3"), receiver.recv().await);
260 assert_eq!(Some("1"), receiver.recv().await);
261 assert_eq!("error", result.unwrap_err().to_string());
262 }
263
264 #[tokio::test]
265 async fn will_return_first_error_returned() {
266 let (sender, mut receiver) = mpsc::channel(5);
267
268 let result = Supervisor::builder()
269 .add_proc(TestProc {
270 name: "1",
271 delay: 1000,
272 result: Ok(()),
273 sender: sender.clone(),
274 })
275 .add_proc(TestProc {
276 name: "2",
277 delay: 50,
278 result: Err(anyhow!("error")),
279 sender: sender.clone(),
280 })
281 .add_proc(TestProc {
282 name: "3",
283 delay: 200,
284 result: Err(anyhow!("second error")),
285 sender: sender.clone(),
286 })
287 .build()
288 .start()
289 .await;
290
291 assert_eq!(Some("2"), receiver.recv().await);
292 assert_eq!(Some("3"), receiver.recv().await);
293 assert_eq!(Some("1"), receiver.recv().await);
294 assert_eq!("error", result.unwrap_err().to_string());
295 }
296
297 #[tokio::test]
298 async fn nested_procs_will_stop_parent_then_move_up() {
299 let (sender, mut receiver) = mpsc::channel(10);
300
301 let result = Supervisor::builder()
302 .add_proc(TestProc {
303 name: "proc-1",
304 delay: 500,
305 result: Ok(()),
306 sender: sender.clone(),
307 })
308 .add_proc(
309 Supervisor::builder()
310 .add_proc(TestProc {
311 name: "proc-2-1",
312 delay: 500,
313 result: Ok(()),
314 sender: sender.clone(),
315 })
316 .add_proc(TestProc {
317 name: "proc-2-2",
318 delay: 100,
319 result: Err(anyhow!("error")),
320 sender: sender.clone(),
321 })
322 .add_proc(TestProc {
323 name: "proc-2-3",
324 delay: 500,
325 result: Ok(()),
326 sender: sender.clone(),
327 })
328 .add_proc(TestProc {
329 name: "proc-2-4",
330 delay: 500,
331 result: Ok(()),
332 sender: sender.clone(),
333 })
334 .build(),
335 )
336 .add_proc(
337 Supervisor::builder()
338 .add_proc(TestProc {
339 name: "proc-3-1",
340 delay: 1000,
341 result: Ok(()),
342 sender: sender.clone(),
343 })
344 .add_proc(TestProc {
345 name: "proc-3-2",
346 delay: 1000,
347 result: Ok(()),
348 sender: sender.clone(),
349 })
350 .add_proc(TestProc {
351 name: "proc-3-3",
352 delay: 1000,
353 result: Ok(()),
354 sender: sender.clone(),
355 })
356 .build(),
357 )
358 .build()
359 .start()
360 .await;
361
362 assert_eq!(Some("proc-2-2"), receiver.recv().await);
363 assert_eq!(Some("proc-2-4"), receiver.recv().await);
364 assert_eq!(Some("proc-2-3"), receiver.recv().await);
365 assert_eq!(Some("proc-2-1"), receiver.recv().await);
366 assert_eq!(Some("proc-3-3"), receiver.recv().await);
367 assert_eq!(Some("proc-3-2"), receiver.recv().await);
368 assert_eq!(Some("proc-3-1"), receiver.recv().await);
369 assert_eq!(Some("proc-1"), receiver.recv().await);
370 assert!(result.is_err());
371 }
372}