sync_utils/worker_pool/
bounded_blocking.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicBool, AtomicUsize, Ordering},
5    },
6    time::Duration,
7};
8
9use crossfire::{SendError, mpmc};
10use tokio::{runtime::Runtime, time::timeout};
11
12use super::*;
13
14/*
15 * WorkerPool that submit side without tokio Context, has one monitor runs in tokio runtime to
16 * spawn and adjust long-lived workers
17 */
18
19#[allow(unused_must_use)]
20pub struct WorkerPool<M: Send + Sized + Unpin + 'static, W: Worker<M>, S: WorkerPoolImpl<M, W>>(
21    Arc<WorkerPoolInner<M, W, S>>,
22);
23
24struct WorkerPoolInner<M, W, S>
25where
26    M: Send + Sized + Unpin + 'static,
27    W: Worker<M>,
28    S: WorkerPoolImpl<M, W>,
29{
30    worker_count: AtomicUsize,
31    sender: mpmc::TxBlocking<Option<M>, mpmc::SharedSenderBRecvF>,
32    recv: mpmc::RxFuture<Option<M>, mpmc::SharedSenderBRecvF>,
33    min_workers: usize,
34    max_workers: usize,
35    worker_timeout: Duration,
36    inner: S,
37    water: AtomicUsize,
38    phantom: std::marker::PhantomData<W>, // to avoid complaining unused param
39    closing: AtomicBool,
40    notify_sender: mpmc::TxBlocking<Option<()>, mpmc::SharedSenderBRecvF>,
41    auto: bool,
42    buffer_size: usize,
43}
44
45impl<M, W, S> Clone for WorkerPool<M, W, S>
46where
47    M: Send + Sized + Unpin + 'static,
48    W: Worker<M>,
49    S: WorkerPoolImpl<M, W>,
50{
51    #[inline]
52    fn clone(&self) -> Self {
53        Self(self.0.clone())
54    }
55}
56
57impl<M, W, S> WorkerPool<M, W, S>
58where
59    M: Send + Sized + Unpin + 'static,
60    W: Worker<M>,
61    S: WorkerPoolImpl<M, W>,
62{
63    pub fn new(
64        inner: S, min_workers: usize, max_workers: usize, mut buffer_size: usize,
65        worker_timeout: Duration, rt: &Runtime,
66    ) -> Self {
67        if buffer_size > max_workers * 2 {
68            buffer_size = max_workers * 2;
69        }
70        let (sender, recv) = mpmc::bounded_tx_blocking_rx_future(buffer_size);
71        let (noti_sender, noti_recv) = mpmc::bounded_tx_blocking_rx_future(1);
72        assert!(min_workers > 0);
73        assert!(max_workers >= min_workers);
74
75        let auto: bool = min_workers < max_workers;
76        if auto {
77            assert!(worker_timeout != ZERO_DUARTION);
78        }
79
80        let pool = Arc::new(WorkerPoolInner {
81            sender,
82            recv,
83            inner,
84            worker_count: AtomicUsize::new(0),
85            min_workers,
86            max_workers,
87            buffer_size,
88            worker_timeout,
89            phantom: Default::default(),
90            closing: AtomicBool::new(false),
91            water: AtomicUsize::new(0),
92            notify_sender: noti_sender,
93            auto,
94        });
95        let _pool = pool.clone();
96        rt.spawn(async move {
97            _pool.monitor(noti_recv).await;
98        });
99        Self(pool)
100    }
101
102    pub fn get_inner(&self) -> &S {
103        &self.0.inner
104    }
105
106    // sending None will notify worker close
107    // return non-null on send fail
108    pub fn submit(&self, msg: M) -> Option<M> {
109        let _self = self.0.as_ref();
110        if _self.closing.load(Ordering::Acquire) {
111            return Some(msg);
112        }
113        if _self.auto {
114            let worker_count = _self.get_worker_count();
115            let water = _self.water.fetch_add(1, Ordering::SeqCst);
116            if worker_count < _self.max_workers {
117                if water > worker_count + 1 || water > _self.buffer_size {
118                    let _ = _self.notify_sender.try_send(Some(()));
119                }
120            }
121        }
122        match _self.sender.send(Some(msg)) {
123            Ok(_) => return None,
124            Err(SendError(t)) => return t,
125        }
126    }
127
128    pub fn close(&self) {
129        let _self = self.0.as_ref();
130        if _self.closing.swap(true, Ordering::SeqCst) {
131            return;
132        }
133        loop {
134            let cur = self.get_worker_count();
135            if cur == 0 {
136                break;
137            }
138            debug!("worker pool closing: cur workers {}", cur);
139            for _ in 0..cur {
140                let _ = _self.sender.send(None);
141            }
142            std::thread::sleep(_self.worker_timeout);
143            // TODO graceful exec all remaining
144        }
145        // Because bound is 1, send twice to unsure monitor coroutine exit
146        let _ = _self.notify_sender.send(None);
147        let _ = _self.notify_sender.send(None);
148    }
149
150    pub fn get_worker_count(&self) -> usize {
151        self.0.get_worker_count()
152    }
153}
154
155impl<M, W, S> WorkerPoolInner<M, W, S>
156where
157    M: Send + Sized + Unpin + 'static,
158    W: Worker<M>,
159    S: WorkerPoolImpl<M, W>,
160{
161    async fn run_worker_simple(&self, mut worker: W) {
162        if let Err(_) = worker.init().await {
163            let _ = self.try_exit();
164            worker.on_exit();
165            return;
166        }
167
168        let recv = &self.recv;
169        'WORKER_LOOP: loop {
170            match recv.recv().await {
171                Ok(item) => {
172                    if item.is_none() {
173                        let _ = self.try_exit();
174                        break 'WORKER_LOOP;
175                    }
176                    worker.run(item.unwrap()).await;
177                }
178                Err(_) => {
179                    // channel closed worker exit
180                    let _ = self.try_exit();
181                    break 'WORKER_LOOP;
182                }
183            }
184        }
185        worker.on_exit();
186    }
187
188    async fn run_worker_adjust(&self, mut worker: W) {
189        if let Err(_) = worker.init().await {
190            let _ = self.try_exit();
191            worker.on_exit();
192            return;
193        }
194
195        let worker_timeout = self.worker_timeout;
196        let recv = &self.recv;
197        let mut is_idle = false;
198        'WORKER_LOOP: loop {
199            if is_idle {
200                match timeout(worker_timeout, recv.recv()).await {
201                    Ok(res) => {
202                        match res {
203                            Ok(item) => {
204                                if item.is_none() {
205                                    let _ = self.try_exit();
206                                    break 'WORKER_LOOP;
207                                }
208                                worker.run(item.unwrap()).await;
209                                is_idle = false;
210                                self.water.fetch_sub(1, Ordering::SeqCst);
211                            }
212                            Err(_) => {
213                                // channel closed worker exit
214                                let _ = self.try_exit();
215                                worker.on_exit();
216                            }
217                        }
218                    }
219                    Err(_) => {
220                        // timeout
221                        if self.try_exit() {
222                            break 'WORKER_LOOP;
223                        }
224                    }
225                }
226            } else {
227                match recv.try_recv() {
228                    Err(e) => {
229                        if e.is_empty() {
230                            is_idle = true;
231                        } else {
232                            let _ = self.try_exit();
233                            break 'WORKER_LOOP;
234                        }
235                    }
236                    Ok(Some(item)) => {
237                        worker.run(item).await;
238                        self.water.fetch_sub(1, Ordering::SeqCst);
239                        is_idle = false;
240                    }
241                    Ok(None) => {
242                        let _ = self.try_exit();
243                        break 'WORKER_LOOP;
244                    }
245                }
246            }
247        }
248        worker.on_exit();
249    }
250
251    #[inline(always)]
252    pub fn get_worker_count(&self) -> usize {
253        self.worker_count.load(Ordering::Acquire)
254    }
255
256    #[inline(always)]
257    fn spawn(self: Arc<Self>) {
258        self.worker_count.fetch_add(1, Ordering::SeqCst);
259        let worker = self.inner.spawn();
260        let _self = self.clone();
261        tokio::spawn(async move {
262            if _self.auto {
263                _self.run_worker_adjust(worker).await
264            } else {
265                _self.run_worker_simple(worker).await
266            }
267        });
268    }
269
270    // check if idle worker should exit
271    #[inline(always)]
272    fn try_exit(&self) -> bool {
273        if self.closing.load(Ordering::Acquire) {
274            self.worker_count.fetch_sub(1, Ordering::SeqCst);
275            return true;
276        }
277        if self.get_worker_count() > self.min_workers {
278            if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
279                self.worker_count.fetch_add(1, Ordering::SeqCst); // rollback
280            } else {
281                return true; // worker exit
282            }
283        }
284        return false;
285    }
286
287    async fn monitor(
288        self: Arc<Self>, noti_recv: mpmc::RxFuture<Option<()>, mpmc::SharedSenderBRecvF>,
289    ) {
290        for _ in 0..self.min_workers {
291            self.clone().spawn();
292        }
293        loop {
294            if let Ok(Some(_)) = noti_recv.recv().await {
295                if self.auto {
296                    let worker_count = self.get_worker_count();
297                    if worker_count > self.max_workers {
298                        continue;
299                    }
300                    let mut pending_msg = self.sender.len();
301                    if pending_msg > worker_count {
302                        pending_msg -= worker_count;
303                        if pending_msg > self.max_workers - worker_count {
304                            pending_msg = self.max_workers - worker_count;
305                        }
306                        for _ in 0..pending_msg {
307                            self.clone().spawn();
308                        }
309                    }
310                } else {
311                    continue;
312                }
313            } else {
314                return;
315            }
316        }
317    }
318}
319
320#[cfg(test)]
321mod tests {
322
323    use std::thread;
324
325    use crossbeam::channel::{Sender, bounded};
326    use tokio::time::{Duration, sleep};
327
328    use super::*;
329
330    #[allow(dead_code)]
331    struct MyWorkerPoolImpl();
332
333    struct MyWorker();
334
335    struct MyMsg(i64, Sender<()>);
336
337    impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
338        fn spawn(&self) -> MyWorker {
339            MyWorker()
340        }
341    }
342
343    #[async_trait]
344    impl Worker<MyMsg> for MyWorker {
345        async fn run(&mut self, msg: MyMsg) {
346            sleep(Duration::from_millis(1)).await;
347            println!("done {}", msg.0);
348            let _ = msg.1.send(());
349        }
350    }
351
352    type MyWorkerPool = WorkerPool<MyMsg, MyWorker, MyWorkerPoolImpl>;
353
354    #[test]
355    fn blocking_workerpool_adjust() {
356        let min_workers = 1;
357        let max_workers = 4;
358        let worker_timeout = Duration::from_secs(1);
359        let rt = tokio::runtime::Builder::new_multi_thread()
360            .enable_all()
361            .worker_threads(2)
362            .build()
363            .unwrap();
364        let worker_pool = MyWorkerPool::new(
365            MyWorkerPoolImpl(),
366            min_workers,
367            max_workers,
368            10,
369            worker_timeout,
370            &rt,
371        );
372
373        let mut ths = Vec::new();
374        for i in 0..8 {
375            let _pool = worker_pool.clone();
376            ths.push(thread::spawn(move || {
377                let (done_tx, done_rx) = bounded(10);
378                for j in 0..10 {
379                    _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
380                }
381                for _j in 0..10 {
382                    let _ = done_rx.recv();
383                }
384            }));
385        }
386        for th in ths {
387            let _ = th.join();
388        }
389        let workers = worker_pool.get_worker_count();
390        println!("cur workers {} should reach max", workers);
391        assert_eq!(workers, max_workers);
392
393        thread::sleep(worker_timeout * 2);
394        let workers = worker_pool.get_worker_count();
395        println!("cur workers: {}, extra should exit due to timeout", workers);
396        assert_eq!(workers, min_workers);
397
398        let (done_tx, done_rx) = bounded(2);
399        for j in 0..10 {
400            worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
401            println!("send {}", j);
402            let _ = done_rx.recv();
403        }
404        println!("closing");
405        worker_pool.close();
406        assert_eq!(worker_pool.get_worker_count(), 0);
407        assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
408    }
409
410    #[test]
411    fn blocking_workerpool_fixed() {
412        let min_workers = 4;
413        let max_workers = 4;
414        let worker_timeout = Duration::from_secs(1);
415        let rt = tokio::runtime::Builder::new_multi_thread()
416            .enable_all()
417            .worker_threads(2)
418            .build()
419            .unwrap();
420        let worker_pool = MyWorkerPool::new(
421            MyWorkerPoolImpl(),
422            min_workers,
423            max_workers,
424            10,
425            worker_timeout,
426            &rt,
427        );
428
429        let mut ths = Vec::new();
430        for i in 0..8 {
431            let _pool = worker_pool.clone();
432            ths.push(thread::spawn(move || {
433                let (done_tx, done_rx) = bounded(10);
434                for j in 0..10 {
435                    _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
436                }
437                for _j in 0..10 {
438                    let _ = done_rx.recv();
439                }
440            }));
441        }
442        for th in ths {
443            let _ = th.join();
444        }
445        let workers = worker_pool.get_worker_count();
446        println!("cur workers {} should reach max", workers);
447        assert_eq!(workers, max_workers);
448
449        thread::sleep(worker_timeout * 2);
450        let workers = worker_pool.get_worker_count();
451        println!("cur workers {} should reach max", workers);
452        assert_eq!(workers, max_workers);
453
454        let (done_tx, done_rx) = bounded(2);
455        for j in 0..10 {
456            worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
457            println!("send {}", j);
458            let _ = done_rx.recv();
459        }
460        println!("closing");
461        worker_pool.close();
462        assert_eq!(worker_pool.get_worker_count(), 0);
463        assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
464    }
465}