1use std::collections::{HashMap};
2use std::hash::Hash;
3use std::ops::{Deref, DerefMut};
4use std::sync::{Arc, Mutex};
5use notify_future::NotifyFuture;
6use crate::PoolResult;
7
8pub trait WorkerClassification: Send + 'static + Clone + Hash + Eq + PartialEq {
9
10}
11
12impl<T: Send + 'static + Clone + Hash + Eq + PartialEq> WorkerClassification for T {
13
14}
15
16#[async_trait::async_trait]
17pub trait ClassifiedWorker<C: WorkerClassification>: Send + 'static {
18 fn is_work(&self) -> bool;
19 fn is_valid(&self, c: C) -> bool;
20 fn classification(&self) -> C;
21}
22
23pub struct ClassifiedWorkerGuard<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
24 pool_ref: ClassifiedWorkerPoolRef<C, W, F>,
25 worker: Option<W>
26}
27
28impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> ClassifiedWorkerGuard<C, W, F> {
29 fn new(worker: W, pool_ref: ClassifiedWorkerPoolRef<C, W, F>) -> Self {
30 ClassifiedWorkerGuard {
31 pool_ref,
32 worker: Some(worker)
33 }
34 }
35}
36
37impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Deref for ClassifiedWorkerGuard<C, W, F> {
38 type Target = W;
39
40 fn deref(&self) -> &Self::Target {
41 self.worker.as_ref().unwrap()
42 }
43}
44
45impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> DerefMut for ClassifiedWorkerGuard<C, W, F> {
46 fn deref_mut(&mut self) -> &mut Self::Target {
47 self.worker.as_mut().unwrap()
48 }
49}
50
51impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Drop for ClassifiedWorkerGuard<C, W, F> {
52 fn drop(&mut self) {
53 if let Some(worker) = self.worker.take() {
54 self.pool_ref.release(worker);
55 }
56 }
57}
58
59#[async_trait::async_trait]
60pub trait ClassifiedWorkerFactory<C: WorkerClassification, W: ClassifiedWorker<C>>: Send + Sync + 'static {
61 async fn create(&self, c: Option<C>) -> PoolResult<W>;
62}
63
64struct WaitingItem<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
65 future: NotifyFuture<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
66 condition: Option<C>,
67}
68struct WorkerPoolState<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
69 current_count: u16,
70 classified_count_map: HashMap<C, u16>,
71 worker_list: Vec<W>,
72 waiting_list: Vec<WaitingItem<C, W, F>>,
73}
74
75impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> WorkerPoolState<C, W, F> {
76 fn inc_classified_count(&mut self, c: C) {
77 let count = self.classified_count_map.entry(c).or_insert(0);
78 *count += 1;
79 }
80
81 fn dec_classified_count(&mut self, c: C) {
82 let count = self.classified_count_map.entry(c).or_insert(0);
83 *count -= 1;
84 }
85
86 fn get_classified_count(&self, c: C) -> u16 {
87 *self.classified_count_map.get(&c).unwrap_or(&0)
88 }
89}
90
91pub struct ClassifiedWorkerPool<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
92 factory: Arc<F>,
93 max_count: u16,
94 state: Mutex<WorkerPoolState<C, W, F>>,
95}
96pub type ClassifiedWorkerPoolRef<C, W, F> = Arc<ClassifiedWorkerPool<C, W, F>>;
97
98impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> ClassifiedWorkerPool<C, W, F> {
99 pub fn new(max_count: u16, factory: F) -> ClassifiedWorkerPoolRef<C, W, F> {
100 Arc::new(ClassifiedWorkerPool {
101 factory: Arc::new(factory),
102 max_count,
103 state: Mutex::new(WorkerPoolState {
104 current_count: 0,
105 classified_count_map: HashMap::new(),
106 worker_list: Vec::with_capacity(max_count as usize),
107 waiting_list: Vec::new(),
108 }),
109 })
110 }
111
112 pub async fn get_worker(self: &ClassifiedWorkerPoolRef<C, W, F>) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
113 let wait = {
114 let mut state = self.state.lock().unwrap();
115
116 while state.worker_list.len() > 0 {
117 let worker = state.worker_list.pop().unwrap();
118 if !worker.is_work() {
119 state.current_count -= 1;
120 state.dec_classified_count(worker.classification());
121 continue;
122 }
123 return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
124 }
125
126 if state.current_count < self.max_count {
127 state.current_count += 1;
128 None
129 } else {
130 let future = NotifyFuture::new();
131 state.waiting_list.push(WaitingItem {
132 future: future.clone(),
133 condition: None,
134 });
135 Some(future)
136 }
137 };
138
139 if let Some(wait) = wait {
140 wait.await
141 } else {
142 let worker = match self.factory.create(None).await {
143 Ok(worker) => worker,
144 Err(err) => {
145 let mut state = self.state.lock().unwrap();
146 state.current_count -= 1;
147 return Err(err)
148 },
149 };
150 let mut state = self.state.lock().unwrap();
151 state.inc_classified_count(worker.classification());
152 Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
153 }
154 }
155
156 pub async fn get_classified_worker(self: &ClassifiedWorkerPoolRef<C, W, F>, classification: C) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
157 let wait = {
158 let mut state = self.state.lock().unwrap();
159
160 let old_count = state.worker_list.len() as u16;
161 let unwork_classification = state.worker_list.iter().filter(|worker| !worker.is_work()).map(|worker| worker.classification()).collect::<Vec<C>>();
162 for classification in unwork_classification.iter() {
163 state.dec_classified_count(classification.clone());
164 }
165 state.worker_list.retain(|worker| worker.is_work());
166 state.current_count -= old_count - state.worker_list.len() as u16;
167 for (index, worker) in state.worker_list.iter().enumerate() {
168 if worker.is_valid(classification.clone()) {
169 let worker = state.worker_list.remove(index);
170 return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
171 }
172 }
173
174 if state.current_count < self.max_count || state.get_classified_count(classification.clone()) == 0 {
175 state.current_count += 1;
176 None
177 } else {
178 let future = NotifyFuture::new();
179 state.waiting_list.push(WaitingItem {
180 future: future.clone(),
181 condition: Some(classification.clone()),
182 });
183 Some(future)
184 }
185 };
186
187 if let Some(wait) = wait {
188 wait.await
189 } else {
190 let worker = match self.factory.create(Some(classification)).await {
191 Ok(worker) => worker,
192 Err(err) => {
193 let mut state = self.state.lock().unwrap();
194 state.current_count -= 1;
195 return Err(err)
196 },
197 };
198 let mut state = self.state.lock().unwrap();
199 state.inc_classified_count(worker.classification());
200 Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
201 }
202 }
203
204 fn release(self: &ClassifiedWorkerPoolRef<C, W, F>, work: W) {
205 if work.is_work() {
206 let mut state = self.state.lock().unwrap();
207 for (index, waiting) in state.waiting_list.iter().enumerate() {
208 if waiting.condition.is_none() {
209 let waiting_item = state.waiting_list.remove(index);
210 waiting_item.future.set_complete(Ok(ClassifiedWorkerGuard::new(work, self.clone())));
211 return;
212 } else {
213 if work.is_valid(waiting.condition.as_ref().unwrap().clone()) {
214 let waiting_item = state.waiting_list.remove(index);
215 waiting_item.future.set_complete(Ok(ClassifiedWorkerGuard::new(work, self.clone())));
216 return;
217 }
218 }
219 }
220 state.worker_list.push(work);
221 } else {
222 let mut state = self.state.lock().unwrap();
223 let classification = work.classification();
224 for (index, waiting) in state.waiting_list.iter().enumerate() {
225 if waiting.condition.is_none() {
226 let waiting_item = state.waiting_list.remove(index);
227 let factory = self.factory.clone();
228 let this = self.clone();
229 let classification = classification.clone();
230 tokio::spawn(async move {
231 match factory.create(Some(classification.clone())).await {
232 Ok(worker) => {
233 waiting_item.future.set_complete(Ok(ClassifiedWorkerGuard::new(worker, this)));
234 }
235 Err(err) => {
236 let mut state = this.state.lock().unwrap();
237 state.current_count -= 1;
238 state.dec_classified_count(classification);
239 waiting_item.future.set_complete(Err(err));
240 }
241 }
242 });
243 return;
244 } else {
245 if classification == waiting.condition.as_ref().unwrap().clone() {
246 let waiting_item = state.waiting_list.remove(index);
247 let factory = self.factory.clone();
248 let this = self.clone();
249 let classification = classification.clone();
250 tokio::spawn(async move {
251 match factory.create(Some(classification.clone())).await {
252 Ok(worker) => {
253 waiting_item.future.set_complete(Ok(ClassifiedWorkerGuard::new(worker, this)));
254 }
255 Err(err) => {
256 let mut state = this.state.lock().unwrap();
257 state.current_count -= 1;
258 state.dec_classified_count(classification);
259 waiting_item.future.set_complete(Err(err));
260 }
261 }
262 });
263 return;
264 }
265 }
266 }
267 state.current_count -= 1;
268 state.dec_classified_count(classification);
269 }
270 }
271}
272
273#[tokio::test]
274async fn test_pool() {
275 struct TestWorker {
276 work: bool,
277 classification: TestWorkerClassification,
278 }
279
280 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
281 enum TestWorkerClassification {
282 A,
283 B,
284 }
285 #[async_trait::async_trait]
286 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
287 fn is_work(&self) -> bool {
288 self.work
289 }
290
291 fn is_valid(&self, c: TestWorkerClassification) -> bool {
292 self.classification == c
293 }
294
295 fn classification(&self) -> TestWorkerClassification {
296 self.classification.clone()
297 }
298 }
299
300 struct TestWorkerFactory;
301
302 #[async_trait::async_trait]
303 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
304 async fn create(&self, classification: Option<TestWorkerClassification>) -> PoolResult<TestWorker> {
305 if let Some(classification) = classification {
306 Ok(TestWorker { work: true, classification })
307 } else {
308 Ok(TestWorker { work: true, classification: TestWorkerClassification::A })
309 }
310 }
311 }
312
313 let pool = ClassifiedWorkerPool::new(2, TestWorkerFactory);
314 let pool_ref = pool.clone();
315 tokio::spawn(async move {
316 let _worker = pool_ref.get_worker().await.unwrap();
317 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
318 });
319 let pool_ref = pool.clone();
320 tokio::spawn(async move {
321 let _worker = pool_ref.get_worker().await.unwrap();
322 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
323 });
324
325 let pool_ref = pool.clone();
326 tokio::spawn(async move {
327 let _worker = pool_ref.get_classified_worker(TestWorkerClassification::B).await.unwrap();
328 tokio::time::sleep(std::time::Duration::from_secs(6)).await;
329 });
330
331 let pool_ref = pool.clone();
332 tokio::spawn(async move {
333 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
334
335 let start = std::time::Instant::now();
336 let _worker3 = pool_ref.get_classified_worker(TestWorkerClassification::B).await.unwrap();
337 let end = std::time::Instant::now();
338 let duration = end.duration_since(start);
339 println!("classified duration {}", duration.as_millis());
340 assert!(duration.as_millis() > 2000);
341 });
342
343 let pool_ref = pool.clone();
344 tokio::spawn(async move {
345 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
346
347 let start = std::time::Instant::now();
348 let _worker3 = pool_ref.get_worker().await.unwrap();
349 let end = std::time::Instant::now();
350 let duration = end.duration_since(start);
351 println!("classified duration2 {}", duration.as_millis());
352 assert!(duration.as_millis() > 2000);
353 });
354
355 tokio::time::sleep(std::time::Duration::from_secs(15)).await;
356}