1use std::collections::{HashMap};
2use std::hash::Hash;
3use std::ops::{Deref, DerefMut};
4use std::sync::{Arc, Mutex};
5use std::thread::sleep;
6use std::time::Duration;
7use notify_future::{Notify};
8use crate::{PoolError, PoolErrorCode, PoolResult};
9
10pub trait WorkerClassification: Send + 'static + Clone + Hash + Eq + PartialEq {
11
12}
13
14impl<T: Send + 'static + Clone + Hash + Eq + PartialEq> WorkerClassification for T {
15
16}
17
18#[async_trait::async_trait]
19pub trait ClassifiedWorker<C: WorkerClassification>: Send + 'static {
20 fn is_work(&self) -> bool;
21 fn is_valid(&self, c: C) -> bool;
22 fn classification(&self) -> C;
23}
24
25pub struct ClassifiedWorkerGuard<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
26 pool_ref: ClassifiedWorkerPoolRef<C, W, F>,
27 worker: Option<W>
28}
29
30impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> ClassifiedWorkerGuard<C, W, F> {
31 fn new(worker: W, pool_ref: ClassifiedWorkerPoolRef<C, W, F>) -> Self {
32 ClassifiedWorkerGuard {
33 pool_ref,
34 worker: Some(worker)
35 }
36 }
37}
38
39impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Deref for ClassifiedWorkerGuard<C, W, F> {
40 type Target = W;
41
42 fn deref(&self) -> &Self::Target {
43 self.worker.as_ref().unwrap()
44 }
45}
46
47impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> DerefMut for ClassifiedWorkerGuard<C, W, F> {
48 fn deref_mut(&mut self) -> &mut Self::Target {
49 self.worker.as_mut().unwrap()
50 }
51}
52
53impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> Drop for ClassifiedWorkerGuard<C, W, F> {
54 fn drop(&mut self) {
55 if let Some(worker) = self.worker.take() {
56 self.pool_ref.release(worker);
57 }
58 }
59}
60
61#[async_trait::async_trait]
62pub trait ClassifiedWorkerFactory<C: WorkerClassification, W: ClassifiedWorker<C>>: Send + Sync + 'static {
63 async fn create(&self, c: Option<C>) -> PoolResult<W>;
64}
65
66struct WaitingItem<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
67 future: Notify<PoolResult<ClassifiedWorkerGuard<C, W, F>>>,
68 condition: Option<C>,
69}
70struct WorkerPoolState<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
71 current_count: u16,
72 classified_count_map: HashMap<C, u16>,
73 worker_list: Vec<W>,
74 waiting_list: Vec<WaitingItem<C, W, F>>,
75 clear_notify: Option<Notify<()>>,
76}
77
78impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> WorkerPoolState<C, W, F> {
79 fn inc_classified_count(&mut self, c: C) {
80 let count = self.classified_count_map.entry(c).or_insert(0);
81 *count += 1;
82 }
83
84 fn dec_classified_count(&mut self, c: C) {
85 let count = self.classified_count_map.entry(c).or_insert(0);
86 *count -= 1;
87 }
88
89 fn get_classified_count(&self, c: C) -> u16 {
90 *self.classified_count_map.get(&c).unwrap_or(&0)
91 }
92}
93
94pub struct ClassifiedWorkerPool<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> {
95 factory: Arc<F>,
96 max_count: u16,
97 state: Mutex<WorkerPoolState<C, W, F>>,
98}
99pub type ClassifiedWorkerPoolRef<C, W, F> = Arc<ClassifiedWorkerPool<C, W, F>>;
100
101impl<C: WorkerClassification, W: ClassifiedWorker<C>, F: ClassifiedWorkerFactory<C, W>> ClassifiedWorkerPool<C, W, F> {
102 pub fn new(max_count: u16, factory: F) -> ClassifiedWorkerPoolRef<C, W, F> {
103 Arc::new(ClassifiedWorkerPool {
104 factory: Arc::new(factory),
105 max_count,
106 state: Mutex::new(WorkerPoolState {
107 current_count: 0,
108 classified_count_map: HashMap::new(),
109 worker_list: Vec::with_capacity(max_count as usize),
110 waiting_list: Vec::new(),
111 clear_notify: None,
112 }),
113 })
114 }
115
116 pub async fn get_worker(self: &ClassifiedWorkerPoolRef<C, W, F>) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
117 let wait = {
118 let mut state = self.state.lock().unwrap();
119 if state.clear_notify.is_some() {
120 return Err(PoolError::new(PoolErrorCode::Failed, "pool is clearing".to_string()));
121 }
122
123 while state.worker_list.len() > 0 {
124 let worker = state.worker_list.pop().unwrap();
125 if !worker.is_work() {
126 state.current_count -= 1;
127 state.dec_classified_count(worker.classification());
128 continue;
129 }
130 return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
131 }
132
133 if state.current_count < self.max_count {
134 state.current_count += 1;
135 None
136 } else {
137 let (notify, waiter) = Notify::new();
138 state.waiting_list.push(WaitingItem {
139 future: notify,
140 condition: None,
141 });
142 Some(waiter)
143 }
144 };
145
146 if let Some(wait) = wait {
147 wait.await
148 } else {
149 let worker = match self.factory.create(None).await {
150 Ok(worker) => worker,
151 Err(err) => {
152 let mut state = self.state.lock().unwrap();
153 state.current_count -= 1;
154 if state.current_count == 0 && state.clear_notify.is_some() {
155 state.clear_notify.take().unwrap().notify(());
156 }
157 return Err(err)
158 },
159 };
160 let mut state = self.state.lock().unwrap();
161 state.inc_classified_count(worker.classification());
162 Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
163 }
164 }
165
166 pub async fn get_classified_worker(self: &ClassifiedWorkerPoolRef<C, W, F>, classification: C) -> PoolResult<ClassifiedWorkerGuard<C, W, F>> {
167 let wait = {
168 let mut state = self.state.lock().unwrap();
169 if state.clear_notify.is_some() {
170 return Err(PoolError::new(PoolErrorCode::Failed, "pool is clearing".to_string()));
171 }
172
173 let old_count = state.worker_list.len() as u16;
174 let unwork_classification = state.worker_list.iter().filter(|worker| !worker.is_work()).map(|worker| worker.classification()).collect::<Vec<C>>();
175 for classification in unwork_classification.iter() {
176 state.dec_classified_count(classification.clone());
177 }
178 state.worker_list.retain(|worker| worker.is_work());
179 state.current_count -= old_count - state.worker_list.len() as u16;
180 for (index, worker) in state.worker_list.iter().enumerate() {
181 if worker.is_valid(classification.clone()) {
182 let worker = state.worker_list.remove(index);
183 return Ok(ClassifiedWorkerGuard::new(worker, self.clone()));
184 }
185 }
186
187 if state.current_count < self.max_count || state.get_classified_count(classification.clone()) == 0 {
188 state.current_count += 1;
189 None
190 } else {
191 let (notify, waiter) = Notify::new();
192 state.waiting_list.push(WaitingItem {
193 future: notify,
194 condition: Some(classification.clone()),
195 });
196 Some(waiter)
197 }
198 };
199
200 if let Some(wait) = wait {
201 wait.await
202 } else {
203 let worker = match self.factory.create(Some(classification)).await {
204 Ok(worker) => worker,
205 Err(err) => {
206 let mut state = self.state.lock().unwrap();
207 state.current_count -= 1;
208 if state.current_count == 0 && state.clear_notify.is_some() {
209 state.clear_notify.take().unwrap().notify(());
210 }
211 return Err(err)
212 },
213 };
214 let mut state = self.state.lock().unwrap();
215 state.inc_classified_count(worker.classification());
216 Ok(ClassifiedWorkerGuard::new(worker, self.clone()))
217 }
218 }
219
220 pub async fn clear_all_worker(&self) {
221 let waiter = {
222 let mut state = self.state.lock().unwrap();
223 let cur_worker_count = state.worker_list.len();
224 state.worker_list.clear();
225 state.current_count -= cur_worker_count as u16;
226
227 for waiting in state.waiting_list.drain(..) {
228 waiting.future.notify(Err(PoolError::new(PoolErrorCode::Failed, "pool cleared".to_string())));
229 }
230 state.classified_count_map.clear();
231
232 if state.current_count == 0 {
233 return;
234 }
235 let (notify, waiter) = Notify::new();
236 state.clear_notify = Some(notify);
237 waiter
238 };
239 waiter.await;
240 {
241 let mut state = self.state.lock().unwrap();
242 for waiting in state.waiting_list.drain(..) {
243 waiting.future.notify(Err(PoolError::new(PoolErrorCode::Failed, "pool cleared".to_string())));
244 }
245 state.classified_count_map.clear();
246 }
247 }
248
249 fn release(self: &ClassifiedWorkerPoolRef<C, W, F>, work: W) {
250 {
251 let mut state = self.state.lock().unwrap();
252 if state.clear_notify.is_some() {
253 state.current_count -= 1;
254 if state.current_count == 0 {
255 state.clear_notify.take().unwrap().notify(());
256 }
257 return;
258 }
259 }
260 if work.is_work() {
261 let mut state = self.state.lock().unwrap();
262 for (index, waiting) in state.waiting_list.iter().enumerate() {
263 if waiting.condition.is_none() {
264 let waiting_item = state.waiting_list.remove(index);
265 waiting_item.future.notify(Ok(ClassifiedWorkerGuard::new(work, self.clone())));
266 return;
267 } else {
268 if work.is_valid(waiting.condition.as_ref().unwrap().clone()) {
269 let waiting_item = state.waiting_list.remove(index);
270 waiting_item.future.notify(Ok(ClassifiedWorkerGuard::new(work, self.clone())));
271 return;
272 }
273 }
274 }
275 state.worker_list.push(work);
276 } else {
277 let mut state = self.state.lock().unwrap();
278 let classification = work.classification();
279 for (index, waiting) in state.waiting_list.iter().enumerate() {
280 if waiting.condition.is_none() {
281 let waiting_item = state.waiting_list.remove(index);
282 let factory = self.factory.clone();
283 let this = self.clone();
284 let classification = classification.clone();
285 tokio::spawn(async move {
286 match factory.create(Some(classification.clone())).await {
287 Ok(worker) => {
288 waiting_item.future.notify(Ok(ClassifiedWorkerGuard::new(worker, this)));
289 }
290 Err(err) => {
291 let mut state = this.state.lock().unwrap();
292 state.current_count -= 1;
293 state.dec_classified_count(classification);
294 waiting_item.future.notify(Err(err));
295 if state.current_count == 0 && state.clear_notify.is_some() {
296 state.clear_notify.take().unwrap().notify(());
297 }
298 }
299 }
300 });
301 return;
302 } else {
303 if classification == waiting.condition.as_ref().unwrap().clone() {
304 let waiting_item = state.waiting_list.remove(index);
305 let factory = self.factory.clone();
306 let this = self.clone();
307 let classification = classification.clone();
308 tokio::spawn(async move {
309 match factory.create(Some(classification.clone())).await {
310 Ok(worker) => {
311 waiting_item.future.notify(Ok(ClassifiedWorkerGuard::new(worker, this)));
312 }
313 Err(err) => {
314 let mut state = this.state.lock().unwrap();
315 state.current_count -= 1;
316 state.dec_classified_count(classification);
317 waiting_item.future.notify(Err(err));
318 if state.current_count == 0 && state.clear_notify.is_some() {
319 state.clear_notify.take().unwrap().notify(());
320 }
321 }
322 }
323 });
324 return;
325 }
326 }
327 }
328 state.current_count -= 1;
329 state.dec_classified_count(classification);
330 if state.current_count == 0 && state.clear_notify.is_some() {
331 state.clear_notify.take().unwrap().notify(());
332 }
333 }
334 }
335}
336
337#[tokio::test]
338async fn test_pool() {
339 struct TestWorker {
340 work: bool,
341 classification: TestWorkerClassification,
342 }
343
344 #[derive(Clone, Debug, Eq, PartialEq, Hash)]
345 enum TestWorkerClassification {
346 A,
347 B,
348 }
349 #[async_trait::async_trait]
350 impl ClassifiedWorker<TestWorkerClassification> for TestWorker {
351 fn is_work(&self) -> bool {
352 self.work
353 }
354
355 fn is_valid(&self, c: TestWorkerClassification) -> bool {
356 self.classification == c
357 }
358
359 fn classification(&self) -> TestWorkerClassification {
360 self.classification.clone()
361 }
362 }
363
364 struct TestWorkerFactory;
365
366 #[async_trait::async_trait]
367 impl ClassifiedWorkerFactory<TestWorkerClassification, TestWorker> for TestWorkerFactory {
368 async fn create(&self, classification: Option<TestWorkerClassification>) -> PoolResult<TestWorker> {
369 if let Some(classification) = classification {
370 Ok(TestWorker { work: true, classification })
371 } else {
372 Ok(TestWorker { work: true, classification: TestWorkerClassification::A })
373 }
374 }
375 }
376
377 let pool = ClassifiedWorkerPool::new(2, TestWorkerFactory);
378 let pool_ref = pool.clone();
379 tokio::spawn(async move {
380 let _worker = pool_ref.get_worker().await.unwrap();
381 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
382 });
383 let pool_ref = pool.clone();
384 tokio::spawn(async move {
385 let _worker = pool_ref.get_worker().await.unwrap();
386 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
387 });
388
389 let pool_ref = pool.clone();
390 tokio::spawn(async move {
391 let _worker = pool_ref.get_classified_worker(TestWorkerClassification::B).await.unwrap();
392 tokio::time::sleep(std::time::Duration::from_secs(6)).await;
393 });
394
395 let pool_ref = pool.clone();
396 tokio::spawn(async move {
397 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
398
399 let start = std::time::Instant::now();
400 let _worker3 = pool_ref.get_classified_worker(TestWorkerClassification::B).await.unwrap();
401 let end = std::time::Instant::now();
402 let duration = end.duration_since(start);
403 println!("classified duration {}", duration.as_millis());
404 assert!(duration.as_millis() > 2000);
405 });
406
407 let pool_ref = pool.clone();
408 tokio::spawn(async move {
409 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
410
411 let start = std::time::Instant::now();
412 let _worker3 = pool_ref.get_worker().await.unwrap();
413 let end = std::time::Instant::now();
414 let duration = end.duration_since(start);
415 println!("classified duration2 {}", duration.as_millis());
416 assert!(duration.as_millis() > 2000);
417 });
418
419 tokio::time::sleep(std::time::Duration::from_secs(15)).await;
420
421 let pool_ref = pool.clone();
422 tokio::spawn(async move {
423 let _worker = pool_ref.get_worker().await;
424 let _worker1 = pool_ref.get_worker().await;
425 tokio::time::sleep(Duration::from_secs(5)).await;
426 });
427
428 let pool_ref = pool.clone();
429 tokio::spawn(async move {
430 tokio::time::sleep(Duration::from_secs(1)).await;
431 let worker = pool_ref.get_worker().await;
432 assert!(worker.is_err());
433 });
434
435 let pool_ref = pool.clone();
436 tokio::spawn(async move {
437 tokio::time::sleep(Duration::from_secs(2)).await;
438 let worker = pool_ref.get_classified_worker(TestWorkerClassification::B).await;
439 assert!(worker.is_err());
440 });
441
442 let pool_ref = pool.clone();
443 tokio::spawn(async move {
444 let start = std::time::Instant::now();
445 pool_ref.clear_all_worker().await;
446 let end = std::time::Instant::now();
447 let duration = end.duration_since(start);
448 println!("classified duration3 {}", duration.as_millis());
449 assert!(duration.as_millis() > 4000);
450 });
451
452 tokio::time::sleep(Duration::from_secs(10)).await;
453}