singleton_task/
context.rs1use std::{
2 sync::{
3 Arc, Mutex,
4 atomic::{AtomicU32, Ordering},
5 },
6 task::{Poll, Waker},
7};
8
9use log::trace;
10use tokio::select;
11use tokio_util::sync::CancellationToken;
12
13use crate::{TError, TaskError};
14
15#[derive(Clone)]
16pub struct Context<E: TError> {
17 id: u32,
18 inner: Arc<Mutex<ContextInner<E>>>,
19 cancel: CancellationToken,
20}
21
22impl<E: TError> Context<E> {
23 pub fn id(&self) -> u32 {
24 self.id
25 }
26
27 pub(crate) fn set_state(&self, state: State) -> Result<(), &'static str> {
28 self.inner.lock().unwrap().set_state(state)
29 }
30
31 pub fn wait_for(&self, state: State) -> FutureTaskState<E> {
32 FutureTaskState::new(self.clone(), state)
33 }
34
35 pub fn stop(&self) -> FutureTaskState<E> {
36 self.stop_with_result(Some(TaskError::Cancelled))
37 }
38
39 pub fn is_active(&self) -> bool {
40 !self.cancel.is_cancelled()
41 }
42
43 pub fn stop_with_result(&self, res: Option<TaskError<E>>) -> FutureTaskState<E> {
44 let fur = self.wait_for(State::Stopped);
45 let mut g = self.inner.lock().unwrap();
46 if g.state >= State::Stopping {
47 return fur;
48 }
49 let _ = g.set_state(State::Stopping);
50 g.error = res;
51 g.wake_all();
52 drop(g);
53 self.cancel.cancel();
54 fur
55 }
56
57 pub fn spawn<F>(&self, fut: F)
58 where
59 F: Future + Send + 'static,
60 {
61 let mut g = self.inner.lock().unwrap();
62 g.spawn(self, fut);
63 }
64
65 pub(crate) fn work_done(&self) {
66 let mut g = self.inner.lock().unwrap();
67 g.work_count -= 1;
68 trace!("[{:>6}] work count {}", self.id, g.work_count);
69 if g.work_count == 1 && g.state == State::Running {
70 let _ = g.set_state(State::Stopping);
71 }
72
73 if g.work_count == 0 {
74 let _ = g.set_state(State::Stopped);
75 }
76 }
77}
78
79impl<E: TError> Default for Context<E> {
80 fn default() -> Self {
81 static TASK_ID: AtomicU32 = AtomicU32::new(1);
82 let id = TASK_ID.fetch_add(1, Ordering::SeqCst);
83
84 Self {
85 id,
86 inner: Arc::new(Mutex::new(ContextInner {
87 id,
88 work_count: 1,
89 ..Default::default()
90 })),
91 cancel: CancellationToken::new(),
92 }
93 }
94}
95
96struct ContextInner<E: TError> {
97 error: Option<TaskError<E>>,
98 state: State,
99 wakers: Vec<Waker>,
100 work_count: u32,
101 id: u32,
102}
103
104impl<E: TError> ContextInner<E> {
105 fn wake_all(&mut self) {
106 for waker in self.wakers.iter() {
107 waker.wake_by_ref();
108 }
109 self.wakers.clear();
110 }
111
112 fn set_state(&mut self, state: State) -> Result<(), &'static str> {
113 if state < self.state {
114 return Err("state is not allowed");
115 }
116 trace!("[{:>6}] [{:?}]=>[{:?}]", self.id, self.state, state);
117 self.state = state;
118 self.wake_all();
119 Ok(())
120 }
121
122 fn spawn<F>(&mut self, ctx: &Context<E>, fur: F)
123 where
124 F: Future + Send + 'static,
125 {
126 let ctx = ctx.clone();
127 if matches!(self.state, State::Stopping | State::Stopped) {
128 return;
129 }
130
131 self.work_count += 1;
132 trace!("[{:>6}] work count {}", ctx.id, self.work_count);
133
134 tokio::spawn(async move {
135 select! {
136 _ = fur =>{}
137 _ = ctx.cancel.cancelled() => {}
138 _ = ctx.wait_for(State::Stopping) => {}
139 }
140 ctx.work_done();
141 });
142 }
143}
144
145impl<E: TError> Default for ContextInner<E> {
146 fn default() -> Self {
147 Self {
148 id: 0,
149 error: None,
150 state: State::default(),
151 wakers: Default::default(),
152 work_count: 0,
153 }
154 }
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
158pub enum State {
159 Idle,
160 Preparing,
161 Running,
162 Stopping,
163 Stopped,
164}
165
166impl Default for State {
167 fn default() -> Self {
168 Self::Idle
169 }
170}
171
172pub struct FutureTaskState<E: TError> {
173 ctx: Context<E>,
174 want: State,
175}
176impl<E: TError> FutureTaskState<E> {
177 fn new(ctx: Context<E>, want: State) -> Self {
178 Self { ctx, want }
179 }
180}
181
182impl<E: TError> Future for FutureTaskState<E> {
183 type Output = Result<(), TaskError<E>>;
184
185 fn poll(
186 self: std::pin::Pin<&mut Self>,
187 cx: &mut std::task::Context<'_>,
188 ) -> std::task::Poll<Self::Output> {
189 let mut g = self.ctx.inner.lock().unwrap();
190 if g.state >= self.want {
191 Poll::Ready(match g.error.clone() {
192 Some(e) => Err(e),
193 None => Ok(()),
194 })
195 } else {
196 g.wakers.push(cx.waker().clone());
197 Poll::Pending
198 }
199 }
200}