1use std::{
2 sync::{
3 Arc,
4 atomic::{AtomicBool, AtomicUsize, Ordering},
5 },
6 time::Duration,
7};
8
9use crossfire::mpmc::*;
10use tokio::{runtime::Runtime, time::timeout};
11
12use super::*;
13
14#[allow(unused_must_use)]
20pub struct WorkerPool<M: Send + Sized + Unpin + 'static, W: Worker<M>, S: WorkerPoolImpl<M, W>>(
21 Arc<WorkerPoolInner<M, W, S>>,
22);
23
24struct WorkerPoolInner<M, W, S>
25where
26 M: Send + Sized + Unpin + 'static,
27 W: Worker<M>,
28 S: WorkerPoolImpl<M, W>,
29{
30 worker_count: AtomicUsize,
31 sender: TxBlocking<Option<M>, SharedSenderBRecvF>,
32 recv: RxFuture<Option<M>, SharedSenderBRecvF>,
33 min_workers: usize,
34 max_workers: usize,
35 worker_timeout: Duration,
36 inner: S,
37 water: AtomicUsize,
38 phantom: std::marker::PhantomData<W>, closing: AtomicBool,
40 notify_sender: TxBlocking<Option<()>, SharedSenderBRecvF>,
41 auto: bool,
42 buffer_size: usize,
43}
44
45impl<M, W, S> Clone for WorkerPool<M, W, S>
46where
47 M: Send + Sized + Unpin + 'static,
48 W: Worker<M>,
49 S: WorkerPoolImpl<M, W>,
50{
51 #[inline]
52 fn clone(&self) -> Self {
53 Self(self.0.clone())
54 }
55}
56
57impl<M, W, S> WorkerPool<M, W, S>
58where
59 M: Send + Sized + Unpin + 'static,
60 W: Worker<M>,
61 S: WorkerPoolImpl<M, W>,
62{
63 pub fn new(
64 inner: S,
65 min_workers: usize,
66 max_workers: usize,
67 mut buffer_size: usize,
68 worker_timeout: Duration,
69 rt: &Runtime,
70 ) -> Self {
71 if buffer_size > max_workers * 2 {
72 buffer_size = max_workers * 2;
73 }
74 let (sender, recv) = bounded_tx_blocking_rx_future(buffer_size);
75 let (noti_sender, noti_recv) = bounded_tx_blocking_rx_future(1);
76 assert!(min_workers > 0);
77 assert!(max_workers >= min_workers);
78
79 let auto: bool = min_workers < max_workers;
80 if auto {
81 assert!(worker_timeout != ZERO_DUARTION);
82 }
83
84 let pool = Arc::new(WorkerPoolInner {
85 sender,
86 recv,
87 inner,
88 worker_count: AtomicUsize::new(0),
89 min_workers,
90 max_workers,
91 buffer_size,
92 worker_timeout,
93 phantom: Default::default(),
94 closing: AtomicBool::new(false),
95 water: AtomicUsize::new(0),
96 notify_sender: noti_sender,
97 auto,
98 });
99 let _pool = pool.clone();
100 rt.spawn(async move {
101 _pool.monitor(noti_recv).await;
102 });
103 Self(pool)
104 }
105
106 pub fn get_inner(&self) -> &S {
107 &self.0.inner
108 }
109
110 pub fn submit(&self, msg: M) -> Option<M> {
113 let _self = self.0.as_ref();
114 if _self.closing.load(Ordering::Acquire) {
115 return Some(msg);
116 }
117 if _self.auto {
118 let worker_count = _self.get_worker_count();
119 let water = _self.water.fetch_add(1, Ordering::SeqCst);
120 if worker_count < _self.max_workers {
121 if water > worker_count + 1 || water > _self.buffer_size {
122 let _ = _self.notify_sender.try_send(Some(()));
123 }
124 }
125 }
126 match _self.sender.send(Some(msg)) {
127 Ok(_) => return None,
128 Err(SendError(t)) => return t,
129 }
130 }
131
132 pub fn close(&self) {
133 let _self = self.0.as_ref();
134 if _self.closing.swap(true, Ordering::SeqCst) {
135 return;
136 }
137 loop {
138 let cur = self.get_worker_count();
139 if cur == 0 {
140 break;
141 }
142 debug!("worker pool closing: cur workers {}", cur);
143 for _ in 0..cur {
144 let _ = _self.sender.send(None);
145 }
146 std::thread::sleep(_self.worker_timeout);
147 }
149 let _ = _self.notify_sender.send(None);
151 let _ = _self.notify_sender.send(None);
152 }
153
154 pub fn get_worker_count(&self) -> usize {
155 self.0.get_worker_count()
156 }
157}
158
159impl<M, W, S> WorkerPoolInner<M, W, S>
160where
161 M: Send + Sized + Unpin + 'static,
162 W: Worker<M>,
163 S: WorkerPoolImpl<M, W>,
164{
165 async fn run_worker_simple(&self, mut worker: W) {
166 if let Err(_) = worker.init().await {
167 let _ = self.try_exit();
168 worker.on_exit();
169 return;
170 }
171
172 let recv = &self.recv;
173 'WORKER_LOOP: loop {
174 match recv.recv().await {
175 Ok(item) => {
176 if item.is_none() {
177 let _ = self.try_exit();
178 break 'WORKER_LOOP;
179 }
180 worker.run(item.unwrap()).await;
181 }
182 Err(_) => {
183 let _ = self.try_exit();
185 break 'WORKER_LOOP;
186 }
187 }
188 }
189 worker.on_exit();
190 }
191
192 async fn run_worker_adjust(&self, mut worker: W) {
193 if let Err(_) = worker.init().await {
194 let _ = self.try_exit();
195 worker.on_exit();
196 return;
197 }
198
199 let worker_timeout = self.worker_timeout;
200 let recv = &self.recv;
201 let mut is_idle = false;
202 'WORKER_LOOP: loop {
203 if is_idle {
204 match timeout(worker_timeout, recv.recv()).await {
205 Ok(res) => {
206 match res {
207 Ok(item) => {
208 if item.is_none() {
209 let _ = self.try_exit();
210 break 'WORKER_LOOP;
211 }
212 worker.run(item.unwrap()).await;
213 is_idle = false;
214 self.water.fetch_sub(1, Ordering::SeqCst);
215 }
216 Err(_) => {
217 let _ = self.try_exit();
219 worker.on_exit();
220 }
221 }
222 }
223 Err(_) => {
224 if self.try_exit() {
226 break 'WORKER_LOOP;
227 }
228 }
229 }
230 } else {
231 match recv.try_recv() {
232 Err(e) => {
233 if e.is_empty() {
234 is_idle = true;
235 } else {
236 let _ = self.try_exit();
237 break 'WORKER_LOOP;
238 }
239 }
240 Ok(Some(item)) => {
241 worker.run(item).await;
242 self.water.fetch_sub(1, Ordering::SeqCst);
243 is_idle = false;
244 }
245 Ok(None) => {
246 let _ = self.try_exit();
247 break 'WORKER_LOOP;
248 }
249 }
250 }
251 }
252 worker.on_exit();
253 }
254
255 #[inline(always)]
256 pub fn get_worker_count(&self) -> usize {
257 self.worker_count.load(Ordering::Acquire)
258 }
259
260 #[inline(always)]
261 fn spawn(self: Arc<Self>) {
262 self.worker_count.fetch_add(1, Ordering::SeqCst);
263 let worker = self.inner.spawn();
264 let _self = self.clone();
265 tokio::spawn(async move {
266 if _self.auto {
267 _self.run_worker_adjust(worker).await
268 } else {
269 _self.run_worker_simple(worker).await
270 }
271 });
272 }
273
274 #[inline(always)]
276 fn try_exit(&self) -> bool {
277 if self.closing.load(Ordering::Acquire) {
278 self.worker_count.fetch_sub(1, Ordering::SeqCst);
279 return true;
280 }
281 if self.get_worker_count() > self.min_workers {
282 if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
283 self.worker_count.fetch_add(1, Ordering::SeqCst); } else {
285 return true; }
287 }
288 return false;
289 }
290
291 async fn monitor(self: Arc<Self>, noti_recv: RxFuture<Option<()>, SharedSenderBRecvF>) {
292 for _ in 0..self.min_workers {
293 self.clone().spawn();
294 }
295 loop {
296 if let Ok(Some(_)) = noti_recv.recv().await {
297 if self.auto {
298 let worker_count = self.get_worker_count();
299 if worker_count > self.max_workers {
300 continue;
301 }
302 let mut pending_msg = self.sender.len();
303 if pending_msg > worker_count {
304 pending_msg -= worker_count;
305 if pending_msg > self.max_workers - worker_count {
306 pending_msg = self.max_workers - worker_count;
307 }
308 for _ in 0..pending_msg {
309 self.clone().spawn();
310 }
311 }
312 } else {
313 continue;
314 }
315 } else {
316 return;
317 }
318 }
319 }
320}
321
322#[cfg(test)]
323mod tests {
324
325 use std::thread;
326
327 use crossbeam::channel::{Sender, bounded};
328 use tokio::time::{Duration, sleep};
329
330 use super::*;
331
332 #[allow(dead_code)]
333 struct MyWorkerPoolImpl();
334
335 struct MyWorker();
336
337 struct MyMsg(i64, Sender<()>);
338
339 impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
340 fn spawn(&self) -> MyWorker {
341 MyWorker()
342 }
343 }
344
345 #[async_trait]
346 impl Worker<MyMsg> for MyWorker {
347 async fn run(&mut self, msg: MyMsg) {
348 sleep(Duration::from_millis(1)).await;
349 println!("done {}", msg.0);
350 let _ = msg.1.send(());
351 }
352 }
353
354 type MyWorkerPool = WorkerPool<MyMsg, MyWorker, MyWorkerPoolImpl>;
355
356 #[test]
357 fn blocking_workerpool_adjust() {
358 let min_workers = 1;
359 let max_workers = 4;
360 let worker_timeout = Duration::from_secs(1);
361 let rt = tokio::runtime::Builder::new_multi_thread()
362 .enable_all()
363 .worker_threads(2)
364 .build()
365 .unwrap();
366 let worker_pool = MyWorkerPool::new(
367 MyWorkerPoolImpl(),
368 min_workers,
369 max_workers,
370 10,
371 worker_timeout,
372 &rt,
373 );
374
375 let mut ths = Vec::new();
376 for i in 0..8 {
377 let _pool = worker_pool.clone();
378 ths.push(thread::spawn(move || {
379 let (done_tx, done_rx) = bounded(10);
380 for j in 0..10 {
381 _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
382 }
383 for _j in 0..10 {
384 let _ = done_rx.recv();
385 }
386 }));
387 }
388 for th in ths {
389 let _ = th.join();
390 }
391 let workers = worker_pool.get_worker_count();
392 println!("cur workers {} should reach max", workers);
393 assert_eq!(workers, max_workers);
394
395 thread::sleep(worker_timeout * 2);
396 let workers = worker_pool.get_worker_count();
397 println!("cur workers: {}, extra should exit due to timeout", workers);
398 assert_eq!(workers, min_workers);
399
400 let (done_tx, done_rx) = bounded(2);
401 for j in 0..10 {
402 worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
403 println!("send {}", j);
404 let _ = done_rx.recv();
405 }
406 println!("closing");
407 worker_pool.close();
408 assert_eq!(worker_pool.get_worker_count(), 0);
409 assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
410 }
411
412 #[test]
413 fn blocking_workerpool_fixed() {
414 let min_workers = 4;
415 let max_workers = 4;
416 let worker_timeout = Duration::from_secs(1);
417 let rt = tokio::runtime::Builder::new_multi_thread()
418 .enable_all()
419 .worker_threads(2)
420 .build()
421 .unwrap();
422 let worker_pool = MyWorkerPool::new(
423 MyWorkerPoolImpl(),
424 min_workers,
425 max_workers,
426 10,
427 worker_timeout,
428 &rt,
429 );
430
431 let mut ths = Vec::new();
432 for i in 0..8 {
433 let _pool = worker_pool.clone();
434 ths.push(thread::spawn(move || {
435 let (done_tx, done_rx) = bounded(10);
436 for j in 0..10 {
437 _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
438 }
439 for _j in 0..10 {
440 let _ = done_rx.recv();
441 }
442 }));
443 }
444 for th in ths {
445 let _ = th.join();
446 }
447 let workers = worker_pool.get_worker_count();
448 println!("cur workers {} should reach max", workers);
449 assert_eq!(workers, max_workers);
450
451 thread::sleep(worker_timeout * 2);
452 let workers = worker_pool.get_worker_count();
453 println!("cur workers {} should reach max", workers);
454 assert_eq!(workers, max_workers);
455
456 let (done_tx, done_rx) = bounded(2);
457 for j in 0..10 {
458 worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
459 println!("send {}", j);
460 let _ = done_rx.recv();
461 }
462 println!("closing");
463 worker_pool.close();
464 assert_eq!(worker_pool.get_worker_count(), 0);
465 assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
466 }
467}