zus_common/
pool.rs

1//! Connection Pool for ZUS RPC Client
2//!
3//! Provides connection pooling for high-throughput scenarios like gateways.
4//! Default configuration (pool_size=1) maintains backward compatibility.
5
6use {
7  parking_lot::RwLock,
8  std::{
9    sync::{
10      Arc,
11      atomic::{AtomicU64, AtomicUsize, Ordering},
12    },
13    time::{Duration, Instant},
14  },
15  tokio::sync::Semaphore,
16  tracing::{debug, info, warn},
17};
18
19use crate::{ZusError, endpoint::RpcEndpoint, error::Result};
20
21/// Connection pool configuration
22#[derive(Debug, Clone)]
23pub struct ConnectionPoolConfig {
24  /// Minimum connections to maintain (default: 1)
25  pub min_connections: usize,
26
27  /// Maximum connections allowed (default: 1 for backward compatibility)
28  pub max_connections: usize,
29
30  /// Idle timeout before closing excess connections (default: 60s)
31  pub idle_timeout: Duration,
32
33  /// Connection timeout (default: 5s)
34  pub connect_timeout: Duration,
35
36  /// Health check interval (default: 10s)
37  pub health_check_interval: Duration,
38}
39
40impl Default for ConnectionPoolConfig {
41  fn default() -> Self {
42    Self {
43      min_connections: 1,
44      max_connections: 1, // Backward compatible
45      idle_timeout: Duration::from_secs(60),
46      connect_timeout: Duration::from_secs(5),
47      health_check_interval: Duration::from_secs(10),
48    }
49  }
50}
51
52impl ConnectionPoolConfig {
53  /// Configuration optimized for gateway scenarios
54  pub fn for_gateway() -> Self {
55    Self {
56      min_connections: 2,
57      max_connections: 10,
58      idle_timeout: Duration::from_secs(300),
59      connect_timeout: Duration::from_secs(5),
60      health_check_interval: Duration::from_secs(10),
61    }
62  }
63
64  /// Create configuration with specific pool size
65  pub fn with_pool_size(min: usize, max: usize) -> Self {
66    Self {
67      min_connections: min,
68      max_connections: max,
69      ..Default::default()
70    }
71  }
72}
73
74/// Strategy for selecting connections from the pool
75#[derive(Debug, Clone, Copy, Default)]
76pub enum ConnectionSelectionStrategy {
77  /// Round-robin selection (default)
78  #[default]
79  RoundRobin,
80
81  /// Select connection with least pending requests
82  LeastPending,
83}
84
85/// Health status of a pooled connection
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87pub enum ConnectionHealth {
88  Healthy,
89  Unhealthy,
90}
91
92/// A connection wrapper with health and usage tracking
93pub struct PooledConnection {
94  endpoint: Arc<RpcEndpoint>,
95  health: RwLock<ConnectionHealth>,
96  pending_requests: AtomicU64,
97  last_used: RwLock<Instant>,
98  created_at: Instant,
99}
100
101impl PooledConnection {
102  fn new(endpoint: RpcEndpoint) -> Self {
103    Self {
104      endpoint: Arc::new(endpoint),
105      health: RwLock::new(ConnectionHealth::Healthy),
106      pending_requests: AtomicU64::new(0),
107      last_used: RwLock::new(Instant::now()),
108      created_at: Instant::now(),
109    }
110  }
111
112  /// Get the underlying endpoint
113  pub fn endpoint(&self) -> Arc<RpcEndpoint> {
114    *self.last_used.write() = Instant::now();
115    self.endpoint.clone()
116  }
117
118  /// Increment pending request count
119  pub fn inc_pending(&self) {
120    self.pending_requests.fetch_add(1, Ordering::SeqCst);
121  }
122
123  /// Decrement pending request count
124  pub fn dec_pending(&self) {
125    self.pending_requests.fetch_sub(1, Ordering::SeqCst);
126  }
127
128  /// Get pending request count
129  pub fn pending_count(&self) -> u64 {
130    self.pending_requests.load(Ordering::SeqCst)
131  }
132
133  /// Check if connection is healthy
134  pub fn is_healthy(&self) -> bool {
135    *self.health.read() == ConnectionHealth::Healthy
136  }
137
138  /// Mark connection as unhealthy
139  pub fn mark_unhealthy(&self) {
140    *self.health.write() = ConnectionHealth::Unhealthy;
141  }
142
143  /// Mark connection as healthy
144  pub fn mark_healthy(&self) {
145    *self.health.write() = ConnectionHealth::Healthy;
146  }
147
148  /// Get time since last use
149  pub fn idle_duration(&self) -> Duration {
150    self.last_used.read().elapsed()
151  }
152
153  /// Get connection age
154  pub fn age(&self) -> Duration {
155    self.created_at.elapsed()
156  }
157
158  /// Get the address of this connection
159  pub fn address(&self) -> String {
160    self.endpoint.address()
161  }
162}
163
164/// Connection pool for a single server
165///
166/// Manages multiple connections to the same server for high throughput.
167pub struct ConnectionPool {
168  host: String,
169  port: u16,
170  config: ConnectionPoolConfig,
171  connections: RwLock<Vec<Arc<PooledConnection>>>,
172  round_robin_index: AtomicUsize,
173  selection_strategy: ConnectionSelectionStrategy,
174  #[allow(dead_code)] // Reserved for future back-pressure support
175  connection_semaphore: Semaphore,
176}
177
178impl ConnectionPool {
179  /// Create a new connection pool
180  pub async fn new(host: String, port: u16, config: ConnectionPoolConfig) -> Result<Arc<Self>> {
181    let max_connections = config.max_connections;
182
183    let pool = Arc::new(Self {
184      host: host.clone(),
185      port,
186      config: config.clone(),
187      connections: RwLock::new(Vec::with_capacity(config.max_connections)),
188      round_robin_index: AtomicUsize::new(0),
189      selection_strategy: ConnectionSelectionStrategy::RoundRobin,
190      connection_semaphore: Semaphore::new(max_connections),
191    });
192
193    // Create minimum connections
194    pool.ensure_min_connections().await?;
195
196    // Start background maintenance task
197    let pool_clone = pool.clone();
198    tokio::spawn(async move {
199      pool_clone.maintenance_loop().await;
200    });
201
202    info!(
203      "Connection pool created for {}:{} (min={}, max={})",
204      host, port, config.min_connections, config.max_connections
205    );
206
207    Ok(pool)
208  }
209
210  /// Create pool with default config (single connection, backward compatible)
211  pub async fn new_default(host: String, port: u16) -> Result<Arc<Self>> {
212    Self::new(host, port, ConnectionPoolConfig::default()).await
213  }
214
215  /// Create pool with gateway config
216  pub async fn new_for_gateway(host: String, port: u16) -> Result<Arc<Self>> {
217    Self::new(host, port, ConnectionPoolConfig::for_gateway()).await
218  }
219
220  /// Set connection selection strategy
221  pub fn with_selection_strategy(self: Arc<Self>, strategy: ConnectionSelectionStrategy) -> Arc<Self> {
222    // Note: This creates a new pool with the strategy set
223    // In practice, you'd set this before sharing the pool
224    Arc::new(Self {
225      host: self.host.clone(),
226      port: self.port,
227      config: self.config.clone(),
228      connections: RwLock::new(self.connections.read().clone()),
229      round_robin_index: AtomicUsize::new(0),
230      selection_strategy: strategy,
231      connection_semaphore: Semaphore::new(self.config.max_connections),
232    })
233  }
234
235  /// Ensure minimum connections are established
236  async fn ensure_min_connections(&self) -> Result<()> {
237    let current_count = self.connections.read().len();
238    let needed = self.config.min_connections.saturating_sub(current_count);
239
240    for _ in 0..needed {
241      if let Err(e) = self.create_connection().await {
242        warn!(
243          "Failed to create initial connection to {}:{}: {:?}",
244          self.host, self.port, e
245        );
246        // Continue trying to create other connections
247      }
248    }
249
250    let final_count = self.connections.read().len();
251    if final_count == 0 {
252      return Err(ZusError::Connection(format!(
253        "Failed to create any connections to {}:{}",
254        self.host, self.port
255      )));
256    }
257
258    Ok(())
259  }
260
261  /// Create a new connection and add to pool
262  async fn create_connection(&self) -> Result<Arc<PooledConnection>> {
263    // Check if we can create more connections
264    let current_count = self.connections.read().len();
265    if current_count >= self.config.max_connections {
266      return Err(ZusError::Connection("Pool at maximum capacity".to_string()));
267    }
268
269    // Create new connection
270    let endpoint = tokio::time::timeout(
271      self.config.connect_timeout,
272      RpcEndpoint::connect(self.host.clone(), self.port),
273    )
274    .await
275    .map_err(|_| ZusError::Timeout)??;
276
277    let pooled = Arc::new(PooledConnection::new(endpoint));
278
279    // Add to pool
280    self.connections.write().push(pooled.clone());
281
282    debug!(
283      "Created new connection to {}:{} (pool size: {})",
284      self.host,
285      self.port,
286      self.connections.read().len()
287    );
288
289    Ok(pooled)
290  }
291
292  /// Get a connection from the pool
293  pub async fn get_connection(&self) -> Result<Arc<PooledConnection>> {
294    // Try to get existing healthy connection
295    if let Some(conn) = self.select_connection() {
296      return Ok(conn);
297    }
298
299    // No healthy connections, try to create new one
300    if self.connections.read().len() < self.config.max_connections
301      && let Ok(conn) = self.create_connection().await
302    {
303      return Ok(conn);
304    }
305
306    // Return any connection as last resort
307    let connections = self.connections.read();
308    if connections.is_empty() {
309      return Err(ZusError::Connection("No connections available".to_string()));
310    }
311
312    let index = self.round_robin_index.fetch_add(1, Ordering::SeqCst);
313    Ok(connections[index % connections.len()].clone())
314  }
315
316  /// Select a connection based on strategy
317  fn select_connection(&self) -> Option<Arc<PooledConnection>> {
318    let connections = self.connections.read();
319    if connections.is_empty() {
320      return None;
321    }
322
323    // Filter healthy connections
324    let healthy: Vec<_> = connections.iter().filter(|c| c.is_healthy()).cloned().collect();
325
326    let pool = if healthy.is_empty() {
327      &connections[..]
328    } else {
329      &healthy[..]
330    };
331
332    match self.selection_strategy {
333      | ConnectionSelectionStrategy::RoundRobin => {
334        let index = self.round_robin_index.fetch_add(1, Ordering::SeqCst);
335        Some(pool[index % pool.len()].clone())
336      }
337      | ConnectionSelectionStrategy::LeastPending => pool.iter().min_by_key(|c| c.pending_count()).cloned(),
338    }
339  }
340
341  /// Get pool statistics
342  pub fn stats(&self) -> PoolStats {
343    let connections = self.connections.read();
344    let total = connections.len();
345    let healthy = connections.iter().filter(|c| c.is_healthy()).count();
346    let total_pending: u64 = connections.iter().map(|c| c.pending_count()).sum();
347
348    PoolStats {
349      total_connections: total,
350      healthy_connections: healthy,
351      unhealthy_connections: total - healthy,
352      total_pending_requests: total_pending,
353      max_connections: self.config.max_connections,
354      min_connections: self.config.min_connections,
355    }
356  }
357
358  /// Get the server address
359  pub fn address(&self) -> String {
360    format!("{}:{}", self.host, self.port)
361  }
362
363  /// Background maintenance loop
364  async fn maintenance_loop(self: Arc<Self>) {
365    let mut interval = tokio::time::interval(self.config.health_check_interval);
366
367    loop {
368      interval.tick().await;
369
370      // Remove unhealthy connections that have been unhealthy for too long
371      self.cleanup_unhealthy();
372
373      // Remove excess idle connections
374      self.cleanup_idle();
375
376      // Ensure minimum connections
377      if let Err(e) = self.ensure_min_connections().await {
378        warn!("Failed to maintain minimum connections: {:?}", e);
379      }
380    }
381  }
382
383  /// Remove connections that have been unhealthy
384  fn cleanup_unhealthy(&self) {
385    let mut connections = self.connections.write();
386    let before = connections.len();
387    let min_connections = self.config.min_connections;
388
389    // Count how many we can remove while staying above minimum
390    let unhealthy_count = connections.iter().filter(|c| !c.is_healthy()).count();
391    let can_remove = before.saturating_sub(min_connections).min(unhealthy_count);
392
393    if can_remove > 0 {
394      let mut removed = 0;
395      connections.retain(|c| {
396        if !c.is_healthy() && removed < can_remove {
397          removed += 1;
398          false
399        } else {
400          true
401        }
402      });
403
404      debug!(
405        "Removed {} unhealthy connections from pool {}:{}",
406        removed, self.host, self.port
407      );
408    }
409  }
410
411  /// Remove excess idle connections
412  fn cleanup_idle(&self) {
413    let mut connections = self.connections.write();
414    let before = connections.len();
415    let min_connections = self.config.min_connections;
416    let idle_timeout = self.config.idle_timeout;
417
418    // Only cleanup if we're above minimum
419    if before <= min_connections {
420      return;
421    }
422
423    // Count how many idle connections we can remove
424    let idle_count = connections
425      .iter()
426      .filter(|c| c.idle_duration() >= idle_timeout && c.pending_count() == 0)
427      .count();
428    let can_remove = before.saturating_sub(min_connections).min(idle_count);
429
430    if can_remove > 0 {
431      let mut removed = 0;
432      connections.retain(|c| {
433        if c.idle_duration() >= idle_timeout && c.pending_count() == 0 && removed < can_remove {
434          removed += 1;
435          false
436        } else {
437          true
438        }
439      });
440
441      debug!(
442        "Removed {} idle connections from pool {}:{}",
443        removed, self.host, self.port
444      );
445    }
446  }
447}
448
449/// Pool statistics
450#[derive(Debug, Clone)]
451pub struct PoolStats {
452  pub total_connections: usize,
453  pub healthy_connections: usize,
454  pub unhealthy_connections: usize,
455  pub total_pending_requests: u64,
456  pub max_connections: usize,
457  pub min_connections: usize,
458}
459
460#[cfg(test)]
461mod tests {
462  use super::*;
463
464  #[test]
465  fn test_default_config() {
466    let config = ConnectionPoolConfig::default();
467    assert_eq!(config.min_connections, 1);
468    assert_eq!(config.max_connections, 1);
469  }
470
471  #[test]
472  fn test_gateway_config() {
473    let config = ConnectionPoolConfig::for_gateway();
474    assert_eq!(config.min_connections, 2);
475    assert_eq!(config.max_connections, 10);
476  }
477
478  #[test]
479  fn test_pool_config_with_size() {
480    let config = ConnectionPoolConfig::with_pool_size(4, 20);
481    assert_eq!(config.min_connections, 4);
482    assert_eq!(config.max_connections, 20);
483  }
484
485  #[test]
486  fn test_pooled_connection_pending() {
487    // We can't test without a real connection, but we can test the atomic operations
488    let pending = AtomicU64::new(0);
489    pending.fetch_add(1, Ordering::SeqCst);
490    assert_eq!(pending.load(Ordering::SeqCst), 1);
491    pending.fetch_sub(1, Ordering::SeqCst);
492    assert_eq!(pending.load(Ordering::SeqCst), 0);
493  }
494
495  #[test]
496  fn test_pool_stats() {
497    let stats = PoolStats {
498      total_connections: 5,
499      healthy_connections: 4,
500      unhealthy_connections: 1,
501      total_pending_requests: 10,
502      max_connections: 10,
503      min_connections: 2,
504    };
505
506    assert_eq!(stats.total_connections, 5);
507    assert_eq!(stats.healthy_connections, 4);
508  }
509}