1use std::fmt::Debug;
2use std::future::Future;
3use std::hash::Hash;
4use std::marker::Unpin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::task::Poll;
8
9use futures::{Sink, SinkExt, Stream, StreamExt};
10use futures::task::AtomicWaker;
11use parking_lot::Mutex;
12use parking_lot::RwLock;
13#[cfg(feature = "rate")]
14use update_rate::{DiscreteRateCounter, RateCounter};
15
16use queue_ext::{Action, QueueExt, Reply};
17
18use super::{
19 assert_future, close::Close, Counter, Error, ErrorType, flush::Flush,
20 GroupTaskQueue, IndexSet, local_builder::SyncSender, LocalSpawner, PendingOnce,
21};
22
23type DashMap<K, V> = dashmap::DashMap<K, V, ahash::RandomState>;
24type GroupChannels<G> = Arc<DashMap<G, Arc<Mutex<GroupTaskQueue<LocalTaskType>>>>>;
25
26pub type LocalTaskType = Box<dyn std::future::Future<Output=()> + 'static + Unpin>;
27
28pub struct LocalExecutor<Tx = SyncSender, G = (), D = ()> {
29 pub(crate) tx: Tx,
30 workers: usize,
31 queue_max: isize,
32 active_count: Counter,
33 pub(crate) waiting_count: Counter,
34 completed_count: Counter,
35 #[cfg(feature = "rate")]
36 rate_counter: Arc<RwLock<DiscreteRateCounter>>,
37 flush_waker: Arc<AtomicWaker>,
38 is_flushing: Arc<AtomicBool>,
39 is_closed: Arc<AtomicBool>,
40
41 group_channels: GroupChannels<G>,
43 _d: std::marker::PhantomData<D>,
44}
45
46impl<Tx, G, D> Clone for LocalExecutor<Tx, G, D>
47 where
48 Tx: Clone,
49{
50 #[inline]
51 fn clone(&self) -> Self {
52 Self {
53 tx: self.tx.clone(),
54 workers: self.workers,
55 queue_max: self.queue_max,
56 active_count: self.active_count.clone(),
57 waiting_count: self.waiting_count.clone(),
58 completed_count: self.completed_count.clone(),
59 #[cfg(feature = "rate")]
60 rate_counter: self.rate_counter.clone(),
61 flush_waker: self.flush_waker.clone(),
62 is_flushing: self.is_flushing.clone(),
63 is_closed: self.is_closed.clone(),
64 group_channels: self.group_channels.clone(),
65 _d: std::marker::PhantomData,
66 }
67 }
68}
69
70impl<Tx, G, D> LocalExecutor<Tx, G, D>
71 where
72 Tx: Clone + Sink<(D, LocalTaskType)> + Unpin + Sync + 'static,
73 G: Hash + Eq + Clone + Debug + Sync + 'static,
74{
75 #[inline]
76 pub(crate) fn with_channel<Rx>(
77 workers: usize,
78 queue_max: usize,
79 tx: Tx,
80 rx: Rx,
81 ) -> (Self, impl Future<Output=()>)
82 where
83 Rx: Stream<Item=(D, LocalTaskType)> + Unpin,
84 {
85 let exec = Self {
86 tx,
87 workers,
88 queue_max: queue_max as isize,
89 active_count: Counter::new(),
90 waiting_count: Counter::new(),
91 completed_count: Counter::new(),
92 #[cfg(feature = "rate")]
93 rate_counter: Arc::new(RwLock::new(DiscreteRateCounter::new(100))),
94 flush_waker: Arc::new(AtomicWaker::new()),
95 is_flushing: Arc::new(AtomicBool::new(false)),
96 is_closed: Arc::new(AtomicBool::new(false)),
97 group_channels: Arc::new(DashMap::default()),
98 _d: std::marker::PhantomData,
99 };
100 let runner = exec.clone().run(rx);
101 (exec, runner)
102 }
103
104 #[inline]
105 pub fn spawn_with<T>(&mut self, msg: T, name: D) -> LocalSpawner<'_, T, Tx, G, D>
106 where
107 D: Clone,
108 T: Future + 'static,
109 T::Output: 'static,
110 {
111 let fut = LocalSpawner::new(self, msg, name);
112 assert_future::<Result<(), _>, _>(fut)
113 }
114
115 #[inline]
116 pub fn flush(&self) -> Flush<Tx, D> {
117 self.is_flushing.store(true, Ordering::SeqCst);
118 Flush::new(
119 self.tx.clone(),
120 self.waiting_count.clone(),
121 self.active_count.clone(),
122 self.is_flushing.clone(),
123 self.flush_waker.clone(),
124 )
125 }
126
127 #[inline]
128 pub fn close(&self) -> Close<Tx, D> {
129 self.is_flushing.store(true, Ordering::SeqCst);
130 self.is_closed.store(true, Ordering::SeqCst);
131 Close::new(
132 self.tx.clone(),
133 self.waiting_count.clone(),
134 self.active_count.clone(),
135 self.is_flushing.clone(),
136 self.flush_waker.clone(),
137 )
138 }
139
140 #[inline]
141 pub fn workers(&self) -> usize {
142 self.workers
143 }
144
145 #[inline]
146 pub fn active_count(&self) -> isize {
147 self.active_count.value()
148 }
149
150 #[inline]
151 pub fn waiting_count(&self) -> isize {
152 self.waiting_count.value()
153 }
154
155 #[inline]
156 pub fn completed_count(&self) -> isize {
157 self.completed_count.value()
158 }
159
160 #[inline]
161 #[cfg(feature = "rate")]
162 pub fn rate(&self) -> f64 {
163 self.rate_counter.read().rate()
164 }
165
166 #[inline]
167 pub fn is_full(&self) -> bool {
168 self.waiting_count() >= self.queue_max
169 }
170
171 #[inline]
172 pub fn is_closed(&self) -> bool {
173 self.is_closed.load(Ordering::SeqCst)
174 }
175
176 #[inline]
177 pub fn is_flushing(&self) -> bool {
178 self.is_flushing.load(Ordering::SeqCst)
179 }
180
181 async fn run<Rx>(self, mut task_rx: Rx)
182 where
183 Rx: Stream<Item=(D, LocalTaskType)> + Unpin,
184 {
185 let exec = self;
186 let idle_waker = Arc::new(AtomicWaker::new());
187
188 let channel = || {
189 let rx = OneValue::new().queue_stream(|s, _| match s.take() {
190 None => Poll::Pending,
191 Some(m) => Poll::Ready(Some(m)),
192 });
193
194 let tx = rx.clone().queue_sender(|s, action| match action {
195 Action::Send(item) => Reply::Send(s.set(item)),
196 Action::IsFull => Reply::IsFull(s.is_full()),
197 Action::IsEmpty => Reply::IsEmpty(s.is_empty()),
198 });
199
200 (tx, rx)
201 };
202
203 let idle_idxs = IndexSet::new();
204 let mut txs = Vec::new();
205 let mut rxs = Vec::new();
206 for i in 0..exec.workers {
207 let (tx, mut rx) = channel();
208 let idle_waker = idle_waker.clone();
209 let idle_idxs = idle_idxs.clone();
210 idle_idxs.insert(i);
211 let exec = exec.clone();
212 let rx_fut = async move {
213 loop {
214 match rx.next().await {
215 Some(task) => {
216 exec.active_count.inc();
217 task.await;
218 exec.completed_count.inc();
219 exec.active_count.dec();
220 #[cfg(feature = "rate")]
221 exec.rate_counter.write().update();
222 }
223 None => break,
224 }
225
226 if !rx.is_full() {
227 idle_idxs.insert(i);
228 idle_waker.wake();
229 }
230
231 if exec.is_flushing() && rx.is_empty() {
232 exec.flush_waker.wake();
233 }
234 }
235 };
236
237 txs.push(tx);
238 rxs.push(rx_fut);
239 }
240
241 let tasks_bus = async move {
242 while let Some((_, task)) = task_rx.next().await {
243 loop {
244 if idle_idxs.is_empty() {
245 PendingOnce::new(idle_waker.clone()).await;
247 } else if let Some(idx) = idle_idxs.pop() {
248 if let Some(tx) = txs.get_mut(idx) {
250 if let Err(_t) = tx.send(task).await {
251 log::error!("send error ...");
252 }
254 }
255 break;
256 };
257 }
258 }
259 };
260
261 futures::future::join(tasks_bus, futures::future::join_all(rxs)).await;
262 log::info!("exit local task executor");
263 }
264}
265
266impl<Tx, G> LocalExecutor<Tx, G, ()>
267 where
268 Tx: Clone + Sink<((), LocalTaskType)> + Unpin + Sync + 'static,
269 G: Hash + Eq + Clone + Debug + Sync + 'static,
270{
271 #[inline]
272 pub fn spawn<T>(&mut self, msg: T) -> LocalSpawner<'_, T, Tx, G, ()>
273 where
274 T: Future + 'static,
275 T::Output: 'static,
276 {
277 let fut = LocalSpawner::new(self, msg, ());
278 assert_future::<Result<(), _>, _>(fut)
279 }
280
281 #[inline]
282 pub(crate) async fn group_send(
283 &self,
284 name: G,
285 task: LocalTaskType,
286 ) -> Result<(), Error<LocalTaskType>> {
287 if self.is_closed() {
288 return Err(Error::SendError(ErrorType::Closed(Some(task))));
289 }
290
291 let gt_queue = self
292 .group_channels
293 .entry(name.clone())
294 .or_insert_with(|| Arc::new(Mutex::new(GroupTaskQueue::new())))
295 .value()
296 .clone();
297
298 let exec = self.clone();
299 let group_channels = self.group_channels.clone();
300 let runner_task = {
301 let mut task_tx = gt_queue.lock();
302 if task_tx.is_running() {
303 task_tx.push(task);
304 drop(task_tx);
305 drop(gt_queue);
306 None
307 } else {
308 task_tx.set_running(true);
309 drop(task_tx);
310 let task_rx = gt_queue; let runner_task = async move {
312 exec.active_count.inc();
313 task.await;
314 exec.active_count.dec();
315 loop {
316 let task: Option<LocalTaskType> = task_rx.lock().pop();
317 if let Some(task) = task {
318 exec.active_count.inc();
319 task.await;
320 exec.completed_count.inc();
321 exec.active_count.dec();
322 } else {
323 group_channels.remove(&name);
324 break;
325 }
326 }
327 };
328 Some(runner_task)
329 }
330 };
331
332 if let Some(runner_task) = runner_task {
333 if (self
334 .tx
335 .clone()
336 .send(((), Box::new(Box::pin(runner_task))))
337 .await)
338 .is_err()
339 {
340 Err(Error::SendError(ErrorType::Closed(None)))
341 } else {
342 Ok(())
343 }
344 } else {
345 Ok(())
346 }
347 }
348}
349
350#[derive(Clone)]
351struct OneValue(Arc<RwLock<Option<LocalTaskType>>>);
352
353unsafe impl Sync for OneValue {}
354
355unsafe impl Send for OneValue {}
356
357impl OneValue {
358 #[inline]
359 fn new() -> Self {
360 Self(Arc::new(RwLock::new(None)))
361 }
362
363 #[inline]
364 fn set(&self, val: LocalTaskType) -> Option<LocalTaskType> {
365 self.0.write().replace(val)
366 }
367
368 #[inline]
369 fn take(&self) -> Option<LocalTaskType> {
370 self.0.write().take()
371 }
372
373 #[inline]
374 fn is_full(&self) -> bool {
375 self.0.read().is_some()
376 }
377
378 fn is_empty(&self) -> bool {
379 self.0.read().is_none()
380 }
381}