1use std::{
2 sync::{
3 Arc,
4 atomic::{AtomicBool, AtomicUsize, Ordering},
5 },
6 time::Duration,
7};
8
9use crossfire::*;
10use tokio::{runtime::Runtime, time::timeout};
11
12use super::*;
13
14#[allow(unused_must_use)]
20pub struct WorkerPool<M: Send + Sized + Unpin + 'static, W: Worker<M>, S: WorkerPoolImpl<M, W>>(
21 Arc<WorkerPoolInner<M, W, S>>,
22);
23
24struct WorkerPoolInner<M, W, S>
25where
26 M: Send + Sized + Unpin + 'static,
27 W: Worker<M>,
28 S: WorkerPoolImpl<M, W>,
29{
30 worker_count: AtomicUsize,
31 sender: MTx<Option<M>>,
32 recv: MAsyncRx<Option<M>>,
33 min_workers: usize,
34 max_workers: usize,
35 worker_timeout: Duration,
36 inner: S,
37 water: AtomicUsize,
38 phantom: std::marker::PhantomData<W>, closing: AtomicBool,
40 notify_sender: MTx<Option<()>>,
41 auto: bool,
42 buffer_size: usize,
43}
44
45impl<M, W, S> Clone for WorkerPool<M, W, S>
46where
47 M: Send + Sized + Unpin + 'static,
48 W: Worker<M>,
49 S: WorkerPoolImpl<M, W>,
50{
51 #[inline]
52 fn clone(&self) -> Self {
53 Self(self.0.clone())
54 }
55}
56
57impl<M, W, S> WorkerPool<M, W, S>
58where
59 M: Send + Sized + Unpin + 'static,
60 W: Worker<M>,
61 S: WorkerPoolImpl<M, W>,
62{
63 pub fn new(
64 inner: S, min_workers: usize, max_workers: usize, mut buffer_size: usize,
65 worker_timeout: Duration, rt: &Runtime,
66 ) -> Self {
67 if buffer_size > max_workers * 2 {
68 buffer_size = max_workers * 2;
69 }
70 let (sender, recv) = mpmc::bounded_tx_blocking_rx_async(buffer_size);
71 let (noti_sender, noti_recv) = mpmc::bounded_tx_blocking_rx_async(1);
72 assert!(min_workers > 0);
73 assert!(max_workers >= min_workers);
74
75 let auto: bool = min_workers < max_workers;
76 if auto {
77 assert!(worker_timeout != ZERO_DUARTION);
78 }
79
80 let pool = Arc::new(WorkerPoolInner {
81 sender,
82 recv,
83 inner,
84 worker_count: AtomicUsize::new(0),
85 min_workers,
86 max_workers,
87 buffer_size,
88 worker_timeout,
89 phantom: Default::default(),
90 closing: AtomicBool::new(false),
91 water: AtomicUsize::new(0),
92 notify_sender: noti_sender,
93 auto,
94 });
95 let _pool = pool.clone();
96 rt.spawn(async move {
97 _pool.monitor(noti_recv).await;
98 });
99 Self(pool)
100 }
101
102 pub fn get_inner(&self) -> &S {
103 &self.0.inner
104 }
105
106 pub fn submit(&self, msg: M) -> Option<M> {
109 let _self = self.0.as_ref();
110 if _self.closing.load(Ordering::Acquire) {
111 return Some(msg);
112 }
113 if _self.auto {
114 let worker_count = _self.get_worker_count();
115 let water = _self.water.fetch_add(1, Ordering::SeqCst);
116 if worker_count < _self.max_workers {
117 if water > worker_count + 1 || water > _self.buffer_size {
118 let _ = _self.notify_sender.try_send(Some(()));
119 }
120 }
121 }
122 match _self.sender.send(Some(msg)) {
123 Ok(_) => return None,
124 Err(SendError(t)) => return t,
125 }
126 }
127
128 pub fn close(&self) {
129 let _self = self.0.as_ref();
130 if _self.closing.swap(true, Ordering::SeqCst) {
131 return;
132 }
133 loop {
134 let cur = self.get_worker_count();
135 if cur == 0 {
136 break;
137 }
138 debug!("worker pool closing: cur workers {}", cur);
139 for _ in 0..cur {
140 let _ = _self.sender.send(None);
141 }
142 std::thread::sleep(_self.worker_timeout);
143 }
145 let _ = _self.notify_sender.send(None);
147 let _ = _self.notify_sender.send(None);
148 }
149
150 pub fn get_worker_count(&self) -> usize {
151 self.0.get_worker_count()
152 }
153}
154
155impl<M, W, S> WorkerPoolInner<M, W, S>
156where
157 M: Send + Sized + Unpin + 'static,
158 W: Worker<M>,
159 S: WorkerPoolImpl<M, W>,
160{
161 async fn run_worker_simple(&self, mut worker: W) {
162 if let Err(_) = worker.init().await {
163 let _ = self.try_exit();
164 worker.on_exit();
165 return;
166 }
167
168 let recv = &self.recv;
169 'WORKER_LOOP: loop {
170 match recv.recv().await {
171 Ok(item) => {
172 if item.is_none() {
173 let _ = self.try_exit();
174 break 'WORKER_LOOP;
175 }
176 worker.run(item.unwrap()).await;
177 }
178 Err(_) => {
179 let _ = self.try_exit();
181 break 'WORKER_LOOP;
182 }
183 }
184 }
185 worker.on_exit();
186 }
187
188 async fn run_worker_adjust(&self, mut worker: W) {
189 if let Err(_) = worker.init().await {
190 let _ = self.try_exit();
191 worker.on_exit();
192 return;
193 }
194
195 let worker_timeout = self.worker_timeout;
196 let recv = &self.recv;
197 let mut is_idle = false;
198 'WORKER_LOOP: loop {
199 if is_idle {
200 match timeout(worker_timeout, recv.recv()).await {
201 Ok(res) => {
202 match res {
203 Ok(item) => {
204 if item.is_none() {
205 let _ = self.try_exit();
206 break 'WORKER_LOOP;
207 }
208 worker.run(item.unwrap()).await;
209 is_idle = false;
210 self.water.fetch_sub(1, Ordering::SeqCst);
211 }
212 Err(_) => {
213 let _ = self.try_exit();
215 worker.on_exit();
216 }
217 }
218 }
219 Err(_) => {
220 if self.try_exit() {
222 break 'WORKER_LOOP;
223 }
224 }
225 }
226 } else {
227 match recv.try_recv() {
228 Err(e) => {
229 if e.is_empty() {
230 is_idle = true;
231 } else {
232 let _ = self.try_exit();
233 break 'WORKER_LOOP;
234 }
235 }
236 Ok(Some(item)) => {
237 worker.run(item).await;
238 self.water.fetch_sub(1, Ordering::SeqCst);
239 is_idle = false;
240 }
241 Ok(None) => {
242 let _ = self.try_exit();
243 break 'WORKER_LOOP;
244 }
245 }
246 }
247 }
248 worker.on_exit();
249 }
250
251 #[inline(always)]
252 pub fn get_worker_count(&self) -> usize {
253 self.worker_count.load(Ordering::Acquire)
254 }
255
256 #[inline(always)]
257 fn spawn(self: Arc<Self>) {
258 self.worker_count.fetch_add(1, Ordering::SeqCst);
259 let worker = self.inner.spawn();
260 let _self = self.clone();
261 tokio::spawn(async move {
262 if _self.auto {
263 _self.run_worker_adjust(worker).await
264 } else {
265 _self.run_worker_simple(worker).await
266 }
267 });
268 }
269
270 #[inline(always)]
272 fn try_exit(&self) -> bool {
273 if self.closing.load(Ordering::Acquire) {
274 self.worker_count.fetch_sub(1, Ordering::SeqCst);
275 return true;
276 }
277 if self.get_worker_count() > self.min_workers {
278 if self.worker_count.fetch_sub(1, Ordering::SeqCst) <= self.min_workers {
279 self.worker_count.fetch_add(1, Ordering::SeqCst); } else {
281 return true; }
283 }
284 return false;
285 }
286
287 async fn monitor(self: Arc<Self>, noti_recv: MAsyncRx<Option<()>>) {
288 for _ in 0..self.min_workers {
289 self.clone().spawn();
290 }
291 loop {
292 if let Ok(Some(_)) = noti_recv.recv().await {
293 if self.auto {
294 let worker_count = self.get_worker_count();
295 if worker_count > self.max_workers {
296 continue;
297 }
298 let mut pending_msg = self.sender.len();
299 if pending_msg > worker_count {
300 pending_msg -= worker_count;
301 if pending_msg > self.max_workers - worker_count {
302 pending_msg = self.max_workers - worker_count;
303 }
304 for _ in 0..pending_msg {
305 self.clone().spawn();
306 }
307 }
308 } else {
309 continue;
310 }
311 } else {
312 return;
313 }
314 }
315 }
316}
317
318#[cfg(test)]
319mod tests {
320
321 use std::thread;
322
323 use crossbeam::channel::{Sender, bounded};
324 use tokio::time::{Duration, sleep};
325
326 use super::*;
327
328 #[allow(dead_code)]
329 struct MyWorkerPoolImpl();
330
331 struct MyWorker();
332
333 struct MyMsg(i64, Sender<()>);
334
335 impl WorkerPoolImpl<MyMsg, MyWorker> for MyWorkerPoolImpl {
336 fn spawn(&self) -> MyWorker {
337 MyWorker()
338 }
339 }
340
341 #[async_trait]
342 impl Worker<MyMsg> for MyWorker {
343 async fn run(&mut self, msg: MyMsg) {
344 sleep(Duration::from_millis(1)).await;
345 println!("done {}", msg.0);
346 let _ = msg.1.send(());
347 }
348 }
349
350 type MyWorkerPool = WorkerPool<MyMsg, MyWorker, MyWorkerPoolImpl>;
351
352 #[test]
353 fn blocking_workerpool_adjust() {
354 let min_workers = 1;
355 let max_workers = 4;
356 let worker_timeout = Duration::from_secs(1);
357 let rt = tokio::runtime::Builder::new_multi_thread()
358 .enable_all()
359 .worker_threads(2)
360 .build()
361 .unwrap();
362 let worker_pool = MyWorkerPool::new(
363 MyWorkerPoolImpl(),
364 min_workers,
365 max_workers,
366 10,
367 worker_timeout,
368 &rt,
369 );
370
371 let mut th_s = Vec::new();
372 for i in 0..8 {
373 let _pool = worker_pool.clone();
374 th_s.push(thread::spawn(move || {
375 let (done_tx, done_rx) = bounded(10);
376 for j in 0..10 {
377 _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
378 }
379 for _j in 0..10 {
380 let _ = done_rx.recv();
381 }
382 }));
383 }
384 for th in th_s {
385 let _ = th.join();
386 }
387 let workers = worker_pool.get_worker_count();
388 println!("cur workers {} should reach max", workers);
389 assert_eq!(workers, max_workers);
390
391 thread::sleep(worker_timeout * 2);
392 let workers = worker_pool.get_worker_count();
393 println!("cur workers: {}, extra should exit due to timeout", workers);
394 assert_eq!(workers, min_workers);
395
396 let (done_tx, done_rx) = bounded(2);
397 for j in 0..10 {
398 worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
399 println!("send {}", j);
400 let _ = done_rx.recv();
401 }
402 println!("closing");
403 worker_pool.close();
404 assert_eq!(worker_pool.get_worker_count(), 0);
405 assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
406 }
407
408 #[test]
409 fn blocking_workerpool_fixed() {
410 let min_workers = 4;
411 let max_workers = 4;
412 let worker_timeout = Duration::from_secs(1);
413 let rt = tokio::runtime::Builder::new_multi_thread()
414 .enable_all()
415 .worker_threads(2)
416 .build()
417 .unwrap();
418 let worker_pool = MyWorkerPool::new(
419 MyWorkerPoolImpl(),
420 min_workers,
421 max_workers,
422 10,
423 worker_timeout,
424 &rt,
425 );
426
427 let mut th_s = Vec::new();
428 for i in 0..8 {
429 let _pool = worker_pool.clone();
430 th_s.push(thread::spawn(move || {
431 let (done_tx, done_rx) = bounded(10);
432 for j in 0..10 {
433 _pool.submit(MyMsg(i * 10 + j, done_tx.clone()));
434 }
435 for _j in 0..10 {
436 let _ = done_rx.recv();
437 }
438 }));
439 }
440 for th in th_s {
441 let _ = th.join();
442 }
443 let workers = worker_pool.get_worker_count();
444 println!("cur workers {} should reach max", workers);
445 assert_eq!(workers, max_workers);
446
447 thread::sleep(worker_timeout * 2);
448 let workers = worker_pool.get_worker_count();
449 println!("cur workers {} should reach max", workers);
450 assert_eq!(workers, max_workers);
451
452 let (done_tx, done_rx) = bounded(2);
453 for j in 0..10 {
454 worker_pool.submit(MyMsg(80 + j, done_tx.clone()));
455 println!("send {}", j);
456 let _ = done_rx.recv();
457 }
458 println!("closing");
459 worker_pool.close();
460 assert_eq!(worker_pool.get_worker_count(), 0);
461 assert_eq!(worker_pool.0.water.load(Ordering::Acquire), 0)
462 }
463}