task_executor/
local.rs

1use std::fmt::Debug;
2use std::future::Future;
3use std::hash::Hash;
4use std::marker::Unpin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::task::Poll;
8
9use futures::{Sink, SinkExt, Stream, StreamExt};
10use futures::task::AtomicWaker;
11use parking_lot::Mutex;
12use parking_lot::RwLock;
13#[cfg(feature = "rate")]
14use update_rate::{DiscreteRateCounter, RateCounter};
15
16use queue_ext::{Action, QueueExt, Reply};
17
18use super::{
19    assert_future, close::Close, Counter, Error, ErrorType, flush::Flush,
20    GroupTaskQueue, IndexSet, local_builder::SyncSender, LocalSpawner, PendingOnce,
21};
22
23type DashMap<K, V> = dashmap::DashMap<K, V, ahash::RandomState>;
24type GroupChannels<G> = Arc<DashMap<G, Arc<Mutex<GroupTaskQueue<LocalTaskType>>>>>;
25
26pub type LocalTaskType = Box<dyn std::future::Future<Output=()> + 'static + Unpin>;
27
28pub struct LocalExecutor<Tx = SyncSender, G = (), D = ()> {
29    pub(crate) tx: Tx,
30    workers: usize,
31    queue_max: isize,
32    active_count: Counter,
33    pub(crate) waiting_count: Counter,
34    completed_count: Counter,
35    #[cfg(feature = "rate")]
36    rate_counter: Arc<RwLock<DiscreteRateCounter>>,
37    flush_waker: Arc<AtomicWaker>,
38    is_flushing: Arc<AtomicBool>,
39    is_closed: Arc<AtomicBool>,
40
41    //group
42    group_channels: GroupChannels<G>,
43    _d: std::marker::PhantomData<D>,
44}
45
46impl<Tx, G, D> Clone for LocalExecutor<Tx, G, D>
47    where
48        Tx: Clone,
49{
50    #[inline]
51    fn clone(&self) -> Self {
52        Self {
53            tx: self.tx.clone(),
54            workers: self.workers,
55            queue_max: self.queue_max,
56            active_count: self.active_count.clone(),
57            waiting_count: self.waiting_count.clone(),
58            completed_count: self.completed_count.clone(),
59            #[cfg(feature = "rate")]
60            rate_counter: self.rate_counter.clone(),
61            flush_waker: self.flush_waker.clone(),
62            is_flushing: self.is_flushing.clone(),
63            is_closed: self.is_closed.clone(),
64            group_channels: self.group_channels.clone(),
65            _d: std::marker::PhantomData,
66        }
67    }
68}
69
70impl<Tx, G, D> LocalExecutor<Tx, G, D>
71    where
72        Tx: Clone + Sink<(D, LocalTaskType)> + Unpin + Sync + 'static,
73        G: Hash + Eq + Clone + Debug + Sync + 'static,
74{
75    #[inline]
76    pub(crate) fn with_channel<Rx>(
77        workers: usize,
78        queue_max: usize,
79        tx: Tx,
80        rx: Rx,
81    ) -> (Self, impl Future<Output=()>)
82        where
83            Rx: Stream<Item=(D, LocalTaskType)> + Unpin,
84    {
85        let exec = Self {
86            tx,
87            workers,
88            queue_max: queue_max as isize,
89            active_count: Counter::new(),
90            waiting_count: Counter::new(),
91            completed_count: Counter::new(),
92            #[cfg(feature = "rate")]
93            rate_counter: Arc::new(RwLock::new(DiscreteRateCounter::new(100))),
94            flush_waker: Arc::new(AtomicWaker::new()),
95            is_flushing: Arc::new(AtomicBool::new(false)),
96            is_closed: Arc::new(AtomicBool::new(false)),
97            group_channels: Arc::new(DashMap::default()),
98            _d: std::marker::PhantomData,
99        };
100        let runner = exec.clone().run(rx);
101        (exec, runner)
102    }
103
104    #[inline]
105    pub fn spawn_with<T>(&mut self, msg: T, name: D) -> LocalSpawner<'_, T, Tx, G, D>
106        where
107            D: Clone,
108            T: Future + 'static,
109            T::Output: 'static,
110    {
111        let fut = LocalSpawner::new(self, msg, name);
112        assert_future::<Result<(), _>, _>(fut)
113    }
114
115    #[inline]
116    pub fn flush(&self) -> Flush<Tx, D> {
117        self.is_flushing.store(true, Ordering::SeqCst);
118        Flush::new(
119            self.tx.clone(),
120            self.waiting_count.clone(),
121            self.active_count.clone(),
122            self.is_flushing.clone(),
123            self.flush_waker.clone(),
124        )
125    }
126
127    #[inline]
128    pub fn close(&self) -> Close<Tx, D> {
129        self.is_flushing.store(true, Ordering::SeqCst);
130        self.is_closed.store(true, Ordering::SeqCst);
131        Close::new(
132            self.tx.clone(),
133            self.waiting_count.clone(),
134            self.active_count.clone(),
135            self.is_flushing.clone(),
136            self.flush_waker.clone(),
137        )
138    }
139
140    #[inline]
141    pub fn workers(&self) -> usize {
142        self.workers
143    }
144
145    #[inline]
146    pub fn active_count(&self) -> isize {
147        self.active_count.value()
148    }
149
150    #[inline]
151    pub fn waiting_count(&self) -> isize {
152        self.waiting_count.value()
153    }
154
155    #[inline]
156    pub fn completed_count(&self) -> isize {
157        self.completed_count.value()
158    }
159
160    #[inline]
161    #[cfg(feature = "rate")]
162    pub fn rate(&self) -> f64 {
163        self.rate_counter.read().rate()
164    }
165
166    #[inline]
167    pub fn is_full(&self) -> bool {
168        self.waiting_count() >= self.queue_max
169    }
170
171    #[inline]
172    pub fn is_closed(&self) -> bool {
173        self.is_closed.load(Ordering::SeqCst)
174    }
175
176    #[inline]
177    pub fn is_flushing(&self) -> bool {
178        self.is_flushing.load(Ordering::SeqCst)
179    }
180
181    async fn run<Rx>(self, mut task_rx: Rx)
182        where
183            Rx: Stream<Item=(D, LocalTaskType)> + Unpin,
184    {
185        let exec = self;
186        let idle_waker = Arc::new(AtomicWaker::new());
187
188        let channel = || {
189            let rx = OneValue::new().queue_stream(|s, _| match s.take() {
190                None => Poll::Pending,
191                Some(m) => Poll::Ready(Some(m)),
192            });
193
194            let tx = rx.clone().queue_sender(|s, action| match action {
195                Action::Send(item) => Reply::Send(s.set(item)),
196                Action::IsFull => Reply::IsFull(s.is_full()),
197                Action::IsEmpty => Reply::IsEmpty(s.is_empty()),
198            });
199
200            (tx, rx)
201        };
202
203        let idle_idxs = IndexSet::new();
204        let mut txs = Vec::new();
205        let mut rxs = Vec::new();
206        for i in 0..exec.workers {
207            let (tx, mut rx) = channel();
208            let idle_waker = idle_waker.clone();
209            let idle_idxs = idle_idxs.clone();
210            idle_idxs.insert(i);
211            let exec = exec.clone();
212            let rx_fut = async move {
213                loop {
214                    match rx.next().await {
215                        Some(task) => {
216                            exec.active_count.inc();
217                            task.await;
218                            exec.completed_count.inc();
219                            exec.active_count.dec();
220                            #[cfg(feature = "rate")]
221                            exec.rate_counter.write().update();
222                        }
223                        None => break,
224                    }
225
226                    if !rx.is_full() {
227                        idle_idxs.insert(i);
228                        idle_waker.wake();
229                    }
230
231                    if exec.is_flushing() && rx.is_empty() {
232                        exec.flush_waker.wake();
233                    }
234                }
235            };
236
237            txs.push(tx);
238            rxs.push(rx_fut);
239        }
240
241        let tasks_bus = async move {
242            while let Some((_, task)) = task_rx.next().await {
243                loop {
244                    if idle_idxs.is_empty() {
245                        //sleep ...
246                        PendingOnce::new(idle_waker.clone()).await;
247                    } else if let Some(idx) = idle_idxs.pop() {
248                        //select ...
249                        if let Some(tx) = txs.get_mut(idx) {
250                            if let Err(_t) = tx.send(task).await {
251                                log::error!("send error ...");
252                                // task = t.into_inner();
253                            }
254                        }
255                        break;
256                    };
257                }
258            }
259        };
260
261        futures::future::join(tasks_bus, futures::future::join_all(rxs)).await;
262        log::info!("exit local task executor");
263    }
264}
265
266impl<Tx, G> LocalExecutor<Tx, G, ()>
267    where
268        Tx: Clone + Sink<((), LocalTaskType)> + Unpin + Sync + 'static,
269        G: Hash + Eq + Clone + Debug + Sync + 'static,
270{
271    #[inline]
272    pub fn spawn<T>(&mut self, msg: T) -> LocalSpawner<'_, T, Tx, G, ()>
273        where
274            T: Future + 'static,
275            T::Output: 'static,
276    {
277        let fut = LocalSpawner::new(self, msg, ());
278        assert_future::<Result<(), _>, _>(fut)
279    }
280
281    #[inline]
282    pub(crate) async fn group_send(
283        &self,
284        name: G,
285        task: LocalTaskType,
286    ) -> Result<(), Error<LocalTaskType>> {
287        if self.is_closed() {
288            return Err(Error::SendError(ErrorType::Closed(Some(task))));
289        }
290
291        let gt_queue = self
292            .group_channels
293            .entry(name.clone())
294            .or_insert_with(|| Arc::new(Mutex::new(GroupTaskQueue::new())))
295            .value()
296            .clone();
297
298        let exec = self.clone();
299        let group_channels = self.group_channels.clone();
300        let runner_task = {
301            let mut task_tx = gt_queue.lock();
302            if task_tx.is_running() {
303                task_tx.push(task);
304                drop(task_tx);
305                drop(gt_queue);
306                None
307            } else {
308                task_tx.set_running(true);
309                drop(task_tx);
310                let task_rx = gt_queue; //.clone();
311                let runner_task = async move {
312                    exec.active_count.inc();
313                    task.await;
314                    exec.active_count.dec();
315                    loop {
316                        let task: Option<LocalTaskType> = task_rx.lock().pop();
317                        if let Some(task) = task {
318                            exec.active_count.inc();
319                            task.await;
320                            exec.completed_count.inc();
321                            exec.active_count.dec();
322                        } else {
323                            group_channels.remove(&name);
324                            break;
325                        }
326                    }
327                };
328                Some(runner_task)
329            }
330        };
331
332        if let Some(runner_task) = runner_task {
333            if (self
334                .tx
335                .clone()
336                .send(((), Box::new(Box::pin(runner_task))))
337                .await)
338                .is_err()
339            {
340                Err(Error::SendError(ErrorType::Closed(None)))
341            } else {
342                Ok(())
343            }
344        } else {
345            Ok(())
346        }
347    }
348}
349
350#[derive(Clone)]
351struct OneValue(Arc<RwLock<Option<LocalTaskType>>>);
352
353unsafe impl Sync for OneValue {}
354
355unsafe impl Send for OneValue {}
356
357impl OneValue {
358    #[inline]
359    fn new() -> Self {
360        Self(Arc::new(RwLock::new(None)))
361    }
362
363    #[inline]
364    fn set(&self, val: LocalTaskType) -> Option<LocalTaskType> {
365        self.0.write().replace(val)
366    }
367
368    #[inline]
369    fn take(&self) -> Option<LocalTaskType> {
370        self.0.write().take()
371    }
372
373    #[inline]
374    fn is_full(&self) -> bool {
375        self.0.read().is_some()
376    }
377
378    fn is_empty(&self) -> bool {
379        self.0.read().is_none()
380    }
381}