singleton_task/
context.rs

1use std::{
2    sync::{
3        Arc, Mutex,
4        atomic::{AtomicU32, Ordering},
5    },
6    task::{Poll, Waker},
7    thread,
8};
9
10use log::trace;
11use tokio::{runtime::Handle, select};
12use tokio_util::sync::CancellationToken;
13
14use crate::{TError, TaskError};
15
16#[derive(Clone)]
17pub struct Context<E: TError> {
18    id: u32,
19    inner: Arc<Mutex<ContextInner<E>>>,
20    cancel: CancellationToken,
21}
22
23impl<E: TError> Context<E> {
24    pub fn id(&self) -> u32 {
25        self.id
26    }
27
28    pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
29        self.inner.lock().unwrap().set_state(state)
30    }
31
32    pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
33        FutureTaskState::new(self.clone(), state)
34    }
35
36    pub fn stop(&self) -> FutureTaskState<E> {
37        self.stop_with_result(Some(TaskError::Cancelled))
38    }
39
40    pub fn is_active(&self) -> bool {
41        !self.cancel.is_cancelled()
42    }
43
44    pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
45        let fur = self.wait_for(State::Stopped);
46        let mut g = self.inner.lock().unwrap();
47        if g.state >= State::Stopping {
48            return fur;
49        }
50        let _ = g.set_state(State::Stopping);
51        g.error = res;
52        g.wake_all();
53        drop(g);
54        self.cancel.cancel();
55        fur
56    }
57
58    pub fn spawn<F>(&self, fut: F)
59    where
60        F: Future + Send + 'static,
61    {
62        let mut g = self.inner.lock().unwrap();
63        g.spawn(self, fut);
64    }
65
66    pub(crate) fn work_done(&self) {
67        let mut g = self.inner.lock().unwrap();
68        g.work_count -= 1;
69        trace!("[{:>6}] work count {}", self.id, g.work_count);
70        if g.work_count == 1 && g.state == State::Running {
71            let _ = g.set_state(State::Stopping);
72        }
73
74        if g.work_count == 0 {
75            let _ = g.set_state(State::Stopped);
76        }
77    }
78}
79
80impl<E: TError> Default for Context<E> {
81    fn default() -> Self {
82        static TASK_ID: AtomicU32 = AtomicU32::new(1);
83        let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
84
85        Self {
86            id,
87            inner: Arc::new(Mutex::new(ContextInner {
88                id,
89                work_count: 1,
90                ..Default::default()
91            })),
92            cancel: CancellationToken::new(),
93        }
94    }
95}
96
97struct ContextInner<E: TError> {
98    error: Option<TaskError<E>>,
99    state: State,
100    wakers: Vec<Waker>,
101    work_count: u32,
102    id: u32,
103}
104
105impl<E: TError> ContextInner<E> {
106    fn wake_all(&mut self) {
107        for waker in self.wakers.iter() {
108            waker.wake_by_ref();
109        }
110        self.wakers.clear();
111    }
112
113    fn set_state(&mut self, state: State) -> Result<(), &'static str> {
114        if state < self.state {
115            return Err("state is not allowed");
116        }
117        trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
118        self.state = state;
119        self.wake_all();
120        Ok(())
121    }
122
123    fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
124    where
125        F: Future + Send + 'static,
126    {
127        let ctx = ctx.clone();
128        if matches!(self.state, State::Stopping | State::Stopped) {
129            return;
130        }
131
132        self.work_count += 1;
133        trace!("[{:>6}] work count {}", ctx.id, self.work_count);
134        let handle = Handle::current();
135        thread::spawn(move || {
136            handle.block_on(async move {
137                select! {
138                    _ = fur =>{
139                        trace!("[{:>6}] exit: finish", ctx.id);
140                    }
141                    _ = ctx.cancel.cancelled() => {
142                        trace!("[{:>6}] exit: cancel token", ctx.id);
143                    }
144                    _ = ctx.wait_for(State::Stopping) => {
145                        trace!("[{:>6}] exit: stopping", ctx.id);
146                    }
147                }
148                ctx.work_done();
149            })
150        });
151    }
152}
153
154impl<E: TError> Default for ContextInner<E> {
155    fn default() -> Self {
156        Self {
157            id: 0,
158            error: None,
159            state: State::default(),
160            wakers: Default::default(),
161            work_count: 0,
162        }
163    }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
167pub enum State {
168    Idle,
169    Preparing,
170    Running,
171    Stopping,
172    Stopped,
173}
174
175impl Default for State {
176    fn default() -> Self {
177        Self::Idle
178    }
179}
180
181pub struct FutureTaskState<E: TError> {
182    ctx: Context<E>,
183    want: State,
184}
185impl<E: TError> FutureTaskState<E> {
186    fn new(ctx: Context<E>, want: State) -> Self {
187        Self { ctx, want }
188    }
189}
190
191impl<E: TError> Future for FutureTaskState<E> {
192    type Output = Result<(), TaskError<E>>;
193
194    fn poll(
195        self: std::pin::Pin<&mut Self>,
196        cx: &mut std::task::Context<'_>,
197    ) -> std::task::Poll<Self::Output> {
198        let mut g = self.ctx.inner.lock().unwrap();
199        if g.state >= self.want {
200            Poll::Ready(match g.error.clone() {
201                Some(e) => Err(e),
202                None => Ok(()),
203            })
204        } else {
205            g.wakers.push(cx.waker().clone());
206            Poll::Pending
207        }
208    }
209}