singleton_task/
context.rs1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicU32, Ordering},
5 },
6 task::{Poll, Waker},
7 thread,
8};
9
10use log::trace;
11use tokio::select;
12
13use crate::{TError, TaskError, rt};
14
15#[derive(Clone)]
16pub struct Context<E: TError> {
17 id: u32,
18 inner: Arc<Mutex<ContextInner<E>>>,
19}
20
21impl<E: TError> Context<E> {
22 pub fn id(&self) -> u32 {
23 self.id
24 }
25
26 pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
27 self.inner.lock().unwrap().set_state(state)
28 }
29
30 pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
31 FutureTaskState::new(self.clone(), state)
32 }
33
34 pub fn stop(&self) -> FutureTaskState<E> {
35 self.stop_with_result(Some(TaskError::Cancelled))
36 }
37
38 pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
39 let fur = self.wait_for(State::Stopped);
40 let mut g = self.inner.lock().unwrap();
41 if g.state >= State::Stopping {
42 return fur;
43 }
44 let _ = g.set_state(State::Stopping);
45 g.error = res;
46 g.wake_all();
47 fur
48 }
49
50 pub fn spawn<F>(&self, fut: F)
51 where
52 F: Future + Send + 'static,
53 {
54 let mut g = self.inner.lock().unwrap();
55 g.spawn(self, fut);
56 }
57
58 pub(crate) fn work_done(&self) {
59 let mut g = self.inner.lock().unwrap();
60 g.work_count -= 1;
61 trace!("[{:>6}] work count {}", self.id, g.work_count);
62 if g.work_count == 0 {
63 let _ = g.set_state(State::Stopped);
64 }
65 }
66}
67
68impl<E: TError> Default for Context<E> {
69 fn default() -> Self {
70 static TASK_ID: AtomicU32 = AtomicU32::new(1);
71 let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
72
73 Self {
74 id,
75 inner: Arc::new(Mutex::new(ContextInner {
76 id,
77 work_count: 1,
78 ..Default::default()
79 })),
80 }
81 }
82}
83
84struct ContextInner<E: TError> {
85 error: Option<TaskError<E>>,
86 state: State,
87 wakers: Vec<Waker>,
88 work_count: u32,
89 id: u32,
90}
91
92impl<E: TError> ContextInner<E> {
93 fn wake_all(&mut self) {
94 for waker in self.wakers.iter() {
95 waker.wake_by_ref();
96 }
97 self.wakers.clear();
98 }
99
100 fn set_state(&mut self, state: State) -> Result<(), &'static str> {
101 if state < self.state {
102 return Err("state is not allowed");
103 }
104 trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
105 self.state = state;
106 self.wake_all();
107 Ok(())
108 }
109
110 fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
111 where
112 F: Future + Send + 'static,
113 {
114 let ctx = ctx.clone();
115 if matches!(self.state, State::Stopping | State::Stopped) {
116 return;
117 }
118
119 self.work_count += 1;
120 trace!("[{:>6}] work count {}", ctx.id, self.work_count);
121 thread::spawn(move || {
122 rt().block_on(async move {
123 select! {
124 _ = fur =>{}
125 _ = ctx.wait_for(State::Stopping) => {}
126 }
127 ctx.work_done();
128 });
129 });
130 }
131}
132
133impl<E: TError> Default for ContextInner<E> {
134 fn default() -> Self {
135 Self {
136 id: 0,
137 error: None,
138 state: State::default(),
139 wakers: Default::default(),
140 work_count: 0,
141 }
142 }
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
146pub enum State {
147 Idle,
148 Preparing,
149 Running,
150 Stopping,
151 Stopped,
152}
153
154impl Default for State {
155 fn default() -> Self {
156 Self::Idle
157 }
158}
159
160pub struct FutureTaskState<E: TError> {
161 ctx: Context<E>,
162 want: State,
163}
164impl<E: TError> FutureTaskState<E> {
165 fn new(ctx: Context<E>, want: State) -> Self {
166 Self { ctx, want }
167 }
168}
169
170impl<E: TError> Future for FutureTaskState<E> {
171 type Output = Result<(), TaskError<E>>;
172
173 fn poll(
174 self: std::pin::Pin<&mut Self>,
175 cx: &mut std::task::Context<'_>,
176 ) -> std::task::Poll<Self::Output> {
177 let mut g = self.ctx.inner.lock().unwrap();
178 if g.state >= self.want {
179 Poll::Ready(match g.error.clone() {
180 Some(e) => Err(e),
181 None => Ok(()),
182 })
183 } else {
184 g.wakers.push(cx.waker().clone());
185 Poll::Pending
186 }
187 }
188}