1use {
7 parking_lot::RwLock,
8 std::{
9 sync::{
10 Arc,
11 atomic::{AtomicU64, AtomicUsize, Ordering},
12 },
13 time::{Duration, Instant},
14 },
15 tokio::sync::Semaphore,
16 tracing::{debug, info, warn},
17};
18
19use crate::{ZusError, endpoint::RpcEndpoint, error::Result};
20
21#[derive(Debug, Clone)]
23pub struct ConnectionPoolConfig {
24 pub min_connections: usize,
26
27 pub max_connections: usize,
29
30 pub idle_timeout: Duration,
32
33 pub connect_timeout: Duration,
35
36 pub health_check_interval: Duration,
38}
39
40impl Default for ConnectionPoolConfig {
41 fn default() -> Self {
42 Self {
43 min_connections: 1,
44 max_connections: 1, idle_timeout: Duration::from_secs(60),
46 connect_timeout: Duration::from_secs(5),
47 health_check_interval: Duration::from_secs(10),
48 }
49 }
50}
51
52impl ConnectionPoolConfig {
53 pub fn for_gateway() -> Self {
55 Self {
56 min_connections: 2,
57 max_connections: 10,
58 idle_timeout: Duration::from_secs(300),
59 connect_timeout: Duration::from_secs(5),
60 health_check_interval: Duration::from_secs(10),
61 }
62 }
63
64 pub fn with_pool_size(min: usize, max: usize) -> Self {
66 Self {
67 min_connections: min,
68 max_connections: max,
69 ..Default::default()
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy, Default)]
76pub enum ConnectionSelectionStrategy {
77 #[default]
79 RoundRobin,
80
81 LeastPending,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum ConnectionHealth {
88 Healthy,
89 Unhealthy,
90}
91
92pub struct PooledConnection {
94 endpoint: Arc<RpcEndpoint>,
95 health: RwLock<ConnectionHealth>,
96 pending_requests: AtomicU64,
97 last_used: RwLock<Instant>,
98 created_at: Instant,
99}
100
101impl PooledConnection {
102 fn new(endpoint: RpcEndpoint) -> Self {
103 Self {
104 endpoint: Arc::new(endpoint),
105 health: RwLock::new(ConnectionHealth::Healthy),
106 pending_requests: AtomicU64::new(0),
107 last_used: RwLock::new(Instant::now()),
108 created_at: Instant::now(),
109 }
110 }
111
112 pub fn endpoint(&self) -> Arc<RpcEndpoint> {
114 *self.last_used.write() = Instant::now();
115 self.endpoint.clone()
116 }
117
118 pub fn inc_pending(&self) {
120 self.pending_requests.fetch_add(1, Ordering::SeqCst);
121 }
122
123 pub fn dec_pending(&self) {
125 self.pending_requests.fetch_sub(1, Ordering::SeqCst);
126 }
127
128 pub fn pending_count(&self) -> u64 {
130 self.pending_requests.load(Ordering::SeqCst)
131 }
132
133 pub fn is_healthy(&self) -> bool {
135 *self.health.read() == ConnectionHealth::Healthy
136 }
137
138 pub fn mark_unhealthy(&self) {
140 *self.health.write() = ConnectionHealth::Unhealthy;
141 }
142
143 pub fn mark_healthy(&self) {
145 *self.health.write() = ConnectionHealth::Healthy;
146 }
147
148 pub fn idle_duration(&self) -> Duration {
150 self.last_used.read().elapsed()
151 }
152
153 pub fn age(&self) -> Duration {
155 self.created_at.elapsed()
156 }
157
158 pub fn address(&self) -> String {
160 self.endpoint.address()
161 }
162}
163
164pub struct ConnectionPool {
168 host: String,
169 port: u16,
170 config: ConnectionPoolConfig,
171 connections: RwLock<Vec<Arc<PooledConnection>>>,
172 round_robin_index: AtomicUsize,
173 selection_strategy: ConnectionSelectionStrategy,
174 #[allow(dead_code)] connection_semaphore: Semaphore,
176}
177
178impl ConnectionPool {
179 pub async fn new(host: String, port: u16, config: ConnectionPoolConfig) -> Result<Arc<Self>> {
181 let max_connections = config.max_connections;
182
183 let pool = Arc::new(Self {
184 host: host.clone(),
185 port,
186 config: config.clone(),
187 connections: RwLock::new(Vec::with_capacity(config.max_connections)),
188 round_robin_index: AtomicUsize::new(0),
189 selection_strategy: ConnectionSelectionStrategy::RoundRobin,
190 connection_semaphore: Semaphore::new(max_connections),
191 });
192
193 pool.ensure_min_connections().await?;
195
196 let pool_clone = pool.clone();
198 tokio::spawn(async move {
199 pool_clone.maintenance_loop().await;
200 });
201
202 info!(
203 "Connection pool created for {}:{} (min={}, max={})",
204 host, port, config.min_connections, config.max_connections
205 );
206
207 Ok(pool)
208 }
209
210 pub async fn new_default(host: String, port: u16) -> Result<Arc<Self>> {
212 Self::new(host, port, ConnectionPoolConfig::default()).await
213 }
214
215 pub async fn new_for_gateway(host: String, port: u16) -> Result<Arc<Self>> {
217 Self::new(host, port, ConnectionPoolConfig::for_gateway()).await
218 }
219
220 pub fn with_selection_strategy(self: Arc<Self>, strategy: ConnectionSelectionStrategy) -> Arc<Self> {
222 Arc::new(Self {
225 host: self.host.clone(),
226 port: self.port,
227 config: self.config.clone(),
228 connections: RwLock::new(self.connections.read().clone()),
229 round_robin_index: AtomicUsize::new(0),
230 selection_strategy: strategy,
231 connection_semaphore: Semaphore::new(self.config.max_connections),
232 })
233 }
234
235 async fn ensure_min_connections(&self) -> Result<()> {
237 let current_count = self.connections.read().len();
238 let needed = self.config.min_connections.saturating_sub(current_count);
239
240 for _ in 0..needed {
241 if let Err(e) = self.create_connection().await {
242 warn!(
243 "Failed to create initial connection to {}:{}: {:?}",
244 self.host, self.port, e
245 );
246 }
248 }
249
250 let final_count = self.connections.read().len();
251 if final_count == 0 {
252 return Err(ZusError::Connection(format!(
253 "Failed to create any connections to {}:{}",
254 self.host, self.port
255 )));
256 }
257
258 Ok(())
259 }
260
261 async fn create_connection(&self) -> Result<Arc<PooledConnection>> {
263 let current_count = self.connections.read().len();
265 if current_count >= self.config.max_connections {
266 return Err(ZusError::Connection("Pool at maximum capacity".to_string()));
267 }
268
269 let endpoint = tokio::time::timeout(
271 self.config.connect_timeout,
272 RpcEndpoint::connect(self.host.clone(), self.port),
273 )
274 .await
275 .map_err(|_| ZusError::Timeout)??;
276
277 let pooled = Arc::new(PooledConnection::new(endpoint));
278
279 self.connections.write().push(pooled.clone());
281
282 debug!(
283 "Created new connection to {}:{} (pool size: {})",
284 self.host,
285 self.port,
286 self.connections.read().len()
287 );
288
289 Ok(pooled)
290 }
291
292 pub async fn get_connection(&self) -> Result<Arc<PooledConnection>> {
294 if let Some(conn) = self.select_connection() {
296 return Ok(conn);
297 }
298
299 if self.connections.read().len() < self.config.max_connections
301 && let Ok(conn) = self.create_connection().await
302 {
303 return Ok(conn);
304 }
305
306 let connections = self.connections.read();
308 if connections.is_empty() {
309 return Err(ZusError::Connection("No connections available".to_string()));
310 }
311
312 let index = self.round_robin_index.fetch_add(1, Ordering::SeqCst);
313 Ok(connections[index % connections.len()].clone())
314 }
315
316 fn select_connection(&self) -> Option<Arc<PooledConnection>> {
318 let connections = self.connections.read();
319 if connections.is_empty() {
320 return None;
321 }
322
323 let healthy: Vec<_> = connections.iter().filter(|c| c.is_healthy()).cloned().collect();
325
326 let pool = if healthy.is_empty() {
327 &connections[..]
328 } else {
329 &healthy[..]
330 };
331
332 match self.selection_strategy {
333 | ConnectionSelectionStrategy::RoundRobin => {
334 let index = self.round_robin_index.fetch_add(1, Ordering::SeqCst);
335 Some(pool[index % pool.len()].clone())
336 }
337 | ConnectionSelectionStrategy::LeastPending => pool.iter().min_by_key(|c| c.pending_count()).cloned(),
338 }
339 }
340
341 pub fn stats(&self) -> PoolStats {
343 let connections = self.connections.read();
344 let total = connections.len();
345 let healthy = connections.iter().filter(|c| c.is_healthy()).count();
346 let total_pending: u64 = connections.iter().map(|c| c.pending_count()).sum();
347
348 PoolStats {
349 total_connections: total,
350 healthy_connections: healthy,
351 unhealthy_connections: total - healthy,
352 total_pending_requests: total_pending,
353 max_connections: self.config.max_connections,
354 min_connections: self.config.min_connections,
355 }
356 }
357
358 pub fn address(&self) -> String {
360 format!("{}:{}", self.host, self.port)
361 }
362
363 async fn maintenance_loop(self: Arc<Self>) {
365 let mut interval = tokio::time::interval(self.config.health_check_interval);
366
367 loop {
368 interval.tick().await;
369
370 self.cleanup_unhealthy();
372
373 self.cleanup_idle();
375
376 if let Err(e) = self.ensure_min_connections().await {
378 warn!("Failed to maintain minimum connections: {:?}", e);
379 }
380 }
381 }
382
383 fn cleanup_unhealthy(&self) {
385 let mut connections = self.connections.write();
386 let before = connections.len();
387 let min_connections = self.config.min_connections;
388
389 let unhealthy_count = connections.iter().filter(|c| !c.is_healthy()).count();
391 let can_remove = before.saturating_sub(min_connections).min(unhealthy_count);
392
393 if can_remove > 0 {
394 let mut removed = 0;
395 connections.retain(|c| {
396 if !c.is_healthy() && removed < can_remove {
397 removed += 1;
398 false
399 } else {
400 true
401 }
402 });
403
404 debug!(
405 "Removed {} unhealthy connections from pool {}:{}",
406 removed, self.host, self.port
407 );
408 }
409 }
410
411 fn cleanup_idle(&self) {
413 let mut connections = self.connections.write();
414 let before = connections.len();
415 let min_connections = self.config.min_connections;
416 let idle_timeout = self.config.idle_timeout;
417
418 if before <= min_connections {
420 return;
421 }
422
423 let idle_count = connections
425 .iter()
426 .filter(|c| c.idle_duration() >= idle_timeout && c.pending_count() == 0)
427 .count();
428 let can_remove = before.saturating_sub(min_connections).min(idle_count);
429
430 if can_remove > 0 {
431 let mut removed = 0;
432 connections.retain(|c| {
433 if c.idle_duration() >= idle_timeout && c.pending_count() == 0 && removed < can_remove {
434 removed += 1;
435 false
436 } else {
437 true
438 }
439 });
440
441 debug!(
442 "Removed {} idle connections from pool {}:{}",
443 removed, self.host, self.port
444 );
445 }
446 }
447}
448
449#[derive(Debug, Clone)]
451pub struct PoolStats {
452 pub total_connections: usize,
453 pub healthy_connections: usize,
454 pub unhealthy_connections: usize,
455 pub total_pending_requests: u64,
456 pub max_connections: usize,
457 pub min_connections: usize,
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 #[test]
465 fn test_default_config() {
466 let config = ConnectionPoolConfig::default();
467 assert_eq!(config.min_connections, 1);
468 assert_eq!(config.max_connections, 1);
469 }
470
471 #[test]
472 fn test_gateway_config() {
473 let config = ConnectionPoolConfig::for_gateway();
474 assert_eq!(config.min_connections, 2);
475 assert_eq!(config.max_connections, 10);
476 }
477
478 #[test]
479 fn test_pool_config_with_size() {
480 let config = ConnectionPoolConfig::with_pool_size(4, 20);
481 assert_eq!(config.min_connections, 4);
482 assert_eq!(config.max_connections, 20);
483 }
484
485 #[test]
486 fn test_pooled_connection_pending() {
487 let pending = AtomicU64::new(0);
489 pending.fetch_add(1, Ordering::SeqCst);
490 assert_eq!(pending.load(Ordering::SeqCst), 1);
491 pending.fetch_sub(1, Ordering::SeqCst);
492 assert_eq!(pending.load(Ordering::SeqCst), 0);
493 }
494
495 #[test]
496 fn test_pool_stats() {
497 let stats = PoolStats {
498 total_connections: 5,
499 healthy_connections: 4,
500 unhealthy_connections: 1,
501 total_pending_requests: 10,
502 max_connections: 10,
503 min_connections: 2,
504 };
505
506 assert_eq!(stats.total_connections, 5);
507 assert_eq!(stats.healthy_connections, 4);
508 }
509}