1use futures::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex};
7use thiserror::Error;
8use workflow_core::channel::{
9 oneshot, Receiver, RecvError, SendError, Sender, TryRecvError, TrySendError,
10};
11pub use workflow_task_macros::{set_task, task};
12
13#[derive(Debug, Error)]
15pub enum TaskError {
16 #[error("The task is not running")]
17 NotRunning,
18 #[error("The task is already running")]
19 AlreadyRunning,
20 #[error("Task channel send error {0}")]
21 SendError(String),
22 #[error("Task channel receive error: {0:?}")]
23 RecvError(#[from] RecvError),
24 #[error("Task channel try send error: {0}")]
25 TrySendError(String),
26 #[error("Task channel try receive {0:?}")]
27 TryRecvError(#[from] TryRecvError),
28}
29
30impl<T> From<SendError<T>> for TaskError {
31 fn from(err: SendError<T>) -> Self {
32 TaskError::SendError(err.to_string())
33 }
34}
35
36impl<T> From<TrySendError<T>> for TaskError {
37 fn from(err: TrySendError<T>) -> Self {
38 TaskError::SendError(err.to_string())
39 }
40}
41
42pub type TaskResult<T> = std::result::Result<T, TaskError>;
44
45pub type TaskFn<A, T> = Arc<Box<dyn Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static>>;
46pub type FnReturn<T> = Pin<Box<(dyn Send + Sync + 'static + Future<Output = T>)>>;
47
48struct TaskInner<A, T>
49where
50 A: Send,
51 T: 'static,
52{
53 termination: (Sender<()>, Receiver<()>),
54 completion: (Sender<T>, Receiver<T>),
55 running: Arc<AtomicBool>,
56 task_fn: Arc<Mutex<Option<TaskFn<A, T>>>>,
57 args: PhantomData<A>,
58}
59
60impl<A, T> TaskInner<A, T>
61where
62 A: Send + Sync + 'static,
63 T: Send + 'static,
64{
65 fn new_with_boxed_task_fn<FN>(task_fn: Box<FN>) -> Self
66 where
68 FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
69 {
70 let termination = oneshot();
71 let completion = oneshot();
72
73 TaskInner {
74 termination,
75 completion,
76 running: Arc::new(AtomicBool::new(false)),
77 task_fn: Arc::new(Mutex::new(Some(Arc::new(task_fn)))),
78 args: PhantomData,
79 }
80 }
81
82 pub fn blank() -> Self {
83 let termination = oneshot();
84 let completion = oneshot();
85 TaskInner {
86 termination,
87 completion,
88 running: Arc::new(AtomicBool::new(false)),
89 task_fn: Arc::new(Mutex::new(None)),
90 args: PhantomData,
91 }
92 }
93
94 fn task_fn(&self) -> TaskFn<A, T> {
95 self.task_fn
96 .lock()
97 .unwrap()
98 .as_ref()
99 .expect("Task::task_fn is not initialized")
100 .clone()
101 }
102
103 fn set_boxed_task_fn(
106 &self,
107 task_fn: Box<dyn Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static>,
108 ) {
109 let task_fn = Arc::new(task_fn);
110 *self.task_fn.lock().unwrap() = Some(task_fn);
111 }
112
113 pub fn run<'l>(self: &'l Arc<Self>, args: A) -> TaskResult<&'l Arc<Self>> {
114 if !self.completion.1.is_empty() {
115 panic!("Task::run(): task completion channel is not empty");
116 }
117
118 if !self.termination.1.is_empty() {
119 panic!("Task::run(): task termination channel is not empty");
120 }
121
122 let this = self.clone();
123 let cb = self.task_fn();
124 workflow_core::task::spawn(async move {
125 this.running.store(true, Ordering::SeqCst);
126
127 let result = cb(args, this.termination.1.clone()).await;
128 this.running.store(false, Ordering::SeqCst);
129 this.completion
130 .0
131 .send(result)
132 .await
133 .expect("Error signaling task completion");
134 });
135
136 Ok(self)
137 }
138
139 pub fn stop(&self) -> TaskResult<()> {
140 if self.running.load(Ordering::SeqCst) {
141 self.termination.0.try_send(())?;
142 }
143 Ok(())
144 }
145
146 pub async fn join(&self) -> TaskResult<T> {
149 if self.running.load(Ordering::SeqCst) {
150 Ok(self.completion.1.recv().await?)
151 } else {
152 Err(TaskError::NotRunning)
153 }
154 }
155
156 pub async fn stop_and_join(&self) -> TaskResult<T> {
159 if self.running.load(Ordering::SeqCst) {
160 self.termination.0.send(()).await?;
161 Ok(self.completion.1.recv().await?)
162 } else {
163 Err(TaskError::NotRunning)
164 }
165 }
166
167 pub fn is_running(&self) -> bool {
168 self.running.load(Ordering::SeqCst)
169 }
170}
171
172#[derive(Clone)]
223pub struct Task<A, T>
224where
225 A: Send,
226 T: 'static,
227{
228 inner: Arc<TaskInner<A, T>>,
229}
230
231impl<A, T> Default for Task<A, T>
232where
233 A: Send + Sync + 'static,
234 T: Send + Sync + 'static,
235{
236 fn default() -> Self {
237 Task::blank()
238 }
239}
240
241impl<A, T> Task<A, T>
242where
243 A: Send + Sync + 'static,
244 T: Send + 'static,
245{
246 pub fn new<FN>(task_fn: FN) -> Task<A, T>
258 where
259 FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
260 {
261 Self::new_with_boxed_task_fn(Box::new(task_fn))
262 }
263
264 fn new_with_boxed_task_fn<FN>(task_fn: Box<FN>) -> Task<A, T>
265 where
266 FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
267 {
268 Task {
269 inner: Arc::new(TaskInner::new_with_boxed_task_fn(task_fn)),
270 }
271 }
272
273 pub fn blank() -> Self {
276 Task {
277 inner: Arc::new(TaskInner::blank()),
278 }
279 }
280
281 pub fn set_task_fn<FN>(&self, task_fn: FN)
286 where
287 FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
288 {
289 self.inner.set_boxed_task_fn(Box::new(task_fn))
290 }
291
292 pub fn run(&self, args: A) -> TaskResult<&Self> {
295 self.inner.run(args)?;
296 Ok(self)
297 }
298
299 pub fn stop(&self) -> TaskResult<()> {
304 self.inner.stop()
305 }
306
307 pub async fn join(&self) -> TaskResult<T> {
310 self.inner.join().await
311 }
312
313 pub async fn stop_and_join(&self) -> TaskResult<T> {
316 self.inner.stop_and_join().await
317 }
318
319 pub fn is_running(&self) -> bool {
322 self.inner.is_running()
323 }
324}
325
326#[cfg(not(target_arch = "wasm32"))]
327#[cfg(test)]
328mod test {
329
330 use super::*;
331 use std::time::Duration;
332
333 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
334 pub async fn test_task() {
335 let task = Task::new(|args, stop| -> FnReturn<String> {
336 Box::pin(async move {
337 println!("starting task... {}", args);
338 for i in 0..10 {
339 if stop.try_recv().is_ok() {
340 println!("stopping task...");
341 break;
342 }
343 println!("t: {}", i);
344 workflow_core::task::sleep(Duration::from_millis(500)).await;
345 }
346 println!("exiting task...");
347 format!("finished {args}")
348 })
349 });
350
351 task.run("- first -").ok();
352
353 for i in 0..5 {
354 println!("m: {}", i);
355 workflow_core::task::sleep(Duration::from_millis(500)).await;
356 }
357
358 let ret1 = task.join().await.expect("[ret1] task wait failed");
359 println!("ret1: {:?}", ret1);
360
361 task.stop().ok();
362
363 task.run("- second -").ok();
364
365 for i in 0..5 {
366 println!("m: {}", i);
367 workflow_core::task::sleep(Duration::from_millis(500)).await;
368 }
369
370 task.stop().ok();
371 let ret2 = task.join().await.expect("[ret2] task wait failed");
372 println!("ret2: {:?}", ret2);
373
374 println!("done");
375 }
376}