singleton_task/
context.rs1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicU32, Ordering},
5 },
6 task::{Poll, Waker},
7 thread,
8};
9
10use log::trace;
11use tokio::{runtime::Handle, select};
12use tokio_util::sync::CancellationToken;
13
14use crate::{TError, TaskError};
15
16#[derive(Clone)]
17pub struct Context<E: TError> {
18 id: u32,
19 inner: Arc<Mutex<ContextInner<E>>>,
20 cancel: CancellationToken,
21}
22
23impl<E: TError> Context<E> {
24 pub fn id(&self) -> u32 {
25 self.id
26 }
27
28 pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
29 self.inner.lock().unwrap().set_state(state)
30 }
31
32 pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
33 FutureTaskState::new(self.clone(), state)
34 }
35
36 pub fn stop(&self) -> FutureTaskState<E> {
37 self.stop_with_result(Some(TaskError::Cancelled))
38 }
39
40 pub fn is_active(&self) -> bool {
41 !self.cancel.is_cancelled()
42 }
43
44 pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
45 let fur = self.wait_for(State::Stopped);
46 let mut g = self.inner.lock().unwrap();
47 if g.state >= State::Stopping {
48 return fur;
49 }
50 let _ = g.set_state(State::Stopping);
51 g.error = res;
52 g.wake_all();
53 drop(g);
54 self.cancel.cancel();
55 fur
56 }
57
58 pub fn spawn<F>(&self, fut: F)
59 where
60 F: Future + Send + 'static,
61 {
62 let mut g = self.inner.lock().unwrap();
63 g.spawn(self, fut);
64 }
65
66 pub(crate) fn work_done(&self) {
67 let mut g = self.inner.lock().unwrap();
68 g.work_count -= 1;
69 trace!("[{:>6}] work count {}", self.id, g.work_count);
70 if g.work_count == 1 && g.state == State::Running {
71 let _ = g.set_state(State::Stopping);
72 }
73
74 if g.work_count == 0 {
75 let _ = g.set_state(State::Stopped);
76 }
77 }
78}
79
80impl<E: TError> Default for Context<E> {
81 fn default() -> Self {
82 static TASK_ID: AtomicU32 = AtomicU32::new(1);
83 let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
84
85 Self {
86 id,
87 inner: Arc::new(Mutex::new(ContextInner {
88 id,
89 work_count: 1,
90 ..Default::default()
91 })),
92 cancel: CancellationToken::new(),
93 }
94 }
95}
96
97struct ContextInner<E: TError> {
98 error: Option<TaskError<E>>,
99 state: State,
100 wakers: Vec<Waker>,
101 work_count: u32,
102 id: u32,
103}
104
105impl<E: TError> ContextInner<E> {
106 fn wake_all(&mut self) {
107 for waker in self.wakers.iter() {
108 waker.wake_by_ref();
109 }
110 self.wakers.clear();
111 }
112
113 fn set_state(&mut self, state: State) -> Result<(), &'static str> {
114 if state < self.state {
115 return Err("state is not allowed");
116 }
117 trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
118 self.state = state;
119 self.wake_all();
120 Ok(())
121 }
122
123 fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
124 where
125 F: Future + Send + 'static,
126 {
127 let ctx = ctx.clone();
128 if matches!(self.state, State::Stopping | State::Stopped) {
129 return;
130 }
131
132 self.work_count += 1;
133 trace!("[{:>6}] work count {}", ctx.id, self.work_count);
134 let handle = Handle::current();
135 thread::spawn(move || {
136 handle.block_on(async move {
137 select! {
138 _ = fur =>{}
139 _ = ctx.cancel.cancelled() => {}
140 _ = ctx.wait_for(State::Stopping) => {}
141 }
142 ctx.work_done();
143 })
144 });
145 }
146}
147
148impl<E: TError> Default for ContextInner<E> {
149 fn default() -> Self {
150 Self {
151 id: 0,
152 error: None,
153 state: State::default(),
154 wakers: Default::default(),
155 work_count: 0,
156 }
157 }
158}
159
160#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
161pub enum State {
162 Idle,
163 Preparing,
164 Running,
165 Stopping,
166 Stopped,
167}
168
169impl Default for State {
170 fn default() -> Self {
171 Self::Idle
172 }
173}
174
175pub struct FutureTaskState<E: TError> {
176 ctx: Context<E>,
177 want: State,
178}
179impl<E: TError> FutureTaskState<E> {
180 fn new(ctx: Context<E>, want: State) -> Self {
181 Self { ctx, want }
182 }
183}
184
185impl<E: TError> Future for FutureTaskState<E> {
186 type Output = Result<(), TaskError<E>>;
187
188 fn poll(
189 self: std::pin::Pin<&mut Self>,
190 cx: &mut std::task::Context<'_>,
191 ) -> std::task::Poll<Self::Output> {
192 let mut g = self.ctx.inner.lock().unwrap();
193 if g.state >= self.want {
194 Poll::Ready(match g.error.clone() {
195 Some(e) => Err(e),
196 None => Ok(()),
197 })
198 } else {
199 g.wakers.push(cx.waker().clone());
200 Poll::Pending
201 }
202 }
203}