1use std::any::Any;
2use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use tokio::sync::mpsc;
6use tokio::task::JoinHandle;
7
8pub struct TaskGroup<E> {
12 new_task: mpsc::Sender<ChildHandle<E>>,
13}
14impl<E> Clone for TaskGroup<E> {
16 fn clone(&self) -> Self {
17 Self {
18 new_task: self.new_task.clone(),
19 }
20 }
21}
22
23impl<E: Send + 'static> TaskGroup<E> {
24 pub fn new() -> (Self, TaskManager<E>) {
25 let (new_task, reciever) = mpsc::channel(64);
26 let group = TaskGroup { new_task };
27 let manager = TaskManager::new(reciever);
28 (group, manager)
29 }
30
31 pub fn spawn(
32 &self,
33 name: impl AsRef<str>,
34 f: impl Future<Output = Result<(), E>> + Send + 'static,
35 ) -> impl Future<Output = Result<(), SpawnError>> + '_ {
36 let name = name.as_ref().to_string();
37 let join = tokio::task::spawn(f);
38 async move {
39 match self.new_task.send(ChildHandle { name, join }).await {
40 Ok(()) => Ok(()),
41 Err(_child) => Err(SpawnError::GroupDied),
44 }
45 }
46 }
47
48 pub fn spawn_on(
49 &self,
50 name: impl AsRef<str>,
51 runtime: tokio::runtime::Handle,
52 f: impl Future<Output = Result<(), E>> + Send + 'static,
53 ) -> impl Future<Output = Result<(), SpawnError>> + '_ {
54 let name = name.as_ref().to_string();
55 let join = runtime.spawn(f);
56 async move {
57 match self.new_task.send(ChildHandle { name, join }).await {
58 Ok(()) => Ok(()),
59 Err(_child) => Err(SpawnError::GroupDied),
62 }
63 }
64 }
65
66 pub fn spawn_local(
67 &self,
68 name: impl AsRef<str>,
69 f: impl Future<Output = Result<(), E>> + 'static,
70 ) -> impl Future<Output = Result<(), SpawnError>> + '_ {
71 let name = name.as_ref().to_string();
72 let join = tokio::task::spawn_local(f);
73 async move {
74 match self.new_task.send(ChildHandle { name, join }).await {
75 Ok(()) => Ok(()),
76 Err(_child) => Err(SpawnError::GroupDied),
79 }
80 }
81 }
82
83 pub fn is_closed(&self) -> bool {
85 self.new_task.is_closed()
86 }
87}
88
89struct ChildHandle<E> {
90 name: String,
91 join: JoinHandle<Result<(), E>>,
92}
93
94impl<E> ChildHandle<E> {
95 pub fn pin_join(self: Pin<&mut Self>) -> Pin<&mut JoinHandle<Result<(), E>>> {
97 unsafe { self.map_unchecked_mut(|s| &mut s.join) }
98 }
99 fn cancel(&mut self) {
100 self.join.abort();
101 }
102}
103
104impl<E> Drop for ChildHandle<E> {
107 fn drop(&mut self) {
108 self.cancel()
109 }
110}
111
112pub struct TaskManager<E> {
123 channel: Option<mpsc::Receiver<ChildHandle<E>>>,
124 children: Vec<Pin<Box<ChildHandle<E>>>>,
125}
126
127impl<E> TaskManager<E> {
128 fn new(channel: mpsc::Receiver<ChildHandle<E>>) -> Self {
129 Self {
130 channel: Some(channel),
131 children: Vec::new(),
132 }
133 }
134}
135
136impl<E> Future for TaskManager<E> {
137 type Output = Result<(), RuntimeError<E>>;
138 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
139 let mut s = self.as_mut();
140
141 if let Some(mut channel) = s.channel.take() {
144 s.channel = loop {
147 match channel.poll_recv(ctx) {
148 Poll::Pending => {
149 break Some(channel);
151 }
152 Poll::Ready(Some(new_child)) => {
153 s.children.push(Box::pin(new_child));
155 }
156 Poll::Ready(None) => {
157 break None;
160 }
161 }
162 };
163 }
164
165 let mut err = None;
167 let mut child_ix = 0;
170 while s.children.get(child_ix).is_some() {
171 let child = s
172 .children
173 .get_mut(child_ix)
174 .expect("precondition: child exists at index");
175 match child.as_mut().pin_join().poll(ctx) {
176 Poll::Pending => child_ix += 1,
178 Poll::Ready(Ok(Ok(()))) => {
182 let _ = s.children.swap_remove(child_ix);
183 }
184 Poll::Ready(Ok(Err(error))) => {
186 err = Some(RuntimeError::Application {
187 name: child.name.clone(),
188 error,
189 });
190 break;
191 }
192 Poll::Ready(Err(e)) => {
194 err = Some(match e.try_into_panic() {
195 Ok(panic) => RuntimeError::Panic {
196 name: child.name.clone(),
197 panic,
198 },
199 Err(_) => unreachable!("impossible to cancel tasks in TaskGroup"),
200 });
201 break;
202 }
203 }
204 }
205
206 if let Some(err) = err {
207 s.children.truncate(0);
210 s.channel.take();
211 Poll::Ready(Err(err))
213 } else if s.children.is_empty() {
214 if s.channel.is_none() {
215 Poll::Ready(Ok(()))
218 } else {
219 Poll::Pending
221 }
222 } else {
223 Poll::Pending
224 }
225 }
226}
227
228#[derive(Debug)]
229pub enum RuntimeError<E> {
230 Panic {
231 name: String,
232 panic: Box<dyn Any + Send + 'static>,
233 },
234 Application {
235 name: String,
236 error: E,
237 },
238}
239impl<E: std::fmt::Display + std::error::Error> std::error::Error for RuntimeError<E> {}
240impl<E: std::fmt::Display> std::fmt::Display for RuntimeError<E> {
241 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
242 match self {
243 RuntimeError::Panic { name, .. } => {
244 write!(f, "Task `{}` panicked", name)
245 }
246 RuntimeError::Application { name, error } => {
247 write!(f, "Task `{}` errored: {}", name, error)
248 }
249 }
250 }
251}
252
253#[derive(Debug)]
254pub enum SpawnError {
255 GroupDied,
256}
257impl std::error::Error for SpawnError {}
258impl std::fmt::Display for SpawnError {
259 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
260 match self {
261 SpawnError::GroupDied => write!(f, "Task group died"),
262 }
263 }
264}
265
266#[cfg(test)]
267mod test {
268 use super::*;
269 use anyhow::{anyhow, Error};
270 use std::sync::Arc;
271 use tokio::sync::Mutex;
272 use tokio::time::{sleep, Duration};
273
274 #[tokio::test]
275 async fn no_task() {
276 let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
277 drop(tg); assert!(tm.await.is_ok());
279 }
280
281 #[tokio::test]
282 async fn one_empty_task() {
283 let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
284 tg.spawn("empty", async move { Ok(()) }).await.unwrap();
285 drop(tg); assert!(tm.await.is_ok());
287 }
288
289 #[tokio::test]
290 async fn empty_child() {
291 let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
292 tg.clone()
293 .spawn("parent", async move {
294 tg.spawn("child", async move { Ok(()) }).await.unwrap();
295 Ok(())
296 })
297 .await
298 .unwrap();
299 assert!(tm.await.is_ok());
300 }
301
302 #[tokio::test]
303 async fn many_nested_children() {
304 let log = Arc::new(Mutex::new(vec![0usize]));
306 let l = log.clone();
307 let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
308 tg.clone()
309 .spawn("root", async move {
310 let log = log.clone();
311 let tg2 = tg.clone();
312 log.lock().await.push(1);
313 tg.spawn("child", async move {
314 let tg3 = tg2.clone();
315 log.lock().await.push(2);
316 tg2.spawn("grandchild", async move {
317 log.lock().await.push(3);
318 tg3.spawn("great grandchild", async move {
319 log.lock().await.push(4);
320 Ok(())
321 })
322 .await
323 .unwrap();
324 Ok(())
325 })
326 .await
327 .unwrap();
328 Ok(())
329 })
330 .await
331 .unwrap();
332 Ok(())
333 })
334 .await
335 .unwrap();
336 assert!(tm.await.is_ok());
337 assert_eq!(*l.lock().await, vec![0usize, 1, 2, 3, 4]);
338 }
339 #[tokio::test]
340 async fn many_nested_children_error() {
341 let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
343 let l = log.clone();
344
345 let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
346 let tg2 = tg.clone();
347 tg.spawn("root", async move {
348 log.lock().await.push("in root");
349 let tg3 = tg2.clone();
350 tg2.spawn("child", async move {
351 log.lock().await.push("in child");
352 let tg4 = tg3.clone();
353 tg3.spawn("grandchild", async move {
354 log.lock().await.push("in grandchild");
355 tg4.spawn("great grandchild", async move {
356 log.lock().await.push("in great grandchild");
357 Err(anyhow!("sooner or later you get a failson"))
358 })
359 .await
360 .unwrap();
361 sleep(Duration::from_secs(1)).await;
362 unreachable!("sleepy grandchild should never wake");
364 })
365 .await
366 .unwrap();
367 Ok(())
368 })
369 .await
370 .unwrap();
371 Ok(())
372 })
373 .await
374 .unwrap();
375 drop(tg);
376 assert_eq!(format!("{:?}", tm.await),
377 "Err(Application { name: \"great grandchild\", error: sooner or later you get a failson })");
378 assert_eq!(
379 *l.lock().await,
380 vec![
381 "in root",
382 "in child",
383 "in grandchild",
384 "in great grandchild"
385 ]
386 );
387 }
388 #[tokio::test]
389 async fn root_task_errors() {
390 let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
391 tg.spawn("root", async move { Err(anyhow!("idk!")) })
392 .await
393 .unwrap();
394 let res = tm.await;
395 assert!(res.is_err());
396 assert_eq!(
397 format!("{:?}", res),
398 "Err(Application { name: \"root\", error: idk! })"
399 );
400 }
401
402 #[tokio::test]
403 async fn child_task_errors() {
404 let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
405 tg.clone()
406 .spawn("parent", async move {
407 tg.spawn("child", async move { Err(anyhow!("whelp")) })
408 .await?;
409 Ok(())
410 })
411 .await
412 .unwrap();
413 let res = tm.await;
414 assert!(res.is_err());
415 assert_eq!(
416 format!("{:?}", res),
417 "Err(Application { name: \"child\", error: whelp })"
418 );
419 }
420
421 #[tokio::test]
422 async fn root_task_panics() {
423 let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
424 tg.spawn("root", async move { panic!("idk!") })
425 .await
426 .unwrap();
427
428 let res = tm.await;
429 assert!(res.is_err());
430 match res.err().unwrap() {
431 RuntimeError::Panic { name, panic } => {
432 assert_eq!(name, "root");
433 assert_eq!(*panic.downcast_ref::<&'static str>().unwrap(), "idk!");
434 }
435 e => panic!("wrong error variant! {:?}", e),
436 }
437 }
438
439 #[tokio::test]
440 async fn child_task_panics() {
441 let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
442 let tg2 = tg.clone();
443 tg.spawn("root", async move {
444 tg2.spawn("child", async move { panic!("whelp") }).await?;
445 Ok(())
446 })
447 .await
448 .unwrap();
449
450 let res = tm.await;
451 assert!(res.is_err());
452 match res.err().unwrap() {
453 RuntimeError::Panic { name, panic } => {
454 assert_eq!(name, "child");
455 assert_eq!(*panic.downcast_ref::<&'static str>().unwrap(), "whelp");
456 }
457 e => panic!("wrong error variant! {:?}", e),
458 }
459 }
460
461 #[tokio::test]
462 async fn child_sleep_no_timeout() {
463 let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
465 let l = log.clone();
466 let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
467 let tg2 = tg.clone();
468 tg.spawn("parent", async move {
469 tg2.spawn("child", async move {
470 log.lock().await.push("child gonna nap");
471 sleep(Duration::from_secs(1)).await; log.lock().await.push("child woke up happy");
473 Ok(())
474 })
475 .await?;
476 Ok(())
477 })
478 .await
479 .unwrap();
480
481 drop(tg); let res = tokio::time::timeout(Duration::from_secs(2), tm).await;
483 assert!(res.is_ok(), "no timeout");
484 assert!(res.unwrap().is_ok(), "returned successfully");
485 assert_eq!(
486 *l.lock().await,
487 vec!["child gonna nap", "child woke up happy"]
488 );
489 }
490
491 #[tokio::test]
492 async fn child_sleep_timeout() {
493 let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
494 let l = log.clone();
495
496 let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
497 let tg2 = tg.clone();
498 tg.spawn("parent", async move {
499 tg2.spawn("child", async move {
500 log.lock().await.push("child gonna nap");
501 sleep(Duration::from_secs(2)).await; unreachable!("child should not wake from this nap");
503 })
504 .await?;
505 Ok(())
506 })
507 .await
508 .unwrap();
509
510 let res = tokio::time::timeout(Duration::from_secs(1), tm).await;
511 assert!(res.is_err(), "timed out");
512 assert_eq!(*l.lock().await, vec!["child gonna nap"]);
513 }
514
515 #[test]
520 fn sizes_of_futures() {
521 use std::mem::size_of_val;
522 assert!(size_of_val(&big_future()) > size_of_val(&empty_future()));
523 assert_eq!(
524 size_of_val(&spawns_big_future_using_tokio()),
525 size_of_val(&spawns_empty_future_using_tokio())
526 );
527
528 assert_eq!(
529 size_of_val(&spawns_big_future_using_task_group()),
530 size_of_val(&spawns_empty_future_using_task_group())
531 );
532
533 async fn spawns_big_future_using_task_group() {
534 let (task_group, task_manager) = TaskGroup::new();
535 task_group.spawn("big future", big_future()).await.unwrap();
536 drop(task_group);
537 task_manager.await.unwrap();
538 }
539
540 async fn spawns_empty_future_using_task_group() {
541 let (task_group, task_manager) = TaskGroup::new();
542 task_group
543 .spawn("empty future", empty_future())
544 .await
545 .unwrap();
546 drop(task_group);
547 task_manager.await.unwrap();
548 }
549
550 async fn spawns_big_future_using_tokio() {
551 tokio::spawn(big_future()).await.unwrap().unwrap();
552 }
553
554 async fn spawns_empty_future_using_tokio() {
555 tokio::spawn(empty_future()).await.unwrap().unwrap();
556 }
557
558 async fn big_future() -> Result<(), ()> {
559 let big_object = [0_u8; 4096];
560 async { () }.await;
562 println!(
563 "printing big_object to keep value from being optimized out: {:?}",
564 big_object
565 );
566 drop(big_object);
567 Ok(())
568 }
569
570 async fn empty_future() -> Result<(), ()> {
571 Ok(())
572 }
573 }
574}