sync_utils/worker_pool/
bounded_blocking.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicBool, AtomicUsize, Ordering},
5    },
6    time::Duration,
7};
8
9use crossfire::*;
10use tokio::{runtime::Runtime, time::timeout};
11
12use super::*;
13
14/*
15 * WorkerPool that submit side without tokio Context, has one monitor runs in tokio runtime to
16 * spawn and adjust long-lived workers
17 */
18
19#[allow(unused_must_use)]
20pub struct WorkerPool<M: Send + Sized + Unpin + 'static, W: Worker<M>, S: WorkerPoolImpl<M, W>>(
21    Arc<WorkerPoolInner<M, W, S>>,
22);
23
24struct WorkerPoolInner<M, W, S>
25where
26    M: Send + Sized + Unpin + 'static,
27    W: Worker<M>,
28    S: WorkerPoolImpl<M, W>,
29{
30    worker_count: AtomicUsize,
31    sender: MTx<Option<M>>,
32    recv: MAsyncRx<Option<M>>,
33    min_workers: usize,
34    max_workers: usize,
35    worker_timeout: Duration,
36    inner: S,
37    water: AtomicUsize,
38    phantom: std::marker::PhantomData<W>, // to avoid complaining unused param
39    closing: AtomicBool,
40    notify_sender: MTx<Option<()>>,
41    auto: bool,
42    buffer_size: usize,
43}
44
45impl<M, W, S> Clone for WorkerPool<M, W, S>
46where
47    M: Send + Sized + Unpin + 'static,
48    W: Worker<M>,
49    S: WorkerPoolImpl<M, W>,
50{
51    #[inline]
52    fn clone(&self) -> Self {
53        Self(self.0.clone())
54    }
55}
56
57impl<M, W, S> WorkerPool<M, W, S>
58where
59    M: Send + Sized + Unpin + 'static,
60    W: Worker<M>,
61    S: WorkerPoolImpl<M, W>,
62{
63    pub fn new(
64        inner: S, min_workers: usize, max_workers: usize, mut buffer_size: usize,
65        worker_timeout: Duration, rt: &Runtime,
66    ) -> Self {
67        if buffer_size > max_workers * 2 {
68            buffer_size = max_workers * 2;
69        }
70        let (sender, recv) = mpmc::bounded_tx_blocking_rx_async(buffer_size);
71        let (noti_sender, noti_recv) = mpmc::bounded_tx_blocking_rx_async(1);
72        assert!(min_workers > 0);
73        assert!(max_workers >= min_workers);
74
75        let auto: bool = min_workers < max_workers;
76        if auto {
77            assert!(worker_timeout != ZERO_DUARTION);
78        }
79
80        let pool = Arc::new(WorkerPoolInner {
81            sender,
82            recv,
83            inner,
84            worker_count: AtomicUsize::new(0),
85            min_workers,
86            max_workers,
87            buffer_size,
88            worker_timeout,
89            phantom: Default::default(),
90            closing: AtomicBool::new(false),
91            water: AtomicUsize::new(0),
92            notify_sender: noti_sender,
93            auto,
94        });
95        let _pool = pool.clone();
96        rt.spawn(async move {
97            _pool.monitor(noti_recv).await;
98        });
99        Self(pool)
100    }
101
102    pub fn get_inner(&self) -> &S {
103        &self.0.inner
104    }
105
106    // sending None will notify worker close
107    // return non-null on send fail
108    pub fn submit(&self, msg: M) -> Option<M> {
109        let _self = self.0.as_ref();
110        if _self.closing.load(Ordering::Acquire) {
111            return Some(msg);
112        }
113        if _self.auto {
114            let worker_count = _self.get_worker_count();
115            let water = _self.water.fetch_add(1, Ordering::SeqCst);
116            if worker_count < _self.max_workers {
117                if water > worker_count + 1 || water > _self.buffer_size {
118                    let _ = _self.notify_sender.try_send(Some(()));
119                }
120            }
121        }
122        match _self.sender.send(Some(msg)) {
123            Ok(_) => return None,
124            Err(SendError(t)) => return t,
125        }
126    }
127
128    pub fn close(&self) {
129        let _self = self.0.as_ref();
130        if _self.closing.swap(true, Ordering::SeqCst) {
131            return;
132        }
133        loop {
134            let cur = self.get_worker_count();
135            if cur == 0 {
136                break;
137            }
138            debug!("worker pool closing: cur workers {}", cur);
139            for _ in 0..cur {
140                let _ = _self.sender.send(None);
141            }
142            std::thread::sleep(_self.worker_timeout);
143            // TODO graceful exec all remaining
144        }
145        // Because bound is 1, send twice to unsure monitor coroutine exit
146        let _ = _self.notify_sender.send(None);
147        let _ = _self.notify_sender.send(None);
148    }
149
150    pub fn get_worker_count(&self) -> usize {
151        self.0.get_worker_count()
152    }
153}
154
155impl<M, W, S> WorkerPoolInner<M, W, S>
156where
157    M: Send + Sized + Unpin + 'static,
158    W: Worker<M>,
159    S: WorkerPoolImpl<M, W>,
160{
161    async fn run_worker_simple(&self, mut worker: W) {
162        if let Err(_) = worker.init().await {
163            let _ = self.try_exit();
164            worker.on_exit();
165            return;
166        }
167
168        let recv = &self.recv;
169        'WORKER_LOOP: loop {
170            match recv.recv().await {
171                Ok(item) => {
172                    if item.is_none() {
173                        let _ = self.try_exit();
174                        break 'WORKER_LOOP;
175                    }
176                    worker.run(item.unwrap()).await;
177                }
178                Err(_) => {
179                    // channel closed worker exit
180                    let _ = self.try_exit();
181                    break 'WORKER_LOOP;
182                }
183            }
184        }
185        worker.on_exit();
186    }
187
188    async fn run_worker_adjust(&self, mut worker: W) {
189        if let Err(_) = worker.init().await {
190            let _ = self.try_exit();
191            worker.on_exit();
192            return;
193        }
194
195        let worker_timeout = self.worker_timeout;
196        let recv = &self.recv;
197        let mut is_idle = false;
198        'WORKER_LOOP: loop {
199            if is_idle {
200                match timeout(worker_timeout, recv.recv()).await {
201                    Ok(res) => {
202                        match res {
203                            Ok(item) => {
204                                if item.is_none() {
205                                    let _ = self.try_exit();
206                                    break 'WORKER_LOOP;
207                                }
208                                worker.run(item.unwrap()).await;
209                                is_idle = false;
210                                self.water.fetch_sub(1, Ordering::SeqCst);
211                            }
212                            Err(_) => {
213                                // channel closed worker exit
214                                let _ = self.try_exit();
215                                worker.on_exit();
216                            }
217                        }
218                    }
219                    Err(_) => {
220                        // timeout
221                        if self.try_exit() {
222                            break 'WORKER_LOOP;
223                        }
224                    }
225                }
226            } else {
227                match recv.try_recv() {
228                    Err(e) => {
229                        if e.is_empty() {
230                            is_idle = true;
231                        } else {
232                            let _ = self.try_exit();
233                            break 'WORKER_LOOP;
234                        }
235                    }
236                    Ok(Some(item)) => {
237                        worker.run(item).await;
238                        self.water.fetch_sub(1, Ordering::SeqCst);
239                        is_idle = false;
240                    }
241                    Ok(None) => {
242                        let _ = self.try_exit();
243                        break 'WORKER_LOOP;
244                    }
245                }
246            }
247        }
248        worker.on_exit();
249    }
250
251    #[inline(always)]
252    pub fn get_worker_count(&self) -> usize {
253        self.worker_count.load(Ordering::Acquire)
254    }
255
256    #[inline(always)]
257    fn spawn(self: Arc<Self>) {
258        self.worker_count.fetch_add(1, Ordering::SeqCst);
259        let worker = self.inner.spawn();
260        let _self = self.clone();
261        tokio::spawn(async move {
262            if _self.auto {
263                _self.run_worker_adjust(worker).await
264            } else {
265                _self.run_worker_simple(worker).await
266            }
267        });
268    }
269
270    // check if idle worker should exit
271    #[inline(always)]
272    fn try_exit(&self) -> bool {
273        if self.closing.load(Ordering::Acquire) {
274            self.worker_count.fetch_sub(1, Ordering::SeqCst);
275            return true;
276        }
277        if self.get_worker_count() > self.min_workers {
278            if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
279                self.worker_count.fetch_add(1, Ordering::SeqCst); // rollback
280            } else {
281                return true; // worker exit
282            }
283        }
284        return false;
285    }
286
287    async fn monitor(self: Arc<Self>, noti_recv: MAsyncRx<Option<()>>) {
288        for _ in 0..self.min_workers {
289            self.clone().spawn();
290        }
291        loop {
292            if let Ok(Some(_)) = noti_recv.recv().await {
293                if self.auto {
294                    let worker_count = self.get_worker_count();
295                    if worker_count > self.max_workers {
296                        continue;
297                    }
298                    let mut pending_msg = self.sender.len();
299                    if pending_msg > worker_count {
300                        pending_msg -= worker_count;
301                        if pending_msg > self.max_workers - worker_count {
302                            pending_msg = self.max_workers - worker_count;
303                        }
304                        for _ in 0..pending_msg {
305                            self.clone().spawn();
306                        }
307                    }
308                } else {
309                    continue;
310                }
311            } else {
312                return;
313            }
314        }
315    }
316}
317
318#[cfg(test)]
319mod tests {
320
321    use std::thread;
322
323    use crossbeam::channel::{Sender, bounded};
324    use tokio::time::{Duration, sleep};
325
326    use super::*;
327
328    #[allow(dead_code)]
329    struct MyWorkerPoolImpl();
330
331    struct MyWorker();
332
333    struct MyMsg(i64, Sender<()>);
334
335    impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
336        fn spawn(&self) -> MyWorker {
337            MyWorker()
338        }
339    }
340
341    #[async_trait]
342    impl Worker<MyMsg> for MyWorker {
343        async fn run(&mut self, msg: MyMsg) {
344            sleep(Duration::from_millis(1)).await;
345            println!("done {}", msg.0);
346            let _ = msg.1.send(());
347        }
348    }
349
350    type MyWorkerPool = WorkerPool<MyMsg, MyWorker, MyWorkerPoolImpl>;
351
352    #[test]
353    fn blocking_workerpool_adjust() {
354        let min_workers = 1;
355        let max_workers = 4;
356        let worker_timeout = Duration::from_secs(1);
357        let rt = tokio::runtime::Builder::new_multi_thread()
358            .enable_all()
359            .worker_threads(2)
360            .build()
361            .unwrap();
362        let worker_pool = MyWorkerPool::new(
363            MyWorkerPoolImpl(),
364            min_workers,
365            max_workers,
366            10,
367            worker_timeout,
368            &rt,
369        );
370
371        let mut th_s = Vec::new();
372        for i in 0..8 {
373            let _pool = worker_pool.clone();
374            th_s.push(thread::spawn(move || {
375                let (done_tx, done_rx) = bounded(10);
376                for j in 0..10 {
377                    _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
378                }
379                for _j in 0..10 {
380                    let _ = done_rx.recv();
381                }
382            }));
383        }
384        for th in th_s {
385            let _ = th.join();
386        }
387        let workers = worker_pool.get_worker_count();
388        println!("cur workers {} should reach max", workers);
389        assert_eq!(workers, max_workers);
390
391        thread::sleep(worker_timeout * 2);
392        let workers = worker_pool.get_worker_count();
393        println!("cur workers: {}, extra should exit due to timeout", workers);
394        assert_eq!(workers, min_workers);
395
396        let (done_tx, done_rx) = bounded(2);
397        for j in 0..10 {
398            worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
399            println!("send {}", j);
400            let _ = done_rx.recv();
401        }
402        println!("closing");
403        worker_pool.close();
404        assert_eq!(worker_pool.get_worker_count(), 0);
405        assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
406    }
407
408    #[test]
409    fn blocking_workerpool_fixed() {
410        let min_workers = 4;
411        let max_workers = 4;
412        let worker_timeout = Duration::from_secs(1);
413        let rt = tokio::runtime::Builder::new_multi_thread()
414            .enable_all()
415            .worker_threads(2)
416            .build()
417            .unwrap();
418        let worker_pool = MyWorkerPool::new(
419            MyWorkerPoolImpl(),
420            min_workers,
421            max_workers,
422            10,
423            worker_timeout,
424            &rt,
425        );
426
427        let mut th_s = Vec::new();
428        for i in 0..8 {
429            let _pool = worker_pool.clone();
430            th_s.push(thread::spawn(move || {
431                let (done_tx, done_rx) = bounded(10);
432                for j in 0..10 {
433                    _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
434                }
435                for _j in 0..10 {
436                    let _ = done_rx.recv();
437                }
438            }));
439        }
440        for th in th_s {
441            let _ = th.join();
442        }
443        let workers = worker_pool.get_worker_count();
444        println!("cur workers {} should reach max", workers);
445        assert_eq!(workers, max_workers);
446
447        thread::sleep(worker_timeout * 2);
448        let workers = worker_pool.get_worker_count();
449        println!("cur workers {} should reach max", workers);
450        assert_eq!(workers, max_workers);
451
452        let (done_tx, done_rx) = bounded(2);
453        for j in 0..10 {
454            worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
455            println!("send {}", j);
456            let _ = done_rx.recv();
457        }
458        println!("closing");
459        worker_pool.close();
460        assert_eq!(worker_pool.get_worker_count(), 0);
461        assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
462    }
463}