1pub use std::sync::mpsc::{Receiver, SyncSender};
2use std::{
3 error::Error,
4 fmt::Display,
5 sync::{Arc, OnceLock, mpsc::sync_channel},
6 thread,
7};
8
9pub use async_trait::async_trait;
10use tokio::{runtime::Runtime, select};
11
12use context::{FutureTaskState, State};
13use log::{trace, warn};
14
15mod context;
16mod task_chan;
17
18pub use context::Context;
19use task_chan::{TaskReceiver, TaskSender, task_channel};
20
21static RT: OnceLock<Runtime> = OnceLock::new();
22
23pub trait TError: Error + Clone + Send + 'static {}
24
25#[derive(Debug, Clone)]
26pub enum TaskError<E: TError> {
27 Cancelled,
28 Error(E),
29}
30
31impl<E: TError> From<E> for TaskError<E> {
32 fn from(value: E) -> Self {
33 Self::Error(value)
34 }
35}
36
37impl<E: TError> Display for TaskError<E> {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::Cancelled => write!(f, "Cancelled"),
41 Self::Error(e) => write!(f, "{}", e),
42 }
43 }
44}
45
46pub trait TaskBuilder {
47 type Output: Send + 'static;
48 type Error: TError;
49 type Task: Task<Self::Error>;
50
51 fn build(self, tx: SyncSender<Self::Output>) -> Self::Task;
52 fn channel_size(&self) -> usize {
53 10
54 }
55}
56
57#[async_trait]
58pub trait Task<E: TError>: Send + 'static {
59 async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
60 drop(ctx);
61 trace!("on_start");
62 Ok(())
63 }
64 async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
65 drop(ctx);
66 trace!("on_stop");
67 Ok(())
68 }
69}
70
71struct TaskBox<E: TError> {
72 task: Box<dyn Task<E>>,
73 ctx: Context<E>,
74}
75
76struct WaitingTask<E: TError> {
77 task: TaskBox<E>,
78}
79
80#[derive(Clone)]
81pub struct SingletonTask<E: TError> {
82 tx: TaskSender<E>,
83}
84
85impl<E: TError> SingletonTask<E> {
86 pub fn new() -> Self {
87 let (tx, rx) = task_channel::<E>();
88
89 tokio::spawn(Self::work_deal_start(rx));
90
91 Self { tx }
92 }
93
94 async fn work_deal_start(rx: TaskReceiver<E>) {
95 while let Some(next) = rx.recv().await {
96 let id = next.task.ctx.id();
97 if let Err(e) = Self::work_start_task(next).await {
98 warn!("task [{}] error: {}", id, e);
99 }
100 }
101 trace!("task work done");
102 }
103
104 async fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
105 trace!("run task {}", next.task.ctx.id());
106 let ctx = next.task.ctx.clone();
107 let mut task = next.task.task;
108 match select! {
109 res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
110 res = ctx.wait_for(State::Stopping) => res
111 } {
112 Ok(_) => {
113 if ctx.set_state(State::Running).is_err() {
114 return Err(TaskError::Cancelled);
115 };
116 }
117 Err(e) => {
118 ctx.stop_with_result(Some(e));
119 }
120 }
121
122 let _ = ctx.wait_for(State::Stopping).await;
123 let _ = task.on_stop(ctx.clone()).await;
124 ctx.work_done();
125 let _ = ctx.wait_for(State::Stopped).await;
126 trace!("task {} stopped", ctx.id());
127 Ok(())
128 }
129
130 pub async fn start<T: TaskBuilder<Error = E>>(
131 &self,
132 task_builder: T,
133 ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
134 let channel_size = task_builder.channel_size();
135 let (tx, rx) = sync_channel::<T::Output>(channel_size);
136 let task = Box::new(task_builder.build(tx));
137 let task_box = TaskBox {
138 task,
139 ctx: Context::default(),
140 };
141 let ctx = task_box.ctx.clone();
142
143 self.tx.send(WaitingTask { task: task_box });
144
145 ctx.wait_for(State::Running).await?;
146
147 Ok(TaskHandle { rx, ctx })
148 }
149}
150
151impl<E: TError> Default for SingletonTask<E> {
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157pub struct TaskHandle<T, E: TError> {
158 pub rx: Receiver<T>,
159 pub ctx: Context<E>,
160}
161
162impl<T, E: TError> TaskHandle<T, E> {
163 pub fn stop(self) -> FutureTaskState<E> {
164 self.ctx.stop()
165 }
166 pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
167 self.ctx.wait_for(State::Stopped)
168 }
169
170 pub fn recv(&self) -> Result<T, std::sync::mpsc::RecvError> {
171 self.rx.recv()
172 }
173}