1use crate::{pool_cleared_error, pool_clearing_error, pool_invalid_config_error, PoolResult};
2use notify_future::Notify;
3use std::collections::HashMap;
4use std::hash::Hash;
5use std::ops::{Deref, DerefMut};
6use std::sync::{Arc, Mutex};
7
8pub trait WorkerClassification: Send + 'static + Clone + Hash + Eq + PartialEq {}
9
10impl<T: Send + 'static + Clone + Hash + Eq + PartialEq> WorkerClassification for T {}
11
12#[async_trait::async_trait]
13pub trait ClassifiedWorker<C: WorkerClassification>: Send + 'static {
14 fn is_work(&self) -> bool;
15 fn is_valid(&self, c: C) -> bool;
18 fn classification(&self) -> C;
20}
21
22pub struct ClassifiedWorkerGuard<
23 C: WorkerClassification,
24 W: ClassifiedWorker<C>,
25 F: ClassifiedWorkerFactory<C, W>,
26> {
27 pool_ref: ClassifiedWorkerPoolRef<C, W, F>,
28 worker: Option<W>,
29}
30
31impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>>
32 ClassifiedWorkerGuard<C, W, F>
33{
34 fn new(worker: W, pool_ref: ClassifiedWorkerPoolRef<C, W, F>) -> Self {
35 ClassifiedWorkerGuard {
36 pool_ref,
37 worker: Some(worker),
38 }
39 }
40}
41
42impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Deref
43 for ClassifiedWorkerGuard<C, W, F>
44{
45 type Target = W;
46
47 fn deref(&self) -> &Self::Target {
48 self.worker.as_ref().unwrap()
49 }
50}
51
52impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> DerefMut
53 for ClassifiedWorkerGuard<C, W, F>
54{
55 fn deref_mut(&mut self) -> &mut Self::Target {
56 self.worker.as_mut().unwrap()
57 }
58}
59
60impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Drop
61 for ClassifiedWorkerGuard<C, W, F>
62{
63 fn drop(&mut self) {
64 if let Some(worker) = self.worker.take() {
65 self.pool_ref.release(worker);
66 }
67 }
68}
69
70#[async_trait::async_trait]
71pub trait ClassifiedWorkerFactory<C: WorkerClassification, W: ClassifiedWorker<C>>:
72 Send + Sync + 'static
73{
74 async fn create(&self, c: Option<C>) -> PoolResult<W>;
75}
76
77struct WaitingItem<
78 C: WorkerClassification,
79 W: ClassifiedWorker<C>,
80 F: ClassifiedWorkerFactory<C, W>,
81> {
82 future: Notify<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
83 condition: Option<C>,
84}
85struct WorkerPoolState<
86 C: WorkerClassification,
87 W: ClassifiedWorker<C>,
88 F: ClassifiedWorkerFactory<C, W>,
89> {
90 current_count: u16,
91 classified_count_map: HashMap<C, u16>,
92 worker_list: Vec<W>,
93 waiting_list: Vec<WaitingItem<C, W, F>>,
94 clearing: bool,
95 clear_waiting_list: Vec<Notify<()>>,
96}
97
98impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>>
99 WorkerPoolState<C, W, F>
100{
101 fn inc_classified_count(&mut self, c: C) {
102 let count = self.classified_count_map.entry(c).or_insert(0);
103 *count += 1;
104 }
105
106 fn dec_classified_count(&mut self, c: C) {
107 let mut should_remove = false;
108 if let Some(count) = self.classified_count_map.get_mut(&c) {
109 debug_assert!(*count > 0);
110 *count -= 1;
111 should_remove = *count == 0;
112 }
113 if should_remove {
114 self.classified_count_map.remove(&c);
115 }
116 }
117
118 fn take_clear_waiters_if_done(&mut self) -> Vec<Notify<()>> {
119 if self.clearing && self.current_count == 0 {
120 self.clearing = false;
121 self.clear_waiting_list.drain(..).collect()
122 } else {
123 Vec::new()
124 }
125 }
126
127 fn find_matching_waiter_index_for_worker(&self, worker: &W) -> Option<usize> {
128 self.waiting_list.iter().position(|waiting| {
129 waiting
130 .condition
131 .as_ref()
132 .map(|condition| worker.is_valid(condition.clone()))
133 .unwrap_or(true)
134 })
135 }
136}
137
138pub struct ClassifiedWorkerPool<
139 C: WorkerClassification,
140 W: ClassifiedWorker<C>,
141 F: ClassifiedWorkerFactory<C, W>,
142> {
143 factory: Arc<F>,
144 max_count: u16,
145 state: Mutex<WorkerPoolState<C, W, F>>,
146}
147pub type ClassifiedWorkerPoolRef<C, W, F> = Arc<ClassifiedWorkerPool<C, W, F>>;
148
149impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>>
150 ClassifiedWorkerPool<C, W, F>
151{
152 fn validate_created_worker(requested_classification: Option<&C>, worker: &W) -> PoolResult<()> {
153 let worker_classification = worker.classification();
154 if !worker.is_valid(worker_classification.clone()) {
155 return Err(pool_invalid_config_error(
156 "worker primary classification is not valid for itself",
157 ));
158 }
159 if let Some(classification) = requested_classification {
160 if worker_classification != classification.clone() {
161 return Err(pool_invalid_config_error(
162 "factory returned worker with mismatched classification",
163 ));
164 }
165 }
166 Ok(())
167 }
168
169 pub fn new(max_count: u16, factory: F) -> ClassifiedWorkerPoolRef<C, W, F> {
170 Arc::new(ClassifiedWorkerPool {
171 factory: Arc::new(factory),
172 max_count,
173 state: Mutex::new(WorkerPoolState {
174 current_count: 0,
175 classified_count_map: HashMap::new(),
176 worker_list: Vec::with_capacity(max_count as usize),
177 waiting_list: Vec::new(),
178 clearing: false,
179 clear_waiting_list: Vec::new(),
180 }),
181 })
182 }
183
184 pub async fn get_worker(
185 self: &ClassifiedWorkerPoolRef<C, W, F>,
186 ) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
187 if self.max_count == 0 {
188 return Err(pool_invalid_config_error("pool max_count is zero"));
189 }
190
191 let wait = {
192 let mut state = self.state.lock().unwrap();
193 if state.clearing {
194 return Err(pool_clearing_error());
195 }
196
197 while state.worker_list.len() > 0 {
198 let worker = state.worker_list.pop().unwrap();
199 if !worker.is_work() {
200 state.current_count -= 1;
201 state.dec_classified_count(worker.classification());
202 continue;
203 }
204 return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
205 }
206
207 if state.current_count < self.max_count {
208 state.current_count += 1;
209 None
210 } else {
211 let (notify, waiter) = Notify::new();
212 state.waiting_list.push(WaitingItem {
213 future: notify,
214 condition: None,
215 });
216 Some(waiter)
217 }
218 };
219
220 if let Some(wait) = wait {
221 wait.await
222 } else {
223 let worker = match self.factory.create(None).await {
224 Ok(worker) => {
225 if let Err(err) = Self::validate_created_worker(None, &worker) {
226 let mut state = self.state.lock().unwrap();
227 state.current_count -= 1;
228 let clear_waiters = state.take_clear_waiters_if_done();
229 drop(state);
230 for waiter in clear_waiters {
231 waiter.notify(());
232 }
233 return Err(err);
234 }
235 worker
236 }
237 Err(err) => {
238 let mut state = self.state.lock().unwrap();
239 state.current_count -= 1;
240 let clear_waiters = state.take_clear_waiters_if_done();
241 drop(state);
242 for waiter in clear_waiters {
243 waiter.notify(());
244 }
245 return Err(err);
246 }
247 };
248 let (clearing, clear_waiters) = {
249 let mut state = self.state.lock().unwrap();
250 if state.clearing {
251 state.current_count -= 1;
252 (true, state.take_clear_waiters_if_done())
253 } else {
254 state.inc_classified_count(worker.classification());
255 (false, Vec::new())
256 }
257 };
258 for waiter in clear_waiters {
259 waiter.notify(());
260 }
261 if clearing {
262 return Err(pool_cleared_error());
263 }
264 Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
265 }
266 }
267
268 pub async fn get_classified_worker(
269 self: &ClassifiedWorkerPoolRef<C, W, F>,
270 classification: C,
271 ) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
272 if self.max_count == 0 {
273 return Err(pool_invalid_config_error("pool max_count is zero"));
274 }
275
276 let wait = {
277 let mut state = self.state.lock().unwrap();
278 if state.clearing {
279 return Err(pool_clearing_error());
280 }
281
282 let old_count = state.worker_list.len() as u16;
283 let unwork_classification = state
284 .worker_list
285 .iter()
286 .filter(|worker| !worker.is_work())
287 .map(|worker| worker.classification())
288 .collect::<Vec<C>>();
289 for classification in unwork_classification.iter() {
290 state.dec_classified_count(classification.clone());
291 }
292 state.worker_list.retain(|worker| worker.is_work());
293 state.current_count -= old_count - state.worker_list.len() as u16;
294 for (index, worker) in state.worker_list.iter().enumerate() {
295 if worker.is_valid(classification.clone()) {
296 let worker = state.worker_list.remove(index);
297 return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
298 }
299 }
300
301 if state.current_count < self.max_count {
302 state.current_count += 1;
303 None
304 } else {
305 let (notify, waiter) = Notify::new();
306 state.waiting_list.push(WaitingItem {
307 future: notify,
308 condition: Some(classification.clone()),
309 });
310 Some(waiter)
311 }
312 };
313
314 if let Some(wait) = wait {
315 wait.await
316 } else {
317 let worker = match self.factory.create(Some(classification.clone())).await {
318 Ok(worker) => {
319 if let Err(err) = Self::validate_created_worker(Some(&classification), &worker)
320 {
321 let mut state = self.state.lock().unwrap();
322 state.current_count -= 1;
323 let clear_waiters = state.take_clear_waiters_if_done();
324 drop(state);
325 for waiter in clear_waiters {
326 waiter.notify(());
327 }
328 return Err(err);
329 }
330 worker
331 }
332 Err(err) => {
333 let mut state = self.state.lock().unwrap();
334 state.current_count -= 1;
335 let clear_waiters = state.take_clear_waiters_if_done();
336 drop(state);
337 for waiter in clear_waiters {
338 waiter.notify(());
339 }
340 return Err(err);
341 }
342 };
343 let (clearing, clear_waiters) = {
344 let mut state = self.state.lock().unwrap();
345 if state.clearing {
346 state.current_count -= 1;
347 (true, state.take_clear_waiters_if_done())
348 } else {
349 state.inc_classified_count(worker.classification());
350 (false, Vec::new())
351 }
352 };
353 for waiter in clear_waiters {
354 waiter.notify(());
355 }
356 if clearing {
357 return Err(pool_cleared_error());
358 }
359 Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
360 }
361 }
362
363 pub async fn clear_all_worker(&self) {
364 let (waiter, waiting_list, clear_waiters) = {
365 let mut state = self.state.lock().unwrap();
366 if !state.clearing {
367 state.clearing = true;
368 let idle_classifications = state
369 .worker_list
370 .iter()
371 .map(|worker| worker.classification())
372 .collect::<Vec<_>>();
373 let cur_worker_count = idle_classifications.len();
374 state.worker_list.clear();
375 state.current_count -= cur_worker_count as u16;
376 for classification in idle_classifications {
377 state.dec_classified_count(classification);
378 }
379 }
380
381 let waiting_list = state.waiting_list.drain(..).collect::<Vec<_>>();
382 if state.current_count == 0 {
383 let clear_waiters = state.take_clear_waiters_if_done();
384 (None, waiting_list, clear_waiters)
385 } else {
386 let (notify, waiter) = Notify::new();
387 state.clear_waiting_list.push(notify);
388 (Some(waiter), waiting_list, Vec::new())
389 }
390 };
391 for waiting in waiting_list {
392 waiting.future.notify(Err(pool_cleared_error()));
393 }
394 for waiter in clear_waiters {
395 waiter.notify(());
396 }
397 if let Some(waiter) = waiter {
398 waiter.await;
399 }
400 }
401
402 fn release(self: &ClassifiedWorkerPoolRef<C, W, F>, work: W) {
403 enum ReleaseAction<
404 C: WorkerClassification,
405 W: ClassifiedWorker<C>,
406 F: ClassifiedWorkerFactory<C, W>,
407 > {
408 None,
409 Notify(
410 Notify<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
411 ClassifiedWorkerGuard<C, W, F>,
412 ),
413 Replace(
414 Notify<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
415 Option<C>,
416 ),
417 }
418
419 let mut clear_waiters = Vec::new();
420 let action = {
421 let mut state = self.state.lock().unwrap();
422 if state.clearing {
423 state.current_count -= 1;
424 let classification = work.classification();
425 state.dec_classified_count(classification);
426 clear_waiters = state.take_clear_waiters_if_done();
427 ReleaseAction::None
428 } else if work.is_work() {
429 if let Some(index) = state.find_matching_waiter_index_for_worker(&work) {
430 let waiting_item = state.waiting_list.remove(index);
431 ReleaseAction::Notify(
432 waiting_item.future,
433 ClassifiedWorkerGuard::new(work, self.clone()),
434 )
435 } else {
436 state.worker_list.push(work);
437 ReleaseAction::None
438 }
439 } else {
440 let classification = work.classification();
441 state.dec_classified_count(classification.clone());
442 if let Some(index) = state.find_matching_waiter_index_for_worker(&work) {
443 let waiting_item = state.waiting_list.remove(index);
444 let request_classification =
445 waiting_item.condition.clone().or(Some(classification));
446 ReleaseAction::Replace(waiting_item.future, request_classification)
447 } else {
448 state.current_count -= 1;
449 clear_waiters = state.take_clear_waiters_if_done();
450 ReleaseAction::None
451 }
452 }
453 };
454
455 for waiter in clear_waiters {
456 waiter.notify(());
457 }
458
459 match action {
460 ReleaseAction::None => {}
461 ReleaseAction::Notify(waiting, worker) => {
462 waiting.notify(Ok(worker));
463 }
464 ReleaseAction::Replace(waiting, request_classification) => {
465 let factory = self.factory.clone();
466 let this = self.clone();
467 tokio::spawn(async move {
468 let result = match factory.create(request_classification.clone()).await {
469 Ok(worker) => {
470 if let Err(err) = Self::validate_created_worker(
471 request_classification.as_ref(),
472 &worker,
473 ) {
474 let mut state = this.state.lock().unwrap();
475 state.current_count -= 1;
476 let clear_waiters = state.take_clear_waiters_if_done();
477 drop(state);
478 for waiter in clear_waiters {
479 waiter.notify(());
480 }
481 waiting.notify(Err(err));
482 return;
483 }
484 let mut state = this.state.lock().unwrap();
485 if state.clearing {
486 state.current_count -= 1;
487 let clear_waiters = state.take_clear_waiters_if_done();
488 drop(state);
489 for waiter in clear_waiters {
490 waiter.notify(());
491 }
492 Err(pool_cleared_error())
493 } else {
494 state.inc_classified_count(worker.classification());
495 drop(state);
496 Ok(ClassifiedWorkerGuard::new(worker, this))
497 }
498 }
499 Err(err) => {
500 let mut state = this.state.lock().unwrap();
501 state.current_count -= 1;
502 let clear_waiters = state.take_clear_waiters_if_done();
503 drop(state);
504 for waiter in clear_waiters {
505 waiter.notify(());
506 }
507 Err(err)
508 }
509 };
510 waiting.notify(result);
511 });
512 }
513 }
514 }
515}
516
517#[tokio::test]
518async fn test_pool() {
519 struct TestWorker {
520 work: bool,
521 classification: TestWorkerClassification,
522 }
523
524 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
525 enum TestWorkerClassification {
526 A,
527 B,
528 }
529 #[async_trait::async_trait]
530 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
531 fn is_work(&self) -> bool {
532 self.work
533 }
534
535 fn is_valid(&self, c: TestWorkerClassification) -> bool {
536 self.classification == c
537 }
538
539 fn classification(&self) -> TestWorkerClassification {
540 self.classification.clone()
541 }
542 }
543
544 struct TestWorkerFactory;
545
546 #[async_trait::async_trait]
547 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
548 async fn create(
549 &self,
550 classification: Option<TestWorkerClassification>,
551 ) -> PoolResult<TestWorker> {
552 if let Some(classification) = classification {
553 Ok(TestWorker {
554 work: true,
555 classification,
556 })
557 } else {
558 Ok(TestWorker {
559 work: true,
560 classification: TestWorkerClassification::A,
561 })
562 }
563 }
564 }
565
566 let pool = ClassifiedWorkerPool::new(3, TestWorkerFactory);
567 let pool_ref = pool.clone();
568 tokio::spawn(async move {
569 let _worker = pool_ref.get_worker().await.unwrap();
570 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
571 });
572 let pool_ref = pool.clone();
573 tokio::spawn(async move {
574 let _worker = pool_ref.get_worker().await.unwrap();
575 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
576 });
577
578 let pool_ref = pool.clone();
579 tokio::spawn(async move {
580 let _worker = pool_ref
581 .get_classified_worker(TestWorkerClassification::B)
582 .await
583 .unwrap();
584 tokio::time::sleep(std::time::Duration::from_secs(6)).await;
585 });
586
587 let pool_ref = pool.clone();
588 tokio::spawn(async move {
589 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
590
591 let start = std::time::Instant::now();
592 let _worker3 = pool_ref
593 .get_classified_worker(TestWorkerClassification::B)
594 .await
595 .unwrap();
596 let end = std::time::Instant::now();
597 let duration = end.duration_since(start);
598 println!("classified duration {}", duration.as_millis());
599 assert!(duration.as_millis() > 2000);
600 });
601
602 let pool_ref = pool.clone();
603 tokio::spawn(async move {
604 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
605
606 let start = std::time::Instant::now();
607 let _worker3 = pool_ref.get_worker().await.unwrap();
608 let end = std::time::Instant::now();
609 let duration = end.duration_since(start);
610 println!("classified duration2 {}", duration.as_millis());
611 assert!(duration.as_millis() > 2000);
612 });
613
614 tokio::time::sleep(std::time::Duration::from_secs(15)).await;
615
616 let pool_ref = pool.clone();
617 tokio::spawn(async move {
618 let _worker = pool_ref.get_worker().await;
619 let _worker1 = pool_ref.get_worker().await;
620 let _worker2 = pool_ref.get_worker().await;
621 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
622 });
623
624 let pool_ref = pool.clone();
625 tokio::spawn(async move {
626 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
627 let worker = pool_ref.get_worker().await;
628 assert!(worker.is_err());
629 });
630
631 let pool_ref = pool.clone();
632 tokio::spawn(async move {
633 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
634 let worker = pool_ref
635 .get_classified_worker(TestWorkerClassification::B)
636 .await;
637 assert!(worker.is_err());
638 });
639
640 let pool_ref = pool.clone();
641 tokio::spawn(async move {
642 let start = std::time::Instant::now();
643 pool_ref.clear_all_worker().await;
644 let end = std::time::Instant::now();
645 let duration = end.duration_since(start);
646 println!("classified duration3 {}", duration.as_millis());
647 assert!(duration.as_millis() > 4000);
648 });
649
650 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
651}
652
653#[tokio::test]
654async fn test_clear_all_worker_waits_for_inflight_create() {
655 use std::sync::atomic::{AtomicUsize, Ordering};
656 use std::sync::Arc;
657
658 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
659 enum TestWorkerClassification {
660 A,
661 }
662
663 struct TestWorker {
664 classification: TestWorkerClassification,
665 }
666
667 #[async_trait::async_trait]
668 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
669 fn is_work(&self) -> bool {
670 true
671 }
672
673 fn is_valid(&self, c: TestWorkerClassification) -> bool {
674 self.classification == c
675 }
676
677 fn classification(&self) -> TestWorkerClassification {
678 self.classification.clone()
679 }
680 }
681
682 struct TestWorkerFactory {
683 create_count: Arc<AtomicUsize>,
684 }
685
686 #[async_trait::async_trait]
687 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
688 async fn create(
689 &self,
690 classification: Option<TestWorkerClassification>,
691 ) -> PoolResult<TestWorker> {
692 self.create_count.fetch_add(1, Ordering::SeqCst);
693 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
694 Ok(TestWorker {
695 classification: classification.unwrap_or(TestWorkerClassification::A),
696 })
697 }
698 }
699
700 let create_count = Arc::new(AtomicUsize::new(0));
701 let pool = ClassifiedWorkerPool::new(
702 1,
703 TestWorkerFactory {
704 create_count: create_count.clone(),
705 },
706 );
707
708 let pool_ref = pool.clone();
709 let worker_task = tokio::spawn(async move { pool_ref.get_worker().await });
710 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
711
712 pool.clear_all_worker().await;
713
714 let worker = worker_task.await.unwrap();
715 assert!(worker.is_err());
716 assert_eq!(create_count.load(Ordering::SeqCst), 1);
717}
718
719#[tokio::test]
720async fn test_concurrent_clear_all_worker() {
721 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
722 enum TestWorkerClassification {
723 A,
724 }
725
726 struct TestWorker {
727 classification: TestWorkerClassification,
728 }
729
730 #[async_trait::async_trait]
731 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
732 fn is_work(&self) -> bool {
733 true
734 }
735
736 fn is_valid(&self, c: TestWorkerClassification) -> bool {
737 self.classification == c
738 }
739
740 fn classification(&self) -> TestWorkerClassification {
741 self.classification.clone()
742 }
743 }
744
745 struct TestWorkerFactory;
746
747 #[async_trait::async_trait]
748 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
749 async fn create(
750 &self,
751 classification: Option<TestWorkerClassification>,
752 ) -> PoolResult<TestWorker> {
753 Ok(TestWorker {
754 classification: classification.unwrap_or(TestWorkerClassification::A),
755 })
756 }
757 }
758
759 let pool = ClassifiedWorkerPool::new(1, TestWorkerFactory);
760 let worker = pool.get_worker().await.unwrap();
761
762 let pool_ref = pool.clone();
763 let clear_task1 = tokio::spawn(async move {
764 pool_ref.clear_all_worker().await;
765 });
766
767 let pool_ref = pool.clone();
768 let clear_task2 = tokio::spawn(async move {
769 pool_ref.clear_all_worker().await;
770 });
771
772 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
773 drop(worker);
774
775 tokio::time::timeout(std::time::Duration::from_secs(1), async {
776 clear_task1.await.unwrap();
777 clear_task2.await.unwrap();
778 })
779 .await
780 .unwrap();
781}
782
783#[tokio::test]
784async fn test_zero_max_count_returns_error() {
785 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
786 enum TestWorkerClassification {
787 A,
788 }
789
790 struct TestWorker {
791 classification: TestWorkerClassification,
792 }
793
794 #[async_trait::async_trait]
795 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
796 fn is_work(&self) -> bool {
797 true
798 }
799
800 fn is_valid(&self, c: TestWorkerClassification) -> bool {
801 self.classification == c
802 }
803
804 fn classification(&self) -> TestWorkerClassification {
805 self.classification.clone()
806 }
807 }
808
809 struct TestWorkerFactory;
810
811 #[async_trait::async_trait]
812 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
813 async fn create(
814 &self,
815 classification: Option<TestWorkerClassification>,
816 ) -> PoolResult<TestWorker> {
817 Ok(TestWorker {
818 classification: classification.unwrap_or(TestWorkerClassification::A),
819 })
820 }
821 }
822
823 let pool = ClassifiedWorkerPool::new(0, TestWorkerFactory);
824 let worker = pool.get_worker().await;
825 assert!(worker.is_err());
826 assert_eq!(
827 worker.err().unwrap().code(),
828 crate::PoolErrorCode::InvalidConfig
829 );
830}
831
832#[tokio::test]
833async fn test_classified_pool_respects_max_count() {
834 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
835 enum TestWorkerClassification {
836 A,
837 B,
838 }
839
840 struct TestWorker {
841 classification: TestWorkerClassification,
842 }
843
844 #[async_trait::async_trait]
845 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
846 fn is_work(&self) -> bool {
847 true
848 }
849
850 fn is_valid(&self, c: TestWorkerClassification) -> bool {
851 self.classification == c
852 }
853
854 fn classification(&self) -> TestWorkerClassification {
855 self.classification.clone()
856 }
857 }
858
859 struct TestWorkerFactory;
860
861 #[async_trait::async_trait]
862 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
863 async fn create(
864 &self,
865 classification: Option<TestWorkerClassification>,
866 ) -> PoolResult<TestWorker> {
867 Ok(TestWorker {
868 classification: classification.unwrap_or(TestWorkerClassification::A),
869 })
870 }
871 }
872
873 let pool = ClassifiedWorkerPool::new(1, TestWorkerFactory);
874 let _worker = pool.get_worker().await.unwrap();
875
876 let pool_ref = pool.clone();
877 let result = tokio::time::timeout(std::time::Duration::from_millis(100), async move {
878 pool_ref
879 .get_classified_worker(TestWorkerClassification::B)
880 .await
881 })
882 .await;
883
884 assert!(result.is_err());
885}
886
887#[tokio::test]
888async fn test_factory_must_return_matching_classification() {
889 use std::sync::atomic::{AtomicUsize, Ordering};
890 use std::sync::Arc;
891
892 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
893 enum TestWorkerClassification {
894 A,
895 B,
896 }
897
898 struct TestWorker {
899 classification: TestWorkerClassification,
900 }
901
902 #[async_trait::async_trait]
903 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
904 fn is_work(&self) -> bool {
905 true
906 }
907
908 fn is_valid(&self, c: TestWorkerClassification) -> bool {
909 self.classification == c
910 }
911
912 fn classification(&self) -> TestWorkerClassification {
913 self.classification.clone()
914 }
915 }
916
917 struct TestWorkerFactory {
918 create_count: Arc<AtomicUsize>,
919 }
920
921 #[async_trait::async_trait]
922 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
923 async fn create(
924 &self,
925 classification: Option<TestWorkerClassification>,
926 ) -> PoolResult<TestWorker> {
927 let count = self.create_count.fetch_add(1, Ordering::SeqCst);
928 let classification = if count == 0 {
929 TestWorkerClassification::A
930 } else {
931 classification.unwrap_or(TestWorkerClassification::A)
932 };
933 Ok(TestWorker { classification })
934 }
935 }
936
937 let create_count = Arc::new(AtomicUsize::new(0));
938 let pool = ClassifiedWorkerPool::new(
939 1,
940 TestWorkerFactory {
941 create_count: create_count.clone(),
942 },
943 );
944 let worker = pool
945 .get_classified_worker(TestWorkerClassification::B)
946 .await;
947 assert!(worker.is_err());
948 assert_eq!(
949 worker.err().unwrap().code(),
950 crate::PoolErrorCode::InvalidConfig
951 );
952
953 let worker = pool
954 .get_classified_worker(TestWorkerClassification::B)
955 .await;
956 assert!(worker.is_ok());
957 assert_eq!(create_count.load(Ordering::SeqCst), 2);
958}
959
960#[tokio::test(flavor = "multi_thread")]
961async fn test_classified_waiter_keeps_queue_priority_over_later_generic_waiter() {
962 use std::sync::mpsc;
963
964 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
965 enum TestWorkerClassification {
966 B,
967 }
968
969 struct TestWorker {
970 classification: TestWorkerClassification,
971 }
972
973 #[async_trait::async_trait]
974 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
975 fn is_work(&self) -> bool {
976 true
977 }
978
979 fn is_valid(&self, c: TestWorkerClassification) -> bool {
980 self.classification == c
981 }
982
983 fn classification(&self) -> TestWorkerClassification {
984 self.classification.clone()
985 }
986 }
987
988 struct TestWorkerFactory;
989
990 #[async_trait::async_trait]
991 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
992 async fn create(
993 &self,
994 classification: Option<TestWorkerClassification>,
995 ) -> PoolResult<TestWorker> {
996 Ok(TestWorker {
997 classification: classification.unwrap_or(TestWorkerClassification::B),
998 })
999 }
1000 }
1001
1002 let pool = ClassifiedWorkerPool::new(1, TestWorkerFactory);
1003 let worker = pool
1004 .get_classified_worker(TestWorkerClassification::B)
1005 .await
1006 .unwrap();
1007
1008 let (tx, rx) = mpsc::channel();
1009
1010 let pool_ref = pool.clone();
1011 let tx_classified = tx.clone();
1012 let classified_task = tokio::spawn(async move {
1013 let _worker = pool_ref
1014 .get_classified_worker(TestWorkerClassification::B)
1015 .await
1016 .unwrap();
1017 tx_classified.send("classified").unwrap();
1018 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1019 });
1020
1021 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1022
1023 let pool_ref = pool.clone();
1024 let generic_task = tokio::spawn(async move {
1025 let _worker = pool_ref.get_worker().await.unwrap();
1026 tx.send("generic").unwrap();
1027 });
1028
1029 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1030 drop(worker);
1031
1032 let first = rx.recv_timeout(std::time::Duration::from_secs(2)).unwrap();
1033 assert_eq!(first, "classified");
1034
1035 classified_task.await.unwrap();
1036 generic_task.await.unwrap();
1037}
1038
1039#[tokio::test]
1040async fn test_generic_factory_worker_must_be_valid_for_its_primary_classification() {
1041 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
1042 enum TestWorkerClassification {
1043 A,
1044 B,
1045 }
1046
1047 struct TestWorker {
1048 classification: TestWorkerClassification,
1049 }
1050
1051 #[async_trait::async_trait]
1052 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
1053 fn is_work(&self) -> bool {
1054 true
1055 }
1056
1057 fn is_valid(&self, c: TestWorkerClassification) -> bool {
1058 c == TestWorkerClassification::B
1059 }
1060
1061 fn classification(&self) -> TestWorkerClassification {
1062 self.classification.clone()
1063 }
1064 }
1065
1066 struct TestWorkerFactory;
1067
1068 #[async_trait::async_trait]
1069 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
1070 async fn create(
1071 &self,
1072 _classification: Option<TestWorkerClassification>,
1073 ) -> PoolResult<TestWorker> {
1074 Ok(TestWorker {
1075 classification: TestWorkerClassification::A,
1076 })
1077 }
1078 }
1079
1080 let pool = ClassifiedWorkerPool::new(1, TestWorkerFactory);
1081 let worker = pool.get_worker().await;
1082 assert!(worker.is_err());
1083 assert_eq!(
1084 worker.err().unwrap().code(),
1085 crate::PoolErrorCode::InvalidConfig
1086 );
1087}