singleton_task/
lib.rs

1use std::{error::Error, fmt::Display, pin::Pin};
2
3pub use async_trait::async_trait;
4pub use tokio::{
5    sync::mpsc::{Receiver, Sender},
6    task::JoinHandle,
7};
8
9use log::{trace, warn};
10use tokio::{select, sync::mpsc::channel};
11
12mod context;
13mod task_chan;
14
15pub use context::Context;
16
17use context::{FutureTaskState, State};
18use task_chan::{TaskReceiver, TaskSender, task_channel};
19use tokio_stream::Stream;
20
21pub trait TError: Error + Clone + Send + 'static {}
22
23#[derive(Debug, Clone)]
24pub enum TaskError<E: TError> {
25    Cancelled,
26    Error(E),
27}
28
29impl<E: TError> From<E> for TaskError<E> {
30    fn from(value: E) -> Self {
31        Self::Error(value)
32    }
33}
34
35impl<E: TError> Display for TaskError<E> {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            Self::Cancelled => write!(f, "Cancelled"),
39            Self::Error(e) => write!(f, "{e}"),
40        }
41    }
42}
43
44pub trait TaskBuilder {
45    type Output: Send + 'static;
46    type Error: TError;
47    type Task: Task<Self::Error>;
48
49    fn build(self, tx: Sender<Self::Output>) -> Self::Task;
50    fn channel_size(&self) -> usize {
51        10
52    }
53}
54
55#[async_trait]
56pub trait Task<E: TError>: Send + 'static {
57    async fn on_start(&mut self, ctx: Context<E>) -> Result<(), E> {
58        drop(ctx);
59        trace!("on_start");
60        Ok(())
61    }
62    async fn on_stop(&mut self, ctx: Context<E>) -> Result<(), E> {
63        drop(ctx);
64        trace!("on_stop");
65        Ok(())
66    }
67}
68
69struct TaskBox<E: TError> {
70    task: Box<dyn Task<E>>,
71    ctx: Context<E>,
72}
73
74struct WaitingTask<E: TError> {
75    task: TaskBox<E>,
76}
77
78#[derive(Clone)]
79pub struct SingletonTask<E: TError> {
80    tx: TaskSender<E>,
81}
82
83impl<E: TError> SingletonTask<E> {
84    pub fn new() -> Self {
85        let (tx, rx) = task_channel::<E>();
86
87        tokio::spawn(Self::work_deal_start(rx));
88
89        Self { tx }
90    }
91
92    async fn work_deal_start(rx: TaskReceiver<E>) {
93        while let Some(next) = rx.recv().await {
94            let id = next.task.ctx.id();
95            if let Err(e) = Self::work_start_task(next).await {
96                warn!("task [{id}] error: {e}");
97            }
98        }
99        trace!("task work done");
100    }
101
102    async fn work_start_task(next: WaitingTask<E>) -> Result<(), TaskError<E>> {
103        trace!("run task {}", next.task.ctx.id());
104        let ctx = next.task.ctx.clone();
105        let mut task = next.task.task;
106        match select! {
107            res = task.on_start(ctx.clone()) => res.map_err(|e|e.into()),
108            res = ctx.wait_for(State::Stopping) => res
109        } {
110            Ok(_) => {
111                if ctx.set_state(State::Running).is_err() {
112                    return Err(TaskError::Cancelled);
113                };
114            }
115            Err(e) => {
116                ctx.stop_with_terr(e);
117            }
118        }
119
120        let _ = ctx.wait_for(State::Stopping).await;
121        let _ = task.on_stop(ctx.clone()).await;
122        ctx.work_done();
123        let _ = ctx.wait_for(State::Stopped).await;
124        trace!("task {} stopped", ctx.id());
125        Ok(())
126    }
127
128    pub async fn start<T: TaskBuilder<Error = E>>(
129        &self,
130        task_builder: T,
131    ) -> Result<TaskHandle<T::Output, E>, TaskError<E>> {
132        let channel_size = task_builder.channel_size();
133        let (tx, rx) = channel::<T::Output>(channel_size);
134        let task = Box::new(task_builder.build(tx));
135        let task_box = TaskBox {
136            task,
137            ctx: Context::default(),
138        };
139        let ctx = task_box.ctx.clone();
140
141        self.tx.send(WaitingTask { task: task_box });
142
143        ctx.wait_for(State::Running).await?;
144
145        Ok(TaskHandle { rx, ctx })
146    }
147}
148
149impl<E: TError> Default for SingletonTask<E> {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155pub struct TaskHandle<T, E: TError> {
156    pub rx: Receiver<T>,
157    pub ctx: Context<E>,
158}
159
160impl<T, E: TError> TaskHandle<T, E> {
161    pub fn stop(self) -> FutureTaskState<E> {
162        self.ctx.stop()
163    }
164    pub fn wait_for_stopped(self) -> impl Future<Output = Result<(), TaskError<E>>> {
165        self.ctx.wait_for(State::Stopped)
166    }
167
168    /// Receives the next value for this receiver.
169    ///
170    /// This method returns `None` if the channel has been closed and there are
171    /// no remaining messages in the channel's buffer. This indicates that no
172    /// further values can ever be received from this `Receiver`. The channel is
173    /// closed when all senders have been dropped, or when [`close`] is called.
174    ///
175    /// If there are no messages in the channel's buffer, but the channel has
176    /// not yet been closed, this method will sleep until a message is sent or
177    /// the channel is closed.  Note that if [`close`] is called, but there are
178    /// still outstanding [`Permits`] from before it was closed, the channel is
179    /// not considered closed by `recv` until the permits are released.
180    ///
181    /// # Cancel safety
182    ///
183    /// This method is cancel safe. If `recv` is used as the event in a
184    /// [`tokio::select!`](crate::select) statement and some other branch
185    /// completes first, it is guaranteed that no messages were received on this
186    /// channel.
187    ///
188    /// [`close`]: Self::close
189    /// [`Permits`]: struct@crate::sync::mpsc::Permit
190    pub async fn recv(&mut self) -> Option<T> {
191        self.rx.recv().await
192    }
193
194    /// Blocking receive to call outside of asynchronous contexts.
195    ///
196    /// This method returns `None` if the channel has been closed and there are
197    /// no remaining messages in the channel's buffer. This indicates that no
198    /// further values can ever be received from this `Receiver`. The channel is
199    /// closed when all senders have been dropped, or when [`close`] is called.
200    ///
201    /// If there are no messages in the channel's buffer, but the channel has
202    /// not yet been closed, this method will block until a message is sent or
203    /// the channel is closed.
204    ///
205    /// This method is intended for use cases where you are sending from
206    /// asynchronous code to synchronous code, and will work even if the sender
207    /// is not using [`blocking_send`] to send the message.
208    ///
209    /// Note that if [`close`] is called, but there are still outstanding
210    /// [`Permits`] from before it was closed, the channel is not considered
211    /// closed by `blocking_recv` until the permits are released.
212    ///
213    /// [`close`]: Self::close
214    /// [`Permits`]: struct@crate::sync::mpsc::Permit
215    /// [`blocking_send`]: fn@crate::sync::mpsc::Sender::blocking_send
216    pub fn blocking_recv(&mut self) -> Option<T> {
217        self.rx.blocking_recv()
218    }
219}
220
221impl<T, E: TError> Stream for TaskHandle<T, E> {
222    type Item = T;
223
224    fn poll_next(
225        mut self: Pin<&mut Self>,
226        cx: &mut std::task::Context<'_>,
227    ) -> std::task::Poll<Option<Self::Item>> {
228        self.rx.poll_recv(cx)
229    }
230}