task_executor/
exec.rs

1use std::fmt::Debug;
2use std::future::Future;
3use std::hash::Hash;
4use std::marker::Unpin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::task::Poll;
8
9use futures::{Sink, SinkExt, Stream, StreamExt};
10use futures::channel::mpsc;
11use futures::task::AtomicWaker;
12use parking_lot::Mutex;
13use parking_lot::RwLock;
14#[cfg(feature = "rate")]
15use update_rate::{DiscreteRateCounter, RateCounter};
16
17use queue_ext::{Action, QueueExt, Reply};
18
19use super::{
20    assert_future, close::Close, Counter, Error, ErrorType, flush::Flush, GroupTaskQueue, IndexSet,
21    PendingOnce, Spawner,
22};
23
24type DashMap<K, V> = dashmap::DashMap<K, V, ahash::RandomState>;
25type GroupChannels<G> = Arc<DashMap<G, Arc<Mutex<GroupTaskQueue<TaskType>>>>>;
26
27pub type TaskType = Box<dyn std::future::Future<Output=()> + Send + 'static + Unpin>;
28
29pub struct Executor<Tx = mpsc::Sender<((), TaskType)>, G = (), D = ()> {
30    pub(crate) tx: Tx,
31    workers: usize,
32    queue_max: isize,
33    active_count: Counter,
34    pub(crate) waiting_count: Counter,
35    completed_count: Counter,
36    #[cfg(feature = "rate")]
37    rate_counter: Arc<RwLock<DiscreteRateCounter>>,
38    flush_waker: Arc<AtomicWaker>,
39    is_flushing: Arc<AtomicBool>,
40    is_closed: Arc<AtomicBool>,
41
42    //group
43    group_channels: GroupChannels<G>,
44    _d: std::marker::PhantomData<D>,
45}
46
47impl<Tx, G, D> Clone for Executor<Tx, G, D>
48    where
49        Tx: Clone,
50{
51    #[inline]
52    fn clone(&self) -> Self {
53        Self {
54            tx: self.tx.clone(),
55            workers: self.workers,
56            queue_max: self.queue_max,
57            active_count: self.active_count.clone(),
58            waiting_count: self.waiting_count.clone(),
59            completed_count: self.completed_count.clone(),
60            #[cfg(feature = "rate")]
61            rate_counter: self.rate_counter.clone(),
62            flush_waker: self.flush_waker.clone(),
63            is_flushing: self.is_flushing.clone(),
64            is_closed: self.is_closed.clone(),
65            group_channels: self.group_channels.clone(),
66            _d: std::marker::PhantomData,
67        }
68    }
69}
70
71impl<Tx, G, D> Executor<Tx, G, D>
72    where
73        Tx: Clone + Sink<(D, TaskType)> + Unpin + Send + Sync + 'static,
74        G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
75{
76    #[inline]
77    pub(crate) fn with_channel<Rx>(
78        workers: usize,
79        queue_max: usize,
80        tx: Tx,
81        rx: Rx,
82    ) -> (Self, impl Future<Output=()>)
83        where
84            Rx: Stream<Item=(D, TaskType)> + Unpin,
85    {
86        let exec = Self {
87            tx,
88            workers,
89            queue_max: queue_max as isize,
90            active_count: Counter::new(),
91            waiting_count: Counter::new(),
92            completed_count: Counter::new(),
93            #[cfg(feature = "rate")]
94            rate_counter: Arc::new(RwLock::new(DiscreteRateCounter::new(100))),
95            flush_waker: Arc::new(AtomicWaker::new()),
96            is_flushing: Arc::new(AtomicBool::new(false)),
97            is_closed: Arc::new(AtomicBool::new(false)),
98            group_channels: Arc::new(DashMap::default()),
99            _d: std::marker::PhantomData,
100        };
101        let runner = exec.clone().run(rx);
102        (exec, runner)
103    }
104
105    #[inline]
106    pub fn spawn_with<T>(&mut self, msg: T, name: D) -> Spawner<'_, T, Tx, G, D>
107        where
108            D: Clone,
109            T: Future + Send + 'static,
110            T::Output: Send + 'static,
111    {
112        let fut = Spawner::new(self, msg, name);
113        assert_future::<Result<(), _>, _>(fut)
114    }
115
116    #[inline]
117    pub fn flush(&self) -> Flush<Tx, D> {
118        self.is_flushing.store(true, Ordering::SeqCst);
119        Flush::new(
120            self.tx.clone(),
121            self.waiting_count.clone(),
122            self.active_count.clone(),
123            self.is_flushing.clone(),
124            self.flush_waker.clone(),
125        )
126    }
127
128    #[inline]
129    pub fn close(&self) -> Close<Tx, D> {
130        self.is_flushing.store(true, Ordering::SeqCst);
131        self.is_closed.store(true, Ordering::SeqCst);
132        Close::new(
133            self.tx.clone(),
134            self.waiting_count.clone(),
135            self.active_count.clone(),
136            self.is_flushing.clone(),
137            self.flush_waker.clone(),
138        )
139    }
140
141    #[inline]
142    pub fn workers(&self) -> usize {
143        self.workers
144    }
145
146    #[inline]
147    pub fn active_count(&self) -> isize {
148        self.active_count.value()
149    }
150
151    #[inline]
152    pub fn waiting_count(&self) -> isize {
153        self.waiting_count.value()
154    }
155
156    #[inline]
157    pub fn completed_count(&self) -> isize {
158        self.completed_count.value()
159    }
160
161    #[inline]
162    #[cfg(feature = "rate")]
163    pub fn rate(&self) -> f64 {
164        self.rate_counter.read().rate()
165    }
166
167    #[inline]
168    pub fn is_full(&self) -> bool {
169        self.waiting_count() >= self.queue_max
170    }
171
172    #[inline]
173    pub fn is_closed(&self) -> bool {
174        self.is_closed.load(Ordering::SeqCst)
175    }
176
177    #[inline]
178    pub fn is_flushing(&self) -> bool {
179        self.is_flushing.load(Ordering::SeqCst)
180    }
181
182    async fn run<Rx>(self, mut task_rx: Rx)
183        where
184            Rx: Stream<Item=(D, TaskType)> + Unpin,
185    {
186        let exec = self;
187        let idle_waker = Arc::new(AtomicWaker::new());
188
189        let channel = || {
190            let rx = OneValue::new().queue_stream(|s, _| match s.take() {
191                None => Poll::Pending,
192                Some(m) => Poll::Ready(Some(m)),
193            });
194
195            let tx = rx.clone().queue_sender(|s, action| match action {
196                Action::Send(item) => Reply::Send(s.set(item)),
197                Action::IsFull => Reply::IsFull(s.is_full()),
198                Action::IsEmpty => Reply::IsEmpty(s.is_empty()),
199            });
200
201            (tx, rx)
202        };
203
204        let idle_idxs = IndexSet::new();
205        let mut txs = Vec::new();
206        let mut rxs = Vec::new();
207        for i in 0..exec.workers {
208            let (tx, mut rx) = channel();
209            let idle_waker = idle_waker.clone();
210            let idle_idxs = idle_idxs.clone();
211            idle_idxs.insert(i);
212            let exec = exec.clone();
213            let rx_fut = async move {
214                loop {
215                    match rx.next().await {
216                        Some(task) => {
217                            exec.active_count.inc();
218                            task.await;
219                            exec.completed_count.inc();
220                            exec.active_count.dec();
221                            #[cfg(feature = "rate")]
222                            exec.rate_counter.write().update();
223                        }
224                        None => break,
225                    }
226
227                    if !rx.is_full() {
228                        idle_idxs.insert(i);
229                        idle_waker.wake();
230                    }
231
232                    if exec.is_flushing() && rx.is_empty() {
233                        exec.flush_waker.wake();
234                    }
235                }
236            };
237
238            txs.push(tx);
239            rxs.push(rx_fut);
240        }
241
242        let tasks_bus = async move {
243            while let Some((_, task)) = task_rx.next().await {
244                loop {
245                    if idle_idxs.is_empty() {
246                        //sleep ...
247                        PendingOnce::new(idle_waker.clone()).await;
248                    } else if let Some(idx) = idle_idxs.pop() {
249                        //select ...
250                        if let Some(tx) = txs.get_mut(idx) {
251                            if let Err(_t) = tx.send(task).await {
252                                log::error!("send error ...");
253                                // task = t.into_inner();
254                            }
255                        }
256                        break;
257                    };
258                }
259            }
260        };
261
262        futures::future::join(tasks_bus, futures::future::join_all(rxs)).await;
263        log::info!("exit task executor");
264    }
265}
266
267impl<Tx, G> Executor<Tx, G, ()>
268    where
269        Tx: Clone + Sink<((), TaskType)> + Unpin + Send + Sync + 'static,
270        G: Hash + Eq + Clone + Debug + Send + Sync + 'static,
271{
272    #[inline]
273    pub fn spawn<T>(&mut self, msg: T) -> Spawner<'_, T, Tx, G, ()>
274        where
275            T: Future + Send + 'static,
276            T::Output: Send + 'static,
277    {
278        let fut = Spawner::new(self, msg, ());
279        assert_future::<Result<(), _>, _>(fut)
280    }
281
282    #[inline]
283    pub(crate) async fn group_send(&self, name: G, task: TaskType) -> Result<(), Error<TaskType>> {
284        if self.is_closed() {
285            return Err(Error::SendError(ErrorType::Closed(Some(task))));
286        }
287
288        let gt_queue = self
289            .group_channels
290            .entry(name.clone())
291            .or_insert_with(|| Arc::new(Mutex::new(GroupTaskQueue::new())))
292            .value()
293            .clone();
294
295        let exec = self.clone();
296        let group_channels = self.group_channels.clone();
297        let runner_task = {
298            let mut task_tx = gt_queue.lock();
299            if task_tx.is_running() {
300                task_tx.push(task);
301                drop(task_tx);
302                drop(gt_queue);
303                None
304            } else {
305                task_tx.set_running(true);
306                drop(task_tx);
307                let task_rx = gt_queue; //.clone();
308                let runner_task = async move {
309                    exec.active_count.inc();
310                    task.await;
311                    exec.active_count.dec();
312                    loop {
313                        let task: Option<TaskType> = task_rx.lock().pop();
314                        if let Some(task) = task {
315                            exec.active_count.inc();
316                            task.await;
317                            exec.completed_count.inc();
318                            exec.active_count.dec();
319                        } else {
320                            group_channels.remove(&name);
321                            break;
322                        }
323                    }
324                };
325                Some(runner_task)
326            }
327        };
328
329        if let Some(runner_task) = runner_task {
330            if (self
331                .tx
332                .clone()
333                .send(((), Box::new(Box::pin(runner_task))))
334                .await)
335                .is_err()
336            {
337                Err(Error::SendError(ErrorType::Closed(None)))
338            } else {
339                Ok(())
340            }
341        } else {
342            Ok(())
343        }
344    }
345}
346
347#[derive(Clone)]
348struct OneValue(Arc<RwLock<Option<TaskType>>>);
349
350unsafe impl Sync for OneValue {}
351
352unsafe impl Send for OneValue {}
353
354impl OneValue {
355    #[inline]
356    fn new() -> Self {
357        Self(Arc::new(RwLock::new(None)))
358    }
359
360    #[inline]
361    fn set(&self, val: TaskType) -> Option<TaskType> {
362        self.0.write().replace(val)
363    }
364
365    #[inline]
366    fn take(&self) -> Option<TaskType> {
367        self.0.write().take()
368    }
369
370    #[inline]
371    fn is_full(&self) -> bool {
372        self.0.read().is_some()
373    }
374
375    fn is_empty(&self) -> bool {
376        self.0.read().is_none()
377    }
378}