1use std::{error::Error, fmt::Display};
2
3pub use async_trait::async_trait;
4pub use tokio::{
5 sync::mpsc::{Receiver, Sender},
6 task::JoinHandle,
7};
8
9use log::{trace, warn};
10use tokio::{select, sync::mpsc::channel};
11
12mod context;
13mod task_chan;
14
15pub use context::Context;
16
17use context::{FutureTaskState, State};
18use task_chan::{TaskReceiver, TaskSender, task_channel};
19
20pub trait TError: Error + Clone + Send + 'static {}
21
22#[derive(Debug, Clone)]
23pub enum TaskError<E: TError> {
24 Cancelled,
25 Error(E),
26}
27
28impl<E: TError> From<E> for TaskError<E> {
29 fn from(value: E) -> Self {
30 Self::Error(value)
31 }
32}
33
34impl<E: TError> Display for TaskError<E> {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 Self::Cancelled => write!(f, "Cancelled"),
38 Self::Error(e) => write!(f, "{e}"),
39 }
40 }
41}
42
43pub trait TaskBuilder {
44 type Output: Send + 'static;
45 type Error: TError;
46 type Task: Task<Self::Error>;
47
48 fn build(self, tx: Sender<Self::Output>) -> Self::Task;
49 fn channel_size(&self) -> usize {
50 10
51 }
52}
53
54#[async_trait]
55pub trait Task<E: TError>: Send + 'static {
56 async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
57 drop(ctx);
58 trace!("on_start");
59 Ok(())
60 }
61 async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
62 drop(ctx);
63 trace!("on_stop");
64 Ok(())
65 }
66}
67
68struct TaskBox<E: TError> {
69 task: Box<dyn Task<E>>,
70 ctx: Context<E>,
71}
72
73struct WaitingTask<E: TError> {
74 task: TaskBox<E>,
75}
76
77#[derive(Clone)]
78pub struct SingletonTask<E: TError> {
79 tx: TaskSender<E>,
80}
81
82impl<E: TError> SingletonTask<E> {
83 pub fn new() -> Self {
84 let (tx, rx) = task_channel::<E>();
85
86 tokio::spawn(Self::work_deal_start(rx));
87
88 Self { tx }
89 }
90
91 async fn work_deal_start(rx: TaskReceiver<E>) {
92 while let Some(next) = rx.recv().await {
93 let id = next.task.ctx.id();
94 if let Err(e) = Self::work_start_task(next).await {
95 warn!("task [{id}] error: {e}");
96 }
97 }
98 trace!("task work done");
99 }
100
101 async fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
102 trace!("run task {}", next.task.ctx.id());
103 let ctx = next.task.ctx.clone();
104 let mut task = next.task.task;
105 match select! {
106 res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
107 res = ctx.wait_for(State::Stopping) => res
108 } {
109 Ok(_) => {
110 if ctx.set_state(State::Running).is_err() {
111 return Err(TaskError::Cancelled);
112 };
113 }
114 Err(e) => {
115 ctx.stop_with_terr(e);
116 }
117 }
118
119 let _ = ctx.wait_for(State::Stopping).await;
120 let _ = task.on_stop(ctx.clone()).await;
121 ctx.work_done();
122 let _ = ctx.wait_for(State::Stopped).await;
123 trace!("task {} stopped", ctx.id());
124 Ok(())
125 }
126
127 pub async fn start<T: TaskBuilder<Error = E>>(
128 &self,
129 task_builder: T,
130 ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
131 let channel_size = task_builder.channel_size();
132 let (tx, rx) = channel::<T::Output>(channel_size);
133 let task = Box::new(task_builder.build(tx));
134 let task_box = TaskBox {
135 task,
136 ctx: Context::default(),
137 };
138 let ctx = task_box.ctx.clone();
139
140 self.tx.send(WaitingTask { task: task_box });
141
142 ctx.wait_for(State::Running).await?;
143
144 Ok(TaskHandle { rx, ctx })
145 }
146}
147
148impl<E: TError> Default for SingletonTask<E> {
149 fn default() -> Self {
150 Self::new()
151 }
152}
153
154pub struct TaskHandle<T, E: TError> {
155 pub rx: Receiver<T>,
156 pub ctx: Context<E>,
157}
158
159impl<T, E: TError> TaskHandle<T, E> {
160 pub fn stop(self) -> FutureTaskState<E> {
161 self.ctx.stop()
162 }
163 pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
164 self.ctx.wait_for(State::Stopped)
165 }
166
167 pub async fn recv(&mut self) -> Option<T> {
190 self.rx.recv().await
191 }
192
193 pub fn blocking_recv(&mut self) -> Option<T> {
216 self.rx.blocking_recv()
217 }
218}