task_tracker/
inner_task.rs

1use tokio::select;
2use tokio::sync::oneshot;
3use tokio::task::{JoinError, JoinHandle};
4use tracing::debug;
5use uuid::Uuid;
6
7pub struct InnerTask<K, R>
8where
9    K: Clone + Send + std::fmt::Debug + 'static,
10    R: Send + 'static,
11{
12    pub id: Uuid,
13    pub key: K,
14    stop_tx: oneshot::Sender<()>,
15    join_handle: JoinHandle<TaskResult<R>>,
16}
17
18impl<K, R> InnerTask<K, R>
19where
20    K: Clone + Send + std::fmt::Debug + 'static,
21    R: Send + 'static,
22{
23    pub fn new<Fut>(key: K, fut: Fut) -> Self
24    where
25        Fut: std::future::Future<Output = R> + Send + 'static,
26    {
27        let id = Uuid::new_v4();
28        let (stop_tx, stop_rx) = oneshot::channel::<()>();
29        debug!(?key, %id, "creating new task");
30        let key_ = key.clone();
31        let fut = fut;
32        let task = async move {
33            debug!(key = ?key_, task_id = %id, "task started");
34            let result = select! {
35                task_result = fut => {
36                    debug!(key = ?key_, task_id = %id, "task finished");
37                    TaskResult::Done(task_result)
38                },
39                _ = stop_rx => {
40                    debug!(key = ?key_, task_id = %id, "task cancelled");
41                    TaskResult::Cancelled
42                },
43            };
44            result
45        };
46        Self {
47            id,
48            key,
49            stop_tx,
50            join_handle: tokio::task::spawn(task),
51        }
52    }
53
54    pub async fn cancel_and_wait(self) -> TaskResult<R> {
55        debug!(key = ?self.key, task_id = %self.id, "waiting for task to finish");
56        self.stop_tx.send(()).unwrap();
57        match self.join_handle.await {
58            Ok(task_result) => task_result,
59            Err(join_error) => TaskResult::JoinError(join_error),
60        }
61    }
62
63    pub fn is_finished(&self) -> bool {
64        self.join_handle.is_finished()
65    }
66
67    pub async fn wait(self) -> TaskResult<R> {
68        match self.join_handle.await {
69            Ok(task_result) => task_result,
70            Err(join_error) => TaskResult::JoinError(join_error),
71        }
72    }
73}
74
75pub enum TaskResult<R>
76where
77    R: Send,
78{
79    Done(R),
80    Cancelled,
81    JoinError(JoinError),
82}
83
84impl<R> std::fmt::Debug for TaskResult<R>
85where
86    R: Send,
87{
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(
90            f,
91            "{}",
92            match self {
93                Self::Done(_) => "Done(..)",
94                Self::Cancelled => "Cancelled",
95                Self::JoinError(_) => "JoinError",
96            }
97        )
98    }
99}