1use std::fmt::Debug;
2use std::future::Future;
3use std::hash::Hash;
4use std::marker::Unpin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::task::Poll;
8
9use futures::{Sink, SinkExt, Stream, StreamExt};
10use futures::channel::mpsc;
11use futures::task::AtomicWaker;
12use parking_lot::Mutex;
13use parking_lot::RwLock;
14#[cfg(feature = "rate")]
15use update_rate::{DiscreteRateCounter, RateCounter};
16
17use queue_ext::{Action, QueueExt, Reply};
18
19use super::{
20 assert_future, close::Close, Counter, Error, ErrorType, flush::Flush, GroupTaskQueue, IndexSet,
21 PendingOnce, Spawner,
22};
23
24type DashMap<K, V> = dashmap::DashMap<K, V, ahash::RandomState>;
25type GroupChannels<G> = Arc<DashMap<G, Arc<Mutex<GroupTaskQueue<TaskType>>>>>;
26
27pub type TaskType = Box<dyn std::future::Future<Output=()> + Send + 'static + Unpin>;
28
29pub struct Executor<Tx = mpsc::Sender<((), TaskType)>, G = (), D = ()> {
30 pub(crate) tx: Tx,
31 workers: usize,
32 queue_max: isize,
33 active_count: Counter,
34 pub(crate) waiting_count: Counter,
35 completed_count: Counter,
36 #[cfg(feature = "rate")]
37 rate_counter: Arc<RwLock<DiscreteRateCounter>>,
38 flush_waker: Arc<AtomicWaker>,
39 is_flushing: Arc<AtomicBool>,
40 is_closed: Arc<AtomicBool>,
41
42 group_channels: GroupChannels<G>,
44 _d: std::marker::PhantomData<D>,
45}
46
47impl<Tx, G, D> Clone for Executor<Tx, G, D>
48 where
49 Tx: Clone,
50{
51 #[inline]
52 fn clone(&self) -> Self {
53 Self {
54 tx: self.tx.clone(),
55 workers: self.workers,
56 queue_max: self.queue_max,
57 active_count: self.active_count.clone(),
58 waiting_count: self.waiting_count.clone(),
59 completed_count: self.completed_count.clone(),
60 #[cfg(feature = "rate")]
61 rate_counter: self.rate_counter.clone(),
62 flush_waker: self.flush_waker.clone(),
63 is_flushing: self.is_flushing.clone(),
64 is_closed: self.is_closed.clone(),
65 group_channels: self.group_channels.clone(),
66 _d: std::marker::PhantomData,
67 }
68 }
69}
70
71impl<Tx, G, D> Executor<Tx, G, D>
72 where
73 Tx: Clone + Sink<(D, TaskType)> + Unpin + Send + Sync + 'static,
74 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
75{
76 #[inline]
77 pub(crate) fn with_channel<Rx>(
78 workers: usize,
79 queue_max: usize,
80 tx: Tx,
81 rx: Rx,
82 ) -> (Self, impl Future<Output=()>)
83 where
84 Rx: Stream<Item=(D, TaskType)> + Unpin,
85 {
86 let exec = Self {
87 tx,
88 workers,
89 queue_max: queue_max as isize,
90 active_count: Counter::new(),
91 waiting_count: Counter::new(),
92 completed_count: Counter::new(),
93 #[cfg(feature = "rate")]
94 rate_counter: Arc::new(RwLock::new(DiscreteRateCounter::new(100))),
95 flush_waker: Arc::new(AtomicWaker::new()),
96 is_flushing: Arc::new(AtomicBool::new(false)),
97 is_closed: Arc::new(AtomicBool::new(false)),
98 group_channels: Arc::new(DashMap::default()),
99 _d: std::marker::PhantomData,
100 };
101 let runner = exec.clone().run(rx);
102 (exec, runner)
103 }
104
105 #[inline]
106 pub fn spawn_with<T>(&mut self, msg: T, name: D) -> Spawner<'_, T, Tx, G, D>
107 where
108 D: Clone,
109 T: Future + Send + 'static,
110 T::Output: Send + 'static,
111 {
112 let fut = Spawner::new(self, msg, name);
113 assert_future::<Result<(), _>, _>(fut)
114 }
115
116 #[inline]
117 pub fn flush(&self) -> Flush<Tx, D> {
118 self.is_flushing.store(true, Ordering::SeqCst);
119 Flush::new(
120 self.tx.clone(),
121 self.waiting_count.clone(),
122 self.active_count.clone(),
123 self.is_flushing.clone(),
124 self.flush_waker.clone(),
125 )
126 }
127
128 #[inline]
129 pub fn close(&self) -> Close<Tx, D> {
130 self.is_flushing.store(true, Ordering::SeqCst);
131 self.is_closed.store(true, Ordering::SeqCst);
132 Close::new(
133 self.tx.clone(),
134 self.waiting_count.clone(),
135 self.active_count.clone(),
136 self.is_flushing.clone(),
137 self.flush_waker.clone(),
138 )
139 }
140
141 #[inline]
142 pub fn workers(&self) -> usize {
143 self.workers
144 }
145
146 #[inline]
147 pub fn active_count(&self) -> isize {
148 self.active_count.value()
149 }
150
151 #[inline]
152 pub fn waiting_count(&self) -> isize {
153 self.waiting_count.value()
154 }
155
156 #[inline]
157 pub fn completed_count(&self) -> isize {
158 self.completed_count.value()
159 }
160
161 #[inline]
162 #[cfg(feature = "rate")]
163 pub fn rate(&self) -> f64 {
164 self.rate_counter.read().rate()
165 }
166
167 #[inline]
168 pub fn is_full(&self) -> bool {
169 self.waiting_count() >= self.queue_max
170 }
171
172 #[inline]
173 pub fn is_closed(&self) -> bool {
174 self.is_closed.load(Ordering::SeqCst)
175 }
176
177 #[inline]
178 pub fn is_flushing(&self) -> bool {
179 self.is_flushing.load(Ordering::SeqCst)
180 }
181
182 async fn run<Rx>(self, mut task_rx: Rx)
183 where
184 Rx: Stream<Item=(D, TaskType)> + Unpin,
185 {
186 let exec = self;
187 let idle_waker = Arc::new(AtomicWaker::new());
188
189 let channel = || {
190 let rx = OneValue::new().queue_stream(|s, _| match s.take() {
191 None => Poll::Pending,
192 Some(m) => Poll::Ready(Some(m)),
193 });
194
195 let tx = rx.clone().queue_sender(|s, action| match action {
196 Action::Send(item) => Reply::Send(s.set(item)),
197 Action::IsFull => Reply::IsFull(s.is_full()),
198 Action::IsEmpty => Reply::IsEmpty(s.is_empty()),
199 });
200
201 (tx, rx)
202 };
203
204 let idle_idxs = IndexSet::new();
205 let mut txs = Vec::new();
206 let mut rxs = Vec::new();
207 for i in 0..exec.workers {
208 let (tx, mut rx) = channel();
209 let idle_waker = idle_waker.clone();
210 let idle_idxs = idle_idxs.clone();
211 idle_idxs.insert(i);
212 let exec = exec.clone();
213 let rx_fut = async move {
214 loop {
215 match rx.next().await {
216 Some(task) => {
217 exec.active_count.inc();
218 task.await;
219 exec.completed_count.inc();
220 exec.active_count.dec();
221 #[cfg(feature = "rate")]
222 exec.rate_counter.write().update();
223 }
224 None => break,
225 }
226
227 if !rx.is_full() {
228 idle_idxs.insert(i);
229 idle_waker.wake();
230 }
231
232 if exec.is_flushing() && rx.is_empty() {
233 exec.flush_waker.wake();
234 }
235 }
236 };
237
238 txs.push(tx);
239 rxs.push(rx_fut);
240 }
241
242 let tasks_bus = async move {
243 while let Some((_, task)) = task_rx.next().await {
244 loop {
245 if idle_idxs.is_empty() {
246 PendingOnce::new(idle_waker.clone()).await;
248 } else if let Some(idx) = idle_idxs.pop() {
249 if let Some(tx) = txs.get_mut(idx) {
251 if let Err(_t) = tx.send(task).await {
252 log::error!("send error ...");
253 }
255 }
256 break;
257 };
258 }
259 }
260 };
261
262 futures::future::join(tasks_bus, futures::future::join_all(rxs)).await;
263 log::info!("exit task executor");
264 }
265}
266
267impl<Tx, G> Executor<Tx, G, ()>
268 where
269 Tx: Clone + Sink<((), TaskType)> + Unpin + Send + Sync + 'static,
270 G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
271{
272 #[inline]
273 pub fn spawn<T>(&mut self, msg: T) -> Spawner<'_, T, Tx, G, ()>
274 where
275 T: Future + Send + 'static,
276 T::Output: Send + 'static,
277 {
278 let fut = Spawner::new(self, msg, ());
279 assert_future::<Result<(), _>, _>(fut)
280 }
281
282 #[inline]
283 pub(crate) async fn group_send(&self, name: G, task: TaskType) -> Result<(), Error<TaskType>> {
284 if self.is_closed() {
285 return Err(Error::SendError(ErrorType::Closed(Some(task))));
286 }
287
288 let gt_queue = self
289 .group_channels
290 .entry(name.clone())
291 .or_insert_with(|| Arc::new(Mutex::new(GroupTaskQueue::new())))
292 .value()
293 .clone();
294
295 let exec = self.clone();
296 let group_channels = self.group_channels.clone();
297 let runner_task = {
298 let mut task_tx = gt_queue.lock();
299 if task_tx.is_running() {
300 task_tx.push(task);
301 drop(task_tx);
302 drop(gt_queue);
303 None
304 } else {
305 task_tx.set_running(true);
306 drop(task_tx);
307 let task_rx = gt_queue; let runner_task = async move {
309 exec.active_count.inc();
310 task.await;
311 exec.active_count.dec();
312 loop {
313 let task: Option<TaskType> = task_rx.lock().pop();
314 if let Some(task) = task {
315 exec.active_count.inc();
316 task.await;
317 exec.completed_count.inc();
318 exec.active_count.dec();
319 } else {
320 group_channels.remove(&name);
321 break;
322 }
323 }
324 };
325 Some(runner_task)
326 }
327 };
328
329 if let Some(runner_task) = runner_task {
330 if (self
331 .tx
332 .clone()
333 .send(((), Box::new(Box::pin(runner_task))))
334 .await)
335 .is_err()
336 {
337 Err(Error::SendError(ErrorType::Closed(None)))
338 } else {
339 Ok(())
340 }
341 } else {
342 Ok(())
343 }
344 }
345}
346
347#[derive(Clone)]
348struct OneValue(Arc<RwLock<Option<TaskType>>>);
349
350unsafe impl Sync for OneValue {}
351
352unsafe impl Send for OneValue {}
353
354impl OneValue {
355 #[inline]
356 fn new() -> Self {
357 Self(Arc::new(RwLock::new(None)))
358 }
359
360 #[inline]
361 fn set(&self, val: TaskType) -> Option<TaskType> {
362 self.0.write().replace(val)
363 }
364
365 #[inline]
366 fn take(&self) -> Option<TaskType> {
367 self.0.write().take()
368 }
369
370 #[inline]
371 fn is_full(&self) -> bool {
372 self.0.read().is_some()
373 }
374
375 fn is_empty(&self) -> bool {
376 self.0.read().is_none()
377 }
378}