Skip to main content

talos_api_rs/client/
pool.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Connection pooling and multi-endpoint support for Talos API clients.
4//!
5//! This module provides:
6//! - [`ConnectionPool`]: A pool of connections to multiple Talos endpoints
7//! - [`EndpointHealth`]: Health tracking for individual endpoints
8//! - [`LoadBalancer`]: Strategies for selecting endpoints
9//!
10//! # Example
11//!
12//! ```ignore
13//! use talos_api_rs::client::{ConnectionPool, ConnectionPoolConfig, LoadBalancer};
14//!
15//! let config = ConnectionPoolConfig::new(vec![
16//!     "https://node1:50000".to_string(),
17//!     "https://node2:50000".to_string(),
18//!     "https://node3:50000".to_string(),
19//! ])
20//! .with_load_balancer(LoadBalancer::RoundRobin)
21//! .with_health_check_interval(Duration::from_secs(30));
22//!
23//! let pool = ConnectionPool::new(config).await?;
24//!
25//! // Get a healthy client
26//! let client = pool.get_client().await?;
27//! ```
28
29use crate::client::{TalosClient, TalosClientConfig};
30use crate::error::{Result, TalosError};
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37/// Health status of an endpoint.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum HealthStatus {
40    /// Endpoint is healthy and accepting requests.
41    Healthy,
42    /// Endpoint is unhealthy and should not receive requests.
43    Unhealthy,
44    /// Health status is unknown (initial state or after reset).
45    Unknown,
46}
47
48/// Health tracking for a single endpoint.
49#[derive(Debug)]
50pub struct EndpointHealth {
51    /// The endpoint URL.
52    pub endpoint: String,
53    /// Current health status.
54    status: AtomicU64, // Encoded HealthStatus
55    /// Number of consecutive failures.
56    consecutive_failures: AtomicUsize,
57    /// Number of consecutive successes.
58    consecutive_successes: AtomicUsize,
59    /// Total number of requests.
60    total_requests: AtomicU64,
61    /// Total number of failures.
62    total_failures: AtomicU64,
63    /// Last successful request time.
64    last_success: RwLock<Option<Instant>>,
65    /// Last failure time.
66    last_failure: RwLock<Option<Instant>>,
67    /// Last health check time.
68    last_health_check: RwLock<Option<Instant>>,
69}
70
71impl EndpointHealth {
72    /// Create a new endpoint health tracker.
73    #[must_use]
74    pub fn new(endpoint: String) -> Self {
75        Self {
76            endpoint,
77            status: AtomicU64::new(Self::status_to_u64(HealthStatus::Unknown)),
78            consecutive_failures: AtomicUsize::new(0),
79            consecutive_successes: AtomicUsize::new(0),
80            total_requests: AtomicU64::new(0),
81            total_failures: AtomicU64::new(0),
82            last_success: RwLock::new(None),
83            last_failure: RwLock::new(None),
84            last_health_check: RwLock::new(None),
85        }
86    }
87
88    fn status_to_u64(status: HealthStatus) -> u64 {
89        match status {
90            HealthStatus::Healthy => 0,
91            HealthStatus::Unhealthy => 1,
92            HealthStatus::Unknown => 2,
93        }
94    }
95
96    fn u64_to_status(value: u64) -> HealthStatus {
97        match value {
98            0 => HealthStatus::Healthy,
99            1 => HealthStatus::Unhealthy,
100            _ => HealthStatus::Unknown,
101        }
102    }
103
104    /// Get the current health status.
105    #[must_use]
106    pub fn status(&self) -> HealthStatus {
107        Self::u64_to_status(self.status.load(Ordering::Acquire))
108    }
109
110    /// Check if the endpoint is healthy.
111    #[must_use]
112    pub fn is_healthy(&self) -> bool {
113        self.status() == HealthStatus::Healthy
114    }
115
116    /// Record a successful request.
117    pub async fn record_success(&self) {
118        self.total_requests.fetch_add(1, Ordering::Relaxed);
119        self.consecutive_failures.store(0, Ordering::Relaxed);
120        self.consecutive_successes.fetch_add(1, Ordering::Relaxed);
121        *self.last_success.write().await = Some(Instant::now());
122        self.status.store(
123            Self::status_to_u64(HealthStatus::Healthy),
124            Ordering::Release,
125        );
126    }
127
128    /// Record a failed request.
129    pub async fn record_failure(&self, failure_threshold: usize) {
130        self.total_requests.fetch_add(1, Ordering::Relaxed);
131        self.total_failures.fetch_add(1, Ordering::Relaxed);
132        self.consecutive_successes.store(0, Ordering::Relaxed);
133        let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
134        *self.last_failure.write().await = Some(Instant::now());
135
136        if failures >= failure_threshold {
137            self.status.store(
138                Self::status_to_u64(HealthStatus::Unhealthy),
139                Ordering::Release,
140            );
141        }
142    }
143
144    /// Record a health check.
145    pub async fn record_health_check(&self, healthy: bool, failure_threshold: usize) {
146        *self.last_health_check.write().await = Some(Instant::now());
147        if healthy {
148            self.record_success().await;
149        } else {
150            self.record_failure(failure_threshold).await;
151        }
152    }
153
154    /// Reset the health status to unknown.
155    pub fn reset(&self) {
156        self.status.store(
157            Self::status_to_u64(HealthStatus::Unknown),
158            Ordering::Release,
159        );
160        self.consecutive_failures.store(0, Ordering::Relaxed);
161        self.consecutive_successes.store(0, Ordering::Relaxed);
162    }
163
164    /// Get the number of consecutive failures.
165    #[must_use]
166    pub fn consecutive_failures(&self) -> usize {
167        self.consecutive_failures.load(Ordering::Relaxed)
168    }
169
170    /// Get the total number of requests.
171    #[must_use]
172    pub fn total_requests(&self) -> u64 {
173        self.total_requests.load(Ordering::Relaxed)
174    }
175
176    /// Get the total number of failures.
177    #[must_use]
178    pub fn total_failures(&self) -> u64 {
179        self.total_failures.load(Ordering::Relaxed)
180    }
181
182    /// Get the failure rate (0.0 to 1.0).
183    #[must_use]
184    pub fn failure_rate(&self) -> f64 {
185        let total = self.total_requests.load(Ordering::Relaxed);
186        if total == 0 {
187            return 0.0;
188        }
189        let failures = self.total_failures.load(Ordering::Relaxed);
190        failures as f64 / total as f64
191    }
192
193    /// Get the last successful request time.
194    pub async fn last_success(&self) -> Option<Instant> {
195        *self.last_success.read().await
196    }
197
198    /// Get the last health check time.
199    pub async fn last_health_check(&self) -> Option<Instant> {
200        *self.last_health_check.read().await
201    }
202}
203
204/// Load balancing strategy for selecting endpoints.
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
206pub enum LoadBalancer {
207    /// Round-robin selection across healthy endpoints.
208    #[default]
209    RoundRobin,
210    /// Random selection among healthy endpoints.
211    Random,
212    /// Select the endpoint with the lowest failure rate.
213    LeastFailures,
214    /// Always prefer the first healthy endpoint (failover mode).
215    Failover,
216}
217
218/// Configuration for the connection pool.
219#[derive(Debug, Clone)]
220pub struct ConnectionPoolConfig {
221    /// List of endpoint URLs.
222    pub endpoints: Vec<String>,
223    /// Load balancing strategy.
224    pub load_balancer: LoadBalancer,
225    /// Health check interval.
226    pub health_check_interval: Duration,
227    /// Number of consecutive failures before marking unhealthy.
228    pub failure_threshold: usize,
229    /// Number of consecutive successes before marking healthy again.
230    pub recovery_threshold: usize,
231    /// Base client configuration (TLS, timeouts, etc.).
232    pub base_config: Option<TalosClientConfig>,
233    /// Enable automatic health checks.
234    pub auto_health_check: bool,
235}
236
237impl ConnectionPoolConfig {
238    /// Create a new connection pool configuration.
239    #[must_use]
240    pub fn new(endpoints: Vec<String>) -> Self {
241        Self {
242            endpoints,
243            load_balancer: LoadBalancer::RoundRobin,
244            health_check_interval: Duration::from_secs(30),
245            failure_threshold: 3,
246            recovery_threshold: 2,
247            base_config: None,
248            auto_health_check: true,
249        }
250    }
251
252    /// Set the load balancing strategy.
253    #[must_use]
254    pub fn with_load_balancer(mut self, lb: LoadBalancer) -> Self {
255        self.load_balancer = lb;
256        self
257    }
258
259    /// Set the health check interval.
260    #[must_use]
261    pub fn with_health_check_interval(mut self, interval: Duration) -> Self {
262        self.health_check_interval = interval;
263        self
264    }
265
266    /// Set the failure threshold.
267    #[must_use]
268    pub fn with_failure_threshold(mut self, threshold: usize) -> Self {
269        self.failure_threshold = threshold;
270        self
271    }
272
273    /// Set the recovery threshold.
274    #[must_use]
275    pub fn with_recovery_threshold(mut self, threshold: usize) -> Self {
276        self.recovery_threshold = threshold;
277        self
278    }
279
280    /// Set the base client configuration.
281    #[must_use]
282    pub fn with_base_config(mut self, config: TalosClientConfig) -> Self {
283        self.base_config = Some(config);
284        self
285    }
286
287    /// Disable automatic health checks.
288    #[must_use]
289    pub fn disable_auto_health_check(mut self) -> Self {
290        self.auto_health_check = false;
291        self
292    }
293}
294
295/// A pool of connections to multiple Talos endpoints.
296///
297/// The pool maintains connections to multiple Talos nodes and routes
298/// requests to healthy endpoints based on the configured load balancing
299/// strategy.
300pub struct ConnectionPool {
301    config: ConnectionPoolConfig,
302    clients: RwLock<HashMap<String, TalosClient>>,
303    health: HashMap<String, Arc<EndpointHealth>>,
304    round_robin_index: AtomicUsize,
305    shutdown: AtomicBool,
306    #[allow(dead_code)]
307    health_check_handle: Option<tokio::task::JoinHandle<()>>,
308}
309
310impl ConnectionPool {
311    /// Create a new connection pool.
312    ///
313    /// # Errors
314    ///
315    /// Returns an error if no endpoints are provided or if initial connection fails.
316    pub async fn new(config: ConnectionPoolConfig) -> Result<Self> {
317        if config.endpoints.is_empty() {
318            return Err(TalosError::Config(
319                "At least one endpoint is required".to_string(),
320            ));
321        }
322
323        // Initialize health tracking for all endpoints
324        let health: HashMap<String, Arc<EndpointHealth>> = config
325            .endpoints
326            .iter()
327            .map(|e| (e.clone(), Arc::new(EndpointHealth::new(e.clone()))))
328            .collect();
329
330        let pool = Self {
331            config,
332            clients: RwLock::new(HashMap::new()),
333            health,
334            round_robin_index: AtomicUsize::new(0),
335            shutdown: AtomicBool::new(false),
336            health_check_handle: None,
337        };
338
339        // Try to connect to at least one endpoint
340        pool.connect_all().await?;
341
342        Ok(pool)
343    }
344
345    /// Connect to all endpoints, collecting errors but not failing.
346    async fn connect_all(&self) -> Result<()> {
347        let mut connected = false;
348        let mut last_error = None;
349
350        for endpoint in &self.config.endpoints {
351            match self.connect_endpoint(endpoint).await {
352                Ok(client) => {
353                    self.clients.write().await.insert(endpoint.clone(), client);
354                    if let Some(health) = self.health.get(endpoint) {
355                        health.record_success().await;
356                    }
357                    connected = true;
358                }
359                Err(e) => {
360                    if let Some(health) = self.health.get(endpoint) {
361                        health.record_failure(self.config.failure_threshold).await;
362                    }
363                    last_error = Some(e);
364                }
365            }
366        }
367
368        if connected {
369            Ok(())
370        } else {
371            Err(last_error.unwrap_or_else(|| {
372                TalosError::Connection("Failed to connect to any endpoint".to_string())
373            }))
374        }
375    }
376
377    /// Connect to a single endpoint.
378    async fn connect_endpoint(&self, endpoint: &str) -> Result<TalosClient> {
379        let config = if let Some(base) = &self.config.base_config {
380            TalosClientConfig {
381                endpoint: endpoint.to_string(),
382                crt_path: base.crt_path.clone(),
383                key_path: base.key_path.clone(),
384                ca_path: base.ca_path.clone(),
385                insecure: base.insecure,
386                connect_timeout: base.connect_timeout,
387                request_timeout: base.request_timeout,
388                keepalive_interval: base.keepalive_interval,
389                keepalive_timeout: base.keepalive_timeout,
390            }
391        } else {
392            TalosClientConfig::new(endpoint)
393        };
394
395        TalosClient::new(config).await
396    }
397
398    /// Get a healthy client using the configured load balancing strategy.
399    ///
400    /// # Errors
401    ///
402    /// Returns an error if no healthy endpoints are available.
403    pub async fn get_client(&self) -> Result<TalosClient> {
404        let healthy_endpoints = self.get_healthy_endpoints();
405
406        if healthy_endpoints.is_empty() {
407            // Try to reconnect to all endpoints
408            self.connect_all().await?;
409            let healthy = self.get_healthy_endpoints();
410            if healthy.is_empty() {
411                return Err(TalosError::Connection(
412                    "No healthy endpoints available".to_string(),
413                ));
414            }
415        }
416
417        let endpoint = self.select_endpoint(&self.get_healthy_endpoints())?;
418        let clients = self.clients.read().await;
419
420        clients.get(&endpoint).cloned().ok_or_else(|| {
421            TalosError::Connection(format!("Client for endpoint {} not found", endpoint))
422        })
423    }
424
425    /// Get a list of healthy endpoint URLs.
426    #[must_use]
427    pub fn get_healthy_endpoints(&self) -> Vec<String> {
428        self.health
429            .iter()
430            .filter(|(_, h)| h.is_healthy())
431            .map(|(e, _)| e.clone())
432            .collect()
433    }
434
435    /// Get health information for an endpoint.
436    #[must_use]
437    pub fn get_endpoint_health(&self, endpoint: &str) -> Option<&Arc<EndpointHealth>> {
438        self.health.get(endpoint)
439    }
440
441    /// Get health information for all endpoints.
442    #[must_use]
443    pub fn get_all_health(&self) -> &HashMap<String, Arc<EndpointHealth>> {
444        &self.health
445    }
446
447    /// Select an endpoint based on the load balancing strategy.
448    #[allow(clippy::result_large_err)]
449    fn select_endpoint(&self, healthy: &[String]) -> Result<String> {
450        if healthy.is_empty() {
451            return Err(TalosError::Connection(
452                "No healthy endpoints available".to_string(),
453            ));
454        }
455
456        let endpoint = match self.config.load_balancer {
457            LoadBalancer::RoundRobin => {
458                let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % healthy.len();
459                healthy[idx].clone()
460            }
461            LoadBalancer::Random => {
462                let idx = rand::random::<usize>() % healthy.len();
463                healthy[idx].clone()
464            }
465            LoadBalancer::LeastFailures => {
466                let mut best = healthy[0].clone();
467                let mut best_rate = f64::MAX;
468                for e in healthy {
469                    if let Some(health) = self.health.get(e) {
470                        let rate = health.failure_rate();
471                        if rate < best_rate {
472                            best_rate = rate;
473                            best = e.clone();
474                        }
475                    }
476                }
477                best
478            }
479            LoadBalancer::Failover => healthy[0].clone(),
480        };
481
482        Ok(endpoint)
483    }
484
485    /// Perform a health check on a specific endpoint.
486    ///
487    /// # Errors
488    ///
489    /// Returns an error if the health check fails.
490    pub async fn health_check(&self, endpoint: &str) -> Result<bool> {
491        let client = match self.connect_endpoint(endpoint).await {
492            Ok(c) => c,
493            Err(e) => {
494                if let Some(health) = self.health.get(endpoint) {
495                    health
496                        .record_health_check(false, self.config.failure_threshold)
497                        .await;
498                }
499                return Err(e);
500            }
501        };
502
503        // Try a simple version request as health check
504        let mut version_client = client.version();
505        let request = crate::api::version::VersionRequest { client: false };
506        match version_client.version(request).await {
507            Ok(_) => {
508                if let Some(health) = self.health.get(endpoint) {
509                    health
510                        .record_health_check(true, self.config.failure_threshold)
511                        .await;
512                }
513                // Update client in pool
514                self.clients
515                    .write()
516                    .await
517                    .insert(endpoint.to_string(), client);
518                Ok(true)
519            }
520            Err(e) => {
521                if let Some(health) = self.health.get(endpoint) {
522                    health
523                        .record_health_check(false, self.config.failure_threshold)
524                        .await;
525                }
526                Err(TalosError::Api(e))
527            }
528        }
529    }
530
531    /// Perform health checks on all endpoints.
532    pub async fn health_check_all(&self) {
533        for endpoint in &self.config.endpoints {
534            let _ = self.health_check(endpoint).await;
535        }
536    }
537
538    /// Record a successful operation for an endpoint.
539    pub async fn record_success(&self, endpoint: &str) {
540        if let Some(health) = self.health.get(endpoint) {
541            health.record_success().await;
542        }
543    }
544
545    /// Record a failed operation for an endpoint.
546    pub async fn record_failure(&self, endpoint: &str) {
547        if let Some(health) = self.health.get(endpoint) {
548            health.record_failure(self.config.failure_threshold).await;
549        }
550    }
551
552    /// Shutdown the connection pool.
553    pub fn shutdown(&self) {
554        self.shutdown.store(true, Ordering::Release);
555    }
556
557    /// Check if the pool is shut down.
558    #[must_use]
559    pub fn is_shutdown(&self) -> bool {
560        self.shutdown.load(Ordering::Acquire)
561    }
562
563    /// Get the number of connected clients.
564    pub async fn connected_count(&self) -> usize {
565        self.clients.read().await.len()
566    }
567
568    /// Get the total number of endpoints.
569    #[must_use]
570    pub fn endpoint_count(&self) -> usize {
571        self.config.endpoints.len()
572    }
573}
574
575impl Drop for ConnectionPool {
576    fn drop(&mut self) {
577        self.shutdown();
578    }
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584
585    #[test]
586    fn test_endpoint_health_new() {
587        let health = EndpointHealth::new("https://test:50000".to_string());
588        assert_eq!(health.status(), HealthStatus::Unknown);
589        assert_eq!(health.consecutive_failures(), 0);
590        assert_eq!(health.total_requests(), 0);
591    }
592
593    #[tokio::test]
594    async fn test_endpoint_health_record_success() {
595        let health = EndpointHealth::new("https://test:50000".to_string());
596        health.record_success().await;
597        assert_eq!(health.status(), HealthStatus::Healthy);
598        assert_eq!(health.total_requests(), 1);
599        assert!(health.last_success().await.is_some());
600    }
601
602    #[tokio::test]
603    async fn test_endpoint_health_record_failure() {
604        let health = EndpointHealth::new("https://test:50000".to_string());
605        health.record_failure(3).await;
606        assert_eq!(health.consecutive_failures(), 1);
607        assert_eq!(health.status(), HealthStatus::Unknown);
608
609        health.record_failure(3).await;
610        health.record_failure(3).await;
611        assert_eq!(health.status(), HealthStatus::Unhealthy);
612    }
613
614    #[tokio::test]
615    async fn test_endpoint_health_recovery() {
616        let health = EndpointHealth::new("https://test:50000".to_string());
617        // Make unhealthy
618        for _ in 0..3 {
619            health.record_failure(3).await;
620        }
621        assert_eq!(health.status(), HealthStatus::Unhealthy);
622
623        // Recover
624        health.record_success().await;
625        assert_eq!(health.status(), HealthStatus::Healthy);
626    }
627
628    #[test]
629    fn test_endpoint_health_failure_rate() {
630        let health = EndpointHealth::new("https://test:50000".to_string());
631        assert_eq!(health.failure_rate(), 0.0);
632
633        health.total_requests.store(10, Ordering::Relaxed);
634        health.total_failures.store(2, Ordering::Relaxed);
635        assert!((health.failure_rate() - 0.2).abs() < f64::EPSILON);
636    }
637
638    #[test]
639    fn test_load_balancer_default() {
640        assert_eq!(LoadBalancer::default(), LoadBalancer::RoundRobin);
641    }
642
643    #[test]
644    fn test_connection_pool_config_new() {
645        let config = ConnectionPoolConfig::new(vec![
646            "https://node1:50000".to_string(),
647            "https://node2:50000".to_string(),
648        ]);
649
650        assert_eq!(config.endpoints.len(), 2);
651        assert_eq!(config.load_balancer, LoadBalancer::RoundRobin);
652        assert_eq!(config.failure_threshold, 3);
653        assert!(config.auto_health_check);
654    }
655
656    #[test]
657    fn test_connection_pool_config_builder() {
658        let config = ConnectionPoolConfig::new(vec!["https://node1:50000".to_string()])
659            .with_load_balancer(LoadBalancer::Random)
660            .with_failure_threshold(5)
661            .with_recovery_threshold(3)
662            .with_health_check_interval(Duration::from_secs(60))
663            .disable_auto_health_check();
664
665        assert_eq!(config.load_balancer, LoadBalancer::Random);
666        assert_eq!(config.failure_threshold, 5);
667        assert_eq!(config.recovery_threshold, 3);
668        assert_eq!(config.health_check_interval, Duration::from_secs(60));
669        assert!(!config.auto_health_check);
670    }
671
672    #[tokio::test]
673    async fn test_connection_pool_empty_endpoints() {
674        let config = ConnectionPoolConfig::new(vec![]);
675        let result = ConnectionPool::new(config).await;
676        assert!(result.is_err());
677    }
678
679    #[test]
680    fn test_health_status_conversions() {
681        assert_eq!(
682            EndpointHealth::u64_to_status(EndpointHealth::status_to_u64(HealthStatus::Healthy)),
683            HealthStatus::Healthy
684        );
685        assert_eq!(
686            EndpointHealth::u64_to_status(EndpointHealth::status_to_u64(HealthStatus::Unhealthy)),
687            HealthStatus::Unhealthy
688        );
689        assert_eq!(
690            EndpointHealth::u64_to_status(EndpointHealth::status_to_u64(HealthStatus::Unknown)),
691            HealthStatus::Unknown
692        );
693    }
694
695    #[test]
696    fn test_endpoint_health_reset() {
697        let health = EndpointHealth::new("https://test:50000".to_string());
698        health.status.store(
699            EndpointHealth::status_to_u64(HealthStatus::Unhealthy),
700            Ordering::Relaxed,
701        );
702        health.consecutive_failures.store(5, Ordering::Relaxed);
703
704        health.reset();
705
706        assert_eq!(health.status(), HealthStatus::Unknown);
707        assert_eq!(health.consecutive_failures(), 0);
708    }
709}