singleton_task/
lib.rs

1pub use std::sync::mpsc::{Receiver, SyncSender};
2use std::{
3    error::Error,
4    fmt::Display,
5    sync::{Arc, OnceLock, mpsc::sync_channel},
6    thread,
7};
8
9pub use async_trait::async_trait;
10use tokio::{runtime::Runtime, select};
11
12use context::{FutureTaskState, State};
13use log::{trace, warn};
14
15mod context;
16mod task_chan;
17
18pub use context::Context;
19use task_chan::{TaskReceiver, TaskSender, task_channel};
20
21static RT: OnceLock<Runtime> = OnceLock::new();
22
23pub trait TError: Error + Clone + Send + 'static {}
24
25#[derive(Debug, Clone)]
26pub enum TaskError<E: TError> {
27    Cancelled,
28    Error(E),
29}
30
31impl<E: TError> From<E> for TaskError<E> {
32    fn from(value: E) -> Self {
33        Self::Error(value)
34    }
35}
36
37impl<E: TError> Display for TaskError<E> {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::Cancelled => write!(f, "Cancelled"),
41            Self::Error(e) => write!(f, "{}", e),
42        }
43    }
44}
45
46pub trait TaskBuilder {
47    type Output: Send + 'static;
48    type Error: TError;
49    type Task: Task<Self::Error>;
50
51    fn build(self, tx: SyncSender<Self::Output>) -> Self::Task;
52    fn channel_size(&self) -> usize {
53        10
54    }
55}
56
57#[async_trait]
58pub trait Task<E: TError>: Send + 'static {
59    async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
60        drop(ctx);
61        trace!("on_start");
62        Ok(())
63    }
64    async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
65        drop(ctx);
66        trace!("on_stop");
67        Ok(())
68    }
69}
70
71struct TaskBox<E: TError> {
72    task: Box<dyn Task<E>>,
73    ctx: Context<E>,
74}
75
76struct WaitingTask<E: TError> {
77    task: TaskBox<E>,
78}
79
80#[derive(Clone)]
81pub struct SingletonTask<E: TError> {
82    tx: TaskSender<E>,
83}
84
85impl<E: TError> SingletonTask<E> {
86    pub fn new() -> Self {
87        let (tx, rx) = task_channel::<E>();
88
89        tokio::spawn(Self::work_deal_start(rx));
90
91        Self { tx }
92    }
93
94    async fn work_deal_start(rx: TaskReceiver<E>) {
95        while let Some(next) = rx.recv().await {
96            let id = next.task.ctx.id();
97            if let Err(e) = Self::work_start_task(next).await {
98                warn!("task [{}] error: {}", id, e);
99            }
100        }
101        trace!("task work done");
102    }
103
104    async fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
105        trace!("run task {}", next.task.ctx.id());
106        let ctx = next.task.ctx.clone();
107        let mut task = next.task.task;
108        match select! {
109            res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
110            res = ctx.wait_for(State::Stopping) => res
111        } {
112            Ok(_) => {
113                if ctx.set_state(State::Running).is_err() {
114                    return Err(TaskError::Cancelled);
115                };
116            }
117            Err(e) => {
118                ctx.stop_with_result(Some(e));
119            }
120        }
121
122        let _ = ctx.wait_for(State::Stopping).await;
123        let _ = task.on_stop(ctx.clone()).await;
124        ctx.work_done();
125        let _ = ctx.wait_for(State::Stopped).await;
126        trace!("task {} stopped", ctx.id());
127        Ok(())
128    }
129
130    pub async fn start<T: TaskBuilder<Error = E>>(
131        &self,
132        task_builder: T,
133    ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
134        let channel_size = task_builder.channel_size();
135        let (tx, rx) = sync_channel::<T::Output>(channel_size);
136        let task = Box::new(task_builder.build(tx));
137        let task_box = TaskBox {
138            task,
139            ctx: Context::default(),
140        };
141        let ctx = task_box.ctx.clone();
142
143        self.tx.send(WaitingTask { task: task_box });
144
145        ctx.wait_for(State::Running).await?;
146
147        Ok(TaskHandle { rx, ctx })
148    }
149}
150
151impl<E: TError> Default for SingletonTask<E> {
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157pub struct TaskHandle<T, E: TError> {
158    pub rx: Receiver<T>,
159    pub ctx: Context<E>,
160}
161
162impl<T, E: TError> TaskHandle<T, E> {
163    pub fn stop(self) -> FutureTaskState<E> {
164        self.ctx.stop()
165    }
166    pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
167        self.ctx.wait_for(State::Stopped)
168    }
169
170    pub fn recv(&self) -> Result<T, std::sync::mpsc::RecvError> {
171        self.rx.recv()
172    }
173}