Skip to main content

sfo_pool/
worker_pool.rs

1use notify_future::Notify;
2pub use sfo_result::err as pool_err;
3pub use sfo_result::into_err as into_pool_err;
4use std::collections::VecDeque;
5use std::ops::{Deref, DerefMut};
6use std::sync::{Arc, Mutex};
7
8#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
9pub enum PoolErrorCode {
10    #[default]
11    Failed,
12    Clearing,
13    Cleared,
14    InvalidConfig,
15}
16pub type PoolError = sfo_result::Error<PoolErrorCode>;
17pub type PoolResult<T> = sfo_result::Result<T, PoolErrorCode>;
18
19pub(crate) fn pool_error(code: PoolErrorCode, message: &str) -> PoolError {
20    PoolError::new(code, message.to_string())
21}
22
23pub(crate) fn pool_clearing_error() -> PoolError {
24    pool_error(PoolErrorCode::Clearing, "pool is clearing")
25}
26
27pub(crate) fn pool_cleared_error() -> PoolError {
28    pool_error(PoolErrorCode::Cleared, "pool cleared")
29}
30
31pub(crate) fn pool_invalid_config_error(message: &str) -> PoolError {
32    pool_error(PoolErrorCode::InvalidConfig, message)
33}
34
35#[async_trait::async_trait]
36pub trait Worker: Send + 'static {
37    fn is_work(&self) -> bool;
38}
39
40pub struct WorkerGuard<W: Worker, F: WorkerFactory<W>> {
41    pool_ref: WorkerPoolRef<W, F>,
42    worker: Option<W>,
43}
44
45impl<W: Worker, F: WorkerFactory<W>> WorkerGuard<W, F> {
46    fn new(worker: W, pool_ref: WorkerPoolRef<W, F>) -> Self {
47        WorkerGuard {
48            pool_ref,
49            worker: Some(worker),
50        }
51    }
52}
53
54impl<W: Worker, F: WorkerFactory<W>> Deref for WorkerGuard<W, F> {
55    type Target = W;
56
57    fn deref(&self) -> &Self::Target {
58        self.worker.as_ref().unwrap()
59    }
60}
61
62impl<W: Worker, F: WorkerFactory<W>> DerefMut for WorkerGuard<W, F> {
63    fn deref_mut(&mut self) -> &mut Self::Target {
64        self.worker.as_mut().unwrap()
65    }
66}
67
68impl<W: Worker, F: WorkerFactory<W>> Drop for WorkerGuard<W, F> {
69    fn drop(&mut self) {
70        if let Some(worker) = self.worker.take() {
71            self.pool_ref.release(worker);
72        }
73    }
74}
75
76#[async_trait::async_trait]
77pub trait WorkerFactory<W: Worker>: Send + Sync + 'static {
78    async fn create(&self) -> PoolResult<W>;
79}
80
81struct WorkerPoolState<W: Worker, F: WorkerFactory<W>> {
82    current_count: u16,
83    worker_list: VecDeque<W>,
84    waiting_list: VecDeque<Notify<PoolResult<WorkerGuard<W, F>>>>,
85    clearing: bool,
86    clear_waiting_list: Vec<Notify<()>>,
87}
88
89impl<W: Worker, F: WorkerFactory<W>> WorkerPoolState<W, F> {
90    fn take_clear_waiters_if_done(&mut self) -> Vec<Notify<()>> {
91        if self.clearing && self.current_count == 0 {
92            self.clearing = false;
93            self.clear_waiting_list.drain(..).collect()
94        } else {
95            Vec::new()
96        }
97    }
98}
99pub struct WorkerPool<W: Worker, F: WorkerFactory<W>> {
100    factory: Arc<F>,
101    max_count: u16,
102    state: Mutex<WorkerPoolState<W, F>>,
103}
104pub type WorkerPoolRef<W, F> = Arc<WorkerPool<W, F>>;
105
106impl<W: Worker, F: WorkerFactory<W>> WorkerPool<W, F> {
107    pub fn new(max_count: u16, factory: F) -> WorkerPoolRef<W, F> {
108        Arc::new(WorkerPool {
109            factory: Arc::new(factory),
110            max_count,
111            state: Mutex::new(WorkerPoolState {
112                current_count: 0,
113                worker_list: VecDeque::with_capacity(max_count as usize),
114                waiting_list: VecDeque::new(),
115                clearing: false,
116                clear_waiting_list: Vec::new(),
117            }),
118        })
119    }
120
121    pub async fn get_worker(self: &WorkerPoolRef<W, F>) -> PoolResult<WorkerGuard<W, F>> {
122        if self.max_count == 0 {
123            return Err(pool_invalid_config_error("pool max_count is zero"));
124        }
125
126        let wait = {
127            let mut state = self.state.lock().unwrap();
128            if state.clearing {
129                return Err(pool_clearing_error());
130            }
131
132            while state.worker_list.len() > 0 {
133                let worker = state.worker_list.pop_front().unwrap();
134                if !worker.is_work() {
135                    state.current_count -= 1;
136                    continue;
137                }
138                return Ok(WorkerGuard::new(worker, self.clone()));
139            }
140
141            if state.current_count < self.max_count {
142                state.current_count += 1;
143                None
144            } else {
145                let (notify, waiter) = Notify::new();
146                state.waiting_list.push_back(notify);
147                Some(waiter)
148            }
149        };
150
151        if let Some(wait) = wait {
152            wait.await
153        } else {
154            let worker = match self.factory.create().await {
155                Ok(worker) => worker,
156                Err(err) => {
157                    let mut state = self.state.lock().unwrap();
158                    state.current_count -= 1;
159                    let clear_waiters = state.take_clear_waiters_if_done();
160                    drop(state);
161                    for waiter in clear_waiters {
162                        waiter.notify(());
163                    }
164                    return Err(err);
165                }
166            };
167            let (clearing, clear_waiters) = {
168                let mut state = self.state.lock().unwrap();
169                if state.clearing {
170                    state.current_count -= 1;
171                    (true, state.take_clear_waiters_if_done())
172                } else {
173                    (false, Vec::new())
174                }
175            };
176            for waiter in clear_waiters {
177                waiter.notify(());
178            }
179            if clearing {
180                return Err(pool_cleared_error());
181            }
182            Ok(WorkerGuard::new(worker, self.clone()))
183        }
184    }
185
186    pub async fn clear_all_worker(&self) {
187        let (waiter, waiting_list, clear_waiters) = {
188            let mut state = self.state.lock().unwrap();
189            if !state.clearing {
190                state.clearing = true;
191                let cur_worker_count = state.worker_list.len();
192                state.worker_list.clear();
193                state.current_count -= cur_worker_count as u16;
194            }
195
196            let waiting_list = state.waiting_list.drain(..).collect::<Vec<_>>();
197            if state.current_count == 0 {
198                let clear_waiters = state.take_clear_waiters_if_done();
199                (None, waiting_list, clear_waiters)
200            } else {
201                let (notify, waiter) = Notify::new();
202                state.clear_waiting_list.push(notify);
203                (Some(waiter), waiting_list, Vec::new())
204            }
205        };
206        for waiting in waiting_list {
207            waiting.notify(Err(pool_cleared_error()));
208        }
209        for waiter in clear_waiters {
210            waiter.notify(());
211        }
212        if let Some(waiter) = waiter {
213            waiter.await;
214        }
215    }
216
217    fn release(self: &WorkerPoolRef<W, F>, work: W) {
218        enum ReleaseAction<W: Worker, F: WorkerFactory<W>> {
219            None,
220            Notify(Notify<PoolResult<WorkerGuard<W, F>>>, WorkerGuard<W, F>),
221            Replace(Notify<PoolResult<WorkerGuard<W, F>>>),
222        }
223
224        let mut clear_waiters = Vec::new();
225        let action = {
226            let mut state = self.state.lock().unwrap();
227            if state.clearing {
228                state.current_count -= 1;
229                clear_waiters = state.take_clear_waiters_if_done();
230                ReleaseAction::None
231            } else if work.is_work() {
232                let future = state.waiting_list.pop_front();
233                if let Some(future) = future {
234                    ReleaseAction::Notify(future, WorkerGuard::new(work, self.clone()))
235                } else {
236                    state.worker_list.push_back(work);
237                    ReleaseAction::None
238                }
239            } else {
240                let future = state.waiting_list.pop_front();
241                if let Some(future) = future {
242                    ReleaseAction::Replace(future)
243                } else {
244                    state.current_count -= 1;
245                    clear_waiters = state.take_clear_waiters_if_done();
246                    ReleaseAction::None
247                }
248            }
249        };
250
251        for waiter in clear_waiters {
252            waiter.notify(());
253        }
254
255        match action {
256            ReleaseAction::None => {}
257            ReleaseAction::Notify(future, worker) => {
258                future.notify(Ok(worker));
259            }
260            ReleaseAction::Replace(future) => {
261                let factory = self.factory.clone();
262                let this = self.clone();
263                tokio::spawn(async move {
264                    let result = match factory.create().await {
265                        Ok(worker) => {
266                            let (clearing, clear_waiters) = {
267                                let mut state = this.state.lock().unwrap();
268                                if state.clearing {
269                                    state.current_count -= 1;
270                                    (true, state.take_clear_waiters_if_done())
271                                } else {
272                                    (false, Vec::new())
273                                }
274                            };
275                            for waiter in clear_waiters {
276                                waiter.notify(());
277                            }
278                            if clearing {
279                                Err(pool_cleared_error())
280                            } else {
281                                Ok(WorkerGuard::new(worker, this))
282                            }
283                        }
284                        Err(err) => {
285                            let mut state = this.state.lock().unwrap();
286                            state.current_count -= 1;
287                            let clear_waiters = state.take_clear_waiters_if_done();
288                            drop(state);
289                            for waiter in clear_waiters {
290                                waiter.notify(());
291                            }
292                            Err(err)
293                        }
294                    };
295                    future.notify(result);
296                });
297            }
298        }
299    }
300}
301
302#[test]
303fn test_pool() {
304    struct TestWorker {
305        work: bool,
306    }
307
308    #[async_trait::async_trait]
309    impl Worker for TestWorker {
310        fn is_work(&self) -> bool {
311            self.work
312        }
313    }
314
315    struct TestWorkerFactory;
316
317    #[async_trait::async_trait]
318    impl WorkerFactory<TestWorker> for TestWorkerFactory {
319        async fn create(&self) -> PoolResult<TestWorker> {
320            Ok(TestWorker { work: true })
321        }
322    }
323
324    let pool = WorkerPool::new(2, TestWorkerFactory);
325    let rt = tokio::runtime::Runtime::new().unwrap();
326    let pool_ref = pool.clone();
327    rt.spawn(async move {
328        let _worker = pool_ref.get_worker().await;
329        tokio::time::sleep(std::time::Duration::from_secs(5)).await;
330    });
331    let pool_ref = pool.clone();
332    rt.spawn(async move {
333        let _worker = pool_ref.get_worker().await;
334        tokio::time::sleep(std::time::Duration::from_secs(10)).await;
335    });
336
337    let pool_ref = pool.clone();
338    rt.spawn(async move {
339        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
340
341        let start = std::time::Instant::now();
342        let _worker3 = pool_ref.get_worker().await;
343        let end = std::time::Instant::now();
344        let duration = end.duration_since(start);
345        println!("duration {}", duration.as_millis());
346        assert!(duration.as_millis() > 2000);
347    });
348
349    std::thread::sleep(std::time::Duration::from_secs(10));
350
351    let pool_ref = pool.clone();
352    rt.spawn(async move {
353        let _worker = pool_ref.get_worker().await;
354        let _worker1 = pool_ref.get_worker().await;
355        tokio::time::sleep(std::time::Duration::from_secs(5)).await;
356    });
357
358    let pool_ref = pool.clone();
359    rt.spawn(async move {
360        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
361        let worker = pool_ref.get_worker().await;
362        assert!(worker.is_err());
363    });
364
365    let pool_ref = pool.clone();
366    rt.spawn(async move {
367        tokio::time::sleep(std::time::Duration::from_secs(2)).await;
368        let worker = pool_ref.get_worker().await;
369        assert!(worker.is_err());
370    });
371
372    let pool_ref = pool.clone();
373    rt.spawn(async move {
374        let start = std::time::Instant::now();
375        pool_ref.clear_all_worker().await;
376        let end = std::time::Instant::now();
377        let duration = end.duration_since(start);
378        println!("duration1 {}", duration.as_millis());
379        assert!(duration.as_millis() > 4000);
380    });
381
382    std::thread::sleep(std::time::Duration::from_secs(10));
383}
384
385#[tokio::test]
386async fn test_clear_all_worker_waits_for_inflight_create() {
387    use std::sync::atomic::{AtomicUsize, Ordering};
388    use std::sync::Arc;
389
390    struct TestWorker;
391
392    #[async_trait::async_trait]
393    impl Worker for TestWorker {
394        fn is_work(&self) -> bool {
395            true
396        }
397    }
398
399    struct TestWorkerFactory {
400        create_count: Arc<AtomicUsize>,
401    }
402
403    #[async_trait::async_trait]
404    impl WorkerFactory<TestWorker> for TestWorkerFactory {
405        async fn create(&self) -> PoolResult<TestWorker> {
406            self.create_count.fetch_add(1, Ordering::SeqCst);
407            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
408            Ok(TestWorker)
409        }
410    }
411
412    let create_count = Arc::new(AtomicUsize::new(0));
413    let pool = WorkerPool::new(
414        1,
415        TestWorkerFactory {
416            create_count: create_count.clone(),
417        },
418    );
419
420    let pool_ref = pool.clone();
421    let worker_task = tokio::spawn(async move { pool_ref.get_worker().await });
422    tokio::time::sleep(std::time::Duration::from_millis(20)).await;
423
424    pool.clear_all_worker().await;
425
426    let worker = worker_task.await.unwrap();
427    assert!(worker.is_err());
428    assert_eq!(create_count.load(Ordering::SeqCst), 1);
429}
430
431#[tokio::test]
432async fn test_concurrent_clear_all_worker() {
433    struct TestWorker;
434
435    #[async_trait::async_trait]
436    impl Worker for TestWorker {
437        fn is_work(&self) -> bool {
438            true
439        }
440    }
441
442    struct TestWorkerFactory;
443
444    #[async_trait::async_trait]
445    impl WorkerFactory<TestWorker> for TestWorkerFactory {
446        async fn create(&self) -> PoolResult<TestWorker> {
447            Ok(TestWorker)
448        }
449    }
450
451    let pool = WorkerPool::new(1, TestWorkerFactory);
452    let worker = pool.get_worker().await.unwrap();
453
454    let pool_ref = pool.clone();
455    let clear_task1 = tokio::spawn(async move {
456        pool_ref.clear_all_worker().await;
457    });
458
459    let pool_ref = pool.clone();
460    let clear_task2 = tokio::spawn(async move {
461        pool_ref.clear_all_worker().await;
462    });
463
464    tokio::time::sleep(std::time::Duration::from_millis(20)).await;
465    drop(worker);
466
467    tokio::time::timeout(std::time::Duration::from_secs(1), async {
468        clear_task1.await.unwrap();
469        clear_task2.await.unwrap();
470    })
471    .await
472    .unwrap();
473}
474
475#[tokio::test]
476async fn test_zero_max_count_returns_error() {
477    struct TestWorker;
478
479    #[async_trait::async_trait]
480    impl Worker for TestWorker {
481        fn is_work(&self) -> bool {
482            true
483        }
484    }
485
486    struct TestWorkerFactory;
487
488    #[async_trait::async_trait]
489    impl WorkerFactory<TestWorker> for TestWorkerFactory {
490        async fn create(&self) -> PoolResult<TestWorker> {
491            Ok(TestWorker)
492        }
493    }
494
495    let pool = WorkerPool::new(0, TestWorkerFactory);
496    let worker = pool.get_worker().await;
497    assert!(worker.is_err());
498    assert_eq!(worker.err().unwrap().code(), PoolErrorCode::InvalidConfig);
499}
500
501#[tokio::test]
502async fn test_clearing_and_cleared_error_codes() {
503    use std::sync::atomic::{AtomicBool, Ordering};
504    use std::sync::Arc;
505
506    struct TestWorker;
507
508    #[async_trait::async_trait]
509    impl Worker for TestWorker {
510        fn is_work(&self) -> bool {
511            true
512        }
513    }
514
515    struct TestWorkerFactory {
516        should_block: Arc<AtomicBool>,
517    }
518
519    #[async_trait::async_trait]
520    impl WorkerFactory<TestWorker> for TestWorkerFactory {
521        async fn create(&self) -> PoolResult<TestWorker> {
522            while self.should_block.load(Ordering::SeqCst) {
523                tokio::task::yield_now().await;
524            }
525            Ok(TestWorker)
526        }
527    }
528
529    let should_block = Arc::new(AtomicBool::new(true));
530    let pool = WorkerPool::new(
531        1,
532        TestWorkerFactory {
533            should_block: should_block.clone(),
534        },
535    );
536
537    let pool_ref = pool.clone();
538    let inflight = tokio::spawn(async move { pool_ref.get_worker().await });
539    tokio::task::yield_now().await;
540
541    let pool_ref = pool.clone();
542    let clear_task = tokio::spawn(async move {
543        pool_ref.clear_all_worker().await;
544    });
545    tokio::task::yield_now().await;
546
547    let err = pool.get_worker().await.err().unwrap();
548    assert_eq!(err.code(), PoolErrorCode::Clearing);
549
550    should_block.store(false, Ordering::SeqCst);
551    clear_task.await.unwrap();
552
553    let err = inflight.await.unwrap().err().unwrap();
554    assert_eq!(err.code(), PoolErrorCode::Cleared);
555}