1use parking_lot::Mutex;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::time::{Duration, Instant};
8use tokio::sync::Semaphore;
9
10use crate::error::{Result, RpcError};
11
12pub struct PoolConfig {
13 pub min_connections: usize,
14 pub max_connections: usize,
15 pub acquire_timeout: Duration,
16 pub idle_timeout: Duration,
17 pub health_check_interval: Duration,
18}
19
20impl Default for PoolConfig {
21 fn default() -> Self {
22 Self {
23 min_connections: 1,
24 max_connections: 10,
25 acquire_timeout: Duration::from_secs(30),
26 idle_timeout: Duration::from_secs(300),
27 health_check_interval: Duration::from_secs(30),
28 }
29 }
30}
31
32impl PoolConfig {
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 pub fn with_min_connections(mut self, min: usize) -> Self {
38 self.min_connections = min;
39 self
40 }
41
42 pub fn with_max_connections(mut self, max: usize) -> Self {
43 self.max_connections = max;
44 self
45 }
46
47 pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
48 self.acquire_timeout = timeout;
49 self
50 }
51
52 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
53 self.idle_timeout = timeout;
54 self
55 }
56}
57
58struct PooledItem<T> {
59 item: T,
60 last_used: Instant,
61}
62
63impl<T> PooledItem<T> {
64 fn new(item: T) -> Self {
65 Self {
66 item,
67 last_used: Instant::now(),
68 }
69 }
70
71 fn touch(&mut self) {
72 self.last_used = Instant::now();
73 }
74
75 #[allow(dead_code)]
76 fn is_idle(&self, timeout: Duration) -> bool {
77 self.last_used.elapsed() > timeout
78 }
79}
80
81type BoxedFactory<T> =
82 Box<dyn Fn() -> Pin<Box<dyn Future<Output = Result<T>> + Send>> + Send + Sync>;
83
84pub struct ConnectionPool<T> {
85 config: PoolConfig,
86 factory: BoxedFactory<T>,
87 available: Arc<Mutex<VecDeque<PooledItem<T>>>>,
88 semaphore: Arc<Semaphore>,
89 total_created: AtomicUsize,
90 shutdown: AtomicBool,
91}
92
93impl<T: Send + 'static> ConnectionPool<T> {
94 pub async fn new<F, Fut>(config: PoolConfig, factory: F) -> Result<Arc<Self>>
95 where
96 F: Fn() -> Fut + Send + Sync + 'static,
97 Fut: Future<Output = Result<T>> + Send + 'static,
98 {
99 let boxed_factory: BoxedFactory<T> = Box::new(move || Box::pin(factory()));
100
101 let pool = Arc::new(Self {
102 semaphore: Arc::new(Semaphore::new(config.max_connections)),
103 config,
104 factory: boxed_factory,
105 available: Arc::new(Mutex::new(VecDeque::new())),
106 total_created: AtomicUsize::new(0),
107 shutdown: AtomicBool::new(false),
108 });
109
110 for _ in 0..pool.config.min_connections {
111 let item = (pool.factory)().await?;
112 pool.available.lock().push_back(PooledItem::new(item));
113 pool.total_created.fetch_add(1, Ordering::Relaxed);
114 }
115
116 Ok(pool)
117 }
118
119 pub async fn get(self: &Arc<Self>) -> Result<PoolGuard<T>> {
120 if self.shutdown.load(Ordering::Acquire) {
121 return Err(RpcError::ClientError("Pool is shutdown".to_string()));
122 }
123
124 let permit = tokio::time::timeout(
125 self.config.acquire_timeout,
126 self.semaphore.clone().acquire_owned(),
127 )
128 .await
129 .map_err(|_| RpcError::Timeout("Acquire connection timeout".to_string()))?
130 .map_err(|_| RpcError::ClientError("Pool is closed".to_string()))?;
131
132 if let Some(mut pooled) = self.available.lock().pop_front() {
133 pooled.touch();
134 return Ok(PoolGuard {
135 item: Some(pooled.item),
136 pool: self.clone(),
137 _permit: permit,
138 });
139 }
140
141 let item = (self.factory)().await?;
142 self.total_created.fetch_add(1, Ordering::Relaxed);
143
144 Ok(PoolGuard {
145 item: Some(item),
146 pool: self.clone(),
147 _permit: permit,
148 })
149 }
150
151 fn return_item(&self, item: T) {
152 if self.shutdown.load(Ordering::Acquire) {
153 return;
154 }
155 self.available.lock().push_back(PooledItem::new(item));
156 }
157
158 pub fn available_count(&self) -> usize {
159 self.available.lock().len()
160 }
161
162 pub fn total_created(&self) -> usize {
163 self.total_created.load(Ordering::Relaxed)
164 }
165
166 pub fn shutdown(&self) {
167 self.shutdown.store(true, Ordering::Release);
168 self.available.lock().clear();
169 }
170}
171
172pub struct PoolGuard<T: Send + 'static> {
173 item: Option<T>,
174 pool: Arc<ConnectionPool<T>>,
175 _permit: tokio::sync::OwnedSemaphorePermit,
176}
177
178impl<T: Send + 'static> PoolGuard<T> {
179 pub fn get(&self) -> &T {
180 self.item.as_ref().unwrap()
181 }
182
183 pub fn get_mut(&mut self) -> &mut T {
184 self.item.as_mut().unwrap()
185 }
186}
187
188impl<T: Send + 'static> Drop for PoolGuard<T> {
189 fn drop(&mut self) {
190 if let Some(item) = self.item.take() {
191 self.pool.return_item(item);
192 }
193 }
194}
195
196impl<T: Send + 'static> std::ops::Deref for PoolGuard<T> {
197 type Target = T;
198
199 fn deref(&self) -> &Self::Target {
200 self.get()
201 }
202}
203
204impl<T: Send + 'static> std::ops::DerefMut for PoolGuard<T> {
205 fn deref_mut(&mut self) -> &mut Self::Target {
206 self.get_mut()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use std::sync::atomic::AtomicI32;
214
215 #[tokio::test]
216 async fn test_pool_basic() {
217 let counter = Arc::new(AtomicI32::new(0));
218 let counter_clone = counter.clone();
219
220 let config = PoolConfig::default()
221 .with_min_connections(2)
222 .with_max_connections(5);
223
224 let pool = ConnectionPool::new(config, move || {
225 let c = counter_clone.clone();
226 async move {
227 let val = c.fetch_add(1, Ordering::Relaxed);
228 Ok(val)
229 }
230 })
231 .await
232 .unwrap();
233
234 assert_eq!(pool.available_count(), 2);
235 assert_eq!(pool.total_created(), 2);
236
237 {
238 let conn = pool.get().await.unwrap();
239 assert!(*conn == 0 || *conn == 1);
240 assert_eq!(pool.available_count(), 1);
241 }
242
243 assert_eq!(pool.available_count(), 2);
244 }
245
246 #[tokio::test]
247 async fn test_pool_max_connections() {
248 let config = PoolConfig::default()
249 .with_min_connections(0)
250 .with_max_connections(2)
251 .with_acquire_timeout(Duration::from_millis(100));
252
253 let pool = ConnectionPool::new(config, || async { Ok(()) })
254 .await
255 .unwrap();
256
257 let _conn1 = pool.get().await.unwrap();
258 let _conn2 = pool.get().await.unwrap();
259
260 let result = pool.get().await;
261 assert!(result.is_err());
262 }
263}