sync_utils/worker_pool/
bounded_blocking.rs

1use std::{
2    sync::{
3        Arc,
4        atomic::{AtomicBool, AtomicUsize, Ordering},
5    },
6    time::Duration,
7};
8
9use crossfire::mpmc::*;
10use tokio::{runtime::Runtime, time::timeout};
11
12use super::*;
13
14/*
15 * WorkerPool that submit side without tokio Context, has one monitor runs in tokio runtime to
16 * spawn and adjust long-lived workers
17 */
18
19#[allow(unused_must_use)]
20pub struct WorkerPool<M: Send + Sized + Unpin + 'static, W: Worker<M>, S: WorkerPoolImpl<M, W>>(
21    Arc<WorkerPoolInner<M, W, S>>,
22);
23
24struct WorkerPoolInner<M, W, S>
25where
26    M: Send + Sized + Unpin + 'static,
27    W: Worker<M>,
28    S: WorkerPoolImpl<M, W>,
29{
30    worker_count: AtomicUsize,
31    sender: TxBlocking<Option<M>, SharedSenderBRecvF>,
32    recv: RxFuture<Option<M>, SharedSenderBRecvF>,
33    min_workers: usize,
34    max_workers: usize,
35    worker_timeout: Duration,
36    inner: S,
37    water: AtomicUsize,
38    phantom: std::marker::PhantomData<W>, // to avoid complaining unused param
39    closing: AtomicBool,
40    notify_sender: TxBlocking<Option<()>, SharedSenderBRecvF>,
41    auto: bool,
42    buffer_size: usize,
43}
44
45impl<M, W, S> Clone for WorkerPool<M, W, S>
46where
47    M: Send + Sized + Unpin + 'static,
48    W: Worker<M>,
49    S: WorkerPoolImpl<M, W>,
50{
51    #[inline]
52    fn clone(&self) -> Self {
53        Self(self.0.clone())
54    }
55}
56
57impl<M, W, S> WorkerPool<M, W, S>
58where
59    M: Send + Sized + Unpin + 'static,
60    W: Worker<M>,
61    S: WorkerPoolImpl<M, W>,
62{
63    pub fn new(
64        inner: S,
65        min_workers: usize,
66        max_workers: usize,
67        mut buffer_size: usize,
68        worker_timeout: Duration,
69        rt: &Runtime,
70    ) -> Self {
71        if buffer_size > max_workers * 2 {
72            buffer_size = max_workers * 2;
73        }
74        let (sender, recv) = bounded_tx_blocking_rx_future(buffer_size);
75        let (noti_sender, noti_recv) = bounded_tx_blocking_rx_future(1);
76        assert!(min_workers > 0);
77        assert!(max_workers >= min_workers);
78
79        let auto: bool = min_workers < max_workers;
80        if auto {
81            assert!(worker_timeout != ZERO_DUARTION);
82        }
83
84        let pool = Arc::new(WorkerPoolInner {
85            sender,
86            recv,
87            inner,
88            worker_count: AtomicUsize::new(0),
89            min_workers,
90            max_workers,
91            buffer_size,
92            worker_timeout,
93            phantom: Default::default(),
94            closing: AtomicBool::new(false),
95            water: AtomicUsize::new(0),
96            notify_sender: noti_sender,
97            auto,
98        });
99        let _pool = pool.clone();
100        rt.spawn(async move {
101            _pool.monitor(noti_recv).await;
102        });
103        Self(pool)
104    }
105
106    pub fn get_inner(&self) -> &S {
107        &self.0.inner
108    }
109
110    // sending None will notify worker close
111    // return non-null on send fail
112    pub fn submit(&self, msg: M) -> Option<M> {
113        let _self = self.0.as_ref();
114        if _self.closing.load(Ordering::Acquire) {
115            return Some(msg);
116        }
117        if _self.auto {
118            let worker_count = _self.get_worker_count();
119            let water = _self.water.fetch_add(1, Ordering::SeqCst);
120            if worker_count < _self.max_workers {
121                if water > worker_count + 1 || water > _self.buffer_size {
122                    let _ = _self.notify_sender.try_send(Some(()));
123                }
124            }
125        }
126        match _self.sender.send(Some(msg)) {
127            Ok(_) => return None,
128            Err(SendError(t)) => return t,
129        }
130    }
131
132    pub fn close(&self) {
133        let _self = self.0.as_ref();
134        if _self.closing.swap(true, Ordering::SeqCst) {
135            return;
136        }
137        loop {
138            let cur = self.get_worker_count();
139            if cur == 0 {
140                break;
141            }
142            debug!("worker pool closing: cur workers {}", cur);
143            for _ in 0..cur {
144                let _ = _self.sender.send(None);
145            }
146            std::thread::sleep(_self.worker_timeout);
147            // TODO graceful exec all remaining
148        }
149        // Because bound is 1, send twice to unsure monitor coroutine exit
150        let _ = _self.notify_sender.send(None);
151        let _ = _self.notify_sender.send(None);
152    }
153
154    pub fn get_worker_count(&self) -> usize {
155        self.0.get_worker_count()
156    }
157}
158
159impl<M, W, S> WorkerPoolInner<M, W, S>
160where
161    M: Send + Sized + Unpin + 'static,
162    W: Worker<M>,
163    S: WorkerPoolImpl<M, W>,
164{
165    async fn run_worker_simple(&self, mut worker: W) {
166        if let Err(_) = worker.init().await {
167            let _ = self.try_exit();
168            worker.on_exit();
169            return;
170        }
171
172        let recv = &self.recv;
173        'WORKER_LOOP: loop {
174            match recv.recv().await {
175                Ok(item) => {
176                    if item.is_none() {
177                        let _ = self.try_exit();
178                        break 'WORKER_LOOP;
179                    }
180                    worker.run(item.unwrap()).await;
181                }
182                Err(_) => {
183                    // channel closed worker exit
184                    let _ = self.try_exit();
185                    break 'WORKER_LOOP;
186                }
187            }
188        }
189        worker.on_exit();
190    }
191
192    async fn run_worker_adjust(&self, mut worker: W) {
193        if let Err(_) = worker.init().await {
194            let _ = self.try_exit();
195            worker.on_exit();
196            return;
197        }
198
199        let worker_timeout = self.worker_timeout;
200        let recv = &self.recv;
201        let mut is_idle = false;
202        'WORKER_LOOP: loop {
203            if is_idle {
204                match timeout(worker_timeout, recv.recv()).await {
205                    Ok(res) => {
206                        match res {
207                            Ok(item) => {
208                                if item.is_none() {
209                                    let _ = self.try_exit();
210                                    break 'WORKER_LOOP;
211                                }
212                                worker.run(item.unwrap()).await;
213                                is_idle = false;
214                                self.water.fetch_sub(1, Ordering::SeqCst);
215                            }
216                            Err(_) => {
217                                // channel closed worker exit
218                                let _ = self.try_exit();
219                                worker.on_exit();
220                            }
221                        }
222                    }
223                    Err(_) => {
224                        // timeout
225                        if self.try_exit() {
226                            break 'WORKER_LOOP;
227                        }
228                    }
229                }
230            } else {
231                match recv.try_recv() {
232                    Err(e) => {
233                        if e.is_empty() {
234                            is_idle = true;
235                        } else {
236                            let _ = self.try_exit();
237                            break 'WORKER_LOOP;
238                        }
239                    }
240                    Ok(Some(item)) => {
241                        worker.run(item).await;
242                        self.water.fetch_sub(1, Ordering::SeqCst);
243                        is_idle = false;
244                    }
245                    Ok(None) => {
246                        let _ = self.try_exit();
247                        break 'WORKER_LOOP;
248                    }
249                }
250            }
251        }
252        worker.on_exit();
253    }
254
255    #[inline(always)]
256    pub fn get_worker_count(&self) -> usize {
257        self.worker_count.load(Ordering::Acquire)
258    }
259
260    #[inline(always)]
261    fn spawn(self: Arc<Self>) {
262        self.worker_count.fetch_add(1, Ordering::SeqCst);
263        let worker = self.inner.spawn();
264        let _self = self.clone();
265        tokio::spawn(async move {
266            if _self.auto {
267                _self.run_worker_adjust(worker).await
268            } else {
269                _self.run_worker_simple(worker).await
270            }
271        });
272    }
273
274    // check if idle worker should exit
275    #[inline(always)]
276    fn try_exit(&self) -> bool {
277        if self.closing.load(Ordering::Acquire) {
278            self.worker_count.fetch_sub(1, Ordering::SeqCst);
279            return true;
280        }
281        if self.get_worker_count() > self.min_workers {
282            if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
283                self.worker_count.fetch_add(1, Ordering::SeqCst); // rollback
284            } else {
285                return true; // worker exit
286            }
287        }
288        return false;
289    }
290
291    async fn monitor(self: Arc<Self>, noti_recv: RxFuture<Option<()>, SharedSenderBRecvF>) {
292        for _ in 0..self.min_workers {
293            self.clone().spawn();
294        }
295        loop {
296            if let Ok(Some(_)) = noti_recv.recv().await {
297                if self.auto {
298                    let worker_count = self.get_worker_count();
299                    if worker_count > self.max_workers {
300                        continue;
301                    }
302                    let mut pending_msg = self.sender.len();
303                    if pending_msg > worker_count {
304                        pending_msg -= worker_count;
305                        if pending_msg > self.max_workers - worker_count {
306                            pending_msg = self.max_workers - worker_count;
307                        }
308                        for _ in 0..pending_msg {
309                            self.clone().spawn();
310                        }
311                    }
312                } else {
313                    continue;
314                }
315            } else {
316                return;
317            }
318        }
319    }
320}
321
322#[cfg(test)]
323mod tests {
324
325    use std::thread;
326
327    use crossbeam::channel::{Sender, bounded};
328    use tokio::time::{Duration, sleep};
329
330    use super::*;
331
332    #[allow(dead_code)]
333    struct MyWorkerPoolImpl();
334
335    struct MyWorker();
336
337    struct MyMsg(i64, Sender<()>);
338
339    impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
340        fn spawn(&self) -> MyWorker {
341            MyWorker()
342        }
343    }
344
345    #[async_trait]
346    impl Worker<MyMsg> for MyWorker {
347        async fn run(&mut self, msg: MyMsg) {
348            sleep(Duration::from_millis(1)).await;
349            println!("done {}", msg.0);
350            let _ = msg.1.send(());
351        }
352    }
353
354    type MyWorkerPool = WorkerPool<MyMsg, MyWorker, MyWorkerPoolImpl>;
355
356    #[test]
357    fn blocking_workerpool_adjust() {
358        let min_workers = 1;
359        let max_workers = 4;
360        let worker_timeout = Duration::from_secs(1);
361        let rt = tokio::runtime::Builder::new_multi_thread()
362            .enable_all()
363            .worker_threads(2)
364            .build()
365            .unwrap();
366        let worker_pool = MyWorkerPool::new(
367            MyWorkerPoolImpl(),
368            min_workers,
369            max_workers,
370            10,
371            worker_timeout,
372            &rt,
373        );
374
375        let mut ths = Vec::new();
376        for i in 0..8 {
377            let _pool = worker_pool.clone();
378            ths.push(thread::spawn(move || {
379                let (done_tx, done_rx) = bounded(10);
380                for j in 0..10 {
381                    _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
382                }
383                for _j in 0..10 {
384                    let _ = done_rx.recv();
385                }
386            }));
387        }
388        for th in ths {
389            let _ = th.join();
390        }
391        let workers = worker_pool.get_worker_count();
392        println!("cur workers {} should reach max", workers);
393        assert_eq!(workers, max_workers);
394
395        thread::sleep(worker_timeout * 2);
396        let workers = worker_pool.get_worker_count();
397        println!("cur workers: {}, extra should exit due to timeout", workers);
398        assert_eq!(workers, min_workers);
399
400        let (done_tx, done_rx) = bounded(2);
401        for j in 0..10 {
402            worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
403            println!("send {}", j);
404            let _ = done_rx.recv();
405        }
406        println!("closing");
407        worker_pool.close();
408        assert_eq!(worker_pool.get_worker_count(), 0);
409        assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
410    }
411
412    #[test]
413    fn blocking_workerpool_fixed() {
414        let min_workers = 4;
415        let max_workers = 4;
416        let worker_timeout = Duration::from_secs(1);
417        let rt = tokio::runtime::Builder::new_multi_thread()
418            .enable_all()
419            .worker_threads(2)
420            .build()
421            .unwrap();
422        let worker_pool = MyWorkerPool::new(
423            MyWorkerPoolImpl(),
424            min_workers,
425            max_workers,
426            10,
427            worker_timeout,
428            &rt,
429        );
430
431        let mut ths = Vec::new();
432        for i in 0..8 {
433            let _pool = worker_pool.clone();
434            ths.push(thread::spawn(move || {
435                let (done_tx, done_rx) = bounded(10);
436                for j in 0..10 {
437                    _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
438                }
439                for _j in 0..10 {
440                    let _ = done_rx.recv();
441                }
442            }));
443        }
444        for th in ths {
445            let _ = th.join();
446        }
447        let workers = worker_pool.get_worker_count();
448        println!("cur workers {} should reach max", workers);
449        assert_eq!(workers, max_workers);
450
451        thread::sleep(worker_timeout * 2);
452        let workers = worker_pool.get_worker_count();
453        println!("cur workers {} should reach max", workers);
454        assert_eq!(workers, max_workers);
455
456        let (done_tx, done_rx) = bounded(2);
457        for j in 0..10 {
458            worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
459            println!("send {}", j);
460            let _ = done_rx.recv();
461        }
462        println!("closing");
463        worker_pool.close();
464        assert_eq!(worker_pool.get_worker_count(), 0);
465        assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
466    }
467}