1use notify_future::Notify;
2pub use sfo_result::err as pool_err;
3pub use sfo_result::into_err as into_pool_err;
4use std::collections::VecDeque;
5use std::ops::{Deref, DerefMut};
6use std::sync::{Arc, Mutex};
7
8#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
9pub enum PoolErrorCode {
10 #[default]
11 Failed,
12 Clearing,
13 Cleared,
14 InvalidConfig,
15}
16pub type PoolError = sfo_result::Error<PoolErrorCode>;
17pub type PoolResult<T> = sfo_result::Result<T, PoolErrorCode>;
18
19pub(crate) fn pool_error(code: PoolErrorCode, message: &str) -> PoolError {
20 PoolError::new(code, message.to_string())
21}
22
23pub(crate) fn pool_clearing_error() -> PoolError {
24 pool_error(PoolErrorCode::Clearing, "pool is clearing")
25}
26
27pub(crate) fn pool_cleared_error() -> PoolError {
28 pool_error(PoolErrorCode::Cleared, "pool cleared")
29}
30
31pub(crate) fn pool_invalid_config_error(message: &str) -> PoolError {
32 pool_error(PoolErrorCode::InvalidConfig, message)
33}
34
35#[async_trait::async_trait]
36pub trait Worker: Send + 'static {
37 fn is_work(&self) -> bool;
38}
39
40pub struct WorkerGuard<W: Worker, F: WorkerFactory<W>> {
41 pool_ref: WorkerPoolRef<W, F>,
42 worker: Option<W>,
43}
44
45impl<W: Worker, F: WorkerFactory<W>> WorkerGuard<W, F> {
46 fn new(worker: W, pool_ref: WorkerPoolRef<W, F>) -> Self {
47 WorkerGuard {
48 pool_ref,
49 worker: Some(worker),
50 }
51 }
52}
53
54impl<W: Worker, F: WorkerFactory<W>> Deref for WorkerGuard<W, F> {
55 type Target = W;
56
57 fn deref(&self) -> &Self::Target {
58 self.worker.as_ref().unwrap()
59 }
60}
61
62impl<W: Worker, F: WorkerFactory<W>> DerefMut for WorkerGuard<W, F> {
63 fn deref_mut(&mut self) -> &mut Self::Target {
64 self.worker.as_mut().unwrap()
65 }
66}
67
68impl<W: Worker, F: WorkerFactory<W>> Drop for WorkerGuard<W, F> {
69 fn drop(&mut self) {
70 if let Some(worker) = self.worker.take() {
71 self.pool_ref.release(worker);
72 }
73 }
74}
75
76#[async_trait::async_trait]
77pub trait WorkerFactory<W: Worker>: Send + Sync + 'static {
78 async fn create(&self) -> PoolResult<W>;
79}
80
81struct WorkerPoolState<W: Worker, F: WorkerFactory<W>> {
82 current_count: u16,
83 worker_list: VecDeque<W>,
84 waiting_list: VecDeque<Notify<PoolResult<WorkerGuard<W, F>>>>,
85 clearing: bool,
86 clear_waiting_list: Vec<Notify<()>>,
87}
88
89impl<W: Worker, F: WorkerFactory<W>> WorkerPoolState<W, F> {
90 fn take_clear_waiters_if_done(&mut self) -> Vec<Notify<()>> {
91 if self.clearing && self.current_count == 0 {
92 self.clearing = false;
93 self.clear_waiting_list.drain(..).collect()
94 } else {
95 Vec::new()
96 }
97 }
98}
99pub struct WorkerPool<W: Worker, F: WorkerFactory<W>> {
100 factory: Arc<F>,
101 max_count: u16,
102 state: Mutex<WorkerPoolState<W, F>>,
103}
104pub type WorkerPoolRef<W, F> = Arc<WorkerPool<W, F>>;
105
106impl<W: Worker, F: WorkerFactory<W>> WorkerPool<W, F> {
107 pub fn new(max_count: u16, factory: F) -> WorkerPoolRef<W, F> {
108 Arc::new(WorkerPool {
109 factory: Arc::new(factory),
110 max_count,
111 state: Mutex::new(WorkerPoolState {
112 current_count: 0,
113 worker_list: VecDeque::with_capacity(max_count as usize),
114 waiting_list: VecDeque::new(),
115 clearing: false,
116 clear_waiting_list: Vec::new(),
117 }),
118 })
119 }
120
121 pub async fn get_worker(self: &WorkerPoolRef<W, F>) -> PoolResult<WorkerGuard<W, F>> {
122 if self.max_count == 0 {
123 return Err(pool_invalid_config_error("pool max_count is zero"));
124 }
125
126 let wait = {
127 let mut state = self.state.lock().unwrap();
128 if state.clearing {
129 return Err(pool_clearing_error());
130 }
131
132 while state.worker_list.len() > 0 {
133 let worker = state.worker_list.pop_front().unwrap();
134 if !worker.is_work() {
135 state.current_count -= 1;
136 continue;
137 }
138 return Ok(WorkerGuard::new(worker, self.clone()));
139 }
140
141 if state.current_count < self.max_count {
142 state.current_count += 1;
143 None
144 } else {
145 let (notify, waiter) = Notify::new();
146 state.waiting_list.push_back(notify);
147 Some(waiter)
148 }
149 };
150
151 if let Some(wait) = wait {
152 wait.await
153 } else {
154 let worker = match self.factory.create().await {
155 Ok(worker) => worker,
156 Err(err) => {
157 let mut state = self.state.lock().unwrap();
158 state.current_count -= 1;
159 let clear_waiters = state.take_clear_waiters_if_done();
160 drop(state);
161 for waiter in clear_waiters {
162 waiter.notify(());
163 }
164 return Err(err);
165 }
166 };
167 let (clearing, clear_waiters) = {
168 let mut state = self.state.lock().unwrap();
169 if state.clearing {
170 state.current_count -= 1;
171 (true, state.take_clear_waiters_if_done())
172 } else {
173 (false, Vec::new())
174 }
175 };
176 for waiter in clear_waiters {
177 waiter.notify(());
178 }
179 if clearing {
180 return Err(pool_cleared_error());
181 }
182 Ok(WorkerGuard::new(worker, self.clone()))
183 }
184 }
185
186 pub async fn clear_all_worker(&self) {
187 let (waiter, waiting_list, clear_waiters) = {
188 let mut state = self.state.lock().unwrap();
189 if !state.clearing {
190 state.clearing = true;
191 let cur_worker_count = state.worker_list.len();
192 state.worker_list.clear();
193 state.current_count -= cur_worker_count as u16;
194 }
195
196 let waiting_list = state.waiting_list.drain(..).collect::<Vec<_>>();
197 if state.current_count == 0 {
198 let clear_waiters = state.take_clear_waiters_if_done();
199 (None, waiting_list, clear_waiters)
200 } else {
201 let (notify, waiter) = Notify::new();
202 state.clear_waiting_list.push(notify);
203 (Some(waiter), waiting_list, Vec::new())
204 }
205 };
206 for waiting in waiting_list {
207 waiting.notify(Err(pool_cleared_error()));
208 }
209 for waiter in clear_waiters {
210 waiter.notify(());
211 }
212 if let Some(waiter) = waiter {
213 waiter.await;
214 }
215 }
216
217 fn release(self: &WorkerPoolRef<W, F>, work: W) {
218 enum ReleaseAction<W: Worker, F: WorkerFactory<W>> {
219 None,
220 Notify(Notify<PoolResult<WorkerGuard<W, F>>>, WorkerGuard<W, F>),
221 Replace(Notify<PoolResult<WorkerGuard<W, F>>>),
222 }
223
224 let mut clear_waiters = Vec::new();
225 let action = {
226 let mut state = self.state.lock().unwrap();
227 if state.clearing {
228 state.current_count -= 1;
229 clear_waiters = state.take_clear_waiters_if_done();
230 ReleaseAction::None
231 } else if work.is_work() {
232 let future = state.waiting_list.pop_front();
233 if let Some(future) = future {
234 ReleaseAction::Notify(future, WorkerGuard::new(work, self.clone()))
235 } else {
236 state.worker_list.push_back(work);
237 ReleaseAction::None
238 }
239 } else {
240 let future = state.waiting_list.pop_front();
241 if let Some(future) = future {
242 ReleaseAction::Replace(future)
243 } else {
244 state.current_count -= 1;
245 clear_waiters = state.take_clear_waiters_if_done();
246 ReleaseAction::None
247 }
248 }
249 };
250
251 for waiter in clear_waiters {
252 waiter.notify(());
253 }
254
255 match action {
256 ReleaseAction::None => {}
257 ReleaseAction::Notify(future, worker) => {
258 future.notify(Ok(worker));
259 }
260 ReleaseAction::Replace(future) => {
261 let factory = self.factory.clone();
262 let this = self.clone();
263 tokio::spawn(async move {
264 let result = match factory.create().await {
265 Ok(worker) => {
266 let (clearing, clear_waiters) = {
267 let mut state = this.state.lock().unwrap();
268 if state.clearing {
269 state.current_count -= 1;
270 (true, state.take_clear_waiters_if_done())
271 } else {
272 (false, Vec::new())
273 }
274 };
275 for waiter in clear_waiters {
276 waiter.notify(());
277 }
278 if clearing {
279 Err(pool_cleared_error())
280 } else {
281 Ok(WorkerGuard::new(worker, this))
282 }
283 }
284 Err(err) => {
285 let mut state = this.state.lock().unwrap();
286 state.current_count -= 1;
287 let clear_waiters = state.take_clear_waiters_if_done();
288 drop(state);
289 for waiter in clear_waiters {
290 waiter.notify(());
291 }
292 Err(err)
293 }
294 };
295 future.notify(result);
296 });
297 }
298 }
299 }
300}
301
302#[test]
303fn test_pool() {
304 struct TestWorker {
305 work: bool,
306 }
307
308 #[async_trait::async_trait]
309 impl Worker for TestWorker {
310 fn is_work(&self) -> bool {
311 self.work
312 }
313 }
314
315 struct TestWorkerFactory;
316
317 #[async_trait::async_trait]
318 impl WorkerFactory<TestWorker> for TestWorkerFactory {
319 async fn create(&self) -> PoolResult<TestWorker> {
320 Ok(TestWorker { work: true })
321 }
322 }
323
324 let pool = WorkerPool::new(2, TestWorkerFactory);
325 let rt = tokio::runtime::Runtime::new().unwrap();
326 let pool_ref = pool.clone();
327 rt.spawn(async move {
328 let _worker = pool_ref.get_worker().await;
329 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
330 });
331 let pool_ref = pool.clone();
332 rt.spawn(async move {
333 let _worker = pool_ref.get_worker().await;
334 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
335 });
336
337 let pool_ref = pool.clone();
338 rt.spawn(async move {
339 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
340
341 let start = std::time::Instant::now();
342 let _worker3 = pool_ref.get_worker().await;
343 let end = std::time::Instant::now();
344 let duration = end.duration_since(start);
345 println!("duration {}", duration.as_millis());
346 assert!(duration.as_millis() > 2000);
347 });
348
349 std::thread::sleep(std::time::Duration::from_secs(10));
350
351 let pool_ref = pool.clone();
352 rt.spawn(async move {
353 let _worker = pool_ref.get_worker().await;
354 let _worker1 = pool_ref.get_worker().await;
355 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
356 });
357
358 let pool_ref = pool.clone();
359 rt.spawn(async move {
360 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
361 let worker = pool_ref.get_worker().await;
362 assert!(worker.is_err());
363 });
364
365 let pool_ref = pool.clone();
366 rt.spawn(async move {
367 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
368 let worker = pool_ref.get_worker().await;
369 assert!(worker.is_err());
370 });
371
372 let pool_ref = pool.clone();
373 rt.spawn(async move {
374 let start = std::time::Instant::now();
375 pool_ref.clear_all_worker().await;
376 let end = std::time::Instant::now();
377 let duration = end.duration_since(start);
378 println!("duration1 {}", duration.as_millis());
379 assert!(duration.as_millis() > 4000);
380 });
381
382 std::thread::sleep(std::time::Duration::from_secs(10));
383}
384
385#[tokio::test]
386async fn test_clear_all_worker_waits_for_inflight_create() {
387 use std::sync::atomic::{AtomicUsize, Ordering};
388 use std::sync::Arc;
389
390 struct TestWorker;
391
392 #[async_trait::async_trait]
393 impl Worker for TestWorker {
394 fn is_work(&self) -> bool {
395 true
396 }
397 }
398
399 struct TestWorkerFactory {
400 create_count: Arc<AtomicUsize>,
401 }
402
403 #[async_trait::async_trait]
404 impl WorkerFactory<TestWorker> for TestWorkerFactory {
405 async fn create(&self) -> PoolResult<TestWorker> {
406 self.create_count.fetch_add(1, Ordering::SeqCst);
407 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
408 Ok(TestWorker)
409 }
410 }
411
412 let create_count = Arc::new(AtomicUsize::new(0));
413 let pool = WorkerPool::new(
414 1,
415 TestWorkerFactory {
416 create_count: create_count.clone(),
417 },
418 );
419
420 let pool_ref = pool.clone();
421 let worker_task = tokio::spawn(async move { pool_ref.get_worker().await });
422 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
423
424 pool.clear_all_worker().await;
425
426 let worker = worker_task.await.unwrap();
427 assert!(worker.is_err());
428 assert_eq!(create_count.load(Ordering::SeqCst), 1);
429}
430
431#[tokio::test]
432async fn test_concurrent_clear_all_worker() {
433 struct TestWorker;
434
435 #[async_trait::async_trait]
436 impl Worker for TestWorker {
437 fn is_work(&self) -> bool {
438 true
439 }
440 }
441
442 struct TestWorkerFactory;
443
444 #[async_trait::async_trait]
445 impl WorkerFactory<TestWorker> for TestWorkerFactory {
446 async fn create(&self) -> PoolResult<TestWorker> {
447 Ok(TestWorker)
448 }
449 }
450
451 let pool = WorkerPool::new(1, TestWorkerFactory);
452 let worker = pool.get_worker().await.unwrap();
453
454 let pool_ref = pool.clone();
455 let clear_task1 = tokio::spawn(async move {
456 pool_ref.clear_all_worker().await;
457 });
458
459 let pool_ref = pool.clone();
460 let clear_task2 = tokio::spawn(async move {
461 pool_ref.clear_all_worker().await;
462 });
463
464 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
465 drop(worker);
466
467 tokio::time::timeout(std::time::Duration::from_secs(1), async {
468 clear_task1.await.unwrap();
469 clear_task2.await.unwrap();
470 })
471 .await
472 .unwrap();
473}
474
475#[tokio::test]
476async fn test_zero_max_count_returns_error() {
477 struct TestWorker;
478
479 #[async_trait::async_trait]
480 impl Worker for TestWorker {
481 fn is_work(&self) -> bool {
482 true
483 }
484 }
485
486 struct TestWorkerFactory;
487
488 #[async_trait::async_trait]
489 impl WorkerFactory<TestWorker> for TestWorkerFactory {
490 async fn create(&self) -> PoolResult<TestWorker> {
491 Ok(TestWorker)
492 }
493 }
494
495 let pool = WorkerPool::new(0, TestWorkerFactory);
496 let worker = pool.get_worker().await;
497 assert!(worker.is_err());
498 assert_eq!(worker.err().unwrap().code(), PoolErrorCode::InvalidConfig);
499}
500
501#[tokio::test]
502async fn test_clearing_and_cleared_error_codes() {
503 use std::sync::atomic::{AtomicBool, Ordering};
504 use std::sync::Arc;
505
506 struct TestWorker;
507
508 #[async_trait::async_trait]
509 impl Worker for TestWorker {
510 fn is_work(&self) -> bool {
511 true
512 }
513 }
514
515 struct TestWorkerFactory {
516 should_block: Arc<AtomicBool>,
517 }
518
519 #[async_trait::async_trait]
520 impl WorkerFactory<TestWorker> for TestWorkerFactory {
521 async fn create(&self) -> PoolResult<TestWorker> {
522 while self.should_block.load(Ordering::SeqCst) {
523 tokio::task::yield_now().await;
524 }
525 Ok(TestWorker)
526 }
527 }
528
529 let should_block = Arc::new(AtomicBool::new(true));
530 let pool = WorkerPool::new(
531 1,
532 TestWorkerFactory {
533 should_block: should_block.clone(),
534 },
535 );
536
537 let pool_ref = pool.clone();
538 let inflight = tokio::spawn(async move { pool_ref.get_worker().await });
539 tokio::task::yield_now().await;
540
541 let pool_ref = pool.clone();
542 let clear_task = tokio::spawn(async move {
543 pool_ref.clear_all_worker().await;
544 });
545 tokio::task::yield_now().await;
546
547 let err = pool.get_worker().await.err().unwrap();
548 assert_eq!(err.code(), PoolErrorCode::Clearing);
549
550 should_block.store(false, Ordering::SeqCst);
551 clear_task.await.unwrap();
552
553 let err = inflight.await.unwrap().err().unwrap();
554 assert_eq!(err.code(), PoolErrorCode::Cleared);
555}