singleton_task/
context.rs

1use std::{
2    sync::{
3        Arc, Mutex,
4        atomic::{AtomicU32, Ordering},
5    },
6    task::{Poll, Waker},
7    thread,
8};
9
10use log::trace;
11use tokio::select;
12use tokio_util::sync::CancellationToken;
13
14use crate::{TError, TaskError, rt};
15
16#[derive(Clone)]
17pub struct Context<E: TError> {
18    id: u32,
19    inner: Arc<Mutex<ContextInner<E>>>,
20    cancel: CancellationToken,
21}
22
23impl<E: TError> Context<E> {
24    pub fn id(&self) -> u32 {
25        self.id
26    }
27
28    pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
29        self.inner.lock().unwrap().set_state(state)
30    }
31
32    pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
33        FutureTaskState::new(self.clone(), state)
34    }
35
36    pub fn stop(&self) -> FutureTaskState<E> {
37        self.stop_with_result(Some(TaskError::Cancelled))
38    }
39
40    pub fn is_active(&self) -> bool {
41        !self.cancel.is_cancelled()
42    }
43
44    pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
45        let fur = self.wait_for(State::Stopped);
46        let mut g = self.inner.lock().unwrap();
47        if g.state >= State::Stopping {
48            return fur;
49        }
50        let _ = g.set_state(State::Stopping);
51        g.error = res;
52        g.wake_all();
53        drop(g);
54        self.cancel.cancel();
55        fur
56    }
57
58    pub fn spawn<F>(&self, fut: F)
59    where
60        F: Future + Send + 'static,
61    {
62        let mut g = self.inner.lock().unwrap();
63        g.spawn(self, fut);
64    }
65
66    pub(crate) fn work_done(&self) {
67        let mut g = self.inner.lock().unwrap();
68        g.work_count -= 1;
69        trace!("[{:>6}] work count {}", self.id, g.work_count);
70        if g.work_count == 0 {
71            let _ = g.set_state(State::Stopped);
72        }
73    }
74}
75
76impl<E: TError> Default for Context<E> {
77    fn default() -> Self {
78        static TASK_ID: AtomicU32 = AtomicU32::new(1);
79        let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
80
81        Self {
82            id,
83            inner: Arc::new(Mutex::new(ContextInner {
84                id,
85                work_count: 1,
86                ..Default::default()
87            })),
88            cancel: CancellationToken::new(),
89        }
90    }
91}
92
93struct ContextInner<E: TError> {
94    error: Option<TaskError<E>>,
95    state: State,
96    wakers: Vec<Waker>,
97    work_count: u32,
98    id: u32,
99}
100
101impl<E: TError> ContextInner<E> {
102    fn wake_all(&mut self) {
103        for waker in self.wakers.iter() {
104            waker.wake_by_ref();
105        }
106        self.wakers.clear();
107    }
108
109    fn set_state(&mut self, state: State) -> Result<(), &'static str> {
110        if state < self.state {
111            return Err("state is not allowed");
112        }
113        trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
114        self.state = state;
115        self.wake_all();
116        Ok(())
117    }
118
119    fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
120    where
121        F: Future + Send + 'static,
122    {
123        let ctx = ctx.clone();
124        if matches!(self.state, State::Stopping | State::Stopped) {
125            return;
126        }
127
128        self.work_count += 1;
129        trace!("[{:>6}] work count {}", ctx.id, self.work_count);
130        thread::spawn(move || {
131            rt().block_on(async move {
132                select! {
133                    _ = fur =>{}
134                    _ = ctx.cancel.cancelled() => {}
135                    _ = ctx.wait_for(State::Stopping) => {}
136                }
137                ctx.work_done();
138            });
139        });
140    }
141}
142
143impl<E: TError> Default for ContextInner<E> {
144    fn default() -> Self {
145        Self {
146            id: 0,
147            error: None,
148            state: State::default(),
149            wakers: Default::default(),
150            work_count: 0,
151        }
152    }
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
156pub enum State {
157    Idle,
158    Preparing,
159    Running,
160    Stopping,
161    Stopped,
162}
163
164impl Default for State {
165    fn default() -> Self {
166        Self::Idle
167    }
168}
169
170pub struct FutureTaskState<E: TError> {
171    ctx: Context<E>,
172    want: State,
173}
174impl<E: TError> FutureTaskState<E> {
175    fn new(ctx: Context<E>, want: State) -> Self {
176        Self { ctx, want }
177    }
178}
179
180impl<E: TError> Future for FutureTaskState<E> {
181    type Output = Result<(), TaskError<E>>;
182
183    fn poll(
184        self: std::pin::Pin<&mut Self>,
185        cx: &mut std::task::Context<'_>,
186    ) -> std::task::Poll<Self::Output> {
187        let mut g = self.ctx.inner.lock().unwrap();
188        if g.state >= self.want {
189            Poll::Ready(match g.error.clone() {
190                Some(e) => Err(e),
191                None => Ok(()),
192            })
193        } else {
194            g.wakers.push(cx.waker().clone());
195            Poll::Pending
196        }
197    }
198}