1use std::{
2 future::{self, Future},
3 ops::{Deref, DerefMut},
4 sync::Arc,
5 time::Duration,
6};
7
8use serde::{Deserialize, Serialize};
9use tokio::{
10 select,
11 sync::mpsc::{self, Receiver, Sender},
12};
13use tokio_stream::{wrappers::errors::BroadcastStreamRecvError, StreamExt};
14
15use crate::{
16 notify::{AsyncLockable, Notify, NotifyArc},
17 MaybeSend,
18};
19
20pub trait Joinable<T> {
21 fn join(&mut self) -> impl std::future::Future<Output = Result<T, Error>>;
22}
23
24pub trait IsRunning {
25 fn running(&self) -> bool;
26}
27
28pub trait Abortable {
29 fn abort(&self);
30 fn abort_on_drop(mut self, abort: bool) -> Self
31 where
32 Self: Sized,
33 {
34 self.set_abort_on_drop(abort);
35 self
36 }
37
38 fn set_abort_on_drop(&mut self, abort: bool);
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
42pub enum TaskStatusError {
43 Lagged(u64),
44 Finished,
45}
46
47impl From<BroadcastStreamRecvError> for TaskStatusError {
48 fn from(value: BroadcastStreamRecvError) -> Self {
49 match value {
50 BroadcastStreamRecvError::Lagged(n) => TaskStatusError::Lagged(n),
51 }
52 }
53}
54#[allow(async_fn_in_trait)]
55pub trait Task<S> {
56 type AsyncLock: AsyncLockable<Status<S>>;
57
58 fn status(&self) -> &Arc<Self::AsyncLock>;
59 async fn running_status(&self) -> Result<NotifyArc<Status<S>>, TaskStatusError>;
60}
61
62#[derive(Debug)]
63pub enum Error {
64 Pending,
65 Aborted,
66 Completed,
67}
68
69#[derive(Serialize, Deserialize, derive_more::Debug, PartialEq, Eq, Clone)]
70#[serde(bound(serialize = "S: Serialize", deserialize = "S: Deserialize<'de>"))]
71pub enum Status<S> {
72 Pending,
73 Running(S),
74 Completed,
75 Aborted,
76}
77
78impl<S> Status<S> {
79 pub fn map<F, U>(self, f: F) -> Status<U>
80 where
81 F: FnOnce(S) -> U,
82 {
83 match self {
84 Status::Pending => Status::Pending,
85 Status::Running(v) => Status::Running(f(v)),
86 Status::Completed => Status::Completed,
87 Status::Aborted => Status::Aborted,
88 }
89 }
90}
91
92impl<S> Status<S> {
93 pub async fn with_state<'a, V, R, F>(&'a self, func: F) -> Result<V, Error>
94 where
95 F: FnOnce(&S) -> R + 'a,
96 R: Future<Output = V> + 'a,
97 {
98 let future = {
99 match self {
100 Status::Pending => Err(Error::Pending),
101 Status::Running(state) => Ok(func(state)),
102 Status::Completed => Err(Error::Completed),
103 Status::Aborted => Err(Error::Aborted),
104 }
105 };
106 match future {
107 Ok(future) => Ok(future.await),
108 Err(e) => Err(e),
109 }
110 }
111
112 pub fn running(&self) -> bool {
113 match self {
114 Status::Running(_) => true,
115 _ => false,
116 }
117 }
118
119 pub fn pending(&self) -> bool {
120 match self {
121 Status::Pending => true,
122 _ => false,
123 }
124 }
125}
126
127pub struct AsyncTask<T, S> {
128 abort_tx: Sender<()>,
129 output_rx: Receiver<Result<T, Error>>,
130 status: Arc<Notify<Status<S>>>,
131 abort_on_drop: bool,
132}
133
134impl<T, S> Default for AsyncTask<T, S> {
135 fn default() -> Self {
136 let (abort_tx, _) = mpsc::channel::<()>(1);
137 let (_, output_rx) = mpsc::channel::<Result<T, Error>>(1);
138
139 let status = Arc::new(Notify::new(Status::Pending));
140
141 AsyncTask {
142 abort_tx,
143 output_rx,
144 status,
145 abort_on_drop: true,
146 }
147 }
148}
149
150impl<T, S: 'static> AsyncTask<T, S> {
151 pub fn with_timeout(mut self, timeout: Duration) -> Self {
152 let status = Arc::get_mut(&mut self.status).unwrap();
153 status.set_timeout(timeout);
154 self
155 }
156}
157
158impl<T: MaybeSend + 'static, S: MaybeSend + Sync + 'static> AsyncTask<T, S> {
159 pub fn spawn<F: FnOnce(&S) -> U, U: Future<Output = T> + MaybeSend + 'static>(
160 &mut self,
161 state: S,
162 func: F,
163 ) {
164 let (abort_tx, mut abort_rx) = mpsc::channel::<()>(1);
165 let (output_tx, output_rx) = mpsc::channel::<Result<T, Error>>(1);
166 let future = func(&state);
167
168 self.abort_tx = abort_tx;
169 self.output_rx = output_rx;
170
171 spawn_platform({
172 let status = self.status.clone();
173 async move {
174 {
175 let mut lock = status.write().await;
176 *lock = Status::Running(state);
177 lock.notify();
178 if let Err(e) = lock.not_cloned(status.get_timeout()).await {
179 tracing::error!(
180 "Timeout waiting for writeable lock when starting task: {:?}",
181 e
182 );
183 }
184 };
185
186 let abort = async move {
187 if let None = abort_rx.recv().await {
188 future::pending::<()>().await;
189 }
190 };
191 let result = select! {
192 r = future => {
193 if let Ok(_) = output_tx.try_send(Ok(r)) {
194 Status::Completed
195 } else {
196 Status::Aborted
197 }
198 },
199 _ = abort => {
200 if let Ok(_) = output_tx.try_send(Err(Error::Aborted)) {
201 Status::Aborted
202 } else {
203 Status::Completed
204 }
205 },
206 };
207 {
208 *status.write().await = result
209 };
210 }
211 });
212 }
213}
214
215impl<T, S> Drop for AsyncTask<T, S> {
216 fn drop(&mut self) {
217 if self.abort_on_drop {
218 self.abort();
219 }
220 }
221}
222impl<A: Abortable, D: Deref<Target = A> + DerefMut> Abortable for D {
223 fn abort(&self) {
224 self.deref().abort()
225 }
226
227 fn set_abort_on_drop(&mut self, abort: bool) {
228 self.deref_mut().set_abort_on_drop(abort);
229 }
230}
231impl<T, S> Abortable for AsyncTask<T, S> {
232 fn abort(&self) {
233 let _ = self.abort_tx.try_send(());
234 }
235
236 fn set_abort_on_drop(&mut self, abort: bool) {
237 self.abort_on_drop = abort;
238 }
239}
240impl<T, S: Send + Sync + 'static> Task<S> for AsyncTask<T, S> {
241 type AsyncLock = crate::notify::Notify<Status<S>>;
242
243 fn status(&self) -> &Arc<Self::AsyncLock> {
244 &self.status
245 }
246
247 async fn running_status(&self) -> Result<NotifyArc<Status<S>>, TaskStatusError> {
248 let mut sub = self.status.subscribe().await;
249 while let Some(next) = sub.next().await {
250 let next = next?;
251 if next.running() {
252 return Ok(next);
253 }
254 }
255 Err(TaskStatusError::Finished)
256 }
257}
258
259impl<T: Send, S: Send + Sync + 'static> IsRunning for AsyncTask<T, S> {
260 fn running(&self) -> bool {
261 !self.output_rx.is_closed()
262 }
263}
264
265impl<T, S> Joinable<T> for AsyncTask<T, S> {
266 async fn join(&mut self) -> Result<T, Error> {
267 match self.output_rx.recv().await {
268 Some(r) => r,
269 None => Err(Error::Aborted),
270 }
271 }
272}
273
274pub fn spawn_with_state<
275 S: MaybeSend + Sync + 'static,
276 F: FnOnce(&S) -> U,
277 U: Future<Output = ()> + MaybeSend + 'static,
278>(
279 state: S,
280 func: F,
281) -> AsyncTask<(), S> {
282 spawn(state, func)
283}
284
285pub fn spawn_with_value<T: MaybeSend + 'static, U: Future<Output = T> + MaybeSend + 'static>(
286 future: U,
287) -> AsyncTask<T, ()> {
288 spawn((), |_| future)
289}
290
291pub fn spawn<
292 T: MaybeSend + 'static,
293 S: MaybeSend + Sync + 'static,
294 F: FnOnce(&S) -> U,
295 U: Future<Output = T> + MaybeSend + 'static,
296>(
297 state: S,
298 func: F,
299) -> AsyncTask<T, S> {
300 let mut task: AsyncTask<T, S> = Default::default();
301 task.spawn(state, func);
302 task
303}
304
305#[cfg(not(target_family = "wasm"))]
306fn spawn_platform<F: Future<Output = ()> + MaybeSend + 'static>(future: F) {
307 tokio::task::spawn(future);
308}
309
310#[cfg(target_family = "wasm")]
311fn spawn_platform<F: Future<Output = ()> + MaybeSend + 'static>(future: F) {
312 wasm_bindgen_futures::spawn_local(future);
313}
314
315#[cfg(test)]
316mod test {
317 use std::time::Duration;
318
319 use super::*;
320 use tracing_test::traced_test;
321
322 #[tokio::test]
323 #[traced_test]
324 async fn test_simple() {
325 let mut task = spawn(10, |num| {
326 let num = *num;
327 async move {
328 tokio::time::sleep(Duration::from_millis(num)).await;
329 11
330 }
331 });
332 assert_eq!(task.join().await.unwrap(), 11);
333 assert_eq!(*task.status().read().await.deref(), Status::Completed);
334
335 task.spawn(12, |num| {
336 let num = *num;
337 async move {
338 tokio::time::sleep(Duration::from_millis(num)).await;
339 13
340 }
341 });
342 assert_eq!(task.join().await.unwrap(), 13);
343 assert_eq!(*task.status().read().await.deref(), Status::Completed);
344 drop(task);
345 }
346}