task_tracker/
inner_task.rs1use tokio::select;
2use tokio::sync::oneshot;
3use tokio::task::{JoinError, JoinHandle};
4use tracing::debug;
5use uuid::Uuid;
6
7pub struct InnerTask<K, R>
8where
9 K: Clone + Send + std::fmt::Debug + 'static,
10 R: Send + 'static,
11{
12 pub id: Uuid,
13 pub key: K,
14 stop_tx: oneshot::Sender<()>,
15 join_handle: JoinHandle<TaskResult<R>>,
16}
17
18impl<K, R> InnerTask<K, R>
19where
20 K: Clone + Send + std::fmt::Debug + 'static,
21 R: Send + 'static,
22{
23 pub fn new<Fut>(key: K, fut: Fut) -> Self
24 where
25 Fut: std::future::Future<Output = R> + Send + 'static,
26 {
27 let id = Uuid::new_v4();
28 let (stop_tx, stop_rx) = oneshot::channel::<()>();
29 debug!(?key, %id, "creating new task");
30 let key_ = key.clone();
31 let fut = fut;
32 let task = async move {
33 debug!(key = ?key_, task_id = %id, "task started");
34 let result = select! {
35 task_result = fut => {
36 debug!(key = ?key_, task_id = %id, "task finished");
37 TaskResult::Done(task_result)
38 },
39 _ = stop_rx => {
40 debug!(key = ?key_, task_id = %id, "task cancelled");
41 TaskResult::Cancelled
42 },
43 };
44 result
45 };
46 Self {
47 id,
48 key,
49 stop_tx,
50 join_handle: tokio::task::spawn(task),
51 }
52 }
53
54 pub async fn cancel_and_wait(self) -> TaskResult<R> {
55 debug!(key = ?self.key, task_id = %self.id, "waiting for task to finish");
56 self.stop_tx.send(()).unwrap();
57 match self.join_handle.await {
58 Ok(task_result) => task_result,
59 Err(join_error) => TaskResult::JoinError(join_error),
60 }
61 }
62
63 pub fn is_finished(&self) -> bool {
64 self.join_handle.is_finished()
65 }
66
67 pub async fn wait(self) -> TaskResult<R> {
68 match self.join_handle.await {
69 Ok(task_result) => task_result,
70 Err(join_error) => TaskResult::JoinError(join_error),
71 }
72 }
73}
74
75pub enum TaskResult<R>
76where
77 R: Send,
78{
79 Done(R),
80 Cancelled,
81 JoinError(JoinError),
82}
83
84impl<R> std::fmt::Debug for TaskResult<R>
85where
86 R: Send,
87{
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(
90 f,
91 "{}",
92 match self {
93 Self::Done(_) => "Done(..)",
94 Self::Cancelled => "Cancelled",
95 Self::JoinError(_) => "JoinError",
96 }
97 )
98 }
99}