sync_utils/worker_pool/
bounded.rs

1use std::{
2    cell::UnsafeCell,
3    future::Future,
4    mem::transmute,
5    pin::Pin,
6    sync::{
7        Arc,
8        atomic::{AtomicBool, AtomicUsize, Ordering},
9    },
10    task::*,
11    thread,
12    time::Duration,
13};
14
15use crossfire::*;
16use tokio::time::{sleep, timeout};
17
18use super::*;
19
20#[allow(unused_must_use)]
21pub struct WorkerPoolBounded<
22    M: Send + Sized + Unpin + 'static,
23    W: Worker<M>,
24    S: WorkerPoolImpl<M, W>,
25>(Arc<WorkerPoolBoundedInner<M, W, S>>);
26
27struct WorkerPoolBoundedInner<M, W, S>
28where
29    M: Send + Sized + Unpin + 'static,
30    W: Worker<M>,
31    S: WorkerPoolImpl<M, W>,
32{
33    worker_count: AtomicUsize,
34    sender: UnsafeCell<Option<MAsyncTx<Option<M>>>>,
35    min_workers: usize,
36    max_workers: usize,
37    worker_timeout: Duration,
38    inner: S,
39    phantom: std::marker::PhantomData<W>, // to avoid complaining unused param
40    closing: AtomicBool,
41    notify_sender: MAsyncTx<Option<()>>,
42    notify_recv: UnsafeCell<Option<AsyncRx<Option<()>>>>,
43    auto: bool,
44    channel_size: usize,
45    real_thread: AtomicBool,
46    bind_cpu: AtomicUsize,
47    max_cpu: usize,
48}
49
50unsafe impl<M, W, S> Send for WorkerPoolBoundedInner<M, W, S>
51where
52    M: Send + Sized + Unpin + 'static,
53    W: Worker<M>,
54    S: WorkerPoolImpl<M, W>,
55{
56}
57
58unsafe impl<M, W, S> Sync for WorkerPoolBoundedInner<M, W, S>
59where
60    M: Send + Sized + Unpin + 'static,
61    W: Worker<M>,
62    S: WorkerPoolImpl<M, W>,
63{
64}
65
66impl<M, W, S> Clone for WorkerPoolBounded<M, W, S>
67where
68    M: Send + Sized + Unpin + 'static,
69    W: Worker<M>,
70    S: WorkerPoolImpl<M, W>,
71{
72    #[inline]
73    fn clone(&self) -> Self {
74        Self(self.0.clone())
75    }
76}
77
78impl<M, W, S> WorkerPoolBounded<M, W, S>
79where
80    M: Send + Sized + Unpin + 'static,
81    W: Worker<M>,
82    S: WorkerPoolImpl<M, W>,
83{
84    pub fn new(
85        inner: S, min_workers: usize, max_workers: usize, channel_size: usize,
86        worker_timeout: Duration,
87    ) -> Self {
88        assert!(min_workers > 0);
89        assert!(max_workers >= min_workers);
90
91        let auto: bool = min_workers < max_workers;
92        if auto {
93            assert!(worker_timeout != ZERO_DUARTION);
94        }
95        let (noti_sender, noti_recv) = mpsc::bounded_async(1);
96        let pool = Arc::new(WorkerPoolBoundedInner {
97            sender: UnsafeCell::new(None),
98            inner,
99            worker_count: AtomicUsize::new(0),
100            min_workers,
101            max_workers,
102            channel_size,
103            worker_timeout,
104            phantom: Default::default(),
105            closing: AtomicBool::new(false),
106            notify_sender: noti_sender,
107            notify_recv: UnsafeCell::new(Some(noti_recv)),
108            auto,
109            real_thread: AtomicBool::new(false),
110            bind_cpu: AtomicUsize::new(0),
111            max_cpu: num_cpus::get(),
112        });
113        Self(pool)
114    }
115
116    // If worker contains blocking logic, run worker in separate threads
117    pub fn set_use_thread(&mut self, ok: bool) {
118        self.0.real_thread.store(ok, Ordering::Release);
119    }
120
121    pub fn start(&self) {
122        let _self = self.0.as_ref();
123        let (sender, rx) = mpmc::bounded_async(_self.channel_size);
124        _self._sender().replace(sender);
125
126        for _ in 0.._self.min_workers {
127            self.0.clone().spawn(true, rx.clone());
128        }
129        if _self.auto {
130            let _pool = self.0.clone();
131            let notify_recv: &mut Option<AsyncRx<Option<()>>> =
132                unsafe { transmute(_self.notify_recv.get()) };
133            let noti_rx = notify_recv.take().unwrap();
134            tokio::spawn(async move {
135                _pool.monitor(noti_rx, rx).await;
136            });
137        }
138    }
139
140    pub async fn close_async(&self) {
141        let _self = self.0.as_ref();
142        if _self.closing.swap(true, Ordering::SeqCst) {
143            return;
144        }
145        if _self.auto {
146            let _ = _self.notify_sender.send(None).await;
147        }
148        let sender = _self._sender().as_ref().unwrap();
149        loop {
150            let cur = self.get_worker_count();
151            if cur == 0 {
152                break;
153            }
154            debug!("worker pool closing: cur workers {}", cur);
155            for _ in 0..cur {
156                let _ = sender.send(None).await;
157            }
158            sleep(Duration::from_secs(1)).await;
159        }
160    }
161
162    // must not use in runtime
163    pub fn close(&self) {
164        if let Ok(_rt) = tokio::runtime::Handle::try_current() {
165            warn!("close in runtime thread, spawn close thread");
166            let _self = self.clone();
167            std::thread::spawn(move || {
168                let rt = tokio::runtime::Builder::new_current_thread()
169                    .enable_all()
170                    .build()
171                    .expect("runtime");
172                rt.block_on(async move {
173                    _self.close_async().await;
174                });
175            });
176        } else {
177            let rt = tokio::runtime::Builder::new_current_thread()
178                .enable_all()
179                .build()
180                .expect("runtime");
181            let _self = self.clone();
182            rt.block_on(async move {
183                _self.close_async().await;
184            });
185        }
186    }
187
188    pub fn get_worker_count(&self) -> usize {
189        self.0.get_worker_count()
190    }
191
192    pub fn get_inner(&self) -> &S {
193        &self.0.inner
194    }
195
196    #[inline]
197    pub fn try_submit(&self, msg: M) -> Option<M> {
198        let _self = self.0.as_ref();
199        if _self.closing.load(Ordering::Acquire) {
200            return Some(msg);
201        }
202        match _self._sender().as_ref().unwrap().try_send(Some(msg)) {
203            Err(TrySendError::Disconnected(m)) => {
204                return m;
205            }
206            Err(TrySendError::Full(m)) => {
207                return m;
208            }
209            Ok(_) => return None,
210        }
211    }
212
213    // return non-null on send fail
214    #[inline]
215    pub fn submit<'a>(&'a self, mut msg: M) -> SubmitFuture<'a, M> {
216        let _self = self.0.as_ref();
217        if _self.closing.load(Ordering::Acquire) {
218            return SubmitFuture { send_f: None, res: Some(Err(msg)) };
219        }
220        let sender = _self._sender().as_ref().unwrap();
221        if _self.auto {
222            match sender.try_send(Some(msg)) {
223                Err(TrySendError::Disconnected(m)) => {
224                    return SubmitFuture { send_f: None, res: Some(Err(m.unwrap())) };
225                }
226                Err(TrySendError::Full(m)) => {
227                    msg = m.unwrap();
228                }
229                Ok(_) => {
230                    return SubmitFuture { send_f: None, res: Some(Ok(())) };
231                }
232            }
233            let worker_count = _self.get_worker_count();
234            if worker_count < _self.max_workers {
235                let _ = _self.notify_sender.try_send(Some(()));
236            }
237        }
238        let send_f = sender.send(Some(msg));
239        return SubmitFuture { send_f: Some(send_f), res: None };
240    }
241}
242
243pub struct SubmitFuture<'a, M: Send + Sized + Unpin + 'static> {
244    send_f: Option<SendFuture<'a, Option<M>>>,
245    res: Option<Result<(), M>>,
246}
247
248impl<'a, M: Send + Sized + Unpin + 'static> Future for SubmitFuture<'a, M> {
249    type Output = Option<M>;
250
251    fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
252        let _self = self.get_mut();
253        if _self.res.is_some() {
254            match _self.res.take().unwrap() {
255                Ok(()) => return Poll::Ready(None),
256                Err(m) => return Poll::Ready(Some(m)),
257            }
258        }
259        let send_f = _self.send_f.as_mut().unwrap();
260        if let Poll::Ready(r) = Pin::new(send_f).poll(ctx) {
261            match r {
262                Ok(()) => return Poll::Ready(None),
263                Err(SendError(e)) => {
264                    return Poll::Ready(e);
265                }
266            }
267        }
268        Poll::Pending
269    }
270}
271
272#[async_trait]
273impl<M, W, S> WorkerPoolAsyncInf<M> for WorkerPoolBounded<M, W, S>
274where
275    M: Send + Sized + Unpin + 'static,
276    W: Worker<M>,
277    S: WorkerPoolImpl<M, W>,
278{
279    #[inline]
280    async fn submit(&self, msg: M) -> Option<M> {
281        self.submit(msg).await
282    }
283
284    #[inline]
285    fn try_submit(&self, msg: M) -> Option<M> {
286        self.try_submit(msg)
287    }
288}
289
290impl<M, W, S> WorkerPoolBoundedInner<M, W, S>
291where
292    M: Send + Sized + Unpin + 'static,
293    W: Worker<M>,
294    S: WorkerPoolImpl<M, W>,
295{
296    #[inline(always)]
297    fn _sender(&self) -> &mut Option<MAsyncTx<Option<M>>> {
298        unsafe { transmute(self.sender.get()) }
299    }
300
301    async fn run_worker_simple(&self, mut worker: W, rx: MAsyncRx<Option<M>>) {
302        if let Err(_) = worker.init().await {
303            let _ = self.try_exit();
304            worker.on_exit();
305            return;
306        }
307        loop {
308            match rx.recv().await {
309                Ok(item) => {
310                    if item.is_none() {
311                        let _ = self.try_exit();
312                        break;
313                    }
314                    worker.run(item.unwrap()).await;
315                }
316                Err(_) => {
317                    // channel closed worker exit
318                    let _ = self.try_exit();
319                    break;
320                }
321            }
322        }
323        worker.on_exit();
324    }
325
326    async fn run_worker_adjust(&self, mut worker: W, rx: MAsyncRx<Option<M>>) {
327        if let Err(_) = worker.init().await {
328            let _ = self.try_exit();
329            worker.on_exit();
330            return;
331        }
332
333        let worker_timeout = self.worker_timeout;
334        let mut is_idle = false;
335        'WORKER_LOOP: loop {
336            if is_idle {
337                match rx.recv_timeout(worker_timeout).await {
338                    Ok(item) => {
339                        if item.is_none() {
340                            let _ = self.try_exit();
341                            break 'WORKER_LOOP;
342                        }
343                        worker.run(item.unwrap()).await;
344                        is_idle = false;
345                    }
346                    Err(RecvTimeoutError::Disconnected) => {
347                        // channel closed worker exit
348                        let _ = self.try_exit();
349                        worker.on_exit();
350                    }
351                    Err(RecvTimeoutError::Timeout) => {
352                        // timeout
353                        if self.try_exit() {
354                            break 'WORKER_LOOP;
355                        }
356                    }
357                }
358            } else {
359                match rx.try_recv() {
360                    Err(e) => {
361                        if e.is_empty() {
362                            is_idle = true;
363                        } else {
364                            let _ = self.try_exit();
365                            break 'WORKER_LOOP;
366                        }
367                    }
368                    Ok(Some(item)) => {
369                        worker.run(item).await;
370                        is_idle = false;
371                    }
372                    Ok(None) => {
373                        let _ = self.try_exit();
374                        break 'WORKER_LOOP;
375                    }
376                }
377            }
378        }
379        worker.on_exit();
380    }
381
382    #[inline(always)]
383    pub fn get_worker_count(&self) -> usize {
384        self.worker_count.load(Ordering::Acquire)
385    }
386
387    #[inline(always)]
388    fn spawn(self: Arc<Self>, initial: bool, rx: MAsyncRx<Option<M>>) {
389        self.worker_count.fetch_add(1, Ordering::SeqCst);
390        let worker = self.inner.spawn();
391        let _self = self.clone();
392        if self.real_thread.load(Ordering::Acquire) {
393            let mut bind_cpu: Option<usize> = None;
394            if _self.bind_cpu.load(Ordering::Acquire) <= _self.max_cpu {
395                let cpu = _self.bind_cpu.fetch_add(1, Ordering::SeqCst);
396                if cpu < _self.max_cpu {
397                    bind_cpu = Some(cpu as usize);
398                }
399            }
400            thread::spawn(move || {
401                if let Some(cpu) = bind_cpu {
402                    core_affinity::set_for_current(core_affinity::CoreId { id: cpu });
403                }
404                let rt = tokio::runtime::Builder::new_current_thread()
405                    .enable_all()
406                    .build()
407                    .expect("runtime");
408                rt.block_on(async move {
409                    if initial || !_self.auto {
410                        _self.run_worker_simple(worker, rx).await
411                    } else {
412                        _self.run_worker_adjust(worker, rx).await
413                    }
414                });
415            });
416        } else {
417            tokio::spawn(async move {
418                if initial || !_self.auto {
419                    _self.run_worker_simple(worker, rx).await
420                } else {
421                    _self.run_worker_adjust(worker, rx).await
422                }
423            });
424        }
425    }
426
427    // check if idle worker should exit
428    #[inline(always)]
429    fn try_exit(&self) -> bool {
430        if self.closing.load(Ordering::Acquire) {
431            self.worker_count.fetch_sub(1, Ordering::SeqCst);
432            return true;
433        }
434        if self.get_worker_count() > self.min_workers {
435            if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
436                self.worker_count.fetch_add(1, Ordering::SeqCst); // rollback
437            } else {
438                return true; // worker exit
439            }
440        }
441        return false;
442    }
443
444    async fn monitor(self: Arc<Self>, noti: AsyncRx<Option<()>>, rx: MAsyncRx<Option<M>>) {
445        let _self = self.as_ref();
446        loop {
447            match timeout(Duration::from_secs(1), noti.recv()).await {
448                Err(_) => {
449                    if _self.closing.load(Ordering::Acquire) {
450                        return;
451                    }
452                    continue;
453                }
454                Ok(Ok(Some(_))) => {
455                    if _self.closing.load(Ordering::Acquire) {
456                        return;
457                    }
458                    let worker_count = _self.get_worker_count();
459                    if worker_count > _self.max_workers {
460                        continue;
461                    }
462                    self.clone().spawn(false, rx.clone());
463                }
464                _ => {
465                    println!("monitor exit");
466                    return;
467                }
468            }
469        }
470    }
471}
472
473#[cfg(test)]
474mod tests {
475
476    use crossfire::*;
477    use tokio::time::{Duration, sleep};
478
479    use super::*;
480
481    #[allow(dead_code)]
482    struct MyWorkerPoolImpl();
483
484    struct MyWorker();
485
486    struct MyMsg(i64, MAsyncTx<()>);
487
488    impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
489        fn spawn(&self) -> MyWorker {
490            MyWorker()
491        }
492    }
493
494    #[async_trait]
495    impl Worker<MyMsg> for MyWorker {
496        async fn run(&mut self, msg: MyMsg) {
497            sleep(Duration::from_millis(1)).await;
498            println!("done {}", msg.0);
499            let _ = msg.1.send(()).await;
500        }
501    }
502
503    type MyWorkerPool = WorkerPoolBounded<MyMsg, MyWorker, MyWorkerPoolImpl>;
504
505    #[test]
506    fn bounded_workerpool_adjust_close_async() {
507        let min_workers = 1;
508        let max_workers = 4;
509        let worker_timeout = Duration::from_secs(1);
510        let rt = tokio::runtime::Builder::new_multi_thread()
511            .enable_all()
512            .worker_threads(2)
513            .build()
514            .unwrap();
515        let worker_pool =
516            MyWorkerPool::new(MyWorkerPoolImpl(), min_workers, max_workers, 1, worker_timeout);
517        let _worker_pool = worker_pool.clone();
518        rt.block_on(async move {
519            worker_pool.start();
520            let mut th_s = Vec::new();
521            for i in 0..5 {
522                let _pool = worker_pool.clone();
523                th_s.push(tokio::task::spawn(async move {
524                    let (done_tx, done_rx) = mpsc::bounded_async(10);
525                    for j in 0..2 {
526                        let m = i * 10 + j;
527                        println!("submit {} in {}/{}", m, j, i);
528                        _pool.submit(MyMsg(m, done_tx.clone())).await;
529                    }
530                    for _j in 0..2 {
531                        //println!("sender {} recv {}", i, _j);
532                        let _ = done_rx.recv().await;
533                    }
534                }));
535            }
536            for th in th_s {
537                let _ = th.await;
538            }
539            let workers = worker_pool.get_worker_count();
540            println!("cur workers {} might reach max {}", workers, max_workers);
541            //assert_eq!(workers, max_workers);
542
543            sleep(worker_timeout * 2).await;
544            let workers = worker_pool.get_worker_count();
545            println!("cur workers: {}, extra should exit due to timeout", workers);
546            assert_eq!(workers, min_workers);
547
548            let (done_tx, done_rx) = mpsc::bounded_async(1);
549            for j in 0..10 {
550                worker_pool.submit(MyMsg(80 + j, done_tx.clone())).await;
551                let _ = done_rx.recv().await;
552            }
553            println!("closing");
554            _worker_pool.close();
555            sleep(Duration::from_secs(2)).await;
556            assert_eq!(_worker_pool.get_worker_count(), 0);
557        });
558    }
559
560    #[test]
561    fn bounded_workerpool_adjust_close() {
562        let min_workers = 1;
563        let max_workers = 4;
564        let worker_timeout = Duration::from_secs(1);
565        let rt = tokio::runtime::Builder::new_multi_thread()
566            .enable_all()
567            .worker_threads(2)
568            .build()
569            .unwrap();
570        let worker_pool =
571            MyWorkerPool::new(MyWorkerPoolImpl(), min_workers, max_workers, 1, worker_timeout);
572        let _worker_pool = worker_pool.clone();
573        rt.block_on(async move {
574            worker_pool.start();
575            let mut th_s = Vec::new();
576            for i in 0..5 {
577                let _pool = worker_pool.clone();
578                th_s.push(tokio::task::spawn(async move {
579                    let (done_tx, done_rx) = mpsc::bounded_async(10);
580                    for j in 0..2 {
581                        let m = i * 10 + j;
582                        println!("submit {} in {}/{}", m, j, i);
583                        _pool.submit(MyMsg(m, done_tx.clone())).await;
584                    }
585                    for _j in 0..2 {
586                        //println!("sender {} recv {}", i, _j);
587                        let _ = done_rx.recv().await;
588                    }
589                }));
590            }
591            for th in th_s {
592                let _ = th.await;
593            }
594            let workers = worker_pool.get_worker_count();
595            println!("cur workers {} might reach max {}", workers, max_workers);
596            //assert_eq!(workers, max_workers);
597
598            sleep(worker_timeout * 2).await;
599            let workers = worker_pool.get_worker_count();
600            println!("cur workers: {}, extra should exit due to timeout", workers);
601            assert_eq!(workers, min_workers);
602
603            let (done_tx, done_rx) = mpsc::bounded_async(1);
604            for j in 0..10 {
605                worker_pool.submit(MyMsg(80 + j, done_tx.clone())).await;
606                let _ = done_rx.recv().await;
607            }
608        });
609        println!("closing");
610        _worker_pool.close();
611        assert_eq!(_worker_pool.get_worker_count(), 0);
612    }
613
614    #[allow(dead_code)]
615    struct MyBlockingWorkerPoolImpl();
616
617    struct MyBlockingWorker();
618
619    impl WorkerPoolImpl<MyMsg, MyBlockingWorker> for MyBlockingWorkerPoolImpl {
620        fn spawn(&self) -> MyBlockingWorker {
621            MyBlockingWorker()
622        }
623    }
624
625    #[async_trait]
626    impl Worker<MyMsg> for MyBlockingWorker {
627        async fn run(&mut self, msg: MyMsg) {
628            std::thread::sleep(Duration::from_millis(1));
629            println!("done {}", msg.0);
630            let _ = msg.1.send(()).await;
631        }
632    }
633
634    type MyBlockingWorkerPool =
635        WorkerPoolBounded<MyMsg, MyBlockingWorker, MyBlockingWorkerPoolImpl>;
636
637    #[test]
638    fn bounded_workerpool_adjust_real_thread() {
639        let min_workers = 1;
640        let max_workers = 4;
641        let worker_timeout = Duration::from_secs(1);
642        let rt = tokio::runtime::Builder::new_multi_thread()
643            .enable_all()
644            .worker_threads(2)
645            .build()
646            .unwrap();
647        let mut worker_pool = MyBlockingWorkerPool::new(
648            MyBlockingWorkerPoolImpl(),
649            min_workers,
650            max_workers,
651            1,
652            worker_timeout,
653        );
654        worker_pool.set_use_thread(true);
655        let _worker_pool = worker_pool.clone();
656        rt.block_on(async move {
657            worker_pool.start();
658            let mut th_s = Vec::new();
659            for i in 0..5 {
660                let _pool = worker_pool.clone();
661                th_s.push(tokio::task::spawn(async move {
662                    let (done_tx, done_rx) = mpsc::bounded_async(10);
663                    for j in 0..2 {
664                        let m = i * 10 + j;
665                        println!("submit {} in {}/{}", m, j, i);
666                        _pool.submit(MyMsg(m, done_tx.clone())).await;
667                    }
668                    for _j in 0..2 {
669                        //println!("sender {} recv {}", i, _j);
670                        let _ = done_rx.recv().await;
671                    }
672                }));
673            }
674            for th in th_s {
675                let _ = th.await;
676            }
677            let workers = worker_pool.get_worker_count();
678            println!("cur workers {} might reach max {}", workers, max_workers);
679            //assert_eq!(workers, max_workers);
680
681            sleep(worker_timeout * 2).await;
682            let workers = worker_pool.get_worker_count();
683            println!("cur workers: {}, extra should exit due to timeout", workers);
684            assert_eq!(workers, min_workers);
685
686            let (done_tx, done_rx) = mpsc::bounded_async(1);
687            for j in 0..10 {
688                worker_pool.submit(MyMsg(80 + j, done_tx.clone())).await;
689                let _ = done_rx.recv().await;
690            }
691        });
692        println!("closing");
693        _worker_pool.close();
694        assert_eq!(_worker_pool.get_worker_count(), 0);
695    }
696}