use std::any::Any;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
pub struct TaskGroup<E> {
new_task: mpsc::Sender<ChildHandle<E>>,
}
impl<E> Clone for TaskGroup<E> {
fn clone(&self) -> Self {
Self {
new_task: self.new_task.clone(),
}
}
}
impl<E: Send + 'static> TaskGroup<E> {
pub fn new() -> (Self, TaskManager<E>) {
let (new_task, reciever) = mpsc::channel(64);
let group = TaskGroup { new_task };
let manager = TaskManager::new(reciever);
(group, manager)
}
pub async fn spawn(
&self,
name: impl AsRef<str>,
f: impl Future<Output = Result<(), E>> + Send + 'static,
) -> Result<(), SpawnError> {
let name = name.as_ref().to_string();
let join = tokio::task::spawn(f);
match self.new_task.send(ChildHandle { name, join }).await {
Ok(()) => Ok(()),
Err(_child) => Err(SpawnError::GroupDied),
}
}
pub async fn spawn_on(
&self,
name: impl AsRef<str>,
runtime: tokio::runtime::Handle,
f: impl Future<Output = Result<(), E>> + Send + 'static,
) -> Result<(), SpawnError> {
let name = name.as_ref().to_string();
let join = runtime.spawn(f);
match self.new_task.send(ChildHandle { name, join }).await {
Ok(()) => Ok(()),
Err(_child) => Err(SpawnError::GroupDied),
}
}
pub async fn spawn_local(
&self,
name: impl AsRef<str>,
f: impl Future<Output = Result<(), E>> + 'static,
) -> Result<(), SpawnError> {
let name = name.as_ref().to_string();
let join = tokio::task::spawn_local(f);
match self.new_task.send(ChildHandle { name, join }).await {
Ok(()) => Ok(()),
Err(_child) => Err(SpawnError::GroupDied),
}
}
}
struct ChildHandle<E> {
name: String,
join: JoinHandle<Result<(), E>>,
}
impl<E> ChildHandle<E> {
pub fn pin_join(self: Pin<&mut Self>) -> Pin<&mut JoinHandle<Result<(), E>>> {
unsafe { self.map_unchecked_mut(|s| &mut s.join) }
}
fn cancel(&mut self) {
self.join.abort();
}
}
impl<E> Drop for ChildHandle<E> {
fn drop(&mut self) {
self.cancel()
}
}
pub struct TaskManager<E> {
channel: Option<mpsc::Receiver<ChildHandle<E>>>,
children: Vec<Pin<Box<ChildHandle<E>>>>,
}
impl<E> TaskManager<E> {
fn new(channel: mpsc::Receiver<ChildHandle<E>>) -> Self {
Self {
channel: Some(channel),
children: Vec::new(),
}
}
}
impl<E> Future for TaskManager<E> {
type Output = Result<(), RuntimeError<E>>;
fn poll(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
let mut s = self.as_mut();
if let Some(mut channel) = s.channel.take() {
s.channel = loop {
match channel.poll_recv(ctx) {
Poll::Pending => {
break Some(channel);
}
Poll::Ready(Some(new_child)) => {
s.children.push(Box::pin(new_child));
}
Poll::Ready(None) => {
break None;
}
}
};
}
let mut err = None;
let mut child_ix = 0;
while s.children.get(child_ix).is_some() {
let child = s
.children
.get_mut(child_ix)
.expect("precondition: child exists at index");
match child.as_mut().pin_join().poll(ctx) {
Poll::Pending => child_ix += 1,
Poll::Ready(Ok(Ok(()))) => {
let _ = s.children.swap_remove(child_ix);
}
Poll::Ready(Ok(Err(error))) => {
err = Some(RuntimeError::Application {
name: child.name.clone(),
error,
});
break;
}
Poll::Ready(Err(e)) => {
err = Some(match e.try_into_panic() {
Ok(panic) => RuntimeError::Panic {
name: child.name.clone(),
panic,
},
Err(_) => unreachable!("impossible to cancel tasks in TaskGroup"),
});
break;
}
}
}
if let Some(err) = err {
s.children.truncate(0);
s.channel.take();
Poll::Ready(Err(err))
} else if s.children.is_empty() {
if s.channel.is_none() {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
} else {
Poll::Pending
}
}
}
#[derive(Debug)]
pub enum RuntimeError<E> {
Panic {
name: String,
panic: Box<dyn Any + Send + 'static>,
},
Application {
name: String,
error: E,
},
}
impl<E: std::fmt::Display + std::error::Error> std::error::Error for RuntimeError<E> {}
impl<E: std::fmt::Display> std::fmt::Display for RuntimeError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
RuntimeError::Panic { name, .. } => {
write!(f, "Task `{}` panicked", name)
}
RuntimeError::Application { name, error } => {
write!(f, "Task `{}` errored: {}", name, error)
}
}
}
}
#[derive(Debug)]
pub enum SpawnError {
GroupDied,
}
impl std::error::Error for SpawnError {}
impl std::fmt::Display for SpawnError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SpawnError::GroupDied => write!(f, "Task group died"),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use anyhow::{anyhow, Error};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::time::{sleep, Duration};
#[tokio::test]
async fn no_task() {
let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
drop(tg);
assert!(tm.await.is_ok());
}
#[tokio::test]
async fn one_empty_task() {
let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
tg.spawn("empty", async move { Ok(()) }).await.unwrap();
drop(tg);
assert!(tm.await.is_ok());
}
#[tokio::test]
async fn empty_child() {
let (tg, tm): (TaskGroup<Error>, TaskManager<Error>) = TaskGroup::new();
tg.clone()
.spawn("parent", async move {
tg.spawn("child", async move { Ok(()) }).await.unwrap();
Ok(())
})
.await
.unwrap();
assert!(tm.await.is_ok());
}
#[tokio::test]
async fn many_nested_children() {
let log = Arc::new(Mutex::new(vec![0usize]));
let l = log.clone();
let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
tg.clone()
.spawn("root", async move {
let log = log.clone();
let tg2 = tg.clone();
log.lock().await.push(1);
tg.spawn("child", async move {
let tg3 = tg2.clone();
log.lock().await.push(2);
tg2.spawn("grandchild", async move {
log.lock().await.push(3);
tg3.spawn("great grandchild", async move {
log.lock().await.push(4);
Ok(())
})
.await
.unwrap();
Ok(())
})
.await
.unwrap();
Ok(())
})
.await
.unwrap();
Ok(())
})
.await
.unwrap();
assert!(tm.await.is_ok());
assert_eq!(*l.lock().await, vec![0usize, 1, 2, 3, 4]);
}
#[tokio::test]
async fn many_nested_children_error() {
let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
let l = log.clone();
let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
let tg2 = tg.clone();
tg.spawn("root", async move {
log.lock().await.push("in root");
let tg3 = tg2.clone();
tg2.spawn("child", async move {
log.lock().await.push("in child");
let tg4 = tg3.clone();
tg3.spawn("grandchild", async move {
log.lock().await.push("in grandchild");
tg4.spawn("great grandchild", async move {
log.lock().await.push("in great grandchild");
Err(anyhow!("sooner or later you get a failson"))
})
.await
.unwrap();
sleep(Duration::from_secs(1)).await;
unreachable!("sleepy grandchild should never wake");
})
.await
.unwrap();
Ok(())
})
.await
.unwrap();
Ok(())
})
.await
.unwrap();
drop(tg);
assert_eq!(format!("{:?}", tm.await),
"Err(Application { name: \"great grandchild\", error: sooner or later you get a failson })");
assert_eq!(
*l.lock().await,
vec![
"in root",
"in child",
"in grandchild",
"in great grandchild"
]
);
}
#[tokio::test]
async fn root_task_errors() {
let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
tg.spawn("root", async move { Err(anyhow!("idk!")) })
.await
.unwrap();
let res = tm.await;
assert!(res.is_err());
assert_eq!(
format!("{:?}", res),
"Err(Application { name: \"root\", error: idk! })"
);
}
#[tokio::test]
async fn child_task_errors() {
let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
tg.clone()
.spawn("parent", async move {
tg.spawn("child", async move { Err(anyhow!("whelp")) })
.await?;
Ok(())
})
.await
.unwrap();
let res = tm.await;
assert!(res.is_err());
assert_eq!(
format!("{:?}", res),
"Err(Application { name: \"child\", error: whelp })"
);
}
#[tokio::test]
async fn root_task_panics() {
let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
tg.spawn("root", async move { panic!("idk!") })
.await
.unwrap();
let res = tm.await;
assert!(res.is_err());
match res.err().unwrap() {
RuntimeError::Panic { name, panic } => {
assert_eq!(name, "root");
assert_eq!(*panic.downcast_ref::<&'static str>().unwrap(), "idk!");
}
e => panic!("wrong error variant! {:?}", e),
}
}
#[tokio::test]
async fn child_task_panics() {
let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
let tg2 = tg.clone();
tg.spawn("root", async move {
tg2.spawn("child", async move { panic!("whelp") }).await?;
Ok(())
})
.await
.unwrap();
let res = tm.await;
assert!(res.is_err());
match res.err().unwrap() {
RuntimeError::Panic { name, panic } => {
assert_eq!(name, "child");
assert_eq!(*panic.downcast_ref::<&'static str>().unwrap(), "whelp");
}
e => panic!("wrong error variant! {:?}", e),
}
}
#[tokio::test]
async fn child_sleep_no_timeout() {
let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
let l = log.clone();
let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
let tg2 = tg.clone();
tg.spawn("parent", async move {
tg2.spawn("child", async move {
log.lock().await.push("child gonna nap");
sleep(Duration::from_secs(1)).await;
log.lock().await.push("child woke up happy");
Ok(())
})
.await?;
Ok(())
})
.await
.unwrap();
drop(tg);
let res = tokio::time::timeout(Duration::from_secs(2), tm).await;
assert!(res.is_ok(), "no timeout");
assert!(res.unwrap().is_ok(), "returned successfully");
assert_eq!(
*l.lock().await,
vec!["child gonna nap", "child woke up happy"]
);
}
#[tokio::test]
async fn child_sleep_timeout() {
let log: Arc<Mutex<Vec<&'static str>>> = Arc::new(Mutex::new(vec![]));
let l = log.clone();
let (tg, tm): (TaskGroup<Error>, TaskManager<_>) = TaskGroup::new();
let tg2 = tg.clone();
tg.spawn("parent", async move {
tg2.spawn("child", async move {
log.lock().await.push("child gonna nap");
sleep(Duration::from_secs(2)).await;
unreachable!("child should not wake from this nap");
})
.await?;
Ok(())
})
.await
.unwrap();
let res = tokio::time::timeout(Duration::from_secs(1), tm).await;
assert!(res.is_err(), "timed out");
assert_eq!(*l.lock().await, vec!["child gonna nap"]);
}
}