sfo_pool/
classified_worker_pool.rs

1use std::collections::{HashMap};
2use std::hash::Hash;
3use std::ops::{Deref, DerefMut};
4use std::sync::{Arc, Mutex};
5use std::thread::sleep;
6use std::time::Duration;
7use notify_future::{Notify};
8use crate::{PoolError, PoolErrorCode, PoolResult};
9
10pub trait WorkerClassification: Send + 'static + Clone + Hash + Eq + PartialEq {
11
12}
13
14impl<T: Send + 'static + Clone + Hash + Eq + PartialEq> WorkerClassification for T {
15
16}
17
18#[async_trait::async_trait]
19pub trait ClassifiedWorker<C: WorkerClassification>: Send + 'static {
20    fn is_work(&self) -> bool;
21    fn is_valid(&self, c: C) -> bool;
22    fn classification(&self) -> C;
23}
24
25pub struct ClassifiedWorkerGuard<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
26    pool_ref: ClassifiedWorkerPoolRef<C, W, F>,
27    worker: Option<W>
28}
29
30impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> ClassifiedWorkerGuard<C, W, F> {
31    fn new(worker: W, pool_ref: ClassifiedWorkerPoolRef<C, W, F>) -> Self {
32        ClassifiedWorkerGuard {
33            pool_ref,
34            worker: Some(worker)
35        }
36    }
37}
38
39impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Deref for ClassifiedWorkerGuard<C, W, F> {
40    type Target = W;
41
42    fn deref(&self) -> &Self::Target {
43        self.worker.as_ref().unwrap()
44    }
45}
46
47impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> DerefMut for ClassifiedWorkerGuard<C, W, F> {
48    fn deref_mut(&mut self) -> &mut Self::Target {
49        self.worker.as_mut().unwrap()
50    }
51}
52
53impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Drop for ClassifiedWorkerGuard<C, W, F> {
54    fn drop(&mut self) {
55        if let Some(worker) = self.worker.take() {
56            self.pool_ref.release(worker);
57        }
58    }
59}
60
61#[async_trait::async_trait]
62pub trait ClassifiedWorkerFactory<C: WorkerClassification, W: ClassifiedWorker<C>>: Send + Sync + 'static {
63    async fn create(&self, c: Option<C>) -> PoolResult<W>;
64}
65
66struct WaitingItem<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
67    future: Notify<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
68    condition: Option<C>,
69}
70struct WorkerPoolState<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
71    current_count: u16,
72    classified_count_map: HashMap<C, u16>,
73    worker_list: Vec<W>,
74    waiting_list: Vec<WaitingItem<C, W, F>>,
75    clear_notify: Option<Notify<()>>,
76}
77
78impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> WorkerPoolState<C, W, F> {
79    fn inc_classified_count(&mut self, c: C) {
80        let count = self.classified_count_map.entry(c).or_insert(0);
81        *count += 1;
82    }
83
84    fn dec_classified_count(&mut self, c: C) {
85        let count = self.classified_count_map.entry(c).or_insert(0);
86        *count -= 1;
87    }
88
89    fn get_classified_count(&self, c: C) -> u16 {
90        *self.classified_count_map.get(&c).unwrap_or(&0)
91    }
92}
93
94pub struct ClassifiedWorkerPool<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
95    factory: Arc<F>,
96    max_count: u16,
97    state: Mutex<WorkerPoolState<C, W, F>>,
98}
99pub type ClassifiedWorkerPoolRef<C, W, F> = Arc<ClassifiedWorkerPool<C, W, F>>;
100
101impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> ClassifiedWorkerPool<C, W, F> {
102    pub fn new(max_count: u16, factory: F) -> ClassifiedWorkerPoolRef<C, W, F> {
103        Arc::new(ClassifiedWorkerPool {
104            factory: Arc::new(factory),
105            max_count,
106            state: Mutex::new(WorkerPoolState {
107                current_count: 0,
108                classified_count_map: HashMap::new(),
109                worker_list: Vec::with_capacity(max_count as usize),
110                waiting_list: Vec::new(),
111                clear_notify: None,
112            }),
113        })
114    }
115
116    pub async fn get_worker(self: &ClassifiedWorkerPoolRef<C, W, F>) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
117        let wait = {
118            let mut state = self.state.lock().unwrap();
119            if state.clear_notify.is_some() {
120                return Err(PoolError::new(PoolErrorCode::Failed, "pool is clearing".to_string()));
121            }
122
123            while state.worker_list.len() > 0 {
124                let worker = state.worker_list.pop().unwrap();
125                if !worker.is_work() {
126                    state.current_count -= 1;
127                    state.dec_classified_count(worker.classification());
128                    continue;
129                }
130                return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
131            }
132
133            if state.current_count < self.max_count {
134                state.current_count += 1;
135                None
136            } else {
137                let (notify, waiter) = Notify::new();
138                state.waiting_list.push(WaitingItem {
139                    future: notify,
140                    condition: None,
141                });
142                Some(waiter)
143            }
144        };
145
146        if let Some(wait) = wait {
147            wait.await
148        } else {
149            let worker = match self.factory.create(None).await {
150                Ok(worker) => worker,
151                Err(err) => {
152                    let mut state = self.state.lock().unwrap();
153                    state.current_count -= 1;
154                    if state.current_count == 0 && state.clear_notify.is_some() {
155                        state.clear_notify.take().unwrap().notify(());
156                    }
157                    return Err(err)
158                },
159            };
160            let mut state = self.state.lock().unwrap();
161            state.inc_classified_count(worker.classification());
162            Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
163        }
164    }
165
166    pub async fn get_classified_worker(self: &ClassifiedWorkerPoolRef<C, W, F>, classification: C) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
167        let wait = {
168            let mut state = self.state.lock().unwrap();
169            if state.clear_notify.is_some() {
170                return Err(PoolError::new(PoolErrorCode::Failed, "pool is clearing".to_string()));
171            }
172
173            let old_count = state.worker_list.len() as u16;
174            let unwork_classification = state.worker_list.iter().filter(|worker| !worker.is_work()).map(|worker| worker.classification()).collect::<Vec<C>>();
175            for classification in unwork_classification.iter() {
176                state.dec_classified_count(classification.clone());
177            }
178            state.worker_list.retain(|worker| worker.is_work());
179            state.current_count -= old_count - state.worker_list.len() as u16;
180            for (index, worker) in state.worker_list.iter().enumerate() {
181                if worker.is_valid(classification.clone()) {
182                    let worker = state.worker_list.remove(index);
183                    return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
184                }
185            }
186
187            if state.current_count < self.max_count || state.get_classified_count(classification.clone()) == 0 {
188                state.current_count += 1;
189                None
190            } else {
191                let (notify, waiter) = Notify::new();
192                state.waiting_list.push(WaitingItem {
193                    future: notify,
194                    condition: Some(classification.clone()),
195                });
196                Some(waiter)
197            }
198        };
199
200        if let Some(wait) = wait {
201            wait.await
202        } else {
203            let worker = match self.factory.create(Some(classification)).await {
204                Ok(worker) => worker,
205                Err(err) => {
206                    let mut state = self.state.lock().unwrap();
207                    state.current_count -= 1;
208                    if state.current_count == 0 && state.clear_notify.is_some() {
209                        state.clear_notify.take().unwrap().notify(());
210                    }
211                    return Err(err)
212                },
213            };
214            let mut state = self.state.lock().unwrap();
215            state.inc_classified_count(worker.classification());
216            Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
217        }
218    }
219
220    pub async fn clear_all_worker(&self) {
221        let waiter = {
222            let mut state = self.state.lock().unwrap();
223            let cur_worker_count = state.worker_list.len();
224            state.worker_list.clear();
225            state.current_count -= cur_worker_count as u16;
226
227            for waiting in state.waiting_list.drain(..) {
228                waiting.future.notify(Err(PoolError::new(PoolErrorCode::Failed, "pool cleared".to_string())));
229            }
230            state.classified_count_map.clear();
231
232            if state.current_count == 0 {
233                return;
234            }
235            let (notify, waiter) = Notify::new();
236            state.clear_notify = Some(notify);
237            waiter
238        };
239        waiter.await;
240        {
241            let mut state = self.state.lock().unwrap();
242            for waiting in state.waiting_list.drain(..) {
243                waiting.future.notify(Err(PoolError::new(PoolErrorCode::Failed, "pool cleared".to_string())));
244            }
245            state.classified_count_map.clear();
246        }
247    }
248
249    fn release(self: &ClassifiedWorkerPoolRef<C, W, F>, work: W) {
250        {
251            let mut state = self.state.lock().unwrap();
252            if state.clear_notify.is_some() {
253                state.current_count -= 1;
254                if state.current_count == 0 {
255                    state.clear_notify.take().unwrap().notify(());
256                }
257                return;
258            }
259        }
260        if work.is_work() {
261            let mut state = self.state.lock().unwrap();
262            for (index, waiting) in state.waiting_list.iter().enumerate() {
263                if waiting.condition.is_none() {
264                    let waiting_item = state.waiting_list.remove(index);
265                    waiting_item.future.notify(Ok(ClassifiedWorkerGuard::new(work, self.clone())));
266                    return;
267                } else {
268                    if work.is_valid(waiting.condition.as_ref().unwrap().clone()) {
269                        let waiting_item = state.waiting_list.remove(index);
270                        waiting_item.future.notify(Ok(ClassifiedWorkerGuard::new(work, self.clone())));
271                        return;
272                    }
273                }
274            }
275            state.worker_list.push(work);
276        } else {
277            let mut state = self.state.lock().unwrap();
278            let classification = work.classification();
279            for (index, waiting) in state.waiting_list.iter().enumerate() {
280                if waiting.condition.is_none() {
281                    let waiting_item = state.waiting_list.remove(index);
282                    let factory = self.factory.clone();
283                    let this = self.clone();
284                    let classification = classification.clone();
285                    tokio::spawn(async move {
286                        match factory.create(Some(classification.clone())).await {
287                            Ok(worker) => {
288                                waiting_item.future.notify(Ok(ClassifiedWorkerGuard::new(worker, this)));
289                            }
290                            Err(err) => {
291                                let mut state = this.state.lock().unwrap();
292                                state.current_count -= 1;
293                                state.dec_classified_count(classification);
294                                waiting_item.future.notify(Err(err));
295                                if state.current_count == 0 && state.clear_notify.is_some() {
296                                    state.clear_notify.take().unwrap().notify(());
297                                }
298                            }
299                        }
300                    });
301                    return;
302                } else {
303                    if classification == waiting.condition.as_ref().unwrap().clone() {
304                        let waiting_item = state.waiting_list.remove(index);
305                        let factory = self.factory.clone();
306                        let this = self.clone();
307                        let classification = classification.clone();
308                        tokio::spawn(async move {
309                            match factory.create(Some(classification.clone())).await {
310                                Ok(worker) => {
311                                    waiting_item.future.notify(Ok(ClassifiedWorkerGuard::new(worker, this)));
312                                }
313                                Err(err) => {
314                                    let mut state = this.state.lock().unwrap();
315                                    state.current_count -= 1;
316                                    state.dec_classified_count(classification);
317                                    waiting_item.future.notify(Err(err));
318                                    if state.current_count == 0 && state.clear_notify.is_some() {
319                                        state.clear_notify.take().unwrap().notify(());
320                                    }
321                                }
322                            }
323                        });
324                        return;
325                    }
326                }
327            }
328            state.current_count -= 1;
329            state.dec_classified_count(classification);
330            if state.current_count == 0 && state.clear_notify.is_some() {
331                state.clear_notify.take().unwrap().notify(());
332            }
333        }
334    }
335}
336
337#[tokio::test]
338async fn test_pool() {
339    struct TestWorker {
340        work: bool,
341        classification: TestWorkerClassification,
342    }
343
344    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
345    enum TestWorkerClassification {
346        A,
347        B,
348    }
349    #[async_trait::async_trait]
350    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
351        fn is_work(&self) -> bool {
352            self.work
353        }
354
355        fn is_valid(&self, c: TestWorkerClassification) -> bool {
356            self.classification == c
357        }
358
359        fn classification(&self) -> TestWorkerClassification {
360            self.classification.clone()
361        }
362    }
363
364    struct TestWorkerFactory;
365
366    #[async_trait::async_trait]
367    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
368        async fn create(&self, classification: Option<TestWorkerClassification>) -> PoolResult<TestWorker> {
369            if let Some(classification) = classification {
370                Ok(TestWorker { work: true, classification })
371            } else {
372                Ok(TestWorker { work: true, classification: TestWorkerClassification::A })
373            }
374        }
375    }
376
377    let pool = ClassifiedWorkerPool::new(2, TestWorkerFactory);
378    let pool_ref = pool.clone();
379    tokio::spawn(async move {
380        let _worker = pool_ref.get_worker().await.unwrap();
381        tokio::time::sleep(std::time::Duration::from_secs(5)).await;
382    });
383    let pool_ref = pool.clone();
384    tokio::spawn(async move {
385        let _worker = pool_ref.get_worker().await.unwrap();
386        tokio::time::sleep(std::time::Duration::from_secs(10)).await;
387    });
388
389    let pool_ref = pool.clone();
390    tokio::spawn(async move {
391        let _worker = pool_ref.get_classified_worker(TestWorkerClassification::B).await.unwrap();
392        tokio::time::sleep(std::time::Duration::from_secs(6)).await;
393    });
394
395    let pool_ref = pool.clone();
396    tokio::spawn(async move {
397        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
398
399        let start = std::time::Instant::now();
400        let _worker3 = pool_ref.get_classified_worker(TestWorkerClassification::B).await.unwrap();
401        let end = std::time::Instant::now();
402        let duration = end.duration_since(start);
403        println!("classified duration {}", duration.as_millis());
404        assert!(duration.as_millis() > 2000);
405    });
406
407    let pool_ref = pool.clone();
408    tokio::spawn(async move {
409        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
410
411        let start = std::time::Instant::now();
412        let _worker3 = pool_ref.get_worker().await.unwrap();
413        let end = std::time::Instant::now();
414        let duration = end.duration_since(start);
415        println!("classified duration2 {}", duration.as_millis());
416        assert!(duration.as_millis() > 2000);
417    });
418
419    tokio::time::sleep(std::time::Duration::from_secs(15)).await;
420
421    let pool_ref = pool.clone();
422    tokio::spawn(async move {
423        let _worker = pool_ref.get_worker().await;
424        let _worker1 = pool_ref.get_worker().await;
425        tokio::time::sleep(Duration::from_secs(5)).await;
426    });
427
428    let pool_ref = pool.clone();
429    tokio::spawn(async move {
430        tokio::time::sleep(Duration::from_secs(1)).await;
431        let worker = pool_ref.get_worker().await;
432        assert!(worker.is_err());
433    });
434
435    let pool_ref = pool.clone();
436    tokio::spawn(async move {
437        tokio::time::sleep(Duration::from_secs(2)).await;
438        let worker = pool_ref.get_classified_worker(TestWorkerClassification::B).await;
439        assert!(worker.is_err());
440    });
441
442    let pool_ref = pool.clone();
443    tokio::spawn(async move {
444        let start = std::time::Instant::now();
445        pool_ref.clear_all_worker().await;
446        let end = std::time::Instant::now();
447        let duration = end.duration_since(start);
448        println!("classified duration3 {}", duration.as_millis());
449        assert!(duration.as_millis() > 4000);
450    });
451
452    tokio::time::sleep(Duration::from_secs(10)).await;
453}