1use std::{
2 sync::{
3 Arc,
4 atomic::{AtomicBool, AtomicUsize, Ordering},
5 },
6 thread,
7 time::Duration,
8};
9
10use crossfire::*;
11use tokio::time::{sleep, timeout};
12
13use super::*;
14
15#[allow(unused_must_use)]
16pub struct WorkerPoolUnbounded<
17 M: Send + Sized + Unpin + 'static,
18 W: Worker<M>,
19 S: WorkerPoolImpl<M, W>,
20>(Arc<WorkerPoolUnboundedInner<M, W, S>>);
21
22struct WorkerPoolUnboundedInner<M, W, S>
23where
24 M: Send + Sized + Unpin + 'static,
25 W: Worker<M>,
26 S: WorkerPoolImpl<M, W>,
27{
28 worker_count: AtomicUsize,
29 sender: MTx<Option<M>>,
30 recv: MAsyncRx<Option<M>>,
31 min_workers: usize,
32 max_workers: usize,
33 worker_timeout: Duration,
34 inner: S,
35 phantom: std::marker::PhantomData<W>, closing: AtomicBool,
37 notify_sender: MTx<Option<()>>,
38 notify_recv: MAsyncRx<Option<()>>, water: AtomicUsize,
40 auto: bool,
41 real_thread: AtomicBool,
42}
43
44impl<M, W, S> Clone for WorkerPoolUnbounded<M, W, S>
45where
46 M: Send + Sized + Unpin + 'static,
47 W: Worker<M>,
48 S: WorkerPoolImpl<M, W>,
49{
50 #[inline]
51 fn clone(&self) -> Self {
52 Self(self.0.clone())
53 }
54}
55
56impl<M, W, S> WorkerPoolUnbounded<M, W, S>
57where
58 M: Send + Sized + Unpin + 'static,
59 W: Worker<M>,
60 S: WorkerPoolImpl<M, W>,
61{
62 pub fn new(inner: S, min_workers: usize, max_workers: usize, worker_timeout: Duration) -> Self {
63 assert!(min_workers > 0);
64 assert!(max_workers >= min_workers);
65
66 let auto: bool = min_workers < max_workers;
67 if auto {
68 assert!(worker_timeout != ZERO_DUARTION);
69 }
70 let (sender, recv) = mpmc::unbounded_async();
71 let (noti_sender, noti_recv) = mpmc::bounded_tx_blocking_rx_async(1);
72 let pool = Arc::new(WorkerPoolUnboundedInner {
73 sender,
74 recv,
75 inner,
76 worker_count: AtomicUsize::new(0),
77 min_workers,
78 max_workers,
79 worker_timeout,
80 phantom: Default::default(),
81 closing: AtomicBool::new(false),
82 notify_sender: noti_sender,
83 notify_recv: noti_recv,
84 water: AtomicUsize::new(0),
85 auto,
86 real_thread: AtomicBool::new(false),
87 });
88 Self(pool)
89 }
90
91 pub fn get_inner(&self) -> &S {
92 &self.0.inner
93 }
94
95 pub fn set_use_thread(&mut self, ok: bool) {
97 self.0.real_thread.store(ok, Ordering::Release);
98 }
99
100 pub async fn start(&self) {
101 let _self = self.0.as_ref();
102 for _ in 0.._self.min_workers {
103 self.0.clone().spawn();
104 }
105 if _self.auto {
106 let _pool = self.0.clone();
107 tokio::spawn(async move {
108 _pool.monitor().await;
109 });
110 }
111 }
112
113 pub async fn try_spawn(&self, num: usize) {
114 let _self = self.0.as_ref();
115 if !_self.auto {
116 return;
117 }
118 for _ in 0..num {
119 if _self.get_worker_count() >= _self.max_workers {
120 return;
121 }
122 self.0.clone().spawn();
123 }
124 }
125
126 pub async fn close(&self) {
127 let _self = self.0.as_ref();
128 if _self.closing.swap(true, Ordering::SeqCst) {
129 return;
130 }
131 loop {
132 let cur = self.get_worker_count();
133 if cur == 0 {
134 break;
135 }
136 debug!("worker pool closing: cur workers {}", cur);
137 for _ in 0..cur {
138 let _ = _self.sender.send(None);
139 }
140 sleep(Duration::from_secs(1)).await;
141 }
143 let _ = _self.notify_sender.try_send(None);
144 }
145
146 pub fn get_worker_count(&self) -> usize {
147 self.0.get_worker_count()
148 }
149}
150
151impl<M, W, S> WorkerPoolInf<M> for WorkerPoolUnbounded<M, W, S>
152where
153 M: Send + Sized + Unpin + 'static,
154 W: Worker<M>,
155 S: WorkerPoolImpl<M, W>,
156{
157 #[inline]
159 fn submit(&self, msg: M) -> Option<M> {
160 let _self = self.0.as_ref();
161 if _self.closing.load(Ordering::Acquire) {
162 return Some(msg);
163 }
164 if _self.auto {
165 let worker_count = _self.get_worker_count();
166 let water = _self.water.fetch_add(1, Ordering::SeqCst);
167 if worker_count < _self.max_workers && water > worker_count + 1 {
168 let _ = _self.notify_sender.try_send(Some(()));
169 }
170 }
171 match _self.sender.send(Some(msg)) {
172 Ok(_) => None,
173 Err(SendError(_msg)) => {
174 return Some(_msg.unwrap());
175 }
176 }
177 }
178}
179
180impl<M, W, S> WorkerPoolUnboundedInner<M, W, S>
181where
182 M: Send + Sized + Unpin + 'static,
183 W: Worker<M>,
184 S: WorkerPoolImpl<M, W>,
185{
186 async fn run_worker_simple(&self, mut worker: W) {
187 if let Err(_) = worker.init().await {
188 let _ = self.try_exit();
189 worker.on_exit();
190 return;
191 }
192
193 let recv = &self.recv;
194 'WORKER_LOOP: loop {
195 match recv.recv().await {
196 Ok(item) => {
197 if item.is_none() {
198 let _ = self.try_exit();
199 break 'WORKER_LOOP;
200 }
201 worker.run(item.unwrap()).await;
202 }
203 Err(_) => {
204 let _ = self.try_exit();
206 break 'WORKER_LOOP;
207 }
208 }
209 }
210 worker.on_exit();
211 trace!("worker pool {} workers", self.get_worker_count());
212 }
213
214 async fn run_worker_adjust(&self, mut worker: W) {
215 if let Err(_) = worker.init().await {
216 let _ = self.try_exit();
217 worker.on_exit();
218 return;
219 }
220
221 let worker_timeout = self.worker_timeout;
222 let recv = &self.recv;
223 let mut is_idle = false;
224 'WORKER_LOOP: loop {
225 if is_idle {
226 match timeout(worker_timeout, recv.recv()).await {
227 Ok(res) => {
228 match res {
229 Ok(item) => {
230 if item.is_none() {
231 let _ = self.try_exit();
232 break 'WORKER_LOOP;
233 }
234 worker.run(item.unwrap()).await;
235 is_idle = false;
236 self.water.fetch_sub(1, Ordering::SeqCst);
237 }
238 Err(_) => {
239 let _ = self.try_exit();
241 worker.on_exit();
242 }
243 }
244 }
245 Err(_) => {
246 if self.try_exit() {
248 break 'WORKER_LOOP;
249 }
250 }
251 }
252 } else {
253 match recv.try_recv() {
254 Err(e) => {
255 if e.is_empty() {
256 is_idle = true;
257 } else {
258 let _ = self.try_exit();
259 break 'WORKER_LOOP;
260 }
261 }
262 Ok(Some(item)) => {
263 worker.run(item).await;
264 self.water.fetch_sub(1, Ordering::SeqCst);
265 is_idle = false;
266 }
267 Ok(None) => {
268 let _ = self.try_exit();
269 break 'WORKER_LOOP;
270 }
271 }
272 }
273 }
274 worker.on_exit();
275 trace!("worker pool {} workers", self.get_worker_count());
276 }
277
278 #[inline(always)]
279 pub fn get_worker_count(&self) -> usize {
280 self.worker_count.load(Ordering::Acquire)
281 }
282 #[inline(always)]
283 fn spawn(self: Arc<Self>) {
284 let cur_count = self.worker_count.fetch_add(1, Ordering::SeqCst) + 1;
285 let worker = self.inner.spawn();
286 let _self = self.clone();
287 if self.real_thread.load(Ordering::Acquire) {
288 thread::spawn(move || {
289 let rt = tokio::runtime::Builder::new_current_thread()
290 .enable_all()
291 .build()
292 .expect("runtime");
293 rt.block_on(async move {
294 trace!("worker pool started worker {}", cur_count);
295 if _self.auto {
296 _self.run_worker_adjust(worker).await
297 } else {
298 _self.run_worker_simple(worker).await
299 }
300 });
301 });
302 } else {
303 tokio::spawn(async move {
304 trace!("worker pool started worker {}", cur_count);
305 if _self.auto {
306 _self.run_worker_adjust(worker).await
307 } else {
308 _self.run_worker_simple(worker).await
309 }
310 });
311 }
312 }
313
314 #[inline(always)]
316 fn try_exit(&self) -> bool {
317 if self.closing.load(Ordering::Acquire) {
318 self.worker_count.fetch_sub(1, Ordering::SeqCst);
319 return true;
320 }
321 if self.get_worker_count() > self.min_workers {
322 if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
323 self.worker_count.fetch_add(1, Ordering::SeqCst); } else {
325 return true; }
327 }
328 return false;
329 }
330
331 async fn monitor(self: Arc<Self>) {
332 let _self = self.as_ref();
333 loop {
334 match timeout(Duration::from_secs(1), _self.notify_recv.recv()).await {
335 Err(_) => {
336 if _self.closing.load(Ordering::Acquire) {
337 return;
338 }
339 continue;
340 }
341 Ok(Ok(Some(_))) => {
342 if _self.closing.load(Ordering::Acquire) {
343 return;
344 }
345 let worker_count = _self.get_worker_count();
346 if worker_count > _self.max_workers {
347 continue;
348 }
349 let mut pending_msg = _self.sender.len();
350 if pending_msg > worker_count {
351 pending_msg -= worker_count;
352 if pending_msg > _self.max_workers - worker_count {
353 pending_msg = _self.max_workers - worker_count;
354 }
355 for _ in 0..pending_msg {
356 self.clone().spawn();
357 }
358 }
359 }
360 _ => return,
361 }
362 }
363 }
364}
365
366#[cfg(test)]
367mod tests {
368
369 use std::sync::atomic::{AtomicUsize, Ordering};
370 use std::time::{SystemTime, UNIX_EPOCH};
371
372 use crossfire::*;
373 use tokio::time::{Duration, sleep};
374
375 use super::*;
376 use atomic_waitgroup::WaitGroup;
377
378 #[allow(dead_code)]
379 struct MyWorkerPoolImpl();
380
381 struct MyWorker();
382
383 struct MyMsg(i64, MAsyncTx<()>);
384
385 impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
386 fn spawn(&self) -> MyWorker {
387 MyWorker()
388 }
389 }
390
391 #[async_trait]
392 impl Worker<MyMsg> for MyWorker {
393 async fn init(&mut self) -> Result<(), ()> {
394 println!("init done");
395 Ok(())
396 }
397
398 async fn run(&mut self, msg: MyMsg) {
399 sleep(Duration::from_millis(1)).await;
400 println!("done {}", msg.0);
401 let _ = msg.1.send(()).await;
402 }
403 }
404
405 type MyWorkerPool = WorkerPoolUnbounded<MyMsg, MyWorker, MyWorkerPoolImpl>;
406
407 #[test]
408 fn unbounded_workerpool_adjust() {
409 let _ = captains_log::recipe::stderr_test_logger(log::Level::Debug).build();
410 let min_workers = 1;
411 let max_workers = 4;
412 let worker_timeout = Duration::from_secs(1);
413 let rt = tokio::runtime::Builder::new_multi_thread()
414 .enable_all()
415 .worker_threads(2)
416 .build()
417 .unwrap();
418 let worker_pool =
419 MyWorkerPool::new(MyWorkerPoolImpl(), min_workers, max_workers, worker_timeout);
420 rt.block_on(async move {
421 worker_pool.start().await;
422 let mut th_s = Vec::new();
423 for i in 0..5 {
424 let _pool = worker_pool.clone();
425 th_s.push(tokio::task::spawn(async move {
426 let (done_tx, done_rx) = mpsc::bounded_async(10);
427 for j in 0..2 {
428 _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
429 }
430 for _j in 0..2 {
431 let _ = done_rx.recv().await;
433 }
434 }));
435 }
436 for th in th_s {
437 let _ = th.await;
438 }
439 let workers = worker_pool.get_worker_count();
440 println!("cur workers {} should reach max", workers);
441 assert_eq!(workers, max_workers);
442
443 worker_pool.try_spawn(5).await;
444 let workers = worker_pool.get_worker_count();
445 println!("cur workers {} should reach max", workers);
446 assert_eq!(workers, max_workers);
447
448 sleep(worker_timeout * 2).await;
449 let workers = worker_pool.get_worker_count();
450 println!("cur workers: {}, extra should exit due to timeout", workers);
451 assert_eq!(workers, min_workers);
452
453 let (done_tx, done_rx) = mpsc::bounded_async(1);
454 for j in 0..10 {
455 worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
456 let _ = done_rx.recv().await;
457 }
458 println!("closing");
459 worker_pool.close().await;
460 assert_eq!(worker_pool.get_worker_count(), 0);
461 assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
462 });
463 }
464
465 #[allow(dead_code)]
466 struct TestWorkerPoolImpl {
467 id: AtomicUsize,
468 }
469
470 struct TestWorker {
471 id: usize,
472 }
473
474 #[allow(dead_code)]
475 struct TestMsg(usize, WaitGroup);
476
477 impl WorkerPoolImpl<TestMsg, TestWorker> for TestWorkerPoolImpl {
478 fn spawn(&self) -> TestWorker {
479 let _id = self.id.fetch_add(1, Ordering::SeqCst);
480 TestWorker { id: _id }
481 }
482 }
483
484 #[async_trait]
485 impl Worker<TestMsg> for TestWorker {
486 async fn init(&mut self) -> Result<(), ()> {
487 log::info!("worker {} init done", self.id);
488 Ok(())
489 }
490
491 async fn run(&mut self, msg: TestMsg) {
492 let run_time = (SystemTime::now().duration_since(UNIX_EPOCH).ok().unwrap().as_millis()
493 % 10) as u64;
494 sleep(Duration::from_millis(run_time)).await;
495 msg.1.done();
496 }
497 }
498
499 type TestWorkerPool = WorkerPoolUnbounded<TestMsg, TestWorker, TestWorkerPoolImpl>;
500
501 #[test]
502 fn unbounded_workerpool_run() {
503 let _ = captains_log::recipe::stderr_test_logger(log::Level::Debug).build();
504
505 log::info!("unbounded_workerpool test start");
506 let min_workers = 8;
507 let max_workers = 128;
508 let worker_timeout = Duration::from_secs(5);
509 let rt = tokio::runtime::Builder::new_multi_thread()
510 .enable_all()
511 .worker_threads(4)
512 .build()
513 .unwrap();
514 let worker_pool = TestWorkerPool::new(
515 TestWorkerPoolImpl { id: AtomicUsize::new(0) },
516 min_workers,
517 max_workers,
518 worker_timeout,
519 );
520 rt.block_on(async move {
521 worker_pool.start().await;
522 let total_threads = 10;
523 let batch_msgs: usize = 10000;
524 let wg = WaitGroup::new();
525 wg.add(batch_msgs * total_threads);
526 for thread in 0..total_threads {
527 let _wg = wg.clone();
528 let _pool = worker_pool.clone();
529 tokio::spawn(async move {
530 log::info!("thread:{} run start", thread);
531 let batch_msg_start = thread * batch_msgs;
532 let mut submit_steps: u64 =
533 (SystemTime::now().duration_since(UNIX_EPOCH).ok().unwrap().as_millis()
534 % 100) as u64;
535 let mut current_submit_step = 0;
536 for i in batch_msg_start..(batch_msg_start + batch_msgs) {
537 let msg = TestMsg(i, _wg.clone());
538 if let Some(_msg) = _pool.submit(msg) {
539 _msg.1.done();
540 }
541 current_submit_step += 1;
542 if current_submit_step >= submit_steps {
543 sleep(Duration::from_millis(submit_steps % 100)).await;
544 current_submit_step = 0;
545 submit_steps = (SystemTime::now()
546 .duration_since(UNIX_EPOCH)
547 .ok()
548 .unwrap()
549 .as_millis()
550 % 100) as u64;
551 }
552 }
553 log::info!("thread:{} run over", thread);
554 });
555 }
556 wg.wait().await;
557 });
558 log::info!("unbounded_workerpool test over");
559 }
560}