sfo_pool/
worker_pool.rs

1use std::collections::VecDeque;
2use std::ops::{Deref, DerefMut};
3use std::sync::{Arc, Mutex};
4use std::thread::sleep;
5use std::time::Duration;
6use notify_future::{Notify};
7use tokio::runtime::Runtime;
8pub use sfo_result::err as pool_err;
9pub use sfo_result::into_err as into_pool_err;
10
11#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
12pub enum PoolErrorCode {
13    #[default]
14    Failed,
15}
16pub type PoolError = sfo_result::Error<PoolErrorCode>;
17pub type PoolResult<T> = sfo_result::Result<T, PoolErrorCode>;
18
19#[async_trait::async_trait]
20pub trait Worker: Send + 'static {
21    fn is_work(&self) -> bool;
22}
23
24pub struct WorkerGuard<W: Worker, F: WorkerFactory<W>> {
25    pool_ref: WorkerPoolRef<W, F>,
26    worker: Option<W>
27}
28
29impl<W: Worker, F: WorkerFactory<W>> WorkerGuard<W, F> {
30    fn new(worker: W, pool_ref: WorkerPoolRef<W, F>) -> Self {
31        WorkerGuard {
32            pool_ref,
33            worker: Some(worker)
34        }
35    }
36}
37
38impl<W: Worker, F: WorkerFactory<W>> Deref for WorkerGuard<W, F> {
39    type Target = W;
40
41    fn deref(&self) -> &Self::Target {
42        self.worker.as_ref().unwrap()
43    }
44}
45
46impl<W: Worker, F: WorkerFactory<W>> DerefMut for WorkerGuard<W, F> {
47    fn deref_mut(&mut self) -> &mut Self::Target {
48        self.worker.as_mut().unwrap()
49    }
50}
51
52impl<W: Worker, F: WorkerFactory<W>> Drop for WorkerGuard<W, F> {
53    fn drop(&mut self) {
54        if let Some(worker) = self.worker.take() {
55            self.pool_ref.release(worker);
56        }
57    }
58}
59
60#[async_trait::async_trait]
61pub trait WorkerFactory<W: Worker>: Send + Sync + 'static {
62    async fn create(&self) -> PoolResult<W>;
63}
64
65struct WorkerPoolState<W: Worker, F: WorkerFactory<W>> {
66    current_count: u16,
67    worker_list: VecDeque<W>,
68    waiting_list: VecDeque<Notify<PoolResult<WorkerGuard<W, F>>>>,
69    clear_notify: Option<Notify<()>>,
70}
71pub struct WorkerPool<W: Worker, F: WorkerFactory<W>> {
72    factory: Arc<F>,
73    max_count: u16,
74    state: Mutex<WorkerPoolState<W, F>>,
75}
76pub type WorkerPoolRef<W, F> = Arc<WorkerPool<W, F>>;
77
78impl<W: Worker, F: WorkerFactory<W>> WorkerPool<W, F> {
79    pub fn new(max_count: u16, factory: F) -> WorkerPoolRef<W, F> {
80        Arc::new(WorkerPool {
81            factory: Arc::new(factory),
82            max_count,
83            state: Mutex::new(WorkerPoolState {
84                current_count: 0,
85                worker_list: VecDeque::with_capacity(max_count as usize),
86                waiting_list: VecDeque::new(),
87                clear_notify: None,
88            }),
89        })
90    }
91
92    pub async fn get_worker(self: &WorkerPoolRef<W, F>) -> PoolResult<WorkerGuard<W, F>> {
93        let wait = {
94            let mut state = self.state.lock().unwrap();
95            if state.clear_notify.is_some() {
96                return Err(PoolError::new(PoolErrorCode::Failed, "pool is clearing".to_string()));
97            }
98
99            while state.worker_list.len() > 0 {
100                let worker = state.worker_list.pop_front().unwrap();
101                if !worker.is_work() {
102                    state.current_count -= 1;
103                    continue;
104                }
105                return Ok(WorkerGuard::new(worker, self.clone()));
106            }
107
108            if state.current_count < self.max_count {
109                state.current_count += 1;
110                None
111            } else {
112                let (notify, waiter) = Notify::new();
113                state.waiting_list.push_back(notify);
114                Some(waiter)
115            }
116        };
117
118        if let Some(wait) = wait {
119            wait.await
120        } else {
121            let worker = match self.factory.create().await {
122                Ok(worker) => worker,
123                Err(err) => {
124                    let mut state = self.state.lock().unwrap();
125                    state.current_count -= 1;
126                    if state.current_count == 0 && state.clear_notify.is_some() {
127                        state.clear_notify.take().unwrap().notify(());
128                    }
129                    return Err(err)
130                },
131            };
132            Ok(WorkerGuard::new(worker, self.clone()))
133        }
134    }
135
136    pub async fn clear_all_worker(&self) {
137        let waiter = {
138            let mut state = self.state.lock().unwrap();
139            let cur_worker_count = state.worker_list.len();
140            state.worker_list.clear();
141            state.current_count -= cur_worker_count as u16;
142
143            for waiting in state.waiting_list.drain(..) {
144                waiting.notify(Err(PoolError::new(PoolErrorCode::Failed, "pool cleared".to_string())));
145            }
146
147            if state.current_count == 0 {
148                return;
149            }
150
151            let (notify, waiter) = Notify::new();
152            state.clear_notify = Some(notify);
153            waiter
154        };
155        waiter.await;
156        {
157            let mut state = self.state.lock().unwrap();
158            for waiting in state.waiting_list.drain(..) {
159                waiting.notify(Err(PoolError::new(PoolErrorCode::Failed, "pool cleared".to_string())));
160            }
161        }
162    }
163
164    fn release(self: &WorkerPoolRef<W, F>, work: W) {
165        {
166            let mut state = self.state.lock().unwrap();
167            if state.clear_notify.is_some() {
168                state.current_count -= 1;
169                if state.current_count == 0 {
170                    state.clear_notify.take().unwrap().notify(());
171                }
172                return;
173            }
174        }
175        if work.is_work() {
176            let mut state = self.state.lock().unwrap();
177            let future = state.waiting_list.pop_front();
178            if let Some(future) = future {
179                future.notify(Ok(WorkerGuard::new(work, self.clone())));
180            } else {
181                state.worker_list.push_back(work);
182            }
183        } else {
184            let mut state = self.state.lock().unwrap();
185            let future = state.waiting_list.pop_front();
186            if let Some(future) = future {
187                let factory = self.factory.clone();
188                let this = self.clone();
189                tokio::spawn(async move {
190                    match factory.create().await {
191                        Ok(worker) => {
192                            future.notify(Ok(WorkerGuard::new(worker, this)));
193                        }
194                        Err(err) => {
195                            let mut state = this.state.lock().unwrap();
196                            state.current_count -= 1;
197                            future.notify(Err(err));
198
199                            if state.current_count == 0 && state.clear_notify.is_some() {
200                                state.clear_notify.take().unwrap().notify(());
201                            }
202                        }
203                    }
204                });
205            } else {
206                state.current_count -= 1;
207                if state.current_count == 0 && state.clear_notify.is_some() {
208                    state.clear_notify.take().unwrap().notify(());
209                }
210            }
211        }
212    }
213}
214
215#[test]
216fn test_pool() {
217    struct TestWorker {
218        work: bool,
219    }
220
221    #[async_trait::async_trait]
222    impl Worker for TestWorker {
223        fn is_work(&self) -> bool {
224            self.work
225        }
226    }
227
228    struct TestWorkerFactory;
229
230    #[async_trait::async_trait]
231    impl WorkerFactory<TestWorker> for TestWorkerFactory {
232        async fn create(&self) -> PoolResult<TestWorker> {
233            Ok(TestWorker { work: true })
234        }
235    }
236
237    let pool = WorkerPool::new(2, TestWorkerFactory);
238    let rt = Runtime::new().unwrap();
239    let pool_ref = pool.clone();
240    rt.spawn(async move {
241        let _worker = pool_ref.get_worker().await;
242        tokio::time::sleep(Duration::from_secs(5)).await;
243    });
244    let pool_ref = pool.clone();
245    rt.spawn(async move {
246        let _worker = pool_ref.get_worker().await;
247        tokio::time::sleep(Duration::from_secs(10)).await;
248    });
249
250    let pool_ref = pool.clone();
251    rt.spawn(async move {
252        tokio::time::sleep(Duration::from_secs(2)).await;
253
254        let start = std::time::Instant::now();
255        let _worker3 = pool_ref.get_worker().await;
256        let end = std::time::Instant::now();
257        let duration = end.duration_since(start);
258        println!("duration {}", duration.as_millis());
259        assert!(duration.as_millis() > 2000);
260    });
261
262    sleep(Duration::from_secs(10));
263
264    let pool_ref = pool.clone();
265    rt.spawn(async move {
266        let _worker = pool_ref.get_worker().await;
267        let _worker1 = pool_ref.get_worker().await;
268        tokio::time::sleep(Duration::from_secs(5)).await;
269    });
270
271    let pool_ref = pool.clone();
272    rt.spawn(async move {
273        tokio::time::sleep(Duration::from_secs(1)).await;
274        let worker = pool_ref.get_worker().await;
275        assert!(worker.is_err());
276    });
277
278    let pool_ref = pool.clone();
279    rt.spawn(async move {
280        tokio::time::sleep(Duration::from_secs(2)).await;
281        let worker = pool_ref.get_worker().await;
282        assert!(worker.is_err());
283    });
284
285    let pool_ref = pool.clone();
286    rt.spawn(async move {
287        let start = std::time::Instant::now();
288        pool_ref.clear_all_worker().await;
289        let end = std::time::Instant::now();
290        let duration = end.duration_since(start);
291        println!("duration1 {}", duration.as_millis());
292        assert!(duration.as_millis() > 4000);
293    });
294
295    sleep(Duration::from_secs(10));
296}