xrpc/
pool.rs

1use parking_lot::Mutex;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
7use std::time::{Duration, Instant};
8use tokio::sync::Semaphore;
9
10use crate::error::{Result, RpcError};
11
12pub struct PoolConfig {
13    pub min_connections: usize,
14    pub max_connections: usize,
15    pub acquire_timeout: Duration,
16    pub idle_timeout: Duration,
17    pub health_check_interval: Duration,
18}
19
20impl Default for PoolConfig {
21    fn default() -> Self {
22        Self {
23            min_connections: 1,
24            max_connections: 10,
25            acquire_timeout: Duration::from_secs(30),
26            idle_timeout: Duration::from_secs(300),
27            health_check_interval: Duration::from_secs(30),
28        }
29    }
30}
31
32impl PoolConfig {
33    pub fn new() -> Self {
34        Self::default()
35    }
36
37    pub fn with_min_connections(mut self, min: usize) -> Self {
38        self.min_connections = min;
39        self
40    }
41
42    pub fn with_max_connections(mut self, max: usize) -> Self {
43        self.max_connections = max;
44        self
45    }
46
47    pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
48        self.acquire_timeout = timeout;
49        self
50    }
51
52    pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
53        self.idle_timeout = timeout;
54        self
55    }
56}
57
58struct PooledItem<T> {
59    item: T,
60    last_used: Instant,
61}
62
63impl<T> PooledItem<T> {
64    fn new(item: T) -> Self {
65        Self {
66            item,
67            last_used: Instant::now(),
68        }
69    }
70
71    fn touch(&mut self) {
72        self.last_used = Instant::now();
73    }
74
75    #[allow(dead_code)]
76    fn is_idle(&self, timeout: Duration) -> bool {
77        self.last_used.elapsed() > timeout
78    }
79}
80
81type BoxedFactory<T> =
82    Box<dyn Fn() -> Pin<Box<dyn Future<Output = Result<T>> + Send>> + Send + Sync>;
83
84pub struct ConnectionPool<T> {
85    config: PoolConfig,
86    factory: BoxedFactory<T>,
87    available: Arc<Mutex<VecDeque<PooledItem<T>>>>,
88    semaphore: Arc<Semaphore>,
89    total_created: AtomicUsize,
90    shutdown: AtomicBool,
91}
92
93impl<T: Send + 'static> ConnectionPool<T> {
94    pub async fn new<F, Fut>(config: PoolConfig, factory: F) -> Result<Arc<Self>>
95    where
96        F: Fn() -> Fut + Send + Sync + 'static,
97        Fut: Future<Output = Result<T>> + Send + 'static,
98    {
99        let boxed_factory: BoxedFactory<T> = Box::new(move || Box::pin(factory()));
100
101        let pool = Arc::new(Self {
102            semaphore: Arc::new(Semaphore::new(config.max_connections)),
103            config,
104            factory: boxed_factory,
105            available: Arc::new(Mutex::new(VecDeque::new())),
106            total_created: AtomicUsize::new(0),
107            shutdown: AtomicBool::new(false),
108        });
109
110        for _ in 0..pool.config.min_connections {
111            let item = (pool.factory)().await?;
112            pool.available.lock().push_back(PooledItem::new(item));
113            pool.total_created.fetch_add(1, Ordering::Relaxed);
114        }
115
116        Ok(pool)
117    }
118
119    pub async fn get(self: &Arc<Self>) -> Result<PoolGuard<T>> {
120        if self.shutdown.load(Ordering::Acquire) {
121            return Err(RpcError::ClientError("Pool is shutdown".to_string()));
122        }
123
124        let permit = tokio::time::timeout(
125            self.config.acquire_timeout,
126            self.semaphore.clone().acquire_owned(),
127        )
128        .await
129        .map_err(|_| RpcError::Timeout("Acquire connection timeout".to_string()))?
130        .map_err(|_| RpcError::ClientError("Pool is closed".to_string()))?;
131
132        if let Some(mut pooled) = self.available.lock().pop_front() {
133            pooled.touch();
134            return Ok(PoolGuard {
135                item: Some(pooled.item),
136                pool: self.clone(),
137                _permit: permit,
138            });
139        }
140
141        let item = (self.factory)().await?;
142        self.total_created.fetch_add(1, Ordering::Relaxed);
143
144        Ok(PoolGuard {
145            item: Some(item),
146            pool: self.clone(),
147            _permit: permit,
148        })
149    }
150
151    fn return_item(&self, item: T) {
152        if self.shutdown.load(Ordering::Acquire) {
153            return;
154        }
155        self.available.lock().push_back(PooledItem::new(item));
156    }
157
158    pub fn available_count(&self) -> usize {
159        self.available.lock().len()
160    }
161
162    pub fn total_created(&self) -> usize {
163        self.total_created.load(Ordering::Relaxed)
164    }
165
166    pub fn shutdown(&self) {
167        self.shutdown.store(true, Ordering::Release);
168        self.available.lock().clear();
169    }
170}
171
172pub struct PoolGuard<T: Send + 'static> {
173    item: Option<T>,
174    pool: Arc<ConnectionPool<T>>,
175    _permit: tokio::sync::OwnedSemaphorePermit,
176}
177
178impl<T: Send + 'static> PoolGuard<T> {
179    pub fn get(&self) -> &T {
180        self.item.as_ref().unwrap()
181    }
182
183    pub fn get_mut(&mut self) -> &mut T {
184        self.item.as_mut().unwrap()
185    }
186}
187
188impl<T: Send + 'static> Drop for PoolGuard<T> {
189    fn drop(&mut self) {
190        if let Some(item) = self.item.take() {
191            self.pool.return_item(item);
192        }
193    }
194}
195
196impl<T: Send + 'static> std::ops::Deref for PoolGuard<T> {
197    type Target = T;
198
199    fn deref(&self) -> &Self::Target {
200        self.get()
201    }
202}
203
204impl<T: Send + 'static> std::ops::DerefMut for PoolGuard<T> {
205    fn deref_mut(&mut self) -> &mut Self::Target {
206        self.get_mut()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use std::sync::atomic::AtomicI32;
214
215    #[tokio::test]
216    async fn test_pool_basic() {
217        let counter = Arc::new(AtomicI32::new(0));
218        let counter_clone = counter.clone();
219
220        let config = PoolConfig::default()
221            .with_min_connections(2)
222            .with_max_connections(5);
223
224        let pool = ConnectionPool::new(config, move || {
225            let c = counter_clone.clone();
226            async move {
227                let val = c.fetch_add(1, Ordering::Relaxed);
228                Ok(val)
229            }
230        })
231        .await
232        .unwrap();
233
234        assert_eq!(pool.available_count(), 2);
235        assert_eq!(pool.total_created(), 2);
236
237        {
238            let conn = pool.get().await.unwrap();
239            assert!(*conn == 0 || *conn == 1);
240            assert_eq!(pool.available_count(), 1);
241        }
242
243        assert_eq!(pool.available_count(), 2);
244    }
245
246    #[tokio::test]
247    async fn test_pool_max_connections() {
248        let config = PoolConfig::default()
249            .with_min_connections(0)
250            .with_max_connections(2)
251            .with_acquire_timeout(Duration::from_millis(100));
252
253        let pool = ConnectionPool::new(config, || async { Ok(()) })
254            .await
255            .unwrap();
256
257        let _conn1 = pool.get().await.unwrap();
258        let _conn2 = pool.get().await.unwrap();
259
260        let result = pool.get().await;
261        assert!(result.is_err());
262    }
263}