task_group/
lib.rs

1use std::any::Any;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::sync::mpsc;
6use tokio::task::JoinHandle;
7
8/// A TaskGroup is used to spawn a collection of tasks. The collection has two properties:
9/// * if any task returns an error or panicks, all tasks are terminated.
10/// * if the `TaskManager` returned by `TaskGroup::new` is dropped, all tasks are terminated.
11pub struct TaskGroup<E> {
12    new_task: mpsc::Sender<ChildHandle<E>>,
13}
14// not the derived impl: E does not need to be Clone
15impl<E> Clone for TaskGroup<E> {
16    fn clone(&self) -> Self {
17        Self {
18            new_task: self.new_task.clone(),
19        }
20    }
21}
22
23impl<E: Send + 'static> TaskGroup<E> {
24    pub fn new() -> (Self, TaskManager<E>) {
25        let (new_task, reciever) = mpsc::channel(64);
26        let group = TaskGroup { new_task };
27        let manager = TaskManager::new(reciever);
28        (group, manager)
29    }
30
31    pub fn spawn(
32        &self,
33        name: impl AsRef<str>,
34        f: impl Future<Output = Result<(), E>> + Send + 'static,
35    ) -> impl Future<Output = Result<(), SpawnError>> + '_ {
36        let name = name.as_ref().to_string();
37        let join = tokio::task::spawn(f);
38        async move {
39            match self.new_task.send(ChildHandle { name, join }).await {
40                Ok(()) => Ok(()),
41                // If there is no receiver alive to manage the new task, drop the child in error to
42                // cancel it:
43                Err(_child) => Err(SpawnError::GroupDied),
44            }
45        }
46    }
47
48    pub fn spawn_on(
49        &self,
50        name: impl AsRef<str>,
51        runtime: tokio::runtime::Handle,
52        f: impl Future<Output = Result<(), E>> + Send + 'static,
53    ) -> impl Future<Output = Result<(), SpawnError>> + '_ {
54        let name = name.as_ref().to_string();
55        let join = runtime.spawn(f);
56        async move {
57            match self.new_task.send(ChildHandle { name, join }).await {
58                Ok(()) => Ok(()),
59                // If there is no receiver alive to manage the new task, drop the child in error to
60                // cancel it:
61                Err(_child) => Err(SpawnError::GroupDied),
62            }
63        }
64    }
65
66    pub fn spawn_local(
67        &self,
68        name: impl AsRef<str>,
69        f: impl Future<Output = Result<(), E>> + 'static,
70    ) -> impl Future<Output = Result<(), SpawnError>> + '_ {
71        let name = name.as_ref().to_string();
72        let join = tokio::task::spawn_local(f);
73        async move {
74            match self.new_task.send(ChildHandle { name, join }).await {
75                Ok(()) => Ok(()),
76                // If there is no receiver alive to manage the new task, drop the child in error to
77                // cancel it:
78                Err(_child) => Err(SpawnError::GroupDied),
79            }
80        }
81    }
82
83    /// Returns `true` if the task group has been shut down.
84    pub fn is_closed(&self) -> bool {
85        self.new_task.is_closed()
86    }
87}
88
89struct ChildHandle<E> {
90    name: String,
91    join: JoinHandle<Result<(), E>>,
92}
93
94impl<E> ChildHandle<E> {
95    // Pin projection. Since there is only this one required, avoid pulling in the proc macro.
96    pub fn pin_join(self: Pin<&mut Self>) -> Pin<&mut JoinHandle<Result<(), E>>> {
97        unsafe { self.map_unchecked_mut(|s| &mut s.join) }
98    }
99    fn cancel(&mut self) {
100        self.join.abort();
101    }
102}
103
104// As a consequence of this Drop impl, when a TaskManager is dropped, all of its children will be
105// canceled.
106impl<E> Drop for ChildHandle<E> {
107    fn drop(&mut self) {
108        self.cancel()
109    }
110}
111
112/// A TaskManager is used to manage a collection of tasks. There are two
113/// things you can do with it:
114/// * TaskManager impls Future, so you can poll or await on it. It will be
115/// Ready when all tasks return Ok(()) and the associated `TaskGroup` is
116/// dropped (so no more tasks can be created), or when any task panicks or
117/// returns an Err(E).
118/// * When a TaskManager is dropped, all tasks it contains are canceled
119/// (terminated). So, if you use a combinator like
120/// `tokio::time::timeout(duration, task_manager).await`, all tasks will be
121/// terminated if the timeout occurs.
122pub struct TaskManager<E> {
123    channel: Option<mpsc::Receiver<ChildHandle<E>>>,
124    children: Vec<Pin<Box<ChildHandle<E>>>>,
125}
126
127impl<E> TaskManager<E> {
128    fn new(channel: mpsc::Receiver<ChildHandle<E>>) -> Self {
129        Self {
130            channel: Some(channel),
131            children: Vec::new(),
132        }
133    }
134}
135
136impl<E> Future for TaskManager<E> {
137    type Output = Result<(), RuntimeError<E>>;
138    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
139        let mut s = self.as_mut();
140
141        // If the channel is still open, take it out of s to satisfy the borrow checker.
142        // We'll put it right back once we're done polling it.
143        if let Some(mut channel) = s.channel.take() {
144            // This loop processes each message in the channel until it is either empty
145            // or closed.
146            s.channel = loop {
147                match channel.poll_recv(ctx) {
148                    Poll::Pending => {
149                        // No more messages, but channel still open
150                        break Some(channel);
151                    }
152                    Poll::Ready(Some(new_child)) => {
153                        // Put element from channel into the children
154                        s.children.push(Box::pin(new_child));
155                    }
156                    Poll::Ready(None) => {
157                        // Channel has closed and all messages have been recieved. No
158                        // longer need channel.
159                        break None;
160                    }
161                }
162            };
163        }
164
165        // Need to mutate s after discovering error: store here temporarily
166        let mut err = None;
167        // Need to iterate through vec, possibly removing via swap_remove, so we cant use
168        // a normal iterator:
169        let mut child_ix = 0;
170        while s.children.get(child_ix).is_some() {
171            let child = s
172                .children
173                .get_mut(child_ix)
174                .expect("precondition: child exists at index");
175            match child.as_mut().pin_join().poll(ctx) {
176                // Pending children get retained - move to next
177                Poll::Pending => child_ix += 1,
178                // Child returns successfully: remove it from children.
179                // Then execute the loop body again with ix unchanged, because
180                // last element was swapped into child_ix.
181                Poll::Ready(Ok(Ok(()))) => {
182                    let _ = s.children.swap_remove(child_ix);
183                }
184                // Child returns with error: yield the error
185                Poll::Ready(Ok(Err(error))) => {
186                    err = Some(RuntimeError::Application {
187                        name: child.name.clone(),
188                        error,
189                    });
190                    break;
191                }
192                // Child join error: it either panicked or was canceled
193                Poll::Ready(Err(e)) => {
194                    err = Some(match e.try_into_panic() {
195                        Ok(panic) => RuntimeError::Panic {
196                            name: child.name.clone(),
197                            panic,
198                        },
199                        Err(_) => unreachable!("impossible to cancel tasks in TaskGroup"),
200                    });
201                    break;
202                }
203            }
204        }
205
206        if let Some(err) = err {
207            // Drop all children, and the channel reciever, current tasks are destroyed
208            // and new tasks cannot be created:
209            s.children.truncate(0);
210            s.channel.take();
211            // Return the error:
212            Poll::Ready(Err(err))
213        } else if s.children.is_empty() {
214            if s.channel.is_none() {
215                // Task manager is complete when there are no more children, and
216                // no more channel to get more children:
217                Poll::Ready(Ok(()))
218            } else {
219                // Channel is still pending, so we are not done:
220                Poll::Pending
221            }
222        } else {
223            Poll::Pending
224        }
225    }
226}
227
228#[derive(Debug)]
229pub enum RuntimeError<E> {
230    Panic {
231        name: String,
232        panic: Box<dyn Any + Send + 'static>,
233    },
234    Application {
235        name: String,
236        error: E,
237    },
238}
239impl<E: std::fmt::Display + std::error::Error> std::error::Error for RuntimeError<E> {}
240impl<E: std::fmt::Display> std::fmt::Display for RuntimeError<E> {
241    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
242        match self {
243            RuntimeError::Panic { name, .. } => {
244                write!(f, "Task `{}` panicked", name)
245            }
246            RuntimeError::Application { name, error } => {
247                write!(f, "Task `{}` errored: {}", name, error)
248            }
249        }
250    }
251}
252
253#[derive(Debug)]
254pub enum SpawnError {
255    GroupDied,
256}
257impl std::error::Error for SpawnError {}
258impl std::fmt::Display for SpawnError {
259    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
260        match self {
261            SpawnError::GroupDied => write!(f, "Task group died"),
262        }
263    }
264}
265
266#[cfg(test)]
267mod test {
268    use super::*;
269    use anyhow::{anyhow, Error};
270    use std::sync::Arc;
271    use tokio::sync::Mutex;
272    use tokio::time::{sleep, Duration};
273
274    #[tokio::test]
275    async fn no_task() {
276        let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
277        drop(tg); // Must drop the ability to spawn for the taskmanager to be finished
278        assert!(tm.await.is_ok());
279    }
280
281    #[tokio::test]
282    async fn one_empty_task() {
283        let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
284        tg.spawn("empty", async move { Ok(()) }).await.unwrap();
285        drop(tg); // Must drop the ability to spawn for the taskmanager to be finished
286        assert!(tm.await.is_ok());
287    }
288
289    #[tokio::test]
290    async fn empty_child() {
291        let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
292        tg.clone()
293            .spawn("parent", async move {
294                tg.spawn("child", async move { Ok(()) }).await.unwrap();
295                Ok(())
296            })
297            .await
298            .unwrap();
299        assert!(tm.await.is_ok());
300    }
301
302    #[tokio::test]
303    async fn many_nested_children() {
304        // Record a side-effect to demonstate that all of these children executed
305        let log = Arc::new(Mutex::new(vec![0usize]));
306        let l = log.clone();
307        let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
308        tg.clone()
309            .spawn("root", async move {
310                let log = log.clone();
311                let tg2 = tg.clone();
312                log.lock().await.push(1);
313                tg.spawn("child", async move {
314                    let tg3 = tg2.clone();
315                    log.lock().await.push(2);
316                    tg2.spawn("grandchild", async move {
317                        log.lock().await.push(3);
318                        tg3.spawn("great grandchild", async move {
319                            log.lock().await.push(4);
320                            Ok(())
321                        })
322                        .await
323                        .unwrap();
324                        Ok(())
325                    })
326                    .await
327                    .unwrap();
328                    Ok(())
329                })
330                .await
331                .unwrap();
332                Ok(())
333            })
334            .await
335            .unwrap();
336        assert!(tm.await.is_ok());
337        assert_eq!(*l.lock().await, vec![0usize, 1, 2, 3, 4]);
338    }
339    #[tokio::test]
340    async fn many_nested_children_error() {
341        // Record a side-effect to demonstate that all of these children executed
342        let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
343        let l = log.clone();
344
345        let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
346        let tg2 = tg.clone();
347        tg.spawn("root", async move {
348            log.lock().await.push("in root");
349            let tg3 = tg2.clone();
350            tg2.spawn("child", async move {
351                log.lock().await.push("in child");
352                let tg4 = tg3.clone();
353                tg3.spawn("grandchild", async move {
354                    log.lock().await.push("in grandchild");
355                    tg4.spawn("great grandchild", async move {
356                        log.lock().await.push("in great grandchild");
357                        Err(anyhow!("sooner or later you get a failson"))
358                    })
359                    .await
360                    .unwrap();
361                    sleep(Duration::from_secs(1)).await;
362                    // The great-grandchild returning error should terminate this task.
363                    unreachable!("sleepy grandchild should never wake");
364                })
365                .await
366                .unwrap();
367                Ok(())
368            })
369            .await
370            .unwrap();
371            Ok(())
372        })
373        .await
374        .unwrap();
375        drop(tg);
376        assert_eq!(format!("{:?}", tm.await),
377            "Err(Application { name: \"great grandchild\", error: sooner or later you get a failson })");
378        assert_eq!(
379            *l.lock().await,
380            vec![
381                "in root",
382                "in child",
383                "in grandchild",
384                "in great grandchild"
385            ]
386        );
387    }
388    #[tokio::test]
389    async fn root_task_errors() {
390        let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
391        tg.spawn("root", async move { Err(anyhow!("idk!")) })
392            .await
393            .unwrap();
394        let res = tm.await;
395        assert!(res.is_err());
396        assert_eq!(
397            format!("{:?}", res),
398            "Err(Application { name: \"root\", error: idk! })"
399        );
400    }
401
402    #[tokio::test]
403    async fn child_task_errors() {
404        let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
405        tg.clone()
406            .spawn("parent", async move {
407                tg.spawn("child", async move { Err(anyhow!("whelp")) })
408                    .await?;
409                Ok(())
410            })
411            .await
412            .unwrap();
413        let res = tm.await;
414        assert!(res.is_err());
415        assert_eq!(
416            format!("{:?}", res),
417            "Err(Application { name: \"child\", error: whelp })"
418        );
419    }
420
421    #[tokio::test]
422    async fn root_task_panics() {
423        let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
424        tg.spawn("root", async move { panic!("idk!") })
425            .await
426            .unwrap();
427
428        let res = tm.await;
429        assert!(res.is_err());
430        match res.err().unwrap() {
431            RuntimeError::Panic { name, panic } => {
432                assert_eq!(name, "root");
433                assert_eq!(*panic.downcast_ref::<&'static str>().unwrap(), "idk!");
434            }
435            e => panic!("wrong error variant! {:?}", e),
436        }
437    }
438
439    #[tokio::test]
440    async fn child_task_panics() {
441        let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
442        let tg2 = tg.clone();
443        tg.spawn("root", async move {
444            tg2.spawn("child", async move { panic!("whelp") }).await?;
445            Ok(())
446        })
447        .await
448        .unwrap();
449
450        let res = tm.await;
451        assert!(res.is_err());
452        match res.err().unwrap() {
453            RuntimeError::Panic { name, panic } => {
454                assert_eq!(name, "child");
455                assert_eq!(*panic.downcast_ref::<&'static str>().unwrap(), "whelp");
456            }
457            e => panic!("wrong error variant! {:?}", e),
458        }
459    }
460
461    #[tokio::test]
462    async fn child_sleep_no_timeout() {
463        // Record a side-effect to demonstate that all of these children executed
464        let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
465        let l = log.clone();
466        let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
467        let tg2 = tg.clone();
468        tg.spawn("parent", async move {
469            tg2.spawn("child", async move {
470                log.lock().await.push("child gonna nap");
471                sleep(Duration::from_secs(1)).await; // 1 sec sleep, 2 sec timeout
472                log.lock().await.push("child woke up happy");
473                Ok(())
474            })
475            .await?;
476            Ok(())
477        })
478        .await
479        .unwrap();
480
481        drop(tg); // Not going to launch anymore tasks
482        let res = tokio::time::timeout(Duration::from_secs(2), tm).await;
483        assert!(res.is_ok(), "no timeout");
484        assert!(res.unwrap().is_ok(), "returned successfully");
485        assert_eq!(
486            *l.lock().await,
487            vec!["child gonna nap", "child woke up happy"]
488        );
489    }
490
491    #[tokio::test]
492    async fn child_sleep_timeout() {
493        let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
494        let l = log.clone();
495
496        let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
497        let tg2 = tg.clone();
498        tg.spawn("parent", async move {
499            tg2.spawn("child", async move {
500                log.lock().await.push("child gonna nap");
501                sleep(Duration::from_secs(2)).await; // 2 sec sleep, 1 sec timeout
502                unreachable!("child should not wake from this nap");
503            })
504            .await?;
505            Ok(())
506        })
507        .await
508        .unwrap();
509
510        let res = tokio::time::timeout(Duration::from_secs(1), tm).await;
511        assert!(res.is_err(), "timed out");
512        assert_eq!(*l.lock().await, vec!["child gonna nap"]);
513    }
514
515    // This serves as a regression check for https://github.com/pchickey/task-group/issues/3
516    // I'm not sure how fragile this test will be, since it may
517    // depend on Rust and LLVM's ability to optimize out unused
518    // allocations like big_object.
519    #[test]
520    fn sizes_of_futures() {
521        use std::mem::size_of_val;
522        assert!(size_of_val(&big_future()) > size_of_val(&empty_future()));
523        assert_eq!(
524            size_of_val(&spawns_big_future_using_tokio()),
525            size_of_val(&spawns_empty_future_using_tokio())
526        );
527
528        assert_eq!(
529            size_of_val(&spawns_big_future_using_task_group()),
530            size_of_val(&spawns_empty_future_using_task_group())
531        );
532
533        async fn spawns_big_future_using_task_group() {
534            let (task_group, task_manager) = TaskGroup::new();
535            task_group.spawn("big future", big_future()).await.unwrap();
536            drop(task_group);
537            task_manager.await.unwrap();
538        }
539
540        async fn spawns_empty_future_using_task_group() {
541            let (task_group, task_manager) = TaskGroup::new();
542            task_group
543                .spawn("empty future", empty_future())
544                .await
545                .unwrap();
546            drop(task_group);
547            task_manager.await.unwrap();
548        }
549
550        async fn spawns_big_future_using_tokio() {
551            tokio::spawn(big_future()).await.unwrap().unwrap();
552        }
553
554        async fn spawns_empty_future_using_tokio() {
555            tokio::spawn(empty_future()).await.unwrap().unwrap();
556        }
557
558        async fn big_future() -> Result<(), ()> {
559            let big_object = [0_u8; 4096];
560            // Hold big_object across an await point
561            async { () }.await;
562            println!(
563                "printing big_object to keep value from being optimized out: {:?}",
564                big_object
565            );
566            drop(big_object);
567            Ok(())
568        }
569
570        async fn empty_future() -> Result<(), ()> {
571            Ok(())
572        }
573    }
574}