singleton_task/
context.rs

1use std::{
2    sync::{
3        Arc, Mutex,
4        atomic::{AtomicU32, Ordering},
5    },
6    task::{Poll, Waker},
7};
8
9use log::trace;
10use tokio::{runtime::Handle, select, task::JoinHandle};
11use tokio_util::sync::CancellationToken;
12
13use crate::{TError, TaskError};
14
15#[derive(Clone)]
16pub struct Context<E: TError> {
17    id: u32,
18    inner: Arc<Mutex<ContextInner<E>>>,
19    cancel: CancellationToken,
20}
21
22impl<E: TError> Context<E> {
23    pub fn id(&self) -> u32 {
24        self.id
25    }
26
27    pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
28        self.inner.lock().unwrap().set_state(state)
29    }
30
31    pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
32        FutureTaskState::new(self.clone(), state)
33    }
34
35    pub fn stop(&self) -> FutureTaskState<E> {
36        self._stop(Some(TaskError::Cancelled))
37    }
38
39    pub fn is_active(&self) -> bool {
40        !self.cancel.is_cancelled()
41    }
42
43    fn _stop(&self, err: Option<TaskError<E>>) -> FutureTaskState<E> {
44        let fur = self.wait_for(State::Stopped);
45        let mut g = self.inner.lock().unwrap();
46        if g.state >= State::Stopping {
47            return fur;
48        }
49        let _ = g.set_state(State::Stopping);
50        g.error = err;
51        g.wake_all();
52        drop(g);
53        self.cancel.cancel();
54        fur
55    }
56
57    pub(crate) fn stop_with_terr(&self, err: TaskError<E>) -> FutureTaskState<E> {
58        self._stop(Some(err))
59    }
60
61    pub fn stop_with_err(&self, err: E) -> FutureTaskState<E> {
62        self._stop(Some(TaskError::Error(err)))
63    }
64
65    pub fn spawn<F>(&self, fut: F) -> JoinHandle<Result<F::Output, TaskError<E>>>
66    where
67        F: Future + Send + 'static,
68        F::Output: Send + 'static,
69    {
70        let mut g = self.inner.lock().unwrap();
71        g.spawn(self, fut)
72    }
73
74    pub fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<Result<R, TaskError<E>>>
75    where
76        F: FnOnce(&Context<E>) -> R + Send + 'static,
77        R: Send + 'static,
78    {
79        let mut g = self.inner.lock().unwrap();
80        g.spawn_blocking(self, f)
81    }
82
83    pub(crate) fn work_done(&self) {
84        let mut g = self.inner.lock().unwrap();
85        g.work_count -= 1;
86        trace!("[{:>6}] work count {}", self.id, g.work_count);
87        if g.work_count == 1 && g.state == State::Running {
88            let _ = g.set_state(State::Stopping);
89        }
90
91        if g.work_count == 0 {
92            let _ = g.set_state(State::Stopped);
93        }
94    }
95}
96
97impl<E: TError> Default for Context<E> {
98    fn default() -> Self {
99        static TASK_ID: AtomicU32 = AtomicU32::new(1);
100        let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
101
102        Self {
103            id,
104            inner: Arc::new(Mutex::new(ContextInner {
105                id,
106                work_count: 1,
107                ..Default::default()
108            })),
109            cancel: CancellationToken::new(),
110        }
111    }
112}
113
114struct ContextInner<E: TError> {
115    error: Option<TaskError<E>>,
116    state: State,
117    wakers: Vec<Waker>,
118    work_count: u32,
119    id: u32,
120}
121
122impl<E: TError> ContextInner<E> {
123    fn wake_all(&mut self) {
124        for waker in self.wakers.iter() {
125            waker.wake_by_ref();
126        }
127        self.wakers.clear();
128    }
129
130    fn set_state(&mut self, state: State) -> Result<(), &'static str> {
131        if state < self.state {
132            return Err("state is not allowed");
133        }
134        trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
135        self.state = state;
136        self.wake_all();
137        Ok(())
138    }
139
140    fn spawn<F>(&mut self, ctx: &Context<E>, fur: F) -> JoinHandle<Result<F::Output, TaskError<E>>>
141    where
142        F: Future + Send + 'static,
143        F::Output: Send + 'static,
144    {
145        let ctx = ctx.clone();
146
147        self.work_count += 1;
148        trace!("[{:>6}] work count {}", ctx.id, self.work_count);
149        let handle = Handle::current();
150
151        handle.spawn(async move {
152            let mut res = Err(TaskError::Cancelled);
153            select! {
154                r = fur =>{
155                    trace!("[{:>6}] exit: finish", ctx.id);
156                    res = Ok(r);
157                }
158                _ = ctx.cancel.cancelled() => {
159                    trace!("[{:>6}] exit: cancel token", ctx.id);
160                }
161                _ = ctx.wait_for(State::Stopping) => {
162                    trace!("[{:>6}] exit: stopping", ctx.id);
163                }
164            }
165            ctx.work_done();
166            res
167        })
168    }
169
170    fn spawn_blocking<F, R>(
171        &mut self,
172        ctx: &Context<E>,
173        fur: F,
174    ) -> JoinHandle<Result<R, TaskError<E>>>
175    where
176        F: FnOnce(&Context<E>) -> R + Send + 'static,
177        R: Send + 'static,
178    {
179        let ctx = ctx.clone();
180
181        self.work_count += 1;
182        trace!("[{:>6}] work count {}", ctx.id, self.work_count);
183        let handle = Handle::current();
184
185        handle.spawn_blocking(move || {
186            if !ctx.is_active() {
187                return Err(TaskError::Cancelled);
188            }
189            let r = fur(&ctx);
190            ctx.work_done();
191            Ok(r)
192        })
193    }
194}
195
196impl<E: TError> Default for ContextInner<E> {
197    fn default() -> Self {
198        Self {
199            id: 0,
200            error: None,
201            state: State::default(),
202            wakers: Default::default(),
203            work_count: 0,
204        }
205    }
206}
207
208#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
209pub enum State {
210    Idle,
211    Preparing,
212    Running,
213    Stopping,
214    Stopped,
215}
216
217impl Default for State {
218    fn default() -> Self {
219        Self::Idle
220    }
221}
222
223pub struct FutureTaskState<E: TError> {
224    ctx: Context<E>,
225    want: State,
226}
227impl<E: TError> FutureTaskState<E> {
228    fn new(ctx: Context<E>, want: State) -> Self {
229        Self { ctx, want }
230    }
231}
232
233impl<E: TError> Future for FutureTaskState<E> {
234    type Output = Result<(), TaskError<E>>;
235
236    fn poll(
237        self: std::pin::Pin<&mut Self>,
238        cx: &mut std::task::Context<'_>,
239    ) -> std::task::Poll<Self::Output> {
240        let mut g = self.ctx.inner.lock().unwrap();
241        if g.state >= self.want {
242            Poll::Ready(match g.error.clone() {
243                Some(e) => Err(e),
244                None => Ok(()),
245            })
246        } else {
247            g.wakers.push(cx.waker().clone());
248            Poll::Pending
249        }
250    }
251}