singleton_task/
context.rs1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicU32, Ordering},
5 },
6 task::{Poll, Waker},
7 thread,
8};
9
10use log::trace;
11use tokio::{runtime::Handle, select};
12use tokio_util::sync::CancellationToken;
13
14use crate::{TError, TaskError};
15
16#[derive(Clone)]
17pub struct Context<E: TError> {
18 id: u32,
19 inner: Arc<Mutex<ContextInner<E>>>,
20 cancel: CancellationToken,
21}
22
23impl<E: TError> Context<E> {
24 pub fn id(&self) -> u32 {
25 self.id
26 }
27
28 pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
29 self.inner.lock().unwrap().set_state(state)
30 }
31
32 pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
33 FutureTaskState::new(self.clone(), state)
34 }
35
36 pub fn stop(&self) -> FutureTaskState<E> {
37 self.stop_with_result(Some(TaskError::Cancelled))
38 }
39
40 pub fn is_active(&self) -> bool {
41 !self.cancel.is_cancelled()
42 }
43
44 pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
45 let fur = self.wait_for(State::Stopped);
46 let mut g = self.inner.lock().unwrap();
47 if g.state >= State::Stopping {
48 return fur;
49 }
50 let _ = g.set_state(State::Stopping);
51 g.error = res;
52 g.wake_all();
53 drop(g);
54 self.cancel.cancel();
55 fur
56 }
57
58 pub fn spawn<F>(&self, fut: F)
59 where
60 F: Future + Send + 'static,
61 {
62 let mut g = self.inner.lock().unwrap();
63 g.spawn(self, fut);
64 }
65
66 pub(crate) fn work_done(&self) {
67 let mut g = self.inner.lock().unwrap();
68 g.work_count -= 1;
69 trace!("[{:>6}] work count {}", self.id, g.work_count);
70 if g.work_count == 1 && g.state == State::Running {
71 let _ = g.set_state(State::Stopping);
72 }
73
74 if g.work_count == 0 {
75 let _ = g.set_state(State::Stopped);
76 }
77 }
78}
79
80impl<E: TError> Default for Context<E> {
81 fn default() -> Self {
82 static TASK_ID: AtomicU32 = AtomicU32::new(1);
83 let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
84
85 Self {
86 id,
87 inner: Arc::new(Mutex::new(ContextInner {
88 id,
89 work_count: 1,
90 ..Default::default()
91 })),
92 cancel: CancellationToken::new(),
93 }
94 }
95}
96
97struct ContextInner<E: TError> {
98 error: Option<TaskError<E>>,
99 state: State,
100 wakers: Vec<Waker>,
101 work_count: u32,
102 id: u32,
103}
104
105impl<E: TError> ContextInner<E> {
106 fn wake_all(&mut self) {
107 for waker in self.wakers.iter() {
108 waker.wake_by_ref();
109 }
110 self.wakers.clear();
111 }
112
113 fn set_state(&mut self, state: State) -> Result<(), &'static str> {
114 if state < self.state {
115 return Err("state is not allowed");
116 }
117 trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
118 self.state = state;
119 self.wake_all();
120 Ok(())
121 }
122
123 fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
124 where
125 F: Future + Send + 'static,
126 {
127 let ctx = ctx.clone();
128 if matches!(self.state, State::Stopping | State::Stopped) {
129 return;
130 }
131
132 self.work_count += 1;
133 trace!("[{:>6}] work count {}", ctx.id, self.work_count);
134 let handle = Handle::current();
135 thread::spawn(move || {
136 handle.block_on(async move {
137 select! {
138 _ = fur =>{
139 trace!("[{:>6}] exit: finish", ctx.id);
140 }
141 _ = ctx.cancel.cancelled() => {
142 trace!("[{:>6}] exit: cancel token", ctx.id);
143 }
144 _ = ctx.wait_for(State::Stopping) => {
145 trace!("[{:>6}] exit: stopping", ctx.id);
146 }
147 }
148 ctx.work_done();
149 })
150 });
151 }
152}
153
154impl<E: TError> Default for ContextInner<E> {
155 fn default() -> Self {
156 Self {
157 id: 0,
158 error: None,
159 state: State::default(),
160 wakers: Default::default(),
161 work_count: 0,
162 }
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
167pub enum State {
168 Idle,
169 Preparing,
170 Running,
171 Stopping,
172 Stopped,
173}
174
175impl Default for State {
176 fn default() -> Self {
177 Self::Idle
178 }
179}
180
181pub struct FutureTaskState<E: TError> {
182 ctx: Context<E>,
183 want: State,
184}
185impl<E: TError> FutureTaskState<E> {
186 fn new(ctx: Context<E>, want: State) -> Self {
187 Self { ctx, want }
188 }
189}
190
191impl<E: TError> Future for FutureTaskState<E> {
192 type Output = Result<(), TaskError<E>>;
193
194 fn poll(
195 self: std::pin::Pin<&mut Self>,
196 cx: &mut std::task::Context<'_>,
197 ) -> std::task::Poll<Self::Output> {
198 let mut g = self.ctx.inner.lock().unwrap();
199 if g.state >= self.want {
200 Poll::Ready(match g.error.clone() {
201 Some(e) => Err(e),
202 None => Ok(()),
203 })
204 } else {
205 g.wakers.push(cx.waker().clone());
206 Poll::Pending
207 }
208 }
209}