1mod ordered_select_all;
2
3use crate::ordered_select_all::ordered_select_all;
4use anyhow::Result;
5use futures::{future::LocalBoxFuture, Future, FutureExt, StreamExt};
6use std::{
7 pin::{pin, Pin},
8 task::{Context, Poll},
9};
10use tokio::signal;
11use tokio_util::sync::{CancellationToken, WaitForCancellationFutureOwned};
12
13pub struct ShutdownSignal {
14 token: CancellationToken,
15 future: Option<Pin<Box<WaitForCancellationFutureOwned>>>,
16}
17
18impl ShutdownSignal {
19 pub fn new(token: CancellationToken) -> Self {
20 Self {
21 token,
22 future: None,
23 }
24 }
25
26 pub fn is_cancelled(&self) -> bool {
27 self.token.is_cancelled()
28 }
29
30 pub fn token(&self) -> &CancellationToken {
31 &self.token
32 }
33}
34
35impl Future for ShutdownSignal {
36 type Output = ();
37
38 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
39 if self.future.is_none() {
41 self.future = Some(Box::pin(self.token.clone().cancelled_owned()));
42 }
43
44 self.future.as_mut().unwrap().as_mut().poll(cx)
46 }
47}
48
49impl Clone for ShutdownSignal {
50 fn clone(&self) -> Self {
51 Self {
52 token: self.token.clone(),
53 future: None,
54 }
55 }
56}
57
58impl Unpin for ShutdownSignal {}
59
60pub trait ManagedProc {
61 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>>;
62}
63
64pub struct Supervisor {
65 procs: Vec<Box<dyn ManagedProc>>,
66}
67
68impl ManagedProc for Supervisor {
69 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
70 Box::pin(self.do_run(Box::pin(shutdown)))
71 }
72}
73
74pub struct SupervisorBuilder {
75 procs: Vec<Box<dyn ManagedProc>>,
76}
77
78struct CancelableLocalFuture {
79 cancel_token: CancellationToken,
80 future: LocalBoxFuture<'static, Result<()>>,
81}
82
83impl Future for CancelableLocalFuture {
84 type Output = Result<()>;
85
86 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
87 pin!(&mut self.future).poll(ctx)
88 }
89}
90
91impl<O, P> ManagedProc for P
92where
93 O: Future<Output = Result<()>> + 'static,
94 P: FnOnce(ShutdownSignal) -> O,
95{
96 fn run_proc(self: Box<Self>, shutdown: ShutdownSignal) -> LocalBoxFuture<'static, Result<()>> {
97 Box::pin(self(shutdown))
98 }
99}
100
101impl Default for Supervisor {
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl Supervisor {
108 pub fn new() -> Self {
109 Self { procs: Vec::new() }
110 }
111
112 pub fn builder() -> SupervisorBuilder {
113 SupervisorBuilder { procs: Vec::new() }
114 }
115
116 pub fn add(&mut self, proc: impl ManagedProc + 'static) {
117 self.procs.push(Box::new(proc));
118 }
119
120 pub async fn start(self) -> Result<()> {
121 let mut sigterm = signal::unix::signal(signal::unix::SignalKind::terminate())?;
122 let shutdown = Box::pin(
123 futures::future::select(
124 Box::pin(async move { sigterm.recv().await }),
125 Box::pin(signal::ctrl_c()),
126 )
127 .map(|_| ()),
128 );
129 self.do_run(shutdown).await
130 }
131
132 async fn do_run(self, mut shutdown: LocalBoxFuture<'static, ()>) -> Result<()> {
133 let mut futures = start_futures(self.procs);
134
135 loop {
136 if futures.is_empty() {
137 break;
138 }
139
140 let mut select = ordered_select_all(futures);
141
142 tokio::select! {
143 biased;
144 _ = &mut shutdown => return stop_all(select.into_inner()).await,
145 (result, _index, remaining) = &mut select => match result {
146 Ok(_) => futures = remaining,
147 Err(err) => {
148 let _ = stop_all(remaining).await;
149 return Err(err);
150 }
151 }
152 }
153 }
154
155 Ok(())
156 }
157}
158
159impl SupervisorBuilder {
160 pub fn add_proc(mut self, proc: impl ManagedProc + 'static) -> Self {
161 self.procs.push(Box::new(proc));
162 self
163 }
164
165 pub fn build(self) -> Supervisor {
166 Supervisor { procs: self.procs }
167 }
168}
169
170fn start_futures(procs: Vec<Box<dyn ManagedProc>>) -> Vec<CancelableLocalFuture> {
171 procs
172 .into_iter()
173 .map(|proc| {
174 let cancel_token = CancellationToken::new();
175 let child_token = cancel_token.child_token();
176 CancelableLocalFuture {
177 cancel_token,
178 future: proc.run_proc(ShutdownSignal::new(child_token)),
179 }
180 })
181 .collect()
182}
183
184async fn stop_all(procs: Vec<CancelableLocalFuture>) -> Result<()> {
185 futures::stream::iter(procs.into_iter().rev())
186 .then(|proc| async move {
187 proc.cancel_token.cancel();
188 proc.future.await
189 })
190 .collect::<Vec<_>>()
191 .await
192 .into_iter()
193 .collect()
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use anyhow::anyhow;
200 use futures::TryFutureExt;
201 use tokio::sync::mpsc;
202
203 struct TestProc {
204 name: &'static str,
205 delay: u64,
206 result: Result<()>,
207 sender: mpsc::Sender<&'static str>,
208 }
209
210 impl ManagedProc for TestProc {
211 fn run_proc(
212 self: Box<Self>,
213 shutdown: ShutdownSignal,
214 ) -> LocalBoxFuture<'static, Result<()>> {
215 let handle = tokio::spawn(async move {
216 tokio::select! {
217 _ = shutdown => (),
218 _ = tokio::time::sleep(std::time::Duration::from_millis(self.delay)) => (),
219 }
220 self.sender.send(self.name).await.expect("unable to send");
221 self.result
222 });
223
224 Box::pin(
225 handle
226 .map_err(|err| err.into())
227 .and_then(|result| async move { result }),
228 )
229 }
230 }
231
232 #[tokio::test]
233 async fn stop_when_all_tasks_have_completed() {
234 let (sender, mut receiver) = mpsc::channel(5);
235
236 let result = Supervisor::builder()
237 .add_proc(TestProc {
238 name: "1",
239 delay: 50,
240 result: Ok(()),
241 sender: sender.clone(),
242 })
243 .add_proc(TestProc {
244 name: "2",
245 delay: 100,
246 result: Ok(()),
247 sender: sender.clone(),
248 })
249 .build()
250 .start()
251 .await;
252
253 assert_eq!(Some("1"), receiver.recv().await);
254 assert_eq!(Some("2"), receiver.recv().await);
255 assert!(result.is_ok());
256 }
257
258 #[tokio::test]
259 async fn will_stop_all_in_reverse_order_after_error() {
260 let (sender, mut receiver) = mpsc::channel(5);
261
262 let result = Supervisor::builder()
263 .add_proc(TestProc {
264 name: "1",
265 delay: 1000,
266 result: Ok(()),
267 sender: sender.clone(),
268 })
269 .add_proc(TestProc {
270 name: "2",
271 delay: 50,
272 result: Err(anyhow!("error")),
273 sender: sender.clone(),
274 })
275 .add_proc(TestProc {
276 name: "3",
277 delay: 1000,
278 result: Ok(()),
279 sender: sender.clone(),
280 })
281 .build()
282 .start()
283 .await;
284
285 assert_eq!(Some("2"), receiver.recv().await);
286 assert_eq!(Some("3"), receiver.recv().await);
287 assert_eq!(Some("1"), receiver.recv().await);
288 assert_eq!("error", result.unwrap_err().to_string());
289 }
290
291 #[tokio::test]
292 async fn will_return_first_error_returned() {
293 let (sender, mut receiver) = mpsc::channel(5);
294
295 let result = Supervisor::builder()
296 .add_proc(TestProc {
297 name: "1",
298 delay: 1000,
299 result: Ok(()),
300 sender: sender.clone(),
301 })
302 .add_proc(TestProc {
303 name: "2",
304 delay: 50,
305 result: Err(anyhow!("error")),
306 sender: sender.clone(),
307 })
308 .add_proc(TestProc {
309 name: "3",
310 delay: 200,
311 result: Err(anyhow!("second error")),
312 sender: sender.clone(),
313 })
314 .build()
315 .start()
316 .await;
317
318 assert_eq!(Some("2"), receiver.recv().await);
319 assert_eq!(Some("3"), receiver.recv().await);
320 assert_eq!(Some("1"), receiver.recv().await);
321 assert_eq!("error", result.unwrap_err().to_string());
322 }
323
324 #[tokio::test]
325 async fn nested_procs_will_stop_parent_then_move_up() {
326 let (sender, mut receiver) = mpsc::channel(10);
327
328 let result = Supervisor::builder()
329 .add_proc(TestProc {
330 name: "proc-1",
331 delay: 500,
332 result: Ok(()),
333 sender: sender.clone(),
334 })
335 .add_proc(
336 Supervisor::builder()
337 .add_proc(TestProc {
338 name: "proc-2-1",
339 delay: 500,
340 result: Ok(()),
341 sender: sender.clone(),
342 })
343 .add_proc(TestProc {
344 name: "proc-2-2",
345 delay: 100,
346 result: Err(anyhow!("error")),
347 sender: sender.clone(),
348 })
349 .add_proc(TestProc {
350 name: "proc-2-3",
351 delay: 500,
352 result: Ok(()),
353 sender: sender.clone(),
354 })
355 .add_proc(TestProc {
356 name: "proc-2-4",
357 delay: 500,
358 result: Ok(()),
359 sender: sender.clone(),
360 })
361 .build(),
362 )
363 .add_proc(
364 Supervisor::builder()
365 .add_proc(TestProc {
366 name: "proc-3-1",
367 delay: 1000,
368 result: Ok(()),
369 sender: sender.clone(),
370 })
371 .add_proc(TestProc {
372 name: "proc-3-2",
373 delay: 1000,
374 result: Ok(()),
375 sender: sender.clone(),
376 })
377 .add_proc(TestProc {
378 name: "proc-3-3",
379 delay: 1000,
380 result: Ok(()),
381 sender: sender.clone(),
382 })
383 .build(),
384 )
385 .build()
386 .start()
387 .await;
388
389 assert_eq!(Some("proc-2-2"), receiver.recv().await);
390 assert_eq!(Some("proc-2-4"), receiver.recv().await);
391 assert_eq!(Some("proc-2-3"), receiver.recv().await);
392 assert_eq!(Some("proc-2-1"), receiver.recv().await);
393 assert_eq!(Some("proc-3-3"), receiver.recv().await);
394 assert_eq!(Some("proc-3-2"), receiver.recv().await);
395 assert_eq!(Some("proc-3-1"), receiver.recv().await);
396 assert_eq!(Some("proc-1"), receiver.recv().await);
397 assert!(result.is_err());
398 }
399}