1use std::collections::VecDeque;
2use std::ops::{Deref, DerefMut};
3use std::sync::{Arc, Mutex};
4use std::thread::sleep;
5use std::time::Duration;
6use notify_future::{Notify};
7use tokio::runtime::Runtime;
8pub use sfo_result::err as pool_err;
9pub use sfo_result::into_err as into_pool_err;
10
11#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
12pub enum PoolErrorCode {
13 #[default]
14 Failed,
15}
16pub type PoolError = sfo_result::Error<PoolErrorCode>;
17pub type PoolResult<T> = sfo_result::Result<T, PoolErrorCode>;
18
19#[async_trait::async_trait]
20pub trait Worker: Send + 'static {
21 fn is_work(&self) -> bool;
22}
23
24pub struct WorkerGuard<W: Worker, F: WorkerFactory<W>> {
25 pool_ref: WorkerPoolRef<W, F>,
26 worker: Option<W>
27}
28
29impl<W: Worker, F: WorkerFactory<W>> WorkerGuard<W, F> {
30 fn new(worker: W, pool_ref: WorkerPoolRef<W, F>) -> Self {
31 WorkerGuard {
32 pool_ref,
33 worker: Some(worker)
34 }
35 }
36}
37
38impl<W: Worker, F: WorkerFactory<W>> Deref for WorkerGuard<W, F> {
39 type Target = W;
40
41 fn deref(&self) -> &Self::Target {
42 self.worker.as_ref().unwrap()
43 }
44}
45
46impl<W: Worker, F: WorkerFactory<W>> DerefMut for WorkerGuard<W, F> {
47 fn deref_mut(&mut self) -> &mut Self::Target {
48 self.worker.as_mut().unwrap()
49 }
50}
51
52impl<W: Worker, F: WorkerFactory<W>> Drop for WorkerGuard<W, F> {
53 fn drop(&mut self) {
54 if let Some(worker) = self.worker.take() {
55 self.pool_ref.release(worker);
56 }
57 }
58}
59
60#[async_trait::async_trait]
61pub trait WorkerFactory<W: Worker>: Send + Sync + 'static {
62 async fn create(&self) -> PoolResult<W>;
63}
64
65struct WorkerPoolState<W: Worker, F: WorkerFactory<W>> {
66 current_count: u16,
67 worker_list: VecDeque<W>,
68 waiting_list: VecDeque<Notify<PoolResult<WorkerGuard<W, F>>>>,
69 clear_notify: Option<Notify<()>>,
70}
71pub struct WorkerPool<W: Worker, F: WorkerFactory<W>> {
72 factory: Arc<F>,
73 max_count: u16,
74 state: Mutex<WorkerPoolState<W, F>>,
75}
76pub type WorkerPoolRef<W, F> = Arc<WorkerPool<W, F>>;
77
78impl<W: Worker, F: WorkerFactory<W>> WorkerPool<W, F> {
79 pub fn new(max_count: u16, factory: F) -> WorkerPoolRef<W, F> {
80 Arc::new(WorkerPool {
81 factory: Arc::new(factory),
82 max_count,
83 state: Mutex::new(WorkerPoolState {
84 current_count: 0,
85 worker_list: VecDeque::with_capacity(max_count as usize),
86 waiting_list: VecDeque::new(),
87 clear_notify: None,
88 }),
89 })
90 }
91
92 pub async fn get_worker(self: &WorkerPoolRef<W, F>) -> PoolResult<WorkerGuard<W, F>> {
93 let wait = {
94 let mut state = self.state.lock().unwrap();
95 if state.clear_notify.is_some() {
96 return Err(PoolError::new(PoolErrorCode::Failed, "pool is clearing".to_string()));
97 }
98
99 while state.worker_list.len() > 0 {
100 let worker = state.worker_list.pop_front().unwrap();
101 if !worker.is_work() {
102 state.current_count -= 1;
103 continue;
104 }
105 return Ok(WorkerGuard::new(worker, self.clone()));
106 }
107
108 if state.current_count < self.max_count {
109 state.current_count += 1;
110 None
111 } else {
112 let (notify, waiter) = Notify::new();
113 state.waiting_list.push_back(notify);
114 Some(waiter)
115 }
116 };
117
118 if let Some(wait) = wait {
119 wait.await
120 } else {
121 let worker = match self.factory.create().await {
122 Ok(worker) => worker,
123 Err(err) => {
124 let mut state = self.state.lock().unwrap();
125 state.current_count -= 1;
126 if state.current_count == 0 && state.clear_notify.is_some() {
127 state.clear_notify.take().unwrap().notify(());
128 }
129 return Err(err)
130 },
131 };
132 Ok(WorkerGuard::new(worker, self.clone()))
133 }
134 }
135
136 pub async fn clear_all_worker(&self) {
137 let waiter = {
138 let mut state = self.state.lock().unwrap();
139 let cur_worker_count = state.worker_list.len();
140 state.worker_list.clear();
141 state.current_count -= cur_worker_count as u16;
142
143 for waiting in state.waiting_list.drain(..) {
144 waiting.notify(Err(PoolError::new(PoolErrorCode::Failed, "pool cleared".to_string())));
145 }
146
147 if state.current_count == 0 {
148 return;
149 }
150
151 let (notify, waiter) = Notify::new();
152 state.clear_notify = Some(notify);
153 waiter
154 };
155 waiter.await;
156 {
157 let mut state = self.state.lock().unwrap();
158 for waiting in state.waiting_list.drain(..) {
159 waiting.notify(Err(PoolError::new(PoolErrorCode::Failed, "pool cleared".to_string())));
160 }
161 }
162 }
163
164 fn release(self: &WorkerPoolRef<W, F>, work: W) {
165 {
166 let mut state = self.state.lock().unwrap();
167 if state.clear_notify.is_some() {
168 state.current_count -= 1;
169 if state.current_count == 0 {
170 state.clear_notify.take().unwrap().notify(());
171 }
172 return;
173 }
174 }
175 if work.is_work() {
176 let mut state = self.state.lock().unwrap();
177 let future = state.waiting_list.pop_front();
178 if let Some(future) = future {
179 future.notify(Ok(WorkerGuard::new(work, self.clone())));
180 } else {
181 state.worker_list.push_back(work);
182 }
183 } else {
184 let mut state = self.state.lock().unwrap();
185 let future = state.waiting_list.pop_front();
186 if let Some(future) = future {
187 let factory = self.factory.clone();
188 let this = self.clone();
189 tokio::spawn(async move {
190 match factory.create().await {
191 Ok(worker) => {
192 future.notify(Ok(WorkerGuard::new(worker, this)));
193 }
194 Err(err) => {
195 let mut state = this.state.lock().unwrap();
196 state.current_count -= 1;
197 future.notify(Err(err));
198
199 if state.current_count == 0 && state.clear_notify.is_some() {
200 state.clear_notify.take().unwrap().notify(());
201 }
202 }
203 }
204 });
205 } else {
206 state.current_count -= 1;
207 if state.current_count == 0 && state.clear_notify.is_some() {
208 state.clear_notify.take().unwrap().notify(());
209 }
210 }
211 }
212 }
213}
214
215#[test]
216fn test_pool() {
217 struct TestWorker {
218 work: bool,
219 }
220
221 #[async_trait::async_trait]
222 impl Worker for TestWorker {
223 fn is_work(&self) -> bool {
224 self.work
225 }
226 }
227
228 struct TestWorkerFactory;
229
230 #[async_trait::async_trait]
231 impl WorkerFactory<TestWorker> for TestWorkerFactory {
232 async fn create(&self) -> PoolResult<TestWorker> {
233 Ok(TestWorker { work: true })
234 }
235 }
236
237 let pool = WorkerPool::new(2, TestWorkerFactory);
238 let rt = Runtime::new().unwrap();
239 let pool_ref = pool.clone();
240 rt.spawn(async move {
241 let _worker = pool_ref.get_worker().await;
242 tokio::time::sleep(Duration::from_secs(5)).await;
243 });
244 let pool_ref = pool.clone();
245 rt.spawn(async move {
246 let _worker = pool_ref.get_worker().await;
247 tokio::time::sleep(Duration::from_secs(10)).await;
248 });
249
250 let pool_ref = pool.clone();
251 rt.spawn(async move {
252 tokio::time::sleep(Duration::from_secs(2)).await;
253
254 let start = std::time::Instant::now();
255 let _worker3 = pool_ref.get_worker().await;
256 let end = std::time::Instant::now();
257 let duration = end.duration_since(start);
258 println!("duration {}", duration.as_millis());
259 assert!(duration.as_millis() > 2000);
260 });
261
262 sleep(Duration::from_secs(10));
263
264 let pool_ref = pool.clone();
265 rt.spawn(async move {
266 let _worker = pool_ref.get_worker().await;
267 let _worker1 = pool_ref.get_worker().await;
268 tokio::time::sleep(Duration::from_secs(5)).await;
269 });
270
271 let pool_ref = pool.clone();
272 rt.spawn(async move {
273 tokio::time::sleep(Duration::from_secs(1)).await;
274 let worker = pool_ref.get_worker().await;
275 assert!(worker.is_err());
276 });
277
278 let pool_ref = pool.clone();
279 rt.spawn(async move {
280 tokio::time::sleep(Duration::from_secs(2)).await;
281 let worker = pool_ref.get_worker().await;
282 assert!(worker.is_err());
283 });
284
285 let pool_ref = pool.clone();
286 rt.spawn(async move {
287 let start = std::time::Instant::now();
288 pool_ref.clear_all_worker().await;
289 let end = std::time::Instant::now();
290 let duration = end.duration_since(start);
291 println!("duration1 {}", duration.as_millis());
292 assert!(duration.as_millis() > 4000);
293 });
294
295 sleep(Duration::from_secs(10));
296}