sfo_pool/
classified_worker_pool.rs

1use std::collections::{HashMap};
2use std::hash::Hash;
3use std::ops::{Deref, DerefMut};
4use std::sync::{Arc, Mutex};
5use notify_future::NotifyFuture;
6use crate::PoolResult;
7
8pub trait WorkerClassification: Send + 'static + Clone + Hash + Eq + PartialEq {
9
10}
11
12impl<T: Send + 'static + Clone + Hash + Eq + PartialEq> WorkerClassification for T {
13
14}
15
16#[async_trait::async_trait]
17pub trait ClassifiedWorker<C: WorkerClassification>: Send + 'static {
18    fn is_work(&self) -> bool;
19    fn is_valid(&self, c: C) -> bool;
20    fn classification(&self) -> C;
21}
22
23pub struct ClassifiedWorkerGuard<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
24    pool_ref: ClassifiedWorkerPoolRef<C, W, F>,
25    worker: Option<W>
26}
27
28impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> ClassifiedWorkerGuard<C, W, F> {
29    fn new(worker: W, pool_ref: ClassifiedWorkerPoolRef<C, W, F>) -> Self {
30        ClassifiedWorkerGuard {
31            pool_ref,
32            worker: Some(worker)
33        }
34    }
35}
36
37impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Deref for ClassifiedWorkerGuard<C, W, F> {
38    type Target = W;
39
40    fn deref(&self) -> &Self::Target {
41        self.worker.as_ref().unwrap()
42    }
43}
44
45impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> DerefMut for ClassifiedWorkerGuard<C, W, F> {
46    fn deref_mut(&mut self) -> &mut Self::Target {
47        self.worker.as_mut().unwrap()
48    }
49}
50
51impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Drop for ClassifiedWorkerGuard<C, W, F> {
52    fn drop(&mut self) {
53        if let Some(worker) = self.worker.take() {
54            self.pool_ref.release(worker);
55        }
56    }
57}
58
59#[async_trait::async_trait]
60pub trait ClassifiedWorkerFactory<C: WorkerClassification, W: ClassifiedWorker<C>>: Send + Sync + 'static {
61    async fn create(&self, c: Option<C>) -> PoolResult<W>;
62}
63
64struct WaitingItem<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
65    future: NotifyFuture<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
66    condition: Option<C>,
67}
68struct WorkerPoolState<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
69    current_count: u16,
70    classified_count_map: HashMap<C, u16>,
71    worker_list: Vec<W>,
72    waiting_list: Vec<WaitingItem<C, W, F>>,
73}
74
75impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> WorkerPoolState<C, W, F> {
76    fn inc_classified_count(&mut self, c: C) {
77        let count = self.classified_count_map.entry(c).or_insert(0);
78        *count += 1;
79    }
80
81    fn dec_classified_count(&mut self, c: C) {
82        let count = self.classified_count_map.entry(c).or_insert(0);
83        *count -= 1;
84    }
85
86    fn get_classified_count(&self, c: C) -> u16 {
87        *self.classified_count_map.get(&c).unwrap_or(&0)
88    }
89}
90
91pub struct ClassifiedWorkerPool<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
92    factory: Arc<F>,
93    max_count: u16,
94    state: Mutex<WorkerPoolState<C, W, F>>,
95}
96pub type ClassifiedWorkerPoolRef<C, W, F> = Arc<ClassifiedWorkerPool<C, W, F>>;
97
98impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> ClassifiedWorkerPool<C, W, F> {
99    pub fn new(max_count: u16, factory: F) -> ClassifiedWorkerPoolRef<C, W, F> {
100        Arc::new(ClassifiedWorkerPool {
101            factory: Arc::new(factory),
102            max_count,
103            state: Mutex::new(WorkerPoolState {
104                current_count: 0,
105                classified_count_map: HashMap::new(),
106                worker_list: Vec::with_capacity(max_count as usize),
107                waiting_list: Vec::new(),
108            }),
109        })
110    }
111
112    pub async fn get_worker(self: &ClassifiedWorkerPoolRef<C, W, F>) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
113        let wait = {
114            let mut state = self.state.lock().unwrap();
115
116            while state.worker_list.len() > 0 {
117                let worker = state.worker_list.pop().unwrap();
118                if !worker.is_work() {
119                    state.current_count -= 1;
120                    state.dec_classified_count(worker.classification());
121                    continue;
122                }
123                return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
124            }
125
126            if state.current_count < self.max_count {
127                state.current_count += 1;
128                None
129            } else {
130                let future = NotifyFuture::new();
131                state.waiting_list.push(WaitingItem {
132                    future: future.clone(),
133                    condition: None,
134                });
135                Some(future)
136            }
137        };
138
139        if let Some(wait) = wait {
140            wait.await
141        } else {
142            let worker = match self.factory.create(None).await {
143                Ok(worker) => worker,
144                Err(err) => {
145                    let mut state = self.state.lock().unwrap();
146                    state.current_count -= 1;
147                    return Err(err)
148                },
149            };
150            let mut state = self.state.lock().unwrap();
151            state.inc_classified_count(worker.classification());
152            Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
153        }
154    }
155
156    pub async fn get_classified_worker(self: &ClassifiedWorkerPoolRef<C, W, F>, classification: C) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
157        let wait = {
158            let mut state = self.state.lock().unwrap();
159
160            let old_count = state.worker_list.len() as u16;
161            let unwork_classification = state.worker_list.iter().filter(|worker| !worker.is_work()).map(|worker| worker.classification()).collect::<Vec<C>>();
162            for classification in unwork_classification.iter() {
163                state.dec_classified_count(classification.clone());
164            }
165            state.worker_list.retain(|worker| worker.is_work());
166            state.current_count -= old_count - state.worker_list.len() as u16;
167            for (index, worker) in state.worker_list.iter().enumerate() {
168                if worker.is_valid(classification.clone()) {
169                    let worker = state.worker_list.remove(index);
170                    return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
171                }
172            }
173
174            if state.current_count < self.max_count || state.get_classified_count(classification.clone()) == 0 {
175                state.current_count += 1;
176                None
177            } else {
178                let future = NotifyFuture::new();
179                state.waiting_list.push(WaitingItem {
180                    future: future.clone(),
181                    condition: Some(classification.clone()),
182                });
183                Some(future)
184            }
185        };
186
187        if let Some(wait) = wait {
188            wait.await
189        } else {
190            let worker = match self.factory.create(Some(classification)).await {
191                Ok(worker) => worker,
192                Err(err) => {
193                    let mut state = self.state.lock().unwrap();
194                    state.current_count -= 1;
195                    return Err(err)
196                },
197            };
198            let mut state = self.state.lock().unwrap();
199            state.inc_classified_count(worker.classification());
200            Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
201        }
202    }
203
204    fn release(self: &ClassifiedWorkerPoolRef<C, W, F>, work: W) {
205        if work.is_work() {
206            let mut state = self.state.lock().unwrap();
207            for (index, waiting) in state.waiting_list.iter().enumerate() {
208                if waiting.condition.is_none() {
209                    let waiting_item = state.waiting_list.remove(index);
210                    waiting_item.future.set_complete(Ok(ClassifiedWorkerGuard::new(work, self.clone())));
211                    return;
212                } else {
213                    if work.is_valid(waiting.condition.as_ref().unwrap().clone()) {
214                        let waiting_item = state.waiting_list.remove(index);
215                        waiting_item.future.set_complete(Ok(ClassifiedWorkerGuard::new(work, self.clone())));
216                        return;
217                    }
218                }
219            }
220            state.worker_list.push(work);
221        } else {
222            let mut state = self.state.lock().unwrap();
223            let classification = work.classification();
224            for (index, waiting) in state.waiting_list.iter().enumerate() {
225                if waiting.condition.is_none() {
226                    let waiting_item = state.waiting_list.remove(index);
227                    let factory = self.factory.clone();
228                    let this = self.clone();
229                    let classification = classification.clone();
230                    tokio::spawn(async move {
231                        match factory.create(Some(classification.clone())).await {
232                            Ok(worker) => {
233                                waiting_item.future.set_complete(Ok(ClassifiedWorkerGuard::new(worker, this)));
234                            }
235                            Err(err) => {
236                                let mut state = this.state.lock().unwrap();
237                                state.current_count -= 1;
238                                state.dec_classified_count(classification);
239                                waiting_item.future.set_complete(Err(err));
240                            }
241                        }
242                    });
243                    return;
244                } else {
245                    if classification == waiting.condition.as_ref().unwrap().clone() {
246                        let waiting_item = state.waiting_list.remove(index);
247                        let factory = self.factory.clone();
248                        let this = self.clone();
249                        let classification = classification.clone();
250                        tokio::spawn(async move {
251                            match factory.create(Some(classification.clone())).await {
252                                Ok(worker) => {
253                                    waiting_item.future.set_complete(Ok(ClassifiedWorkerGuard::new(worker, this)));
254                                }
255                                Err(err) => {
256                                    let mut state = this.state.lock().unwrap();
257                                    state.current_count -= 1;
258                                    state.dec_classified_count(classification);
259                                    waiting_item.future.set_complete(Err(err));
260                                }
261                            }
262                        });
263                        return;
264                    }
265                }
266            }
267            state.current_count -= 1;
268            state.dec_classified_count(classification);
269        }
270    }
271}
272
273#[tokio::test]
274async fn test_pool() {
275    struct TestWorker {
276        work: bool,
277        classification: TestWorkerClassification,
278    }
279
280    #[derive(Clone, Debug, Eq, PartialEq, Hash)]
281    enum TestWorkerClassification {
282        A,
283        B,
284    }
285    #[async_trait::async_trait]
286    impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
287        fn is_work(&self) -> bool {
288            self.work
289        }
290
291        fn is_valid(&self, c: TestWorkerClassification) -> bool {
292            self.classification == c
293        }
294
295        fn classification(&self) -> TestWorkerClassification {
296            self.classification.clone()
297        }
298    }
299
300    struct TestWorkerFactory;
301
302    #[async_trait::async_trait]
303    impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
304        async fn create(&self, classification: Option<TestWorkerClassification>) -> PoolResult<TestWorker> {
305            if let Some(classification) = classification {
306                Ok(TestWorker { work: true, classification })
307            } else {
308                Ok(TestWorker { work: true, classification: TestWorkerClassification::A })
309            }
310        }
311    }
312
313    let pool = ClassifiedWorkerPool::new(2, TestWorkerFactory);
314    let pool_ref = pool.clone();
315    tokio::spawn(async move {
316        let _worker = pool_ref.get_worker().await.unwrap();
317        tokio::time::sleep(std::time::Duration::from_secs(5)).await;
318    });
319    let pool_ref = pool.clone();
320    tokio::spawn(async move {
321        let _worker = pool_ref.get_worker().await.unwrap();
322        tokio::time::sleep(std::time::Duration::from_secs(10)).await;
323    });
324
325    let pool_ref = pool.clone();
326    tokio::spawn(async move {
327        let _worker = pool_ref.get_classified_worker(TestWorkerClassification::B).await.unwrap();
328        tokio::time::sleep(std::time::Duration::from_secs(6)).await;
329    });
330
331    let pool_ref = pool.clone();
332    tokio::spawn(async move {
333        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
334
335        let start = std::time::Instant::now();
336        let _worker3 = pool_ref.get_classified_worker(TestWorkerClassification::B).await.unwrap();
337        let end = std::time::Instant::now();
338        let duration = end.duration_since(start);
339        println!("classified duration {}", duration.as_millis());
340        assert!(duration.as_millis() > 2000);
341    });
342
343    let pool_ref = pool.clone();
344    tokio::spawn(async move {
345        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
346
347        let start = std::time::Instant::now();
348        let _worker3 = pool_ref.get_worker().await.unwrap();
349        let end = std::time::Instant::now();
350        let duration = end.duration_since(start);
351        println!("classified duration2 {}", duration.as_millis());
352        assert!(duration.as_millis() > 2000);
353    });
354
355    tokio::time::sleep(std::time::Duration::from_secs(15)).await;
356}