1pub use std::sync::mpsc::{Receiver, SyncSender};
2use std::{
3 error::Error,
4 fmt::Display,
5 sync::{Arc, OnceLock, mpsc::sync_channel},
6 thread,
7};
8
9use context::{FutureTaskState, State};
10pub use futures::{FutureExt, future::LocalBoxFuture};
11use log::{trace, warn};
12
13mod context;
14mod task_chan;
15
16pub use context::Context;
17use task_chan::{TaskReceiver, TaskSender, task_channel};
18use tokio::{runtime::Runtime, select};
19
20static RT: OnceLock<Runtime> = OnceLock::new();
21
22pub trait TError: Error + Clone + Send + 'static {}
23
24#[derive(Debug, Clone)]
25pub enum TaskError<E: TError> {
26 Cancelled,
27 Error(E),
28}
29
30impl<E: TError> From<E> for TaskError<E> {
31 fn from(value: E) -> Self {
32 Self::Error(value)
33 }
34}
35
36impl<E: TError> Display for TaskError<E> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 match self {
39 Self::Cancelled => write!(f, "Cancelled"),
40 Self::Error(e) => write!(f, "{}", e),
41 }
42 }
43}
44
45pub trait TaskBuilder {
46 type Output: Send + 'static;
47 type Error: TError;
48 type Task: Task<Self::Error>;
49
50 fn build(self, tx: SyncSender<Self::Output>) -> Self::Task;
51 fn channel_size(&self) -> usize {
52 10
53 }
54}
55
56pub trait Task<E: TError>: Send + 'static {
57 fn on_start(&mut self, ctx: Context<E>) -> LocalBoxFuture<'_, Result<(), E>> {
58 drop(ctx);
59 async {
60 trace!("on_start");
61 Ok(())
62 }
63 .boxed_local()
64 }
65 fn on_stop(&mut self, ctx: Context<E>) -> LocalBoxFuture<'_, Result<(), E>> {
66 drop(ctx);
67 async {
68 trace!("on_stop");
69 Ok(())
70 }
71 .boxed_local()
72 }
73}
74
75struct TaskBox<E: TError> {
76 task: Box<dyn Task<E>>,
77 ctx: Context<E>,
78}
79
80struct WaitingTask<E: TError> {
81 task: TaskBox<E>,
82}
83
84#[derive(Clone)]
85pub struct SingletonTask<E: TError> {
86 tx: TaskSender<E>,
87 _drop: Arc<TaskDrop<E>>,
88}
89
90impl<E: TError> SingletonTask<E> {
91 pub fn new() -> Self {
92 let (tx, rx) = task_channel::<E>();
93
94 thread::spawn(move || Self::work_deal_start(rx));
95
96 Self {
97 _drop: Arc::new(TaskDrop { tx: tx.clone() }),
98 tx,
99 }
100 }
101
102 fn work_deal_start(rx: TaskReceiver<E>) {
103 while let Some(next) = rx.recv() {
104 let id = next.task.ctx.id();
105 if let Err(e) = Self::work_start_task(next) {
106 warn!("task [{}] error: {}", id, e);
107 }
108 }
109 }
110
111 fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
112 trace!("run task {}", next.task.ctx.id());
113 let ctx = next.task.ctx.clone();
114 let mut task = next.task.task;
115 match rt().block_on(async {
116 select! {
117 res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
118 res = ctx.wait_for(State::Stopping) => res
119 }
120 }) {
121 Ok(_) => {
122 if ctx.set_state(State::Running).is_err() {
123 return Err(TaskError::Cancelled);
124 };
125 }
126 Err(e) => {
127 ctx.stop_with_result(Some(e));
128 }
129 }
130
131 rt().block_on(async {
132 let _ = ctx.wait_for(State::Stopping).await;
133 let _ = task.on_stop(ctx.clone()).await;
134 ctx.work_done();
135 let _ = ctx.wait_for(State::Stopped).await;
136 });
137
138 Ok(())
139 }
140
141 pub async fn start<T: TaskBuilder<Error = E>>(
142 &self,
143 task_builder: T,
144 ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
145 let channel_size = task_builder.channel_size();
146 let (tx, rx) = sync_channel::<T::Output>(channel_size);
147 let task = Box::new(task_builder.build(tx));
148 let task_box = TaskBox {
149 task,
150 ctx: Context::default(),
151 };
152 let ctx = task_box.ctx.clone();
153
154 self.tx.send(WaitingTask { task: task_box });
155
156 ctx.wait_for(State::Running).await?;
157
158 Ok(TaskHandle { rx, ctx })
159 }
160}
161
162struct TaskDrop<E: TError> {
163 tx: TaskSender<E>,
164}
165impl<E: TError> Drop for TaskDrop<E> {
166 fn drop(&mut self) {
167 self.tx.stop();
168 }
169}
170
171impl<E: TError> Default for SingletonTask<E> {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177pub struct TaskHandle<T, E: TError> {
178 pub rx: Receiver<T>,
179 pub ctx: Context<E>,
180}
181
182impl<T, E: TError> TaskHandle<T, E> {
183 pub fn stop(self) -> FutureTaskState<E> {
184 self.ctx.stop()
185 }
186 pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
187 self.ctx.wait_for(State::Stopped)
188 }
189
190 pub fn recv(&self) -> Result<T, std::sync::mpsc::RecvError> {
191 self.rx.recv()
192 }
193}
194
195fn rt() -> &'static Runtime {
196 RT.get_or_init(|| {
197 tokio::runtime::Builder::new_current_thread()
198 .enable_all()
199 .build()
200 .unwrap()
201 })
202}
203
204#[cfg(test)]
205mod test {
206 use log::LevelFilter;
207
208 use super::*;
209
210 #[derive(Debug, Clone)]
211 enum Error1 {
212 _A,
213 }
214
215 impl TError for Error1 {}
216 impl Error for Error1 {}
217 impl Display for Error1 {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 write!(f, "{:?}", self)
220 }
221 }
222
223 struct Task1 {
224 _a: i32,
225 }
226
227 impl Task<Error1> for Task1 {
228 fn on_start(&mut self, _ctx: Context<Error1>) -> LocalBoxFuture<'_, Result<(), Error1>> {
229 async {
230 trace!("on_start 1");
231 Ok(())
232 }
233 .boxed_local()
234 }
235 }
236
237 struct Tasl1Builder {}
238
239 impl TaskBuilder for Tasl1Builder {
240 type Output = u32;
241 type Error = Error1;
242 type Task = Task1;
243
244 fn build(self, _tx: SyncSender<u32>) -> Self::Task {
245 Task1 { _a: 1 }
246 }
247 }
248
249 #[tokio::test]
250 async fn test_task() {
251 env_logger::builder()
252 .is_test(true)
253 .filter_level(LevelFilter::Trace)
254 .init();
255
256 let st = SingletonTask::<Error1>::new();
257 let _rx = st.start(Tasl1Builder {}).await.unwrap();
258 }
259}