workflow_task/
lib.rs

1// use workflow_core::task::*;
2use futures::Future;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, Mutex};
7use thiserror::Error;
8use workflow_core::channel::{
9    oneshot, Receiver, RecvError, SendError, Sender, TryRecvError, TrySendError,
10};
11pub use workflow_task_macros::{set_task, task};
12
13/// Errors produced by the [`Task`] implementation
14#[derive(Debug, Error)]
15pub enum TaskError {
16    #[error("The task is not running")]
17    NotRunning,
18    #[error("The task is already running")]
19    AlreadyRunning,
20    #[error("Task channel send error {0}")]
21    SendError(String),
22    #[error("Task channel receive error: {0:?}")]
23    RecvError(#[from] RecvError),
24    #[error("Task channel try send error: {0}")]
25    TrySendError(String),
26    #[error("Task channel try receive {0:?}")]
27    TryRecvError(#[from] TryRecvError),
28}
29
30impl<T> From<SendError<T>> for TaskError {
31    fn from(err: SendError<T>) -> Self {
32        TaskError::SendError(err.to_string())
33    }
34}
35
36impl<T> From<TrySendError<T>> for TaskError {
37    fn from(err: TrySendError<T>) -> Self {
38        TaskError::SendError(err.to_string())
39    }
40}
41
42/// Result type used by the [`Task`] implementation
43pub type TaskResult<T> = std::result::Result<T, TaskError>;
44
45pub type TaskFn<A, T> = Arc<Box<dyn Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static>>;
46pub type FnReturn<T> = Pin<Box<(dyn Send + Sync + 'static + Future<Output = T>)>>;
47
48struct TaskInner<A, T>
49where
50    A: Send,
51    T: 'static,
52{
53    termination: (Sender<()>, Receiver<()>),
54    completion: (Sender<T>, Receiver<T>),
55    running: Arc<AtomicBool>,
56    task_fn: Arc<Mutex<Option<TaskFn<A, T>>>>,
57    args: PhantomData<A>,
58}
59
60impl<A, T> TaskInner<A, T>
61where
62    A: Send + Sync + 'static,
63    T: Send + 'static,
64{
65    fn new_with_boxed_task_fn<FN>(task_fn: Box<FN>) -> Self
66    //TaskInner<A, T>
67    where
68        FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
69    {
70        let termination = oneshot();
71        let completion = oneshot();
72
73        TaskInner {
74            termination,
75            completion,
76            running: Arc::new(AtomicBool::new(false)),
77            task_fn: Arc::new(Mutex::new(Some(Arc::new(task_fn)))),
78            args: PhantomData,
79        }
80    }
81
82    pub fn blank() -> Self {
83        let termination = oneshot();
84        let completion = oneshot();
85        TaskInner {
86            termination,
87            completion,
88            running: Arc::new(AtomicBool::new(false)),
89            task_fn: Arc::new(Mutex::new(None)),
90            args: PhantomData,
91        }
92    }
93
94    fn task_fn(&self) -> TaskFn<A, T> {
95        self.task_fn
96            .lock()
97            .unwrap()
98            .as_ref()
99            .expect("Task::task_fn is not initialized")
100            .clone()
101    }
102
103    /// Replace task fn with an alternate function.
104    /// The passed function must be boxed.
105    fn set_boxed_task_fn(
106        &self,
107        task_fn: Box<dyn Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static>,
108    ) {
109        let task_fn = Arc::new(task_fn);
110        *self.task_fn.lock().unwrap() = Some(task_fn);
111    }
112
113    pub fn run<'l>(self: &'l Arc<Self>, args: A) -> TaskResult<&'l Arc<Self>> {
114        if !self.completion.1.is_empty() {
115            panic!("Task::run(): task completion channel is not empty");
116        }
117
118        if !self.termination.1.is_empty() {
119            panic!("Task::run(): task termination channel is not empty");
120        }
121
122        let this = self.clone();
123        let cb = self.task_fn();
124        workflow_core::task::spawn(async move {
125            this.running.store(true, Ordering::SeqCst);
126
127            let result = cb(args, this.termination.1.clone()).await;
128            this.running.store(false, Ordering::SeqCst);
129            this.completion
130                .0
131                .send(result)
132                .await
133                .expect("Error signaling task completion");
134        });
135
136        Ok(self)
137    }
138
139    pub fn stop(&self) -> TaskResult<()> {
140        if self.running.load(Ordering::SeqCst) {
141            self.termination.0.try_send(())?;
142        }
143        Ok(())
144    }
145
146    /// Blocks until the task exits. Resolves immediately
147    /// if the task is not running.
148    pub async fn join(&self) -> TaskResult<T> {
149        if self.running.load(Ordering::SeqCst) {
150            Ok(self.completion.1.recv().await?)
151        } else {
152            Err(TaskError::NotRunning)
153        }
154    }
155
156    /// Signals termination and blocks until the
157    /// task exits.
158    pub async fn stop_and_join(&self) -> TaskResult<T> {
159        if self.running.load(Ordering::SeqCst) {
160            self.termination.0.send(()).await?;
161            Ok(self.completion.1.recv().await?)
162        } else {
163            Err(TaskError::NotRunning)
164        }
165    }
166
167    pub fn is_running(&self) -> bool {
168        self.running.load(Ordering::SeqCst)
169    }
170}
171
172/// [`Task`]{self::Task} struct allows you to spawn an async fn that can run
173/// in a loop as a task (similar to a thread), checking for a
174/// termination signal (so that execution can be aborted),
175/// upon completion returning a value to the creator.
176///
177/// You can pass a [`channel`](workflow_core::channel::Receiver) as an argument to the async
178/// function if you wish to communicate with the task.
179///
180/// NOTE: You should always call `task.join().await` to await
181/// for the task completion if re-using the task.
182///
183/// ```rust
184/// use workflow_task::{task, TaskResult};
185///
186/// # #[tokio::test]
187/// # async fn test()->TaskResult<()>{
188///
189/// let task = task!(
190///     |args : (), stop : Receiver<()>| async move {
191///         let mut index = args;
192///         loop {
193///             if stop.try_recv().is_ok() {
194///                 break;
195///             }
196///             // ... do something ...
197///             index += 1;
198///         }
199///         return index;
200///     }
201/// );
202///
203/// // spawn the task instance ...
204/// // passing 256 as the `args` argument
205/// task.run(256)?;
206///
207/// // signal termination ...
208/// task.stop()?;
209///
210/// // await for the task completion ...
211/// // the `result` is the returned `index` value
212/// let result = task.join().await?;
213///
214/// // rinse and repeat if needed
215/// task.run(256)?;
216///
217/// # Ok(())
218/// # }
219///
220/// ```
221///
222#[derive(Clone)]
223pub struct Task<A, T>
224where
225    A: Send,
226    T: 'static,
227{
228    inner: Arc<TaskInner<A, T>>,
229}
230
231impl<A, T> Default for Task<A, T>
232where
233    A: Send + Sync + 'static,
234    T: Send + Sync + 'static,
235{
236    fn default() -> Self {
237        Task::blank()
238    }
239}
240
241impl<A, T> Task<A, T>
242where
243    A: Send + Sync + 'static,
244    T: Send + 'static,
245{
246    ///
247    /// Create a new [`Task`](self::Task) instance by supplying it with
248    /// an async closure that has 2 arguments:
249    /// ```rust
250    /// use workflow_task::task;
251    ///
252    /// task!(|args:bool, signal| async move {
253    ///     // ...
254    ///     return true;
255    /// });
256    /// ```
257    pub fn new<FN>(task_fn: FN) -> Task<A, T>
258    where
259        FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
260    {
261        Self::new_with_boxed_task_fn(Box::new(task_fn))
262    }
263
264    fn new_with_boxed_task_fn<FN>(task_fn: Box<FN>) -> Task<A, T>
265    where
266        FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
267    {
268        Task {
269            inner: Arc::new(TaskInner::new_with_boxed_task_fn(task_fn)),
270        }
271    }
272
273    /// Create an instance of the task without any task function.
274    /// The task function can be passed later via [`Task::set_task_fn()`].
275    pub fn blank() -> Self {
276        Task {
277            inner: Arc::new(TaskInner::blank()),
278        }
279    }
280
281    /// Replace task fn with an alternate function.
282    /// The task must be restarted for the replacement
283    /// to take effect.  The function passed does not
284    /// need to be boxed.
285    pub fn set_task_fn<FN>(&self, task_fn: FN)
286    where
287        FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
288    {
289        self.inner.set_boxed_task_fn(Box::new(task_fn))
290    }
291
292    /// Run the task supplying the provided argument to the
293    /// closure supplied at creation.
294    pub fn run(&self, args: A) -> TaskResult<&Self> {
295        self.inner.run(args)?;
296        Ok(self)
297    }
298
299    /// Signal termination on the channel supplied
300    /// to the task closure; The task has to check
301    /// for the signal periodically or await on
302    /// the future of the signal.
303    pub fn stop(&self) -> TaskResult<()> {
304        self.inner.stop()
305    }
306
307    /// Blocks until the task exits. Resolves immediately
308    /// if the task is not running.
309    pub async fn join(&self) -> TaskResult<T> {
310        self.inner.join().await
311    }
312
313    /// Signals termination and blocks until the
314    /// task exits.
315    pub async fn stop_and_join(&self) -> TaskResult<T> {
316        self.inner.stop_and_join().await
317    }
318
319    /// Returns `true` if the task is running, otherwise
320    /// returns `false`.
321    pub fn is_running(&self) -> bool {
322        self.inner.is_running()
323    }
324}
325
326#[cfg(not(target_arch = "wasm32"))]
327#[cfg(test)]
328mod test {
329
330    use super::*;
331    use std::time::Duration;
332
333    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
334    pub async fn test_task() {
335        let task = Task::new(|args, stop| -> FnReturn<String> {
336            Box::pin(async move {
337                println!("starting task... {}", args);
338                for i in 0..10 {
339                    if stop.try_recv().is_ok() {
340                        println!("stopping task...");
341                        break;
342                    }
343                    println!("t: {}", i);
344                    workflow_core::task::sleep(Duration::from_millis(500)).await;
345                }
346                println!("exiting task...");
347                format!("finished {args}")
348            })
349        });
350
351        task.run("- first -").ok();
352
353        for i in 0..5 {
354            println!("m: {}", i);
355            workflow_core::task::sleep(Duration::from_millis(500)).await;
356        }
357
358        let ret1 = task.join().await.expect("[ret1] task wait failed");
359        println!("ret1: {:?}", ret1);
360
361        task.stop().ok();
362
363        task.run("- second -").ok();
364
365        for i in 0..5 {
366            println!("m: {}", i);
367            workflow_core::task::sleep(Duration::from_millis(500)).await;
368        }
369
370        task.stop().ok();
371        let ret2 = task.join().await.expect("[ret2] task wait failed");
372        println!("ret2: {:?}", ret2);
373
374        println!("done");
375    }
376}