1use std::{error::Error, fmt::Display, pin::Pin};
2
3pub use async_trait::async_trait;
4pub use tokio::{
5 sync::mpsc::{Receiver, Sender},
6 task::JoinHandle,
7};
8
9use log::{trace, warn};
10use tokio::{select, sync::mpsc::channel};
11
12mod context;
13mod task_chan;
14
15pub use context::Context;
16
17use context::{FutureTaskState, State};
18use task_chan::{TaskReceiver, TaskSender, task_channel};
19use tokio_stream::Stream;
20
21pub trait TError: Error + Clone + Send + 'static {}
22
23#[derive(Debug, Clone)]
24pub enum TaskError<E: TError> {
25 Cancelled,
26 Error(E),
27}
28
29impl<E: TError> From<E> for TaskError<E> {
30 fn from(value: E) -> Self {
31 Self::Error(value)
32 }
33}
34
35impl<E: TError> Display for TaskError<E> {
36 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37 match self {
38 Self::Cancelled => write!(f, "Cancelled"),
39 Self::Error(e) => write!(f, "{e}"),
40 }
41 }
42}
43
44pub trait TaskBuilder {
45 type Output: Send + 'static;
46 type Error: TError;
47 type Task: Task<Self::Error>;
48
49 fn build(self, tx: Sender<Self::Output>) -> Self::Task;
50 fn channel_size(&self) -> usize {
51 10
52 }
53}
54
55#[async_trait]
56pub trait Task<E: TError>: Send + 'static {
57 async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
58 drop(ctx);
59 trace!("on_start");
60 Ok(())
61 }
62 async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
63 drop(ctx);
64 trace!("on_stop");
65 Ok(())
66 }
67}
68
69struct TaskBox<E: TError> {
70 task: Box<dyn Task<E>>,
71 ctx: Context<E>,
72}
73
74struct WaitingTask<E: TError> {
75 task: TaskBox<E>,
76}
77
78#[derive(Clone)]
79pub struct SingletonTask<E: TError> {
80 tx: TaskSender<E>,
81}
82
83impl<E: TError> SingletonTask<E> {
84 pub fn new() -> Self {
85 let (tx, rx) = task_channel::<E>();
86
87 tokio::spawn(Self::work_deal_start(rx));
88
89 Self { tx }
90 }
91
92 async fn work_deal_start(rx: TaskReceiver<E>) {
93 while let Some(next) = rx.recv().await {
94 let id = next.task.ctx.id();
95 if let Err(e) = Self::work_start_task(next).await {
96 warn!("task [{id}] error: {e}");
97 }
98 }
99 trace!("task work done");
100 }
101
102 async fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
103 trace!("run task {}", next.task.ctx.id());
104 let ctx = next.task.ctx.clone();
105 let mut task = next.task.task;
106 match select! {
107 res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
108 res = ctx.wait_for(State::Stopping) => res
109 } {
110 Ok(_) => {
111 if ctx.set_state(State::Running).is_err() {
112 return Err(TaskError::Cancelled);
113 };
114 }
115 Err(e) => {
116 ctx.stop_with_terr(e);
117 }
118 }
119
120 let _ = ctx.wait_for(State::Stopping).await;
121 let _ = task.on_stop(ctx.clone()).await;
122 ctx.work_done();
123 let _ = ctx.wait_for(State::Stopped).await;
124 trace!("task {} stopped", ctx.id());
125 Ok(())
126 }
127
128 pub async fn start<T: TaskBuilder<Error = E>>(
129 &self,
130 task_builder: T,
131 ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
132 let channel_size = task_builder.channel_size();
133 let (tx, rx) = channel::<T::Output>(channel_size);
134 let task = Box::new(task_builder.build(tx));
135 let task_box = TaskBox {
136 task,
137 ctx: Context::default(),
138 };
139 let ctx = task_box.ctx.clone();
140
141 self.tx.send(WaitingTask { task: task_box });
142
143 ctx.wait_for(State::Running).await?;
144
145 Ok(TaskHandle { rx, ctx })
146 }
147}
148
149impl<E: TError> Default for SingletonTask<E> {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155pub struct TaskHandle<T, E: TError> {
156 pub rx: Receiver<T>,
157 pub ctx: Context<E>,
158}
159
160impl<T, E: TError> TaskHandle<T, E> {
161 pub fn stop(self) -> FutureTaskState<E> {
162 self.ctx.stop()
163 }
164 pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
165 self.ctx.wait_for(State::Stopped)
166 }
167
168 pub async fn recv(&mut self) -> Option<T> {
191 self.rx.recv().await
192 }
193
194 pub fn blocking_recv(&mut self) -> Option<T> {
217 self.rx.blocking_recv()
218 }
219}
220
221impl<T, E: TError> Stream for TaskHandle<T, E> {
222 type Item = T;
223
224 fn poll_next(
225 mut self: Pin<&mut Self>,
226 cx: &mut std::task::Context<'_>,
227 ) -> std::task::Poll<Option<Self::Item>> {
228 self.rx.poll_recv(cx)
229 }
230}