singleton_task/
lib.rs

1pub use std::sync::mpsc::{Receiver, SyncSender};
2use std::{error::Error, fmt::Display, sync::mpsc::sync_channel};
3
4pub use async_trait::async_trait;
5use tokio::select;
6
7use context::{FutureTaskState, State};
8use log::{trace, warn};
9
10mod context;
11mod task_chan;
12
13pub use context::Context;
14use task_chan::{TaskReceiver, TaskSender, task_channel};
15
16pub trait TError: Error + Clone + Send + 'static {}
17
18#[derive(Debug, Clone)]
19pub enum TaskError<E: TError> {
20    Cancelled,
21    Error(E),
22}
23
24impl<E: TError> From<E> for TaskError<E> {
25    fn from(value: E) -> Self {
26        Self::Error(value)
27    }
28}
29
30impl<E: TError> Display for TaskError<E> {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Cancelled => write!(f, "Cancelled"),
34            Self::Error(e) => write!(f, "{}", e),
35        }
36    }
37}
38
39pub trait TaskBuilder {
40    type Output: Send + 'static;
41    type Error: TError;
42    type Task: Task<Self::Error>;
43
44    fn build(self, tx: SyncSender<Self::Output>) -> Self::Task;
45    fn channel_size(&self) -> usize {
46        10
47    }
48}
49
50#[async_trait]
51pub trait Task<E: TError>: Send + 'static {
52    async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
53        drop(ctx);
54        trace!("on_start");
55        Ok(())
56    }
57    async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
58        drop(ctx);
59        trace!("on_stop");
60        Ok(())
61    }
62}
63
64struct TaskBox<E: TError> {
65    task: Box<dyn Task<E>>,
66    ctx: Context<E>,
67}
68
69struct WaitingTask<E: TError> {
70    task: TaskBox<E>,
71}
72
73#[derive(Clone)]
74pub struct SingletonTask<E: TError> {
75    tx: TaskSender<E>,
76}
77
78impl<E: TError> SingletonTask<E> {
79    pub fn new() -> Self {
80        let (tx, rx) = task_channel::<E>();
81
82        tokio::spawn(Self::work_deal_start(rx));
83
84        Self { tx }
85    }
86
87    async fn work_deal_start(rx: TaskReceiver<E>) {
88        while let Some(next) = rx.recv().await {
89            let id = next.task.ctx.id();
90            if let Err(e) = Self::work_start_task(next).await {
91                warn!("task [{}] error: {}", id, e);
92            }
93        }
94        trace!("task work done");
95    }
96
97    async fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
98        trace!("run task {}", next.task.ctx.id());
99        let ctx = next.task.ctx.clone();
100        let mut task = next.task.task;
101        match select! {
102            res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
103            res = ctx.wait_for(State::Stopping) => res
104        } {
105            Ok(_) => {
106                if ctx.set_state(State::Running).is_err() {
107                    return Err(TaskError::Cancelled);
108                };
109            }
110            Err(e) => {
111                ctx.stop_with_result(Some(e));
112            }
113        }
114
115        let _ = ctx.wait_for(State::Stopping).await;
116        let _ = task.on_stop(ctx.clone()).await;
117        ctx.work_done();
118        let _ = ctx.wait_for(State::Stopped).await;
119        trace!("task {} stopped", ctx.id());
120        Ok(())
121    }
122
123    pub async fn start<T: TaskBuilder<Error = E>>(
124        &self,
125        task_builder: T,
126    ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
127        let channel_size = task_builder.channel_size();
128        let (tx, rx) = sync_channel::<T::Output>(channel_size);
129        let task = Box::new(task_builder.build(tx));
130        let task_box = TaskBox {
131            task,
132            ctx: Context::default(),
133        };
134        let ctx = task_box.ctx.clone();
135
136        self.tx.send(WaitingTask { task: task_box });
137
138        ctx.wait_for(State::Running).await?;
139
140        Ok(TaskHandle { rx, ctx })
141    }
142}
143
144impl<E: TError> Default for SingletonTask<E> {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150pub struct TaskHandle<T, E: TError> {
151    pub rx: Receiver<T>,
152    pub ctx: Context<E>,
153}
154
155impl<T, E: TError> TaskHandle<T, E> {
156    pub fn stop(self) -> FutureTaskState<E> {
157        self.ctx.stop()
158    }
159    pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
160        self.ctx.wait_for(State::Stopped)
161    }
162
163    pub fn recv(&self) -> Result<T, std::sync::mpsc::RecvError> {
164        self.rx.recv()
165    }
166}