1pub use std::sync::mpsc::{Receiver, SyncSender};
2use std::{error::Error, fmt::Display, sync::mpsc::sync_channel};
3
4pub use async_trait::async_trait;
5use tokio::select;
6
7use context::{FutureTaskState, State};
8use log::{trace, warn};
9
10mod context;
11mod task_chan;
12
13pub use context::Context;
14use task_chan::{TaskReceiver, TaskSender, task_channel};
15
16pub trait TError: Error + Clone + Send + 'static {}
17
18#[derive(Debug, Clone)]
19pub enum TaskError<E: TError> {
20 Cancelled,
21 Error(E),
22}
23
24impl<E: TError> From<E> for TaskError<E> {
25 fn from(value: E) -> Self {
26 Self::Error(value)
27 }
28}
29
30impl<E: TError> Display for TaskError<E> {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Cancelled => write!(f, "Cancelled"),
34 Self::Error(e) => write!(f, "{}", e),
35 }
36 }
37}
38
39pub trait TaskBuilder {
40 type Output: Send + 'static;
41 type Error: TError;
42 type Task: Task<Self::Error>;
43
44 fn build(self, tx: SyncSender<Self::Output>) -> Self::Task;
45 fn channel_size(&self) -> usize {
46 10
47 }
48}
49
50#[async_trait]
51pub trait Task<E: TError>: Send + 'static {
52 async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
53 drop(ctx);
54 trace!("on_start");
55 Ok(())
56 }
57 async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
58 drop(ctx);
59 trace!("on_stop");
60 Ok(())
61 }
62}
63
64struct TaskBox<E: TError> {
65 task: Box<dyn Task<E>>,
66 ctx: Context<E>,
67}
68
69struct WaitingTask<E: TError> {
70 task: TaskBox<E>,
71}
72
73#[derive(Clone)]
74pub struct SingletonTask<E: TError> {
75 tx: TaskSender<E>,
76}
77
78impl<E: TError> SingletonTask<E> {
79 pub fn new() -> Self {
80 let (tx, rx) = task_channel::<E>();
81
82 tokio::spawn(Self::work_deal_start(rx));
83
84 Self { tx }
85 }
86
87 async fn work_deal_start(rx: TaskReceiver<E>) {
88 while let Some(next) = rx.recv().await {
89 let id = next.task.ctx.id();
90 if let Err(e) = Self::work_start_task(next).await {
91 warn!("task [{}] error: {}", id, e);
92 }
93 }
94 trace!("task work done");
95 }
96
97 async fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
98 trace!("run task {}", next.task.ctx.id());
99 let ctx = next.task.ctx.clone();
100 let mut task = next.task.task;
101 match select! {
102 res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
103 res = ctx.wait_for(State::Stopping) => res
104 } {
105 Ok(_) => {
106 if ctx.set_state(State::Running).is_err() {
107 return Err(TaskError::Cancelled);
108 };
109 }
110 Err(e) => {
111 ctx.stop_with_result(Some(e));
112 }
113 }
114
115 let _ = ctx.wait_for(State::Stopping).await;
116 let _ = task.on_stop(ctx.clone()).await;
117 ctx.work_done();
118 let _ = ctx.wait_for(State::Stopped).await;
119 trace!("task {} stopped", ctx.id());
120 Ok(())
121 }
122
123 pub async fn start<T: TaskBuilder<Error = E>>(
124 &self,
125 task_builder: T,
126 ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
127 let channel_size = task_builder.channel_size();
128 let (tx, rx) = sync_channel::<T::Output>(channel_size);
129 let task = Box::new(task_builder.build(tx));
130 let task_box = TaskBox {
131 task,
132 ctx: Context::default(),
133 };
134 let ctx = task_box.ctx.clone();
135
136 self.tx.send(WaitingTask { task: task_box });
137
138 ctx.wait_for(State::Running).await?;
139
140 Ok(TaskHandle { rx, ctx })
141 }
142}
143
144impl<E: TError> Default for SingletonTask<E> {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150pub struct TaskHandle<T, E: TError> {
151 pub rx: Receiver<T>,
152 pub ctx: Context<E>,
153}
154
155impl<T, E: TError> TaskHandle<T, E> {
156 pub fn stop(self) -> FutureTaskState<E> {
157 self.ctx.stop()
158 }
159 pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
160 self.ctx.wait_for(State::Stopped)
161 }
162
163 pub fn recv(&self) -> Result<T, std::sync::mpsc::RecvError> {
164 self.rx.recv()
165 }
166}