sync_utils/worker_pool/
unbounded.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicBool, AtomicUsize, Ordering},
5    },
6    thread,
7    time::Duration,
8};
9
10use crossfire::*;
11use tokio::time::{sleep, timeout};
12
13use super::*;
14
15#[allow(unused_must_use)]
16pub struct WorkerPoolUnbounded<
17    M: Send + Sized + Unpin + 'static,
18    W: Worker<M>,
19    S: WorkerPoolImpl<M, W>,
20>(Arc<WorkerPoolUnboundedInner<M, W, S>>);
21
22struct WorkerPoolUnboundedInner<M, W, S>
23where
24    M: Send + Sized + Unpin + 'static,
25    W: Worker<M>,
26    S: WorkerPoolImpl<M, W>,
27{
28    worker_count: AtomicUsize,
29    sender: MTx<Option<M>>,
30    recv: MAsyncRx<Option<M>>,
31    min_workers: usize,
32    max_workers: usize,
33    worker_timeout: Duration,
34    inner: S,
35    phantom: std::marker::PhantomData<W>, // to avoid complaining unused param
36    closing: AtomicBool,
37    notify_sender: MTx<Option<()>>,
38    notify_recv: MAsyncRx<Option<()>>, // Only use in monitor()
39    water: AtomicUsize,
40    auto: bool,
41    real_thread: AtomicBool,
42}
43
44impl<M, W, S> Clone for WorkerPoolUnbounded<M, W, S>
45where
46    M: Send + Sized + Unpin + 'static,
47    W: Worker<M>,
48    S: WorkerPoolImpl<M, W>,
49{
50    #[inline]
51    fn clone(&self) -> Self {
52        Self(self.0.clone())
53    }
54}
55
56impl<M, W, S> WorkerPoolUnbounded<M, W, S>
57where
58    M: Send + Sized + Unpin + 'static,
59    W: Worker<M>,
60    S: WorkerPoolImpl<M, W>,
61{
62    pub fn new(inner: S, min_workers: usize, max_workers: usize, worker_timeout: Duration) -> Self {
63        assert!(min_workers > 0);
64        assert!(max_workers >= min_workers);
65
66        let auto: bool = min_workers < max_workers;
67        if auto {
68            assert!(worker_timeout != ZERO_DUARTION);
69        }
70        let (sender, recv) = mpmc::unbounded_async();
71        let (noti_sender, noti_recv) = mpmc::bounded_tx_blocking_rx_async(1);
72        let pool = Arc::new(WorkerPoolUnboundedInner {
73            sender,
74            recv,
75            inner,
76            worker_count: AtomicUsize::new(0),
77            min_workers,
78            max_workers,
79            worker_timeout,
80            phantom: Default::default(),
81            closing: AtomicBool::new(false),
82            notify_sender: noti_sender,
83            notify_recv: noti_recv,
84            water: AtomicUsize::new(0),
85            auto,
86            real_thread: AtomicBool::new(false),
87        });
88        Self(pool)
89    }
90
91    pub fn get_inner(&self) -> &S {
92        &self.0.inner
93    }
94
95    // If worker contains blocking logic, run worker in separate threads
96    pub fn set_use_thread(&mut self, ok: bool) {
97        self.0.real_thread.store(ok, Ordering::Release);
98    }
99
100    pub async fn start(&self) {
101        let _self = self.0.as_ref();
102        for _ in 0.._self.min_workers {
103            self.0.clone().spawn();
104        }
105        if _self.auto {
106            let _pool = self.0.clone();
107            tokio::spawn(async move {
108                _pool.monitor().await;
109            });
110        }
111    }
112
113    pub async fn try_spawn(&self, num: usize) {
114        let _self = self.0.as_ref();
115        if !_self.auto {
116            return;
117        }
118        for _ in 0..num {
119            if _self.get_worker_count() >= _self.max_workers {
120                return;
121            }
122            self.0.clone().spawn();
123        }
124    }
125
126    pub async fn close(&self) {
127        let _self = self.0.as_ref();
128        if _self.closing.swap(true, Ordering::SeqCst) {
129            return;
130        }
131        loop {
132            let cur = self.get_worker_count();
133            if cur == 0 {
134                break;
135            }
136            debug!("worker pool closing: cur workers {}", cur);
137            for _ in 0..cur {
138                let _ = _self.sender.send(None);
139            }
140            sleep(Duration::from_secs(1)).await;
141            // TODO graceful exec all remaining
142        }
143        let _ = _self.notify_sender.try_send(None);
144    }
145
146    pub fn get_worker_count(&self) -> usize {
147        self.0.get_worker_count()
148    }
149}
150
151impl<M, W, S> WorkerPoolInf<M> for WorkerPoolUnbounded<M, W, S>
152where
153    M: Send + Sized + Unpin + 'static,
154    W: Worker<M>,
155    S: WorkerPoolImpl<M, W>,
156{
157    // return non-null on send fail
158    #[inline]
159    fn submit(&self, msg: M) -> Option<M> {
160        let _self = self.0.as_ref();
161        if _self.closing.load(Ordering::Acquire) {
162            return Some(msg);
163        }
164        if _self.auto {
165            let worker_count = _self.get_worker_count();
166            let water = _self.water.fetch_add(1, Ordering::SeqCst);
167            if worker_count < _self.max_workers && water > worker_count + 1 {
168                let _ = _self.notify_sender.try_send(Some(()));
169            }
170        }
171        match _self.sender.send(Some(msg)) {
172            Ok(_) => None,
173            Err(SendError(_msg)) => {
174                return Some(_msg.unwrap());
175            }
176        }
177    }
178}
179
180impl<M, W, S> WorkerPoolUnboundedInner<M, W, S>
181where
182    M: Send + Sized + Unpin + 'static,
183    W: Worker<M>,
184    S: WorkerPoolImpl<M, W>,
185{
186    async fn run_worker_simple(&self, mut worker: W) {
187        if let Err(_) = worker.init().await {
188            let _ = self.try_exit();
189            worker.on_exit();
190            return;
191        }
192
193        let recv = &self.recv;
194        'WORKER_LOOP: loop {
195            match recv.recv().await {
196                Ok(item) => {
197                    if item.is_none() {
198                        let _ = self.try_exit();
199                        break 'WORKER_LOOP;
200                    }
201                    worker.run(item.unwrap()).await;
202                }
203                Err(_) => {
204                    // channel closed worker exit
205                    let _ = self.try_exit();
206                    break 'WORKER_LOOP;
207                }
208            }
209        }
210        worker.on_exit();
211        trace!("worker pool {} workers", self.get_worker_count());
212    }
213
214    async fn run_worker_adjust(&self, mut worker: W) {
215        if let Err(_) = worker.init().await {
216            let _ = self.try_exit();
217            worker.on_exit();
218            return;
219        }
220
221        let worker_timeout = self.worker_timeout;
222        let recv = &self.recv;
223        let mut is_idle = false;
224        'WORKER_LOOP: loop {
225            if is_idle {
226                match timeout(worker_timeout, recv.recv()).await {
227                    Ok(res) => {
228                        match res {
229                            Ok(item) => {
230                                if item.is_none() {
231                                    let _ = self.try_exit();
232                                    break 'WORKER_LOOP;
233                                }
234                                worker.run(item.unwrap()).await;
235                                is_idle = false;
236                                self.water.fetch_sub(1, Ordering::SeqCst);
237                            }
238                            Err(_) => {
239                                // channel closed worker exit
240                                let _ = self.try_exit();
241                                worker.on_exit();
242                            }
243                        }
244                    }
245                    Err(_) => {
246                        // timeout
247                        if self.try_exit() {
248                            break 'WORKER_LOOP;
249                        }
250                    }
251                }
252            } else {
253                match recv.try_recv() {
254                    Err(e) => {
255                        if e.is_empty() {
256                            is_idle = true;
257                        } else {
258                            let _ = self.try_exit();
259                            break 'WORKER_LOOP;
260                        }
261                    }
262                    Ok(Some(item)) => {
263                        worker.run(item).await;
264                        self.water.fetch_sub(1, Ordering::SeqCst);
265                        is_idle = false;
266                    }
267                    Ok(None) => {
268                        let _ = self.try_exit();
269                        break 'WORKER_LOOP;
270                    }
271                }
272            }
273        }
274        worker.on_exit();
275        trace!("worker pool {} workers", self.get_worker_count());
276    }
277
278    #[inline(always)]
279    pub fn get_worker_count(&self) -> usize {
280        self.worker_count.load(Ordering::Acquire)
281    }
282    #[inline(always)]
283    fn spawn(self: Arc<Self>) {
284        let cur_count = self.worker_count.fetch_add(1, Ordering::SeqCst) + 1;
285        let worker = self.inner.spawn();
286        let _self = self.clone();
287        if self.real_thread.load(Ordering::Acquire) {
288            thread::spawn(move || {
289                let rt = tokio::runtime::Builder::new_current_thread()
290                    .enable_all()
291                    .build()
292                    .expect("runtime");
293                rt.block_on(async move {
294                    trace!("worker pool started worker {}", cur_count);
295                    if _self.auto {
296                        _self.run_worker_adjust(worker).await
297                    } else {
298                        _self.run_worker_simple(worker).await
299                    }
300                });
301            });
302        } else {
303            tokio::spawn(async move {
304                trace!("worker pool started worker {}", cur_count);
305                if _self.auto {
306                    _self.run_worker_adjust(worker).await
307                } else {
308                    _self.run_worker_simple(worker).await
309                }
310            });
311        }
312    }
313
314    // check if idle worker should exit
315    #[inline(always)]
316    fn try_exit(&self) -> bool {
317        if self.closing.load(Ordering::Acquire) {
318            self.worker_count.fetch_sub(1, Ordering::SeqCst);
319            return true;
320        }
321        if self.get_worker_count() > self.min_workers {
322            if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
323                self.worker_count.fetch_add(1, Ordering::SeqCst); // rollback
324            } else {
325                return true; // worker exit
326            }
327        }
328        return false;
329    }
330
331    async fn monitor(self: Arc<Self>) {
332        let _self = self.as_ref();
333        loop {
334            match timeout(Duration::from_secs(1), _self.notify_recv.recv()).await {
335                Err(_) => {
336                    if _self.closing.load(Ordering::Acquire) {
337                        return;
338                    }
339                    continue;
340                }
341                Ok(Ok(Some(_))) => {
342                    if _self.closing.load(Ordering::Acquire) {
343                        return;
344                    }
345                    let worker_count = _self.get_worker_count();
346                    if worker_count > _self.max_workers {
347                        continue;
348                    }
349                    let mut pending_msg = _self.sender.len();
350                    if pending_msg > worker_count {
351                        pending_msg -= worker_count;
352                        if pending_msg > _self.max_workers - worker_count {
353                            pending_msg = _self.max_workers - worker_count;
354                        }
355                        for _ in 0..pending_msg {
356                            self.clone().spawn();
357                        }
358                    }
359                }
360                _ => return,
361            }
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368
369    use std::sync::atomic::{AtomicUsize, Ordering};
370    use std::time::{SystemTime, UNIX_EPOCH};
371
372    use crossfire::*;
373    use tokio::time::{Duration, sleep};
374
375    use super::*;
376    use atomic_waitgroup::WaitGroup;
377
378    #[allow(dead_code)]
379    struct MyWorkerPoolImpl();
380
381    struct MyWorker();
382
383    struct MyMsg(i64, MAsyncTx<()>);
384
385    impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
386        fn spawn(&self) -> MyWorker {
387            MyWorker()
388        }
389    }
390
391    #[async_trait]
392    impl Worker<MyMsg> for MyWorker {
393        async fn init(&mut self) -> Result<(), ()> {
394            println!("init done");
395            Ok(())
396        }
397
398        async fn run(&mut self, msg: MyMsg) {
399            sleep(Duration::from_millis(1)).await;
400            println!("done {}", msg.0);
401            let _ = msg.1.send(()).await;
402        }
403    }
404
405    type MyWorkerPool = WorkerPoolUnbounded<MyMsg, MyWorker, MyWorkerPoolImpl>;
406
407    #[test]
408    fn unbounded_workerpool_adjust() {
409        let _ = captains_log::recipe::stderr_test_logger(log::Level::Debug).build();
410        let min_workers = 1;
411        let max_workers = 4;
412        let worker_timeout = Duration::from_secs(1);
413        let rt = tokio::runtime::Builder::new_multi_thread()
414            .enable_all()
415            .worker_threads(2)
416            .build()
417            .unwrap();
418        let worker_pool =
419            MyWorkerPool::new(MyWorkerPoolImpl(), min_workers, max_workers, worker_timeout);
420        rt.block_on(async move {
421            worker_pool.start().await;
422            let mut th_s = Vec::new();
423            for i in 0..5 {
424                let _pool = worker_pool.clone();
425                th_s.push(tokio::task::spawn(async move {
426                    let (done_tx, done_rx) = mpsc::bounded_async(10);
427                    for j in 0..2 {
428                        _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
429                    }
430                    for _j in 0..2 {
431                        //println!("sender {} recv {}", i, _j);
432                        let _ = done_rx.recv().await;
433                    }
434                }));
435            }
436            for th in th_s {
437                let _ = th.await;
438            }
439            let workers = worker_pool.get_worker_count();
440            println!("cur workers {} should reach max", workers);
441            assert_eq!(workers, max_workers);
442
443            worker_pool.try_spawn(5).await;
444            let workers = worker_pool.get_worker_count();
445            println!("cur workers {} should reach max", workers);
446            assert_eq!(workers, max_workers);
447
448            sleep(worker_timeout * 2).await;
449            let workers = worker_pool.get_worker_count();
450            println!("cur workers: {}, extra should exit due to timeout", workers);
451            assert_eq!(workers, min_workers);
452
453            let (done_tx, done_rx) = mpsc::bounded_async(1);
454            for j in 0..10 {
455                worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
456                let _ = done_rx.recv().await;
457            }
458            println!("closing");
459            worker_pool.close().await;
460            assert_eq!(worker_pool.get_worker_count(), 0);
461            assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
462        });
463    }
464
465    #[allow(dead_code)]
466    struct TestWorkerPoolImpl {
467        id: AtomicUsize,
468    }
469
470    struct TestWorker {
471        id: usize,
472    }
473
474    #[allow(dead_code)]
475    struct TestMsg(usize, WaitGroup);
476
477    impl WorkerPoolImpl<TestMsg, TestWorker> for TestWorkerPoolImpl {
478        fn spawn(&self) -> TestWorker {
479            let _id = self.id.fetch_add(1, Ordering::SeqCst);
480            TestWorker { id: _id }
481        }
482    }
483
484    #[async_trait]
485    impl Worker<TestMsg> for TestWorker {
486        async fn init(&mut self) -> Result<(), ()> {
487            log::info!("worker {} init done", self.id);
488            Ok(())
489        }
490
491        async fn run(&mut self, msg: TestMsg) {
492            let run_time = (SystemTime::now().duration_since(UNIX_EPOCH).ok().unwrap().as_millis()
493                % 10) as u64;
494            sleep(Duration::from_millis(run_time)).await;
495            msg.1.done();
496        }
497    }
498
499    type TestWorkerPool = WorkerPoolUnbounded<TestMsg, TestWorker, TestWorkerPoolImpl>;
500
501    #[test]
502    fn unbounded_workerpool_run() {
503        let _ = captains_log::recipe::stderr_test_logger(log::Level::Debug).build();
504
505        log::info!("unbounded_workerpool test start");
506        let min_workers = 8;
507        let max_workers = 128;
508        let worker_timeout = Duration::from_secs(5);
509        let rt = tokio::runtime::Builder::new_multi_thread()
510            .enable_all()
511            .worker_threads(4)
512            .build()
513            .unwrap();
514        let worker_pool = TestWorkerPool::new(
515            TestWorkerPoolImpl { id: AtomicUsize::new(0) },
516            min_workers,
517            max_workers,
518            worker_timeout,
519        );
520        rt.block_on(async move {
521            worker_pool.start().await;
522            let total_threads = 10;
523            let batch_msgs: usize = 10000;
524            let wg = WaitGroup::new();
525            wg.add(batch_msgs * total_threads);
526            for thread in 0..total_threads {
527                let _wg = wg.clone();
528                let _pool = worker_pool.clone();
529                tokio::spawn(async move {
530                    log::info!("thread:{} run start", thread);
531                    let batch_msg_start = thread * batch_msgs;
532                    let mut submit_steps: u64 =
533                        (SystemTime::now().duration_since(UNIX_EPOCH).ok().unwrap().as_millis()
534                            % 100) as u64;
535                    let mut current_submit_step = 0;
536                    for i in batch_msg_start..(batch_msg_start + batch_msgs) {
537                        let msg = TestMsg(i, _wg.clone());
538                        if let Some(_msg) = _pool.submit(msg) {
539                            _msg.1.done();
540                        }
541                        current_submit_step += 1;
542                        if current_submit_step >= submit_steps {
543                            sleep(Duration::from_millis(submit_steps % 100)).await;
544                            current_submit_step = 0;
545                            submit_steps = (SystemTime::now()
546                                .duration_since(UNIX_EPOCH)
547                                .ok()
548                                .unwrap()
549                                .as_millis()
550                                % 100) as u64;
551                        }
552                    }
553                    log::info!("thread:{} run over", thread);
554                });
555            }
556            wg.wait().await;
557        });
558        log::info!("unbounded_workerpool test over");
559    }
560}