1use std::collections::VecDeque;
2use std::ops::{Deref, DerefMut};
3use std::sync::{Arc, Mutex};
4use std::thread::sleep;
5use std::time::Duration;
6use notify_future::NotifyFuture;
7use tokio::runtime::Runtime;
8pub use sfo_result::err as pool_err;
9pub use sfo_result::into_err as into_pool_err;
10
11#[derive(Debug, Copy, Clone, Default, Eq, PartialEq)]
12pub enum PoolErrorCode {
13 #[default]
14 Failed,
15}
16pub type PoolError = sfo_result::Error<PoolErrorCode>;
17pub type PoolResult<T> = sfo_result::Result<T, PoolErrorCode>;
18
19#[async_trait::async_trait]
20pub trait Worker: Send + 'static {
21 fn is_work(&self) -> bool;
22}
23
24pub struct WorkerGuard<W: Worker, F: WorkerFactory<W>> {
25 pool_ref: WorkerPoolRef<W, F>,
26 worker: Option<W>
27}
28
29impl<W: Worker, F: WorkerFactory<W>> WorkerGuard<W, F> {
30 fn new(worker: W, pool_ref: WorkerPoolRef<W, F>) -> Self {
31 WorkerGuard {
32 pool_ref,
33 worker: Some(worker)
34 }
35 }
36}
37
38impl<W: Worker, F: WorkerFactory<W>> Deref for WorkerGuard<W, F> {
39 type Target = W;
40
41 fn deref(&self) -> &Self::Target {
42 self.worker.as_ref().unwrap()
43 }
44}
45
46impl<W: Worker, F: WorkerFactory<W>> DerefMut for WorkerGuard<W, F> {
47 fn deref_mut(&mut self) -> &mut Self::Target {
48 self.worker.as_mut().unwrap()
49 }
50}
51
52impl<W: Worker, F: WorkerFactory<W>> Drop for WorkerGuard<W, F> {
53 fn drop(&mut self) {
54 if let Some(worker) = self.worker.take() {
55 self.pool_ref.release(worker);
56 }
57 }
58}
59
60#[async_trait::async_trait]
61pub trait WorkerFactory<W: Worker>: Send + Sync + 'static {
62 async fn create(&self) -> PoolResult<W>;
63}
64
65struct WorkerPoolState<W: Worker, F: WorkerFactory<W>> {
66 current_count: u16,
67 worker_list: VecDeque<W>,
68 waiting_list: VecDeque<NotifyFuture<PoolResult<WorkerGuard<W, F>>>>,
69}
70pub struct WorkerPool<W: Worker, F: WorkerFactory<W>> {
71 factory: Arc<F>,
72 max_count: u16,
73 state: Mutex<WorkerPoolState<W, F>>,
74}
75pub type WorkerPoolRef<W, F> = Arc<WorkerPool<W, F>>;
76
77impl<W: Worker, F: WorkerFactory<W>> WorkerPool<W, F> {
78 pub fn new(max_count: u16, factory: F) -> WorkerPoolRef<W, F> {
79 Arc::new(WorkerPool {
80 factory: Arc::new(factory),
81 max_count,
82 state: Mutex::new(WorkerPoolState {
83 current_count: 0,
84 worker_list: VecDeque::with_capacity(max_count as usize),
85 waiting_list: VecDeque::new(),
86 }),
87 })
88 }
89
90 pub async fn get_worker(self: &WorkerPoolRef<W, F>) -> PoolResult<WorkerGuard<W, F>> {
91 let wait = {
92 let mut state = self.state.lock().unwrap();
93
94 while state.worker_list.len() > 0 {
95 let worker = state.worker_list.pop_front().unwrap();
96 if !worker.is_work() {
97 state.current_count -= 1;
98 continue;
99 }
100 return Ok(WorkerGuard::new(worker, self.clone()));
101 }
102
103 if state.current_count < self.max_count {
104 state.current_count += 1;
105 None
106 } else {
107 let future = NotifyFuture::new();
108 state.waiting_list.push_back(future.clone());
109 Some(future)
110 }
111 };
112
113 if let Some(wait) = wait {
114 wait.await
115 } else {
116 let worker = match self.factory.create().await {
117 Ok(worker) => worker,
118 Err(err) => {
119 let mut state = self.state.lock().unwrap();
120 state.current_count -= 1;
121 return Err(err)
122 },
123 };
124 Ok(WorkerGuard::new(worker, self.clone()))
125 }
126 }
127
128 fn release(self: &WorkerPoolRef<W, F>, work: W) {
129 if work.is_work() {
130 let mut state = self.state.lock().unwrap();
131 let future = state.waiting_list.pop_front();
132 if let Some(future) = future {
133 future.set_complete(Ok(WorkerGuard::new(work, self.clone())));
134 } else {
135 state.worker_list.push_back(work);
136 }
137 } else {
138 let mut state = self.state.lock().unwrap();
139 let future = state.waiting_list.pop_front();
140 if let Some(future) = future {
141 let rt = Runtime::new().unwrap();
142 let factory = self.factory.clone();
143 let this = self.clone();
144 rt.spawn(async move {
145 match factory.create().await {
146 Ok(worker) => {
147 future.set_complete(Ok(WorkerGuard::new(worker, this)));
148 }
149 Err(err) => {
150 let mut state = this.state.lock().unwrap();
151 state.current_count -= 1;
152 future.set_complete(Err(err));
153 }
154 }
155 });
156 } else {
157 state.current_count -= 1;
158 }
159 }
160 }
161}
162
163#[test]
164fn test_pool() {
165 struct TestWorker {
166 work: bool,
167 }
168
169 #[async_trait::async_trait]
170 impl Worker for TestWorker {
171 fn is_work(&self) -> bool {
172 self.work
173 }
174 }
175
176 struct TestWorkerFactory;
177
178 #[async_trait::async_trait]
179 impl WorkerFactory<TestWorker> for TestWorkerFactory {
180 async fn create(&self) -> PoolResult<TestWorker> {
181 Ok(TestWorker { work: true })
182 }
183 }
184
185 let pool = WorkerPool::new(2, TestWorkerFactory);
186 let rt = Runtime::new().unwrap();
187 let pool_ref = pool.clone();
188 rt.spawn(async move {
189 let _worker = pool_ref.get_worker().await;
190 tokio::time::sleep(Duration::from_secs(5)).await;
191 });
192 let pool_ref = pool.clone();
193 rt.spawn(async move {
194 let _worker = pool_ref.get_worker().await;
195 tokio::time::sleep(Duration::from_secs(10)).await;
196 });
197
198 let pool_ref = pool.clone();
199 rt.spawn(async move {
200 tokio::time::sleep(Duration::from_secs(2)).await;
201
202 let start = std::time::Instant::now();
203 let _worker3 = pool_ref.get_worker().await;
204 let end = std::time::Instant::now();
205 let duration = end.duration_since(start);
206 println!("duration {}", duration.as_millis());
207 assert!(duration.as_millis() > 2000);
208 });
209
210 sleep(Duration::from_secs(10));
211}