sfo_pool/
worker_pool.rs

1use std::collections::VecDeque;
2use std::ops::{Deref, DerefMut};
3use std::sync::{Arc, Mutex};
4use std::thread::sleep;
5use std::time::Duration;
6use notify_future::NotifyFuture;
7use tokio::runtime::Runtime;
8pub use sfo_result::err as pool_err;
9pub use sfo_result::into_err as into_pool_err;
10
11#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
12pub enum PoolErrorCode {
13    #[default]
14    Failed,
15}
16pub type PoolError = sfo_result::Error<PoolErrorCode>;
17pub type PoolResult<T> = sfo_result::Result<T, PoolErrorCode>;
18
19#[async_trait::async_trait]
20pub trait Worker: Send + 'static {
21    fn is_work(&self) -> bool;
22}
23
24pub struct WorkerGuard<W: Worker, F: WorkerFactory<W>> {
25    pool_ref: WorkerPoolRef<W, F>,
26    worker: Option<W>
27}
28
29impl<W: Worker, F: WorkerFactory<W>> WorkerGuard<W, F> {
30    fn new(worker: W, pool_ref: WorkerPoolRef<W, F>) -> Self {
31        WorkerGuard {
32            pool_ref,
33            worker: Some(worker)
34        }
35    }
36}
37
38impl<W: Worker, F: WorkerFactory<W>> Deref for WorkerGuard<W, F> {
39    type Target = W;
40
41    fn deref(&self) -> &Self::Target {
42        self.worker.as_ref().unwrap()
43    }
44}
45
46impl<W: Worker, F: WorkerFactory<W>> DerefMut for WorkerGuard<W, F> {
47    fn deref_mut(&mut self) -> &mut Self::Target {
48        self.worker.as_mut().unwrap()
49    }
50}
51
52impl<W: Worker, F: WorkerFactory<W>> Drop for WorkerGuard<W, F> {
53    fn drop(&mut self) {
54        if let Some(worker) = self.worker.take() {
55            self.pool_ref.release(worker);
56        }
57    }
58}
59
60#[async_trait::async_trait]
61pub trait WorkerFactory<W: Worker>: Send + Sync + 'static {
62    async fn create(&self) -> PoolResult<W>;
63}
64
65struct WorkerPoolState<W: Worker, F: WorkerFactory<W>> {
66    current_count: u16,
67    worker_list: VecDeque<W>,
68    waiting_list: VecDeque<NotifyFuture<PoolResult<WorkerGuard<W, F>>>>,
69}
70pub struct WorkerPool<W: Worker, F: WorkerFactory<W>> {
71    factory: Arc<F>,
72    max_count: u16,
73    state: Mutex<WorkerPoolState<W, F>>,
74}
75pub type WorkerPoolRef<W, F> = Arc<WorkerPool<W, F>>;
76
77impl<W: Worker, F: WorkerFactory<W>> WorkerPool<W, F> {
78    pub fn new(max_count: u16, factory: F) -> WorkerPoolRef<W, F> {
79        Arc::new(WorkerPool {
80            factory: Arc::new(factory),
81            max_count,
82            state: Mutex::new(WorkerPoolState {
83                current_count: 0,
84                worker_list: VecDeque::with_capacity(max_count as usize),
85                waiting_list: VecDeque::new(),
86            }),
87        })
88    }
89
90    pub async fn get_worker(self: &WorkerPoolRef<W, F>) -> PoolResult<WorkerGuard<W, F>> {
91        let wait = {
92            let mut state = self.state.lock().unwrap();
93
94            while state.worker_list.len() > 0 {
95                let worker = state.worker_list.pop_front().unwrap();
96                if !worker.is_work() {
97                    state.current_count -= 1;
98                    continue;
99                }
100                return Ok(WorkerGuard::new(worker, self.clone()));
101            }
102
103            if state.current_count < self.max_count {
104                state.current_count += 1;
105                None
106            } else {
107                let future = NotifyFuture::new();
108                state.waiting_list.push_back(future.clone());
109                Some(future)
110            }
111        };
112
113        if let Some(wait) = wait {
114            wait.await
115        } else {
116            let worker = match self.factory.create().await {
117                Ok(worker) => worker,
118                Err(err) => {
119                    let mut state = self.state.lock().unwrap();
120                    state.current_count -= 1;
121                    return Err(err)
122                },
123            };
124            Ok(WorkerGuard::new(worker, self.clone()))
125        }
126    }
127
128    fn release(self: &WorkerPoolRef<W, F>, work: W) {
129        if work.is_work() {
130            let mut state = self.state.lock().unwrap();
131            let future = state.waiting_list.pop_front();
132            if let Some(future) = future {
133                future.set_complete(Ok(WorkerGuard::new(work, self.clone())));
134            } else {
135                state.worker_list.push_back(work);
136            }
137        } else {
138            let mut state = self.state.lock().unwrap();
139            let future = state.waiting_list.pop_front();
140            if let Some(future) = future {
141                let rt = Runtime::new().unwrap();
142                let factory = self.factory.clone();
143                let this = self.clone();
144                rt.spawn(async move {
145                    match factory.create().await {
146                        Ok(worker) => {
147                            future.set_complete(Ok(WorkerGuard::new(worker, this)));
148                        }
149                        Err(err) => {
150                            let mut state = this.state.lock().unwrap();
151                            state.current_count -= 1;
152                            future.set_complete(Err(err));
153                        }
154                    }
155                });
156            } else {
157                state.current_count -= 1;
158            }
159        }
160    }
161}
162
163#[test]
164fn test_pool() {
165    struct TestWorker {
166        work: bool,
167    }
168
169    #[async_trait::async_trait]
170    impl Worker for TestWorker {
171        fn is_work(&self) -> bool {
172            self.work
173        }
174    }
175
176    struct TestWorkerFactory;
177
178    #[async_trait::async_trait]
179    impl WorkerFactory<TestWorker> for TestWorkerFactory {
180        async fn create(&self) -> PoolResult<TestWorker> {
181            Ok(TestWorker { work: true })
182        }
183    }
184
185    let pool = WorkerPool::new(2, TestWorkerFactory);
186    let rt = Runtime::new().unwrap();
187    let pool_ref = pool.clone();
188    rt.spawn(async move {
189        let _worker = pool_ref.get_worker().await;
190        tokio::time::sleep(Duration::from_secs(5)).await;
191    });
192    let pool_ref = pool.clone();
193    rt.spawn(async move {
194        let _worker = pool_ref.get_worker().await;
195        tokio::time::sleep(Duration::from_secs(10)).await;
196    });
197
198    let pool_ref = pool.clone();
199    rt.spawn(async move {
200        tokio::time::sleep(Duration::from_secs(2)).await;
201
202        let start = std::time::Instant::now();
203        let _worker3 = pool_ref.get_worker().await;
204        let end = std::time::Instant::now();
205        let duration = end.duration_since(start);
206        println!("duration {}", duration.as_millis());
207        assert!(duration.as_millis() > 2000);
208    });
209
210    sleep(Duration::from_secs(10));
211}