1use std::{
2 cell::UnsafeCell,
3 future::Future,
4 mem::transmute,
5 pin::Pin,
6 sync::{
7 Arc,
8 atomic::{AtomicBool, AtomicUsize, Ordering},
9 },
10 task::*,
11 thread,
12 time::Duration,
13};
14
15use crossfire::*;
16use tokio::time::{sleep, timeout};
17
18use super::*;
19
20#[allow(unused_must_use)]
21pub struct WorkerPoolBounded<
22 M: Send + Sized + Unpin + 'static,
23 W: Worker<M>,
24 S: WorkerPoolImpl<M, W>,
25>(Arc<WorkerPoolBoundedInner<M, W, S>>);
26
27struct WorkerPoolBoundedInner<M, W, S>
28where
29 M: Send + Sized + Unpin + 'static,
30 W: Worker<M>,
31 S: WorkerPoolImpl<M, W>,
32{
33 worker_count: AtomicUsize,
34 sender: UnsafeCell<Option<MAsyncTx<Option<M>>>>,
35 min_workers: usize,
36 max_workers: usize,
37 worker_timeout: Duration,
38 inner: S,
39 phantom: std::marker::PhantomData<W>, closing: AtomicBool,
41 notify_sender: MAsyncTx<Option<()>>,
42 notify_recv: UnsafeCell<Option<AsyncRx<Option<()>>>>,
43 auto: bool,
44 channel_size: usize,
45 real_thread: AtomicBool,
46 bind_cpu: AtomicUsize,
47 max_cpu: usize,
48}
49
50unsafe impl<M, W, S> Send for WorkerPoolBoundedInner<M, W, S>
51where
52 M: Send + Sized + Unpin + 'static,
53 W: Worker<M>,
54 S: WorkerPoolImpl<M, W>,
55{
56}
57
58unsafe impl<M, W, S> Sync for WorkerPoolBoundedInner<M, W, S>
59where
60 M: Send + Sized + Unpin + 'static,
61 W: Worker<M>,
62 S: WorkerPoolImpl<M, W>,
63{
64}
65
66impl<M, W, S> Clone for WorkerPoolBounded<M, W, S>
67where
68 M: Send + Sized + Unpin + 'static,
69 W: Worker<M>,
70 S: WorkerPoolImpl<M, W>,
71{
72 #[inline]
73 fn clone(&self) -> Self {
74 Self(self.0.clone())
75 }
76}
77
78impl<M, W, S> WorkerPoolBounded<M, W, S>
79where
80 M: Send + Sized + Unpin + 'static,
81 W: Worker<M>,
82 S: WorkerPoolImpl<M, W>,
83{
84 pub fn new(
85 inner: S, min_workers: usize, max_workers: usize, channel_size: usize,
86 worker_timeout: Duration,
87 ) -> Self {
88 assert!(min_workers > 0);
89 assert!(max_workers >= min_workers);
90
91 let auto: bool = min_workers < max_workers;
92 if auto {
93 assert!(worker_timeout != ZERO_DUARTION);
94 }
95 let (noti_sender, noti_recv) = mpsc::bounded_async(1);
96 let pool = Arc::new(WorkerPoolBoundedInner {
97 sender: UnsafeCell::new(None),
98 inner,
99 worker_count: AtomicUsize::new(0),
100 min_workers,
101 max_workers,
102 channel_size,
103 worker_timeout,
104 phantom: Default::default(),
105 closing: AtomicBool::new(false),
106 notify_sender: noti_sender,
107 notify_recv: UnsafeCell::new(Some(noti_recv)),
108 auto,
109 real_thread: AtomicBool::new(false),
110 bind_cpu: AtomicUsize::new(0),
111 max_cpu: num_cpus::get(),
112 });
113 Self(pool)
114 }
115
116 pub fn set_use_thread(&mut self, ok: bool) {
118 self.0.real_thread.store(ok, Ordering::Release);
119 }
120
121 pub fn start(&self) {
122 let _self = self.0.as_ref();
123 let (sender, rx) = mpmc::bounded_async(_self.channel_size);
124 _self._sender().replace(sender);
125
126 for _ in 0.._self.min_workers {
127 self.0.clone().spawn(true, rx.clone());
128 }
129 if _self.auto {
130 let _pool = self.0.clone();
131 let notify_recv: &mut Option<AsyncRx<Option<()>>> =
132 unsafe { transmute(_self.notify_recv.get()) };
133 let noti_rx = notify_recv.take().unwrap();
134 tokio::spawn(async move {
135 _pool.monitor(noti_rx, rx).await;
136 });
137 }
138 }
139
140 pub async fn close_async(&self) {
141 let _self = self.0.as_ref();
142 if _self.closing.swap(true, Ordering::SeqCst) {
143 return;
144 }
145 if _self.auto {
146 let _ = _self.notify_sender.send(None).await;
147 }
148 let sender = _self._sender().as_ref().unwrap();
149 loop {
150 let cur = self.get_worker_count();
151 if cur == 0 {
152 break;
153 }
154 debug!("worker pool closing: cur workers {}", cur);
155 for _ in 0..cur {
156 let _ = sender.send(None).await;
157 }
158 sleep(Duration::from_secs(1)).await;
159 }
160 }
161
162 pub fn close(&self) {
164 if let Ok(_rt) = tokio::runtime::Handle::try_current() {
165 warn!("close in runtime thread, spawn close thread");
166 let _self = self.clone();
167 std::thread::spawn(move || {
168 let rt = tokio::runtime::Builder::new_current_thread()
169 .enable_all()
170 .build()
171 .expect("runtime");
172 rt.block_on(async move {
173 _self.close_async().await;
174 });
175 });
176 } else {
177 let rt = tokio::runtime::Builder::new_current_thread()
178 .enable_all()
179 .build()
180 .expect("runtime");
181 let _self = self.clone();
182 rt.block_on(async move {
183 _self.close_async().await;
184 });
185 }
186 }
187
188 pub fn get_worker_count(&self) -> usize {
189 self.0.get_worker_count()
190 }
191
192 pub fn get_inner(&self) -> &S {
193 &self.0.inner
194 }
195
196 #[inline]
197 pub fn try_submit(&self, msg: M) -> Option<M> {
198 let _self = self.0.as_ref();
199 if _self.closing.load(Ordering::Acquire) {
200 return Some(msg);
201 }
202 match _self._sender().as_ref().unwrap().try_send(Some(msg)) {
203 Err(TrySendError::Disconnected(m)) => {
204 return m;
205 }
206 Err(TrySendError::Full(m)) => {
207 return m;
208 }
209 Ok(_) => return None,
210 }
211 }
212
213 #[inline]
215 pub fn submit<'a>(&'a self, mut msg: M) -> SubmitFuture<'a, M> {
216 let _self = self.0.as_ref();
217 if _self.closing.load(Ordering::Acquire) {
218 return SubmitFuture { send_f: None, res: Some(Err(msg)) };
219 }
220 let sender = _self._sender().as_ref().unwrap();
221 if _self.auto {
222 match sender.try_send(Some(msg)) {
223 Err(TrySendError::Disconnected(m)) => {
224 return SubmitFuture { send_f: None, res: Some(Err(m.unwrap())) };
225 }
226 Err(TrySendError::Full(m)) => {
227 msg = m.unwrap();
228 }
229 Ok(_) => {
230 return SubmitFuture { send_f: None, res: Some(Ok(())) };
231 }
232 }
233 let worker_count = _self.get_worker_count();
234 if worker_count < _self.max_workers {
235 let _ = _self.notify_sender.try_send(Some(()));
236 }
237 }
238 let send_f = sender.send(Some(msg));
239 return SubmitFuture { send_f: Some(send_f), res: None };
240 }
241}
242
243pub struct SubmitFuture<'a, M: Send + Sized + Unpin + 'static> {
244 send_f: Option<SendFuture<'a, Option<M>>>,
245 res: Option<Result<(), M>>,
246}
247
248impl<'a, M: Send + Sized + Unpin + 'static> Future for SubmitFuture<'a, M> {
249 type Output = Option<M>;
250
251 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
252 let _self = self.get_mut();
253 if _self.res.is_some() {
254 match _self.res.take().unwrap() {
255 Ok(()) => return Poll::Ready(None),
256 Err(m) => return Poll::Ready(Some(m)),
257 }
258 }
259 let send_f = _self.send_f.as_mut().unwrap();
260 if let Poll::Ready(r) = Pin::new(send_f).poll(ctx) {
261 match r {
262 Ok(()) => return Poll::Ready(None),
263 Err(SendError(e)) => {
264 return Poll::Ready(e);
265 }
266 }
267 }
268 Poll::Pending
269 }
270}
271
272#[async_trait]
273impl<M, W, S> WorkerPoolAsyncInf<M> for WorkerPoolBounded<M, W, S>
274where
275 M: Send + Sized + Unpin + 'static,
276 W: Worker<M>,
277 S: WorkerPoolImpl<M, W>,
278{
279 #[inline]
280 async fn submit(&self, msg: M) -> Option<M> {
281 self.submit(msg).await
282 }
283
284 #[inline]
285 fn try_submit(&self, msg: M) -> Option<M> {
286 self.try_submit(msg)
287 }
288}
289
290impl<M, W, S> WorkerPoolBoundedInner<M, W, S>
291where
292 M: Send + Sized + Unpin + 'static,
293 W: Worker<M>,
294 S: WorkerPoolImpl<M, W>,
295{
296 #[inline(always)]
297 fn _sender(&self) -> &mut Option<MAsyncTx<Option<M>>> {
298 unsafe { transmute(self.sender.get()) }
299 }
300
301 async fn run_worker_simple(&self, mut worker: W, rx: MAsyncRx<Option<M>>) {
302 if let Err(_) = worker.init().await {
303 let _ = self.try_exit();
304 worker.on_exit();
305 return;
306 }
307 loop {
308 match rx.recv().await {
309 Ok(item) => {
310 if item.is_none() {
311 let _ = self.try_exit();
312 break;
313 }
314 worker.run(item.unwrap()).await;
315 }
316 Err(_) => {
317 let _ = self.try_exit();
319 break;
320 }
321 }
322 }
323 worker.on_exit();
324 }
325
326 async fn run_worker_adjust(&self, mut worker: W, rx: MAsyncRx<Option<M>>) {
327 if let Err(_) = worker.init().await {
328 let _ = self.try_exit();
329 worker.on_exit();
330 return;
331 }
332
333 let worker_timeout = self.worker_timeout;
334 let mut is_idle = false;
335 'WORKER_LOOP: loop {
336 if is_idle {
337 match rx.recv_timeout(worker_timeout).await {
338 Ok(item) => {
339 if item.is_none() {
340 let _ = self.try_exit();
341 break 'WORKER_LOOP;
342 }
343 worker.run(item.unwrap()).await;
344 is_idle = false;
345 }
346 Err(RecvTimeoutError::Disconnected) => {
347 let _ = self.try_exit();
349 worker.on_exit();
350 }
351 Err(RecvTimeoutError::Timeout) => {
352 if self.try_exit() {
354 break 'WORKER_LOOP;
355 }
356 }
357 }
358 } else {
359 match rx.try_recv() {
360 Err(e) => {
361 if e.is_empty() {
362 is_idle = true;
363 } else {
364 let _ = self.try_exit();
365 break 'WORKER_LOOP;
366 }
367 }
368 Ok(Some(item)) => {
369 worker.run(item).await;
370 is_idle = false;
371 }
372 Ok(None) => {
373 let _ = self.try_exit();
374 break 'WORKER_LOOP;
375 }
376 }
377 }
378 }
379 worker.on_exit();
380 }
381
382 #[inline(always)]
383 pub fn get_worker_count(&self) -> usize {
384 self.worker_count.load(Ordering::Acquire)
385 }
386
387 #[inline(always)]
388 fn spawn(self: Arc<Self>, initial: bool, rx: MAsyncRx<Option<M>>) {
389 self.worker_count.fetch_add(1, Ordering::SeqCst);
390 let worker = self.inner.spawn();
391 let _self = self.clone();
392 if self.real_thread.load(Ordering::Acquire) {
393 let mut bind_cpu: Option<usize> = None;
394 if _self.bind_cpu.load(Ordering::Acquire) <= _self.max_cpu {
395 let cpu = _self.bind_cpu.fetch_add(1, Ordering::SeqCst);
396 if cpu < _self.max_cpu {
397 bind_cpu = Some(cpu as usize);
398 }
399 }
400 thread::spawn(move || {
401 if let Some(cpu) = bind_cpu {
402 core_affinity::set_for_current(core_affinity::CoreId { id: cpu });
403 }
404 let rt = tokio::runtime::Builder::new_current_thread()
405 .enable_all()
406 .build()
407 .expect("runtime");
408 rt.block_on(async move {
409 if initial || !_self.auto {
410 _self.run_worker_simple(worker, rx).await
411 } else {
412 _self.run_worker_adjust(worker, rx).await
413 }
414 });
415 });
416 } else {
417 tokio::spawn(async move {
418 if initial || !_self.auto {
419 _self.run_worker_simple(worker, rx).await
420 } else {
421 _self.run_worker_adjust(worker, rx).await
422 }
423 });
424 }
425 }
426
427 #[inline(always)]
429 fn try_exit(&self) -> bool {
430 if self.closing.load(Ordering::Acquire) {
431 self.worker_count.fetch_sub(1, Ordering::SeqCst);
432 return true;
433 }
434 if self.get_worker_count() > self.min_workers {
435 if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
436 self.worker_count.fetch_add(1, Ordering::SeqCst); } else {
438 return true; }
440 }
441 return false;
442 }
443
444 async fn monitor(self: Arc<Self>, noti: AsyncRx<Option<()>>, rx: MAsyncRx<Option<M>>) {
445 let _self = self.as_ref();
446 loop {
447 match timeout(Duration::from_secs(1), noti.recv()).await {
448 Err(_) => {
449 if _self.closing.load(Ordering::Acquire) {
450 return;
451 }
452 continue;
453 }
454 Ok(Ok(Some(_))) => {
455 if _self.closing.load(Ordering::Acquire) {
456 return;
457 }
458 let worker_count = _self.get_worker_count();
459 if worker_count > _self.max_workers {
460 continue;
461 }
462 self.clone().spawn(false, rx.clone());
463 }
464 _ => {
465 println!("monitor exit");
466 return;
467 }
468 }
469 }
470 }
471}
472
473#[cfg(test)]
474mod tests {
475
476 use crossfire::*;
477 use tokio::time::{Duration, sleep};
478
479 use super::*;
480
481 #[allow(dead_code)]
482 struct MyWorkerPoolImpl();
483
484 struct MyWorker();
485
486 struct MyMsg(i64, MAsyncTx<()>);
487
488 impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
489 fn spawn(&self) -> MyWorker {
490 MyWorker()
491 }
492 }
493
494 #[async_trait]
495 impl Worker<MyMsg> for MyWorker {
496 async fn run(&mut self, msg: MyMsg) {
497 sleep(Duration::from_millis(1)).await;
498 println!("done {}", msg.0);
499 let _ = msg.1.send(()).await;
500 }
501 }
502
503 type MyWorkerPool = WorkerPoolBounded<MyMsg, MyWorker, MyWorkerPoolImpl>;
504
505 #[test]
506 fn bounded_workerpool_adjust_close_async() {
507 let min_workers = 1;
508 let max_workers = 4;
509 let worker_timeout = Duration::from_secs(1);
510 let rt = tokio::runtime::Builder::new_multi_thread()
511 .enable_all()
512 .worker_threads(2)
513 .build()
514 .unwrap();
515 let worker_pool =
516 MyWorkerPool::new(MyWorkerPoolImpl(), min_workers, max_workers, 1, worker_timeout);
517 let _worker_pool = worker_pool.clone();
518 rt.block_on(async move {
519 worker_pool.start();
520 let mut th_s = Vec::new();
521 for i in 0..5 {
522 let _pool = worker_pool.clone();
523 th_s.push(tokio::task::spawn(async move {
524 let (done_tx, done_rx) = mpsc::bounded_async(10);
525 for j in 0..2 {
526 let m = i * 10 + j;
527 println!("submit {} in {}/{}", m, j, i);
528 _pool.submit(MyMsg(m, done_tx.clone())).await;
529 }
530 for _j in 0..2 {
531 let _ = done_rx.recv().await;
533 }
534 }));
535 }
536 for th in th_s {
537 let _ = th.await;
538 }
539 let workers = worker_pool.get_worker_count();
540 println!("cur workers {} might reach max {}", workers, max_workers);
541 sleep(worker_timeout * 2).await;
544 let workers = worker_pool.get_worker_count();
545 println!("cur workers: {}, extra should exit due to timeout", workers);
546 assert_eq!(workers, min_workers);
547
548 let (done_tx, done_rx) = mpsc::bounded_async(1);
549 for j in 0..10 {
550 worker_pool.submit(MyMsg(80 + j, done_tx.clone())).await;
551 let _ = done_rx.recv().await;
552 }
553 println!("closing");
554 _worker_pool.close();
555 sleep(Duration::from_secs(2)).await;
556 assert_eq!(_worker_pool.get_worker_count(), 0);
557 });
558 }
559
560 #[test]
561 fn bounded_workerpool_adjust_close() {
562 let min_workers = 1;
563 let max_workers = 4;
564 let worker_timeout = Duration::from_secs(1);
565 let rt = tokio::runtime::Builder::new_multi_thread()
566 .enable_all()
567 .worker_threads(2)
568 .build()
569 .unwrap();
570 let worker_pool =
571 MyWorkerPool::new(MyWorkerPoolImpl(), min_workers, max_workers, 1, worker_timeout);
572 let _worker_pool = worker_pool.clone();
573 rt.block_on(async move {
574 worker_pool.start();
575 let mut th_s = Vec::new();
576 for i in 0..5 {
577 let _pool = worker_pool.clone();
578 th_s.push(tokio::task::spawn(async move {
579 let (done_tx, done_rx) = mpsc::bounded_async(10);
580 for j in 0..2 {
581 let m = i * 10 + j;
582 println!("submit {} in {}/{}", m, j, i);
583 _pool.submit(MyMsg(m, done_tx.clone())).await;
584 }
585 for _j in 0..2 {
586 let _ = done_rx.recv().await;
588 }
589 }));
590 }
591 for th in th_s {
592 let _ = th.await;
593 }
594 let workers = worker_pool.get_worker_count();
595 println!("cur workers {} might reach max {}", workers, max_workers);
596 sleep(worker_timeout * 2).await;
599 let workers = worker_pool.get_worker_count();
600 println!("cur workers: {}, extra should exit due to timeout", workers);
601 assert_eq!(workers, min_workers);
602
603 let (done_tx, done_rx) = mpsc::bounded_async(1);
604 for j in 0..10 {
605 worker_pool.submit(MyMsg(80 + j, done_tx.clone())).await;
606 let _ = done_rx.recv().await;
607 }
608 });
609 println!("closing");
610 _worker_pool.close();
611 assert_eq!(_worker_pool.get_worker_count(), 0);
612 }
613
614 #[allow(dead_code)]
615 struct MyBlockingWorkerPoolImpl();
616
617 struct MyBlockingWorker();
618
619 impl WorkerPoolImpl<MyMsg, MyBlockingWorker> for MyBlockingWorkerPoolImpl {
620 fn spawn(&self) -> MyBlockingWorker {
621 MyBlockingWorker()
622 }
623 }
624
625 #[async_trait]
626 impl Worker<MyMsg> for MyBlockingWorker {
627 async fn run(&mut self, msg: MyMsg) {
628 std::thread::sleep(Duration::from_millis(1));
629 println!("done {}", msg.0);
630 let _ = msg.1.send(()).await;
631 }
632 }
633
634 type MyBlockingWorkerPool =
635 WorkerPoolBounded<MyMsg, MyBlockingWorker, MyBlockingWorkerPoolImpl>;
636
637 #[test]
638 fn bounded_workerpool_adjust_real_thread() {
639 let min_workers = 1;
640 let max_workers = 4;
641 let worker_timeout = Duration::from_secs(1);
642 let rt = tokio::runtime::Builder::new_multi_thread()
643 .enable_all()
644 .worker_threads(2)
645 .build()
646 .unwrap();
647 let mut worker_pool = MyBlockingWorkerPool::new(
648 MyBlockingWorkerPoolImpl(),
649 min_workers,
650 max_workers,
651 1,
652 worker_timeout,
653 );
654 worker_pool.set_use_thread(true);
655 let _worker_pool = worker_pool.clone();
656 rt.block_on(async move {
657 worker_pool.start();
658 let mut th_s = Vec::new();
659 for i in 0..5 {
660 let _pool = worker_pool.clone();
661 th_s.push(tokio::task::spawn(async move {
662 let (done_tx, done_rx) = mpsc::bounded_async(10);
663 for j in 0..2 {
664 let m = i * 10 + j;
665 println!("submit {} in {}/{}", m, j, i);
666 _pool.submit(MyMsg(m, done_tx.clone())).await;
667 }
668 for _j in 0..2 {
669 let _ = done_rx.recv().await;
671 }
672 }));
673 }
674 for th in th_s {
675 let _ = th.await;
676 }
677 let workers = worker_pool.get_worker_count();
678 println!("cur workers {} might reach max {}", workers, max_workers);
679 sleep(worker_timeout * 2).await;
682 let workers = worker_pool.get_worker_count();
683 println!("cur workers: {}, extra should exit due to timeout", workers);
684 assert_eq!(workers, min_workers);
685
686 let (done_tx, done_rx) = mpsc::bounded_async(1);
687 for j in 0..10 {
688 worker_pool.submit(MyMsg(80 + j, done_tx.clone())).await;
689 let _ = done_rx.recv().await;
690 }
691 });
692 println!("closing");
693 _worker_pool.close();
694 assert_eq!(_worker_pool.get_worker_count(), 0);
695 }
696}