1pub use std::sync::mpsc::{Receiver, SyncSender};
2use std::{
3 error::Error,
4 fmt::Display,
5 sync::{Arc, OnceLock, mpsc::sync_channel},
6 thread,
7};
8
9pub use async_trait::async_trait;
10use tokio::{runtime::Runtime, select};
11
12use context::{FutureTaskState, State};
13use log::{trace, warn};
14
15mod context;
16mod task_chan;
17
18pub use context::Context;
19use task_chan::{TaskReceiver, TaskSender, task_channel};
20
21static RT: OnceLock<Runtime> = OnceLock::new();
22
23pub trait TError: Error + Clone + Send + 'static {}
24
25#[derive(Debug, Clone)]
26pub enum TaskError<E: TError> {
27 Cancelled,
28 Error(E),
29}
30
31impl<E: TError> From<E> for TaskError<E> {
32 fn from(value: E) -> Self {
33 Self::Error(value)
34 }
35}
36
37impl<E: TError> Display for TaskError<E> {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::Cancelled => write!(f, "Cancelled"),
41 Self::Error(e) => write!(f, "{}", e),
42 }
43 }
44}
45
46pub trait TaskBuilder {
47 type Output: Send + 'static;
48 type Error: TError;
49 type Task: Task<Self::Error>;
50
51 fn build(self, tx: SyncSender<Self::Output>) -> Self::Task;
52 fn channel_size(&self) -> usize {
53 10
54 }
55}
56
57#[async_trait]
58pub trait Task<E: TError>: Send + 'static {
59 async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
60 drop(ctx);
61 trace!("on_start");
62 Ok(())
63 }
64 async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
65 drop(ctx);
66 trace!("on_stop");
67 Ok(())
68 }
69}
70
71struct TaskBox<E: TError> {
72 task: Box<dyn Task<E>>,
73 ctx: Context<E>,
74}
75
76struct WaitingTask<E: TError> {
77 task: TaskBox<E>,
78}
79
80#[derive(Clone)]
81pub struct SingletonTask<E: TError> {
82 tx: TaskSender<E>,
83 _drop: Arc<TaskDrop<E>>,
84}
85
86impl<E: TError> SingletonTask<E> {
87 pub fn new() -> Self {
88 let (tx, rx) = task_channel::<E>();
89
90 thread::spawn(move || Self::work_deal_start(rx));
91
92 Self {
93 _drop: Arc::new(TaskDrop { tx: tx.clone() }),
94 tx,
95 }
96 }
97
98 fn work_deal_start(rx: TaskReceiver<E>) {
99 while let Some(next) = rx.recv() {
100 let id = next.task.ctx.id();
101 if let Err(e) = Self::work_start_task(next) {
102 warn!("task [{}] error: {}", id, e);
103 }
104 }
105 }
106
107 fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
108 trace!("run task {}", next.task.ctx.id());
109 let ctx = next.task.ctx.clone();
110 let mut task = next.task.task;
111 match rt().block_on(async {
112 select! {
113 res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
114 res = ctx.wait_for(State::Stopping) => res
115 }
116 }) {
117 Ok(_) => {
118 if ctx.set_state(State::Running).is_err() {
119 return Err(TaskError::Cancelled);
120 };
121 }
122 Err(e) => {
123 ctx.stop_with_result(Some(e));
124 }
125 }
126
127 rt().block_on(async {
128 let _ = ctx.wait_for(State::Stopping).await;
129 let _ = task.on_stop(ctx.clone()).await;
130 ctx.work_done();
131 let _ = ctx.wait_for(State::Stopped).await;
132 });
133
134 Ok(())
135 }
136
137 pub async fn start<T: TaskBuilder<Error = E>>(
138 &self,
139 task_builder: T,
140 ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
141 let channel_size = task_builder.channel_size();
142 let (tx, rx) = sync_channel::<T::Output>(channel_size);
143 let task = Box::new(task_builder.build(tx));
144 let task_box = TaskBox {
145 task,
146 ctx: Context::default(),
147 };
148 let ctx = task_box.ctx.clone();
149
150 self.tx.send(WaitingTask { task: task_box });
151
152 ctx.wait_for(State::Running).await?;
153
154 Ok(TaskHandle { rx, ctx })
155 }
156}
157
158struct TaskDrop<E: TError> {
159 tx: TaskSender<E>,
160}
161impl<E: TError> Drop for TaskDrop<E> {
162 fn drop(&mut self) {
163 self.tx.stop();
164 }
165}
166
167impl<E: TError> Default for SingletonTask<E> {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173pub struct TaskHandle<T, E: TError> {
174 pub rx: Receiver<T>,
175 pub ctx: Context<E>,
176}
177
178impl<T, E: TError> TaskHandle<T, E> {
179 pub fn stop(self) -> FutureTaskState<E> {
180 self.ctx.stop()
181 }
182 pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
183 self.ctx.wait_for(State::Stopped)
184 }
185
186 pub fn recv(&self) -> Result<T, std::sync::mpsc::RecvError> {
187 self.rx.recv()
188 }
189}
190
191fn rt() -> &'static Runtime {
192 RT.get_or_init(|| {
193 tokio::runtime::Builder::new_current_thread()
194 .enable_all()
195 .build()
196 .unwrap()
197 })
198}