Skip to main content

sfo_pool/
classified_worker_pool.rs

1use crate::{pool_cleared_error, pool_clearing_error, pool_invalid_config_error, PoolResult};
2use notify_future::Notify;
3use std::collections::HashMap;
4use std::hash::Hash;
5use std::ops::{Deref, DerefMut};
6use std::sync::{Arc, Mutex};
7
8pub trait WorkerClassification: Send + 'static + Clone + Hash + Eq + PartialEq {}
9
10impl<T: Send + 'static + Clone + Hash + Eq + PartialEq> WorkerClassification for T {}
11
12#[async_trait::async_trait]
13pub trait ClassifiedWorker<C: WorkerClassification>: Send + 'static {
14    fn is_work(&self) -> bool;
15    /// Returns whether this worker can currently serve the requested classification.
16    /// The pool still tracks capacity by the worker's primary `classification()`.
17    fn is_valid(&self, c: C) -> bool;
18    /// Returns the worker's primary classification used for accounting and replacement.
19    fn classification(&self) -> C;
20}
21
22pub struct ClassifiedWorkerGuard<
23    C: WorkerClassification,
24    W: ClassifiedWorker<C>,
25    F: ClassifiedWorkerFactory<C, W>,
26> {
27    pool_ref: ClassifiedWorkerPoolRef<C, W, F>,
28    worker: Option<W>,
29}
30
31impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>>
32    ClassifiedWorkerGuard<C, W, F>
33{
34    fn new(worker: W, pool_ref: ClassifiedWorkerPoolRef<C, W, F>) -> Self {
35        ClassifiedWorkerGuard {
36            pool_ref,
37            worker: Some(worker),
38        }
39    }
40}
41
42impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Deref
43    for ClassifiedWorkerGuard<C, W, F>
44{
45    type Target = W;
46
47    fn deref(&self) -> &Self::Target {
48        self.worker.as_ref().unwrap()
49    }
50}
51
52impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> DerefMut
53    for ClassifiedWorkerGuard<C, W, F>
54{
55    fn deref_mut(&mut self) -> &mut Self::Target {
56        self.worker.as_mut().unwrap()
57    }
58}
59
60impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Drop
61    for ClassifiedWorkerGuard<C, W, F>
62{
63    fn drop(&mut self) {
64        if let Some(worker) = self.worker.take() {
65            self.pool_ref.release(worker);
66        }
67    }
68}
69
70#[async_trait::async_trait]
71pub trait ClassifiedWorkerFactory<C: WorkerClassification, W: ClassifiedWorker<C>>:
72    Send + Sync + 'static
73{
74    async fn create(&self, c: Option<C>) -> PoolResult<W>;
75}
76
77struct WaitingItem<
78    C: WorkerClassification,
79    W: ClassifiedWorker<C>,
80    F: ClassifiedWorkerFactory<C, W>,
81> {
82    future: Notify<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
83    condition: Option<C>,
84}
85struct WorkerPoolState<
86    C: WorkerClassification,
87    W: ClassifiedWorker<C>,
88    F: ClassifiedWorkerFactory<C, W>,
89> {
90    current_count: u16,
91    classified_count_map: HashMap<C, u16>,
92    worker_list: Vec<W>,
93    waiting_list: Vec<WaitingItem<C, W, F>>,
94    clearing: bool,
95    clear_waiting_list: Vec<Notify<()>>,
96}
97
98impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>>
99    WorkerPoolState<C, W, F>
100{
101    fn inc_classified_count(&mut self, c: C) {
102        let count = self.classified_count_map.entry(c).or_insert(0);
103        *count += 1;
104    }
105
106    fn dec_classified_count(&mut self, c: C) {
107        let mut should_remove = false;
108        if let Some(count) = self.classified_count_map.get_mut(&c) {
109            debug_assert!(*count > 0);
110            *count -= 1;
111            should_remove = *count == 0;
112        }
113        if should_remove {
114            self.classified_count_map.remove(&c);
115        }
116    }
117
118    fn take_clear_waiters_if_done(&mut self) -> Vec<Notify<()>> {
119        if self.clearing && self.current_count == 0 {
120            self.clearing = false;
121            self.clear_waiting_list.drain(..).collect()
122        } else {
123            Vec::new()
124        }
125    }
126
127    fn find_matching_waiter_index_for_worker(&self, worker: &W) -> Option<usize> {
128        self.waiting_list.iter().position(|waiting| {
129            waiting
130                .condition
131                .as_ref()
132                .map(|condition| worker.is_valid(condition.clone()))
133                .unwrap_or(true)
134        })
135    }
136}
137
138pub struct ClassifiedWorkerPool<
139    C: WorkerClassification,
140    W: ClassifiedWorker<C>,
141    F: ClassifiedWorkerFactory<C, W>,
142> {
143    factory: Arc<F>,
144    max_count: u16,
145    state: Mutex<WorkerPoolState<C, W, F>>,
146}
147pub type ClassifiedWorkerPoolRef<C, W, F> = Arc<ClassifiedWorkerPool<C, W, F>>;
148
149impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>>
150    ClassifiedWorkerPool<C, W, F>
151{
152    fn validate_created_worker(requested_classification: Option<&C>, worker: &W) -> PoolResult<()> {
153        let worker_classification = worker.classification();
154        if !worker.is_valid(worker_classification.clone()) {
155            return Err(pool_invalid_config_error(
156                "worker primary classification is not valid for itself",
157            ));
158        }
159        if let Some(classification) = requested_classification {
160            if worker_classification != classification.clone() {
161                return Err(pool_invalid_config_error(
162                    "factory returned worker with mismatched classification",
163                ));
164            }
165        }
166        Ok(())
167    }
168
169    pub fn new(max_count: u16, factory: F) -> ClassifiedWorkerPoolRef<C, W, F> {
170        Arc::new(ClassifiedWorkerPool {
171            factory: Arc::new(factory),
172            max_count,
173            state: Mutex::new(WorkerPoolState {
174                current_count: 0,
175                classified_count_map: HashMap::new(),
176                worker_list: Vec::with_capacity(max_count as usize),
177                waiting_list: Vec::new(),
178                clearing: false,
179                clear_waiting_list: Vec::new(),
180            }),
181        })
182    }
183
184    pub async fn get_worker(
185        self: &ClassifiedWorkerPoolRef<C, W, F>,
186    ) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
187        if self.max_count == 0 {
188            return Err(pool_invalid_config_error("pool max_count is zero"));
189        }
190
191        let wait = {
192            let mut state = self.state.lock().unwrap();
193            if state.clearing {
194                return Err(pool_clearing_error());
195            }
196
197            while state.worker_list.len() > 0 {
198                let worker = state.worker_list.pop().unwrap();
199                if !worker.is_work() {
200                    state.current_count -= 1;
201                    state.dec_classified_count(worker.classification());
202                    continue;
203                }
204                return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
205            }
206
207            if state.current_count < self.max_count {
208                state.current_count += 1;
209                None
210            } else {
211                let (notify, waiter) = Notify::new();
212                state.waiting_list.push(WaitingItem {
213                    future: notify,
214                    condition: None,
215                });
216                Some(waiter)
217            }
218        };
219
220        if let Some(wait) = wait {
221            wait.await
222        } else {
223            let worker = match self.factory.create(None).await {
224                Ok(worker) => {
225                    if let Err(err) = Self::validate_created_worker(None, &worker) {
226                        let mut state = self.state.lock().unwrap();
227                        state.current_count -= 1;
228                        let clear_waiters = state.take_clear_waiters_if_done();
229                        drop(state);
230                        for waiter in clear_waiters {
231                            waiter.notify(());
232                        }
233                        return Err(err);
234                    }
235                    worker
236                }
237                Err(err) => {
238                    let mut state = self.state.lock().unwrap();
239                    state.current_count -= 1;
240                    let clear_waiters = state.take_clear_waiters_if_done();
241                    drop(state);
242                    for waiter in clear_waiters {
243                        waiter.notify(());
244                    }
245                    return Err(err);
246                }
247            };
248            let (clearing, clear_waiters) = {
249                let mut state = self.state.lock().unwrap();
250                if state.clearing {
251                    state.current_count -= 1;
252                    (true, state.take_clear_waiters_if_done())
253                } else {
254                    state.inc_classified_count(worker.classification());
255                    (false, Vec::new())
256                }
257            };
258            for waiter in clear_waiters {
259                waiter.notify(());
260            }
261            if clearing {
262                return Err(pool_cleared_error());
263            }
264            Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
265        }
266    }
267
268    pub async fn get_classified_worker(
269        self: &ClassifiedWorkerPoolRef<C, W, F>,
270        classification: C,
271    ) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
272        if self.max_count == 0 {
273            return Err(pool_invalid_config_error("pool max_count is zero"));
274        }
275
276        let wait = {
277            let mut state = self.state.lock().unwrap();
278            if state.clearing {
279                return Err(pool_clearing_error());
280            }
281
282            let old_count = state.worker_list.len() as u16;
283            let unwork_classification = state
284                .worker_list
285                .iter()
286                .filter(|worker| !worker.is_work())
287                .map(|worker| worker.classification())
288                .collect::<Vec<C>>();
289            for classification in unwork_classification.iter() {
290                state.dec_classified_count(classification.clone());
291            }
292            state.worker_list.retain(|worker| worker.is_work());
293            state.current_count -= old_count - state.worker_list.len() as u16;
294            for (index, worker) in state.worker_list.iter().enumerate() {
295                if worker.is_valid(classification.clone()) {
296                    let worker = state.worker_list.remove(index);
297                    return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
298                }
299            }
300
301            if state.current_count < self.max_count {
302                state.current_count += 1;
303                None
304            } else {
305                let (notify, waiter) = Notify::new();
306                state.waiting_list.push(WaitingItem {
307                    future: notify,
308                    condition: Some(classification.clone()),
309                });
310                Some(waiter)
311            }
312        };
313
314        if let Some(wait) = wait {
315            wait.await
316        } else {
317            let worker = match self.factory.create(Some(classification.clone())).await {
318                Ok(worker) => {
319                    if let Err(err) = Self::validate_created_worker(Some(&classification), &worker)
320                    {
321                        let mut state = self.state.lock().unwrap();
322                        state.current_count -= 1;
323                        let clear_waiters = state.take_clear_waiters_if_done();
324                        drop(state);
325                        for waiter in clear_waiters {
326                            waiter.notify(());
327                        }
328                        return Err(err);
329                    }
330                    worker
331                }
332                Err(err) => {
333                    let mut state = self.state.lock().unwrap();
334                    state.current_count -= 1;
335                    let clear_waiters = state.take_clear_waiters_if_done();
336                    drop(state);
337                    for waiter in clear_waiters {
338                        waiter.notify(());
339                    }
340                    return Err(err);
341                }
342            };
343            let (clearing, clear_waiters) = {
344                let mut state = self.state.lock().unwrap();
345                if state.clearing {
346                    state.current_count -= 1;
347                    (true, state.take_clear_waiters_if_done())
348                } else {
349                    state.inc_classified_count(worker.classification());
350                    (false, Vec::new())
351                }
352            };
353            for waiter in clear_waiters {
354                waiter.notify(());
355            }
356            if clearing {
357                return Err(pool_cleared_error());
358            }
359            Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
360        }
361    }
362
363    pub async fn clear_all_worker(&self) {
364        let (waiter, waiting_list, clear_waiters) = {
365            let mut state = self.state.lock().unwrap();
366            if !state.clearing {
367                state.clearing = true;
368                let idle_classifications = state
369                    .worker_list
370                    .iter()
371                    .map(|worker| worker.classification())
372                    .collect::<Vec<_>>();
373                let cur_worker_count = idle_classifications.len();
374                state.worker_list.clear();
375                state.current_count -= cur_worker_count as u16;
376                for classification in idle_classifications {
377                    state.dec_classified_count(classification);
378                }
379            }
380
381            let waiting_list = state.waiting_list.drain(..).collect::<Vec<_>>();
382            if state.current_count == 0 {
383                let clear_waiters = state.take_clear_waiters_if_done();
384                (None, waiting_list, clear_waiters)
385            } else {
386                let (notify, waiter) = Notify::new();
387                state.clear_waiting_list.push(notify);
388                (Some(waiter), waiting_list, Vec::new())
389            }
390        };
391        for waiting in waiting_list {
392            waiting.future.notify(Err(pool_cleared_error()));
393        }
394        for waiter in clear_waiters {
395            waiter.notify(());
396        }
397        if let Some(waiter) = waiter {
398            waiter.await;
399        }
400    }
401
402    fn release(self: &ClassifiedWorkerPoolRef<C, W, F>, work: W) {
403        enum ReleaseAction<
404            C: WorkerClassification,
405            W: ClassifiedWorker<C>,
406            F: ClassifiedWorkerFactory<C, W>,
407        > {
408            None,
409            Notify(
410                Notify<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
411                ClassifiedWorkerGuard<C, W, F>,
412            ),
413            Replace(
414                Notify<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
415                Option<C>,
416            ),
417        }
418
419        let mut clear_waiters = Vec::new();
420        let action = {
421            let mut state = self.state.lock().unwrap();
422            if state.clearing {
423                state.current_count -= 1;
424                let classification = work.classification();
425                state.dec_classified_count(classification);
426                clear_waiters = state.take_clear_waiters_if_done();
427                ReleaseAction::None
428            } else if work.is_work() {
429                if let Some(index) = state.find_matching_waiter_index_for_worker(&work) {
430                    let waiting_item = state.waiting_list.remove(index);
431                    ReleaseAction::Notify(
432                        waiting_item.future,
433                        ClassifiedWorkerGuard::new(work, self.clone()),
434                    )
435                } else {
436                    state.worker_list.push(work);
437                    ReleaseAction::None
438                }
439            } else {
440                let classification = work.classification();
441                state.dec_classified_count(classification.clone());
442                if let Some(index) = state.find_matching_waiter_index_for_worker(&work) {
443                    let waiting_item = state.waiting_list.remove(index);
444                    let request_classification =
445                        waiting_item.condition.clone().or(Some(classification));
446                    ReleaseAction::Replace(waiting_item.future, request_classification)
447                } else {
448                    state.current_count -= 1;
449                    clear_waiters = state.take_clear_waiters_if_done();
450                    ReleaseAction::None
451                }
452            }
453        };
454
455        for waiter in clear_waiters {
456            waiter.notify(());
457        }
458
459        match action {
460            ReleaseAction::None => {}
461            ReleaseAction::Notify(waiting, worker) => {
462                waiting.notify(Ok(worker));
463            }
464            ReleaseAction::Replace(waiting, request_classification) => {
465                let factory = self.factory.clone();
466                let this = self.clone();
467                tokio::spawn(async move {
468                    let result = match factory.create(request_classification.clone()).await {
469                        Ok(worker) => {
470                            if let Err(err) = Self::validate_created_worker(
471                                request_classification.as_ref(),
472                                &worker,
473                            ) {
474                                let mut state = this.state.lock().unwrap();
475                                state.current_count -= 1;
476                                let clear_waiters = state.take_clear_waiters_if_done();
477                                drop(state);
478                                for waiter in clear_waiters {
479                                    waiter.notify(());
480                                }
481                                waiting.notify(Err(err));
482                                return;
483                            }
484                            let mut state = this.state.lock().unwrap();
485                            if state.clearing {
486                                state.current_count -= 1;
487                                let clear_waiters = state.take_clear_waiters_if_done();
488                                drop(state);
489                                for waiter in clear_waiters {
490                                    waiter.notify(());
491                                }
492                                Err(pool_cleared_error())
493                            } else {
494                                state.inc_classified_count(worker.classification());
495                                drop(state);
496                                Ok(ClassifiedWorkerGuard::new(worker, this))
497                            }
498                        }
499                        Err(err) => {
500                            let mut state = this.state.lock().unwrap();
501                            state.current_count -= 1;
502                            let clear_waiters = state.take_clear_waiters_if_done();
503                            drop(state);
504                            for waiter in clear_waiters {
505                                waiter.notify(());
506                            }
507                            Err(err)
508                        }
509                    };
510                    waiting.notify(result);
511                });
512            }
513        }
514    }
515}
516
517#[tokio::test]
518async fn test_pool() {
519    struct TestWorker {
520        work: bool,
521        classification: TestWorkerClassification,
522    }
523
524    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
525    enum TestWorkerClassification {
526        A,
527        B,
528    }
529    #[async_trait::async_trait]
530    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
531        fn is_work(&self) -> bool {
532            self.work
533        }
534
535        fn is_valid(&self, c: TestWorkerClassification) -> bool {
536            self.classification == c
537        }
538
539        fn classification(&self) -> TestWorkerClassification {
540            self.classification.clone()
541        }
542    }
543
544    struct TestWorkerFactory;
545
546    #[async_trait::async_trait]
547    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
548        async fn create(
549            &self,
550            classification: Option<TestWorkerClassification>,
551        ) -> PoolResult<TestWorker> {
552            if let Some(classification) = classification {
553                Ok(TestWorker {
554                    work: true,
555                    classification,
556                })
557            } else {
558                Ok(TestWorker {
559                    work: true,
560                    classification: TestWorkerClassification::A,
561                })
562            }
563        }
564    }
565
566    let pool = ClassifiedWorkerPool::new(3, TestWorkerFactory);
567    let pool_ref = pool.clone();
568    tokio::spawn(async move {
569        let _worker = pool_ref.get_worker().await.unwrap();
570        tokio::time::sleep(std::time::Duration::from_secs(5)).await;
571    });
572    let pool_ref = pool.clone();
573    tokio::spawn(async move {
574        let _worker = pool_ref.get_worker().await.unwrap();
575        tokio::time::sleep(std::time::Duration::from_secs(10)).await;
576    });
577
578    let pool_ref = pool.clone();
579    tokio::spawn(async move {
580        let _worker = pool_ref
581            .get_classified_worker(TestWorkerClassification::B)
582            .await
583            .unwrap();
584        tokio::time::sleep(std::time::Duration::from_secs(6)).await;
585    });
586
587    let pool_ref = pool.clone();
588    tokio::spawn(async move {
589        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
590
591        let start = std::time::Instant::now();
592        let _worker3 = pool_ref
593            .get_classified_worker(TestWorkerClassification::B)
594            .await
595            .unwrap();
596        let end = std::time::Instant::now();
597        let duration = end.duration_since(start);
598        println!("classified duration {}", duration.as_millis());
599        assert!(duration.as_millis() > 2000);
600    });
601
602    let pool_ref = pool.clone();
603    tokio::spawn(async move {
604        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
605
606        let start = std::time::Instant::now();
607        let _worker3 = pool_ref.get_worker().await.unwrap();
608        let end = std::time::Instant::now();
609        let duration = end.duration_since(start);
610        println!("classified duration2 {}", duration.as_millis());
611        assert!(duration.as_millis() > 2000);
612    });
613
614    tokio::time::sleep(std::time::Duration::from_secs(15)).await;
615
616    let pool_ref = pool.clone();
617    tokio::spawn(async move {
618        let _worker = pool_ref.get_worker().await;
619        let _worker1 = pool_ref.get_worker().await;
620        let _worker2 = pool_ref.get_worker().await;
621        tokio::time::sleep(std::time::Duration::from_secs(5)).await;
622    });
623
624    let pool_ref = pool.clone();
625    tokio::spawn(async move {
626        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
627        let worker = pool_ref.get_worker().await;
628        assert!(worker.is_err());
629    });
630
631    let pool_ref = pool.clone();
632    tokio::spawn(async move {
633        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
634        let worker = pool_ref
635            .get_classified_worker(TestWorkerClassification::B)
636            .await;
637        assert!(worker.is_err());
638    });
639
640    let pool_ref = pool.clone();
641    tokio::spawn(async move {
642        let start = std::time::Instant::now();
643        pool_ref.clear_all_worker().await;
644        let end = std::time::Instant::now();
645        let duration = end.duration_since(start);
646        println!("classified duration3 {}", duration.as_millis());
647        assert!(duration.as_millis() > 4000);
648    });
649
650    tokio::time::sleep(std::time::Duration::from_secs(10)).await;
651}
652
653#[tokio::test]
654async fn test_clear_all_worker_waits_for_inflight_create() {
655    use std::sync::atomic::{AtomicUsize, Ordering};
656    use std::sync::Arc;
657
658    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
659    enum TestWorkerClassification {
660        A,
661    }
662
663    struct TestWorker {
664        classification: TestWorkerClassification,
665    }
666
667    #[async_trait::async_trait]
668    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
669        fn is_work(&self) -> bool {
670            true
671        }
672
673        fn is_valid(&self, c: TestWorkerClassification) -> bool {
674            self.classification == c
675        }
676
677        fn classification(&self) -> TestWorkerClassification {
678            self.classification.clone()
679        }
680    }
681
682    struct TestWorkerFactory {
683        create_count: Arc<AtomicUsize>,
684    }
685
686    #[async_trait::async_trait]
687    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
688        async fn create(
689            &self,
690            classification: Option<TestWorkerClassification>,
691        ) -> PoolResult<TestWorker> {
692            self.create_count.fetch_add(1, Ordering::SeqCst);
693            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
694            Ok(TestWorker {
695                classification: classification.unwrap_or(TestWorkerClassification::A),
696            })
697        }
698    }
699
700    let create_count = Arc::new(AtomicUsize::new(0));
701    let pool = ClassifiedWorkerPool::new(
702        1,
703        TestWorkerFactory {
704            create_count: create_count.clone(),
705        },
706    );
707
708    let pool_ref = pool.clone();
709    let worker_task = tokio::spawn(async move { pool_ref.get_worker().await });
710    tokio::time::sleep(std::time::Duration::from_millis(100)).await;
711
712    pool.clear_all_worker().await;
713
714    let worker = worker_task.await.unwrap();
715    assert!(worker.is_err());
716    assert_eq!(create_count.load(Ordering::SeqCst), 1);
717}
718
719#[tokio::test]
720async fn test_concurrent_clear_all_worker() {
721    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
722    enum TestWorkerClassification {
723        A,
724    }
725
726    struct TestWorker {
727        classification: TestWorkerClassification,
728    }
729
730    #[async_trait::async_trait]
731    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
732        fn is_work(&self) -> bool {
733            true
734        }
735
736        fn is_valid(&self, c: TestWorkerClassification) -> bool {
737            self.classification == c
738        }
739
740        fn classification(&self) -> TestWorkerClassification {
741            self.classification.clone()
742        }
743    }
744
745    struct TestWorkerFactory;
746
747    #[async_trait::async_trait]
748    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
749        async fn create(
750            &self,
751            classification: Option<TestWorkerClassification>,
752        ) -> PoolResult<TestWorker> {
753            Ok(TestWorker {
754                classification: classification.unwrap_or(TestWorkerClassification::A),
755            })
756        }
757    }
758
759    let pool = ClassifiedWorkerPool::new(1, TestWorkerFactory);
760    let worker = pool.get_worker().await.unwrap();
761
762    let pool_ref = pool.clone();
763    let clear_task1 = tokio::spawn(async move {
764        pool_ref.clear_all_worker().await;
765    });
766
767    let pool_ref = pool.clone();
768    let clear_task2 = tokio::spawn(async move {
769        pool_ref.clear_all_worker().await;
770    });
771
772    tokio::time::sleep(std::time::Duration::from_millis(100)).await;
773    drop(worker);
774
775    tokio::time::timeout(std::time::Duration::from_secs(1), async {
776        clear_task1.await.unwrap();
777        clear_task2.await.unwrap();
778    })
779    .await
780    .unwrap();
781}
782
783#[tokio::test]
784async fn test_zero_max_count_returns_error() {
785    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
786    enum TestWorkerClassification {
787        A,
788    }
789
790    struct TestWorker {
791        classification: TestWorkerClassification,
792    }
793
794    #[async_trait::async_trait]
795    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
796        fn is_work(&self) -> bool {
797            true
798        }
799
800        fn is_valid(&self, c: TestWorkerClassification) -> bool {
801            self.classification == c
802        }
803
804        fn classification(&self) -> TestWorkerClassification {
805            self.classification.clone()
806        }
807    }
808
809    struct TestWorkerFactory;
810
811    #[async_trait::async_trait]
812    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
813        async fn create(
814            &self,
815            classification: Option<TestWorkerClassification>,
816        ) -> PoolResult<TestWorker> {
817            Ok(TestWorker {
818                classification: classification.unwrap_or(TestWorkerClassification::A),
819            })
820        }
821    }
822
823    let pool = ClassifiedWorkerPool::new(0, TestWorkerFactory);
824    let worker = pool.get_worker().await;
825    assert!(worker.is_err());
826    assert_eq!(
827        worker.err().unwrap().code(),
828        crate::PoolErrorCode::InvalidConfig
829    );
830}
831
832#[tokio::test]
833async fn test_classified_pool_respects_max_count() {
834    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
835    enum TestWorkerClassification {
836        A,
837        B,
838    }
839
840    struct TestWorker {
841        classification: TestWorkerClassification,
842    }
843
844    #[async_trait::async_trait]
845    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
846        fn is_work(&self) -> bool {
847            true
848        }
849
850        fn is_valid(&self, c: TestWorkerClassification) -> bool {
851            self.classification == c
852        }
853
854        fn classification(&self) -> TestWorkerClassification {
855            self.classification.clone()
856        }
857    }
858
859    struct TestWorkerFactory;
860
861    #[async_trait::async_trait]
862    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
863        async fn create(
864            &self,
865            classification: Option<TestWorkerClassification>,
866        ) -> PoolResult<TestWorker> {
867            Ok(TestWorker {
868                classification: classification.unwrap_or(TestWorkerClassification::A),
869            })
870        }
871    }
872
873    let pool = ClassifiedWorkerPool::new(1, TestWorkerFactory);
874    let _worker = pool.get_worker().await.unwrap();
875
876    let pool_ref = pool.clone();
877    let result = tokio::time::timeout(std::time::Duration::from_millis(100), async move {
878        pool_ref
879            .get_classified_worker(TestWorkerClassification::B)
880            .await
881    })
882    .await;
883
884    assert!(result.is_err());
885}
886
887#[tokio::test]
888async fn test_factory_must_return_matching_classification() {
889    use std::sync::atomic::{AtomicUsize, Ordering};
890    use std::sync::Arc;
891
892    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
893    enum TestWorkerClassification {
894        A,
895        B,
896    }
897
898    struct TestWorker {
899        classification: TestWorkerClassification,
900    }
901
902    #[async_trait::async_trait]
903    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
904        fn is_work(&self) -> bool {
905            true
906        }
907
908        fn is_valid(&self, c: TestWorkerClassification) -> bool {
909            self.classification == c
910        }
911
912        fn classification(&self) -> TestWorkerClassification {
913            self.classification.clone()
914        }
915    }
916
917    struct TestWorkerFactory {
918        create_count: Arc<AtomicUsize>,
919    }
920
921    #[async_trait::async_trait]
922    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
923        async fn create(
924            &self,
925            classification: Option<TestWorkerClassification>,
926        ) -> PoolResult<TestWorker> {
927            let count = self.create_count.fetch_add(1, Ordering::SeqCst);
928            let classification = if count == 0 {
929                TestWorkerClassification::A
930            } else {
931                classification.unwrap_or(TestWorkerClassification::A)
932            };
933            Ok(TestWorker { classification })
934        }
935    }
936
937    let create_count = Arc::new(AtomicUsize::new(0));
938    let pool = ClassifiedWorkerPool::new(
939        1,
940        TestWorkerFactory {
941            create_count: create_count.clone(),
942        },
943    );
944    let worker = pool
945        .get_classified_worker(TestWorkerClassification::B)
946        .await;
947    assert!(worker.is_err());
948    assert_eq!(
949        worker.err().unwrap().code(),
950        crate::PoolErrorCode::InvalidConfig
951    );
952
953    let worker = pool
954        .get_classified_worker(TestWorkerClassification::B)
955        .await;
956    assert!(worker.is_ok());
957    assert_eq!(create_count.load(Ordering::SeqCst), 2);
958}
959
960#[tokio::test(flavor = "multi_thread")]
961async fn test_classified_waiter_keeps_queue_priority_over_later_generic_waiter() {
962    use std::sync::mpsc;
963
964    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
965    enum TestWorkerClassification {
966        B,
967    }
968
969    struct TestWorker {
970        classification: TestWorkerClassification,
971    }
972
973    #[async_trait::async_trait]
974    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
975        fn is_work(&self) -> bool {
976            true
977        }
978
979        fn is_valid(&self, c: TestWorkerClassification) -> bool {
980            self.classification == c
981        }
982
983        fn classification(&self) -> TestWorkerClassification {
984            self.classification.clone()
985        }
986    }
987
988    struct TestWorkerFactory;
989
990    #[async_trait::async_trait]
991    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
992        async fn create(
993            &self,
994            classification: Option<TestWorkerClassification>,
995        ) -> PoolResult<TestWorker> {
996            Ok(TestWorker {
997                classification: classification.unwrap_or(TestWorkerClassification::B),
998            })
999        }
1000    }
1001
1002    let pool = ClassifiedWorkerPool::new(1, TestWorkerFactory);
1003    let worker = pool
1004        .get_classified_worker(TestWorkerClassification::B)
1005        .await
1006        .unwrap();
1007
1008    let (tx, rx) = mpsc::channel();
1009
1010    let pool_ref = pool.clone();
1011    let tx_classified = tx.clone();
1012    let classified_task = tokio::spawn(async move {
1013        let _worker = pool_ref
1014            .get_classified_worker(TestWorkerClassification::B)
1015            .await
1016            .unwrap();
1017        tx_classified.send("classified").unwrap();
1018        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1019    });
1020
1021    tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1022
1023    let pool_ref = pool.clone();
1024    let generic_task = tokio::spawn(async move {
1025        let _worker = pool_ref.get_worker().await.unwrap();
1026        tx.send("generic").unwrap();
1027    });
1028
1029    tokio::time::sleep(std::time::Duration::from_millis(20)).await;
1030    drop(worker);
1031
1032    let first = rx.recv_timeout(std::time::Duration::from_secs(2)).unwrap();
1033    assert_eq!(first, "classified");
1034
1035    classified_task.await.unwrap();
1036    generic_task.await.unwrap();
1037}
1038
1039#[tokio::test]
1040async fn test_generic_factory_worker_must_be_valid_for_its_primary_classification() {
1041    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
1042    enum TestWorkerClassification {
1043        A,
1044        B,
1045    }
1046
1047    struct TestWorker {
1048        classification: TestWorkerClassification,
1049    }
1050
1051    #[async_trait::async_trait]
1052    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
1053        fn is_work(&self) -> bool {
1054            true
1055        }
1056
1057        fn is_valid(&self, c: TestWorkerClassification) -> bool {
1058            c == TestWorkerClassification::B
1059        }
1060
1061        fn classification(&self) -> TestWorkerClassification {
1062            self.classification.clone()
1063        }
1064    }
1065
1066    struct TestWorkerFactory;
1067
1068    #[async_trait::async_trait]
1069    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
1070        async fn create(
1071            &self,
1072            _classification: Option<TestWorkerClassification>,
1073        ) -> PoolResult<TestWorker> {
1074            Ok(TestWorker {
1075                classification: TestWorkerClassification::A,
1076            })
1077        }
1078    }
1079
1080    let pool = ClassifiedWorkerPool::new(1, TestWorkerFactory);
1081    let worker = pool.get_worker().await;
1082    assert!(worker.is_err());
1083    assert_eq!(
1084        worker.err().unwrap().code(),
1085        crate::PoolErrorCode::InvalidConfig
1086    );
1087}