use futures::Future;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use thiserror::Error;
use workflow_core::channel::{
oneshot, Receiver, RecvError, SendError, Sender, TryRecvError, TrySendError,
};
pub use workflow_task_macros::{set_task, task};
#[derive(Debug, Error)]
pub enum TaskError {
#[error("The task is not running")]
NotRunning,
#[error("The task is already running")]
AlreadyRunning,
#[error("Task channel send error {0}")]
SendError(String),
#[error("Task channel receive error: {0:?}")]
RecvError(#[from] RecvError),
#[error("Task channel try send error: {0}")]
TrySendError(String),
#[error("Task channel try receive {0:?}")]
TryRecvError(#[from] TryRecvError),
}
impl<T> From<SendError<T>> for TaskError {
fn from(err: SendError<T>) -> Self {
TaskError::SendError(err.to_string())
}
}
impl<T> From<TrySendError<T>> for TaskError {
fn from(err: TrySendError<T>) -> Self {
TaskError::SendError(err.to_string())
}
}
pub type TaskResult<T> = std::result::Result<T, TaskError>;
pub type TaskFn<A, T> = Arc<Box<dyn Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static>>;
pub type FnReturn<T> = Pin<Box<(dyn Send + Sync + 'static + Future<Output = T>)>>;
struct TaskInner<A, T>
where
A: Send,
T: 'static,
{
termination: (Sender<()>, Receiver<()>),
completion: (Sender<T>, Receiver<T>),
running: Arc<AtomicBool>,
task_fn: Arc<Mutex<Option<TaskFn<A, T>>>>,
args: PhantomData<A>,
}
impl<A, T> TaskInner<A, T>
where
A: Send + Sync + 'static,
T: Send + 'static,
{
fn new_with_boxed_task_fn<FN>(task_fn: Box<FN>) -> Self
where
FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
{
let termination = oneshot();
let completion = oneshot();
TaskInner {
termination,
completion,
running: Arc::new(AtomicBool::new(false)),
task_fn: Arc::new(Mutex::new(Some(Arc::new(task_fn)))),
args: PhantomData,
}
}
pub fn blank() -> Self {
let termination = oneshot();
let completion = oneshot();
TaskInner {
termination,
completion,
running: Arc::new(AtomicBool::new(false)),
task_fn: Arc::new(Mutex::new(None)),
args: PhantomData,
}
}
fn task_fn(&self) -> TaskFn<A, T> {
self.task_fn
.lock()
.unwrap()
.as_ref()
.expect("Task::task_fn is not initialized")
.clone()
}
fn set_boxed_task_fn(
&self,
task_fn: Box<dyn Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static>,
) {
let task_fn = Arc::new(task_fn);
*self.task_fn.lock().unwrap() = Some(task_fn);
}
pub fn run<'l>(self: &'l Arc<Self>, args: A) -> TaskResult<&'l Arc<Self>> {
if !self.completion.1.is_empty() {
panic!("Task::run(): task completion channel is not empty");
}
if !self.termination.1.is_empty() {
panic!("Task::run(): task termination channel is not empty");
}
let this = self.clone();
let cb = self.task_fn();
workflow_core::task::spawn(async move {
this.running.store(true, Ordering::SeqCst);
let result = cb(args, this.termination.1.clone()).await;
this.running.store(false, Ordering::SeqCst);
this.completion
.0
.send(result)
.await
.expect("Error signaling task completion");
});
Ok(self)
}
pub fn stop(&self) -> TaskResult<()> {
if self.running.load(Ordering::SeqCst) {
self.termination.0.try_send(())?;
}
Ok(())
}
pub async fn join(&self) -> TaskResult<T> {
if self.running.load(Ordering::SeqCst) {
Ok(self.completion.1.recv().await?)
} else {
Err(TaskError::NotRunning)
}
}
pub async fn stop_and_join(&self) -> TaskResult<T> {
if self.running.load(Ordering::SeqCst) {
self.termination.0.send(()).await?;
Ok(self.completion.1.recv().await?)
} else {
Err(TaskError::NotRunning)
}
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::SeqCst)
}
}
#[derive(Clone)]
pub struct Task<A, T>
where
A: Send,
T: 'static,
{
inner: Arc<TaskInner<A, T>>,
}
impl<A, T> Default for Task<A, T>
where
A: Send + Sync + 'static,
T: Send + Sync + 'static,
{
fn default() -> Self {
Task::blank()
}
}
impl<A, T> Task<A, T>
where
A: Send + Sync + 'static,
T: Send + 'static,
{
pub fn new<FN>(task_fn: FN) -> Task<A, T>
where
FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
{
Self::new_with_boxed_task_fn(Box::new(task_fn))
}
fn new_with_boxed_task_fn<FN>(task_fn: Box<FN>) -> Task<A, T>
where
FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
{
Task {
inner: Arc::new(TaskInner::new_with_boxed_task_fn(task_fn)),
}
}
pub fn blank() -> Self {
Task {
inner: Arc::new(TaskInner::blank()),
}
}
pub fn set_task_fn<FN>(&self, task_fn: FN)
where
FN: Send + Sync + Fn(A, Receiver<()>) -> FnReturn<T> + 'static,
{
self.inner.set_boxed_task_fn(Box::new(task_fn))
}
pub fn run(&self, args: A) -> TaskResult<&Self> {
self.inner.run(args)?;
Ok(self)
}
pub fn stop(&self) -> TaskResult<()> {
self.inner.stop()
}
pub async fn join(&self) -> TaskResult<T> {
self.inner.join().await
}
pub async fn stop_and_join(&self) -> TaskResult<T> {
self.inner.stop_and_join().await
}
pub fn is_running(&self) -> bool {
self.inner.is_running()
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(test)]
mod test {
use super::*;
use std::time::Duration;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
pub async fn test_task() {
let task = Task::new(|args, stop| -> FnReturn<String> {
Box::pin(async move {
println!("starting task... {}", args);
for i in 0..10 {
if stop.try_recv().is_ok() {
println!("stopping task...");
break;
}
println!("t: {}", i);
workflow_core::task::sleep(Duration::from_millis(500)).await;
}
println!("exiting task...");
return format!("finished {args}");
})
});
task.run("- first -").ok();
for i in 0..5 {
println!("m: {}", i);
workflow_core::task::sleep(Duration::from_millis(500)).await;
}
let ret1 = task.join().await.expect("[ret1] task wait failed");
println!("ret1: {:?}", ret1);
task.stop().ok();
task.run("- second -").ok();
for i in 0..5 {
println!("m: {}", i);
workflow_core::task::sleep(Duration::from_millis(500)).await;
}
task.stop().ok();
let ret2 = task.join().await.expect("[ret2] task wait failed");
println!("ret2: {:?}", ret2);
println!("done");
}
}