twinkle_client/
task.rs

1use std::{
2    future::{self, Future},
3    ops::{Deref, DerefMut},
4    sync::Arc,
5    time::Duration,
6};
7
8use serde::{Deserialize, Serialize};
9use tokio::{
10    select,
11    sync::mpsc::{self, Receiver, Sender},
12};
13use tokio_stream::{wrappers::errors::BroadcastStreamRecvError, StreamExt};
14
15use crate::{
16    notify::{AsyncLockable, Notify, NotifyArc},
17    MaybeSend,
18};
19
20pub trait Joinable<T> {
21    fn join(&mut self) -> impl std::future::Future<Output = Result<T, Error>>;
22}
23
24pub trait IsRunning {
25    fn running(&self) -> bool;
26}
27
28pub trait Abortable {
29    fn abort(&self);
30    fn abort_on_drop(mut self, abort: bool) -> Self
31    where
32        Self: Sized,
33    {
34        self.set_abort_on_drop(abort);
35        self
36    }
37
38    fn set_abort_on_drop(&mut self, abort: bool);
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub enum TaskStatusError {
43    Lagged(u64),
44    Finished,
45}
46
47impl From<BroadcastStreamRecvError> for TaskStatusError {
48    fn from(value: BroadcastStreamRecvError) -> Self {
49        match value {
50            BroadcastStreamRecvError::Lagged(n) => TaskStatusError::Lagged(n),
51        }
52    }
53}
54#[allow(async_fn_in_trait)]
55pub trait Task<S> {
56    type AsyncLock: AsyncLockable<Status<S>>;
57
58    fn status(&self) -> &Arc<Self::AsyncLock>;
59    async fn running_status(&self) -> Result<NotifyArc<Status<S>>, TaskStatusError>;
60}
61
62#[derive(Debug)]
63pub enum Error {
64    Pending,
65    Aborted,
66    Completed,
67}
68
69#[derive(Serialize, Deserialize, derive_more::Debug, PartialEq, Eq, Clone)]
70#[serde(bound(serialize = "S: Serialize", deserialize = "S: Deserialize<'de>"))]
71pub enum Status<S> {
72    Pending,
73    Running(S),
74    Completed,
75    Aborted,
76}
77
78impl<S> Status<S> {
79    pub fn map<F, U>(self, f: F) -> Status<U>
80    where
81        F: FnOnce(S) -> U,
82    {
83        match self {
84            Status::Pending => Status::Pending,
85            Status::Running(v) => Status::Running(f(v)),
86            Status::Completed => Status::Completed,
87            Status::Aborted => Status::Aborted,
88        }
89    }
90}
91
92impl<S> Status<S> {
93    pub async fn with_state<'a, V, R, F>(&'a self, func: F) -> Result<V, Error>
94    where
95        F: FnOnce(&S) -> R + 'a,
96        R: Future<Output = V> + 'a,
97    {
98        let future = {
99            match self {
100                Status::Pending => Err(Error::Pending),
101                Status::Running(state) => Ok(func(state)),
102                Status::Completed => Err(Error::Completed),
103                Status::Aborted => Err(Error::Aborted),
104            }
105        };
106        match future {
107            Ok(future) => Ok(future.await),
108            Err(e) => Err(e),
109        }
110    }
111
112    pub fn running(&self) -> bool {
113        match self {
114            Status::Running(_) => true,
115            _ => false,
116        }
117    }
118
119    pub fn pending(&self) -> bool {
120        match self {
121            Status::Pending => true,
122            _ => false,
123        }
124    }
125}
126
127pub struct AsyncTask<T, S> {
128    abort_tx: Sender<()>,
129    output_rx: Receiver<Result<T, Error>>,
130    status: Arc<Notify<Status<S>>>,
131    abort_on_drop: bool,
132}
133
134impl<T, S> Default for AsyncTask<T, S> {
135    fn default() -> Self {
136        let (abort_tx, _) = mpsc::channel::<()>(1);
137        let (_, output_rx) = mpsc::channel::<Result<T, Error>>(1);
138
139        let status = Arc::new(Notify::new(Status::Pending));
140
141        AsyncTask {
142            abort_tx,
143            output_rx,
144            status,
145            abort_on_drop: true,
146        }
147    }
148}
149
150impl<T, S: 'static> AsyncTask<T, S> {
151    pub fn with_timeout(mut self, timeout: Duration) -> Self {
152        let status = Arc::get_mut(&mut self.status).unwrap();
153        status.set_timeout(timeout);
154        self
155    }
156}
157
158impl<T: MaybeSend + 'static, S: MaybeSend + Sync + 'static> AsyncTask<T, S> {
159    pub fn spawn<F: FnOnce(&S) -> U, U: Future<Output = T> + MaybeSend + 'static>(
160        &mut self,
161        state: S,
162        func: F,
163    ) {
164        let (abort_tx, mut abort_rx) = mpsc::channel::<()>(1);
165        let (output_tx, output_rx) = mpsc::channel::<Result<T, Error>>(1);
166        let future = func(&state);
167
168        self.abort_tx = abort_tx;
169        self.output_rx = output_rx;
170
171        spawn_platform({
172            let status = self.status.clone();
173            async move {
174                {
175                    let mut lock = status.write().await;
176                    *lock = Status::Running(state);
177                    lock.notify();
178                    if let Err(e) = lock.not_cloned(status.get_timeout()).await {
179                        tracing::error!(
180                            "Timeout waiting for writeable lock when starting task: {:?}",
181                            e
182                        );
183                    }
184                };
185
186                let abort = async move {
187                    if let None = abort_rx.recv().await {
188                        future::pending::<()>().await;
189                    }
190                };
191                let result = select! {
192                    r = future => {
193                        if let Ok(_) = output_tx.try_send(Ok(r)) {
194                            Status::Completed
195                        } else {
196                            Status::Aborted
197                        }
198                    },
199                    _ = abort => {
200                        if let Ok(_) =  output_tx.try_send(Err(Error::Aborted))  {
201                            Status::Aborted
202                        } else {
203                            Status::Completed
204                        }
205                     },
206                };
207                {
208                    *status.write().await = result
209                };
210            }
211        });
212    }
213}
214
215impl<T, S> Drop for AsyncTask<T, S> {
216    fn drop(&mut self) {
217        if self.abort_on_drop {
218            self.abort();
219        }
220    }
221}
222impl<A: Abortable, D: Deref<Target = A> + DerefMut> Abortable for D {
223    fn abort(&self) {
224        self.deref().abort()
225    }
226
227    fn set_abort_on_drop(&mut self, abort: bool) {
228        self.deref_mut().set_abort_on_drop(abort);
229    }
230}
231impl<T, S> Abortable for AsyncTask<T, S> {
232    fn abort(&self) {
233        let _ = self.abort_tx.try_send(());
234    }
235
236    fn set_abort_on_drop(&mut self, abort: bool) {
237        self.abort_on_drop = abort;
238    }
239}
240impl<T, S: Send + Sync + 'static> Task<S> for AsyncTask<T, S> {
241    type AsyncLock = crate::notify::Notify<Status<S>>;
242
243    fn status(&self) -> &Arc<Self::AsyncLock> {
244        &self.status
245    }
246
247    async fn running_status(&self) -> Result<NotifyArc<Status<S>>, TaskStatusError> {
248        let mut sub = self.status.subscribe().await;
249        while let Some(next) = sub.next().await {
250            let next = next?;
251            if next.running() {
252                return Ok(next);
253            }
254        }
255        Err(TaskStatusError::Finished)
256    }
257}
258
259impl<T: Send, S: Send + Sync + 'static> IsRunning for AsyncTask<T, S> {
260    fn running(&self) -> bool {
261        !self.output_rx.is_closed()
262    }
263}
264
265impl<T, S> Joinable<T> for AsyncTask<T, S> {
266    async fn join(&mut self) -> Result<T, Error> {
267        match self.output_rx.recv().await {
268            Some(r) => r,
269            None => Err(Error::Aborted),
270        }
271    }
272}
273
274pub fn spawn_with_state<
275    S: MaybeSend + Sync + 'static,
276    F: FnOnce(&S) -> U,
277    U: Future<Output = ()> + MaybeSend + 'static,
278>(
279    state: S,
280    func: F,
281) -> AsyncTask<(), S> {
282    spawn(state, func)
283}
284
285pub fn spawn_with_value<T: MaybeSend + 'static, U: Future<Output = T> + MaybeSend + 'static>(
286    future: U,
287) -> AsyncTask<T, ()> {
288    spawn((), |_| future)
289}
290
291pub fn spawn<
292    T: MaybeSend + 'static,
293    S: MaybeSend + Sync + 'static,
294    F: FnOnce(&S) -> U,
295    U: Future<Output = T> + MaybeSend + 'static,
296>(
297    state: S,
298    func: F,
299) -> AsyncTask<T, S> {
300    let mut task: AsyncTask<T, S> = Default::default();
301    task.spawn(state, func);
302    task
303}
304
305#[cfg(not(target_family = "wasm"))]
306fn spawn_platform<F: Future<Output = ()> + MaybeSend + 'static>(future: F) {
307    tokio::task::spawn(future);
308}
309
310#[cfg(target_family = "wasm")]
311fn spawn_platform<F: Future<Output = ()> + MaybeSend + 'static>(future: F) {
312    wasm_bindgen_futures::spawn_local(future);
313}
314
315#[cfg(test)]
316mod test {
317    use std::time::Duration;
318
319    use super::*;
320    use tracing_test::traced_test;
321
322    #[tokio::test]
323    #[traced_test]
324    async fn test_simple() {
325        let mut task = spawn(10, |num| {
326            let num = *num;
327            async move {
328                tokio::time::sleep(Duration::from_millis(num)).await;
329                11
330            }
331        });
332        assert_eq!(task.join().await.unwrap(), 11);
333        assert_eq!(*task.status().read().await.deref(), Status::Completed);
334
335        task.spawn(12, |num| {
336            let num = *num;
337            async move {
338                tokio::time::sleep(Duration::from_millis(num)).await;
339                13
340            }
341        });
342        assert_eq!(task.join().await.unwrap(), 13);
343        assert_eq!(*task.status().read().await.deref(), Status::Completed);
344        drop(task);
345    }
346}