1use crate::client::{TalosClient, TalosClientConfig};
30use crate::error::{Result, TalosError};
31use std::collections::HashMap;
32use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
33use std::sync::Arc;
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum HealthStatus {
40 Healthy,
42 Unhealthy,
44 Unknown,
46}
47
48#[derive(Debug)]
50pub struct EndpointHealth {
51 pub endpoint: String,
53 status: AtomicU64, consecutive_failures: AtomicUsize,
57 consecutive_successes: AtomicUsize,
59 total_requests: AtomicU64,
61 total_failures: AtomicU64,
63 last_success: RwLock<Option<Instant>>,
65 last_failure: RwLock<Option<Instant>>,
67 last_health_check: RwLock<Option<Instant>>,
69}
70
71impl EndpointHealth {
72 #[must_use]
74 pub fn new(endpoint: String) -> Self {
75 Self {
76 endpoint,
77 status: AtomicU64::new(Self::status_to_u64(HealthStatus::Unknown)),
78 consecutive_failures: AtomicUsize::new(0),
79 consecutive_successes: AtomicUsize::new(0),
80 total_requests: AtomicU64::new(0),
81 total_failures: AtomicU64::new(0),
82 last_success: RwLock::new(None),
83 last_failure: RwLock::new(None),
84 last_health_check: RwLock::new(None),
85 }
86 }
87
88 fn status_to_u64(status: HealthStatus) -> u64 {
89 match status {
90 HealthStatus::Healthy => 0,
91 HealthStatus::Unhealthy => 1,
92 HealthStatus::Unknown => 2,
93 }
94 }
95
96 fn u64_to_status(value: u64) -> HealthStatus {
97 match value {
98 0 => HealthStatus::Healthy,
99 1 => HealthStatus::Unhealthy,
100 _ => HealthStatus::Unknown,
101 }
102 }
103
104 #[must_use]
106 pub fn status(&self) -> HealthStatus {
107 Self::u64_to_status(self.status.load(Ordering::Acquire))
108 }
109
110 #[must_use]
112 pub fn is_healthy(&self) -> bool {
113 self.status() == HealthStatus::Healthy
114 }
115
116 pub async fn record_success(&self) {
118 self.total_requests.fetch_add(1, Ordering::Relaxed);
119 self.consecutive_failures.store(0, Ordering::Relaxed);
120 self.consecutive_successes.fetch_add(1, Ordering::Relaxed);
121 *self.last_success.write().await = Some(Instant::now());
122 self.status.store(
123 Self::status_to_u64(HealthStatus::Healthy),
124 Ordering::Release,
125 );
126 }
127
128 pub async fn record_failure(&self, failure_threshold: usize) {
130 self.total_requests.fetch_add(1, Ordering::Relaxed);
131 self.total_failures.fetch_add(1, Ordering::Relaxed);
132 self.consecutive_successes.store(0, Ordering::Relaxed);
133 let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
134 *self.last_failure.write().await = Some(Instant::now());
135
136 if failures >= failure_threshold {
137 self.status.store(
138 Self::status_to_u64(HealthStatus::Unhealthy),
139 Ordering::Release,
140 );
141 }
142 }
143
144 pub async fn record_health_check(&self, healthy: bool, failure_threshold: usize) {
146 *self.last_health_check.write().await = Some(Instant::now());
147 if healthy {
148 self.record_success().await;
149 } else {
150 self.record_failure(failure_threshold).await;
151 }
152 }
153
154 pub fn reset(&self) {
156 self.status.store(
157 Self::status_to_u64(HealthStatus::Unknown),
158 Ordering::Release,
159 );
160 self.consecutive_failures.store(0, Ordering::Relaxed);
161 self.consecutive_successes.store(0, Ordering::Relaxed);
162 }
163
164 #[must_use]
166 pub fn consecutive_failures(&self) -> usize {
167 self.consecutive_failures.load(Ordering::Relaxed)
168 }
169
170 #[must_use]
172 pub fn total_requests(&self) -> u64 {
173 self.total_requests.load(Ordering::Relaxed)
174 }
175
176 #[must_use]
178 pub fn total_failures(&self) -> u64 {
179 self.total_failures.load(Ordering::Relaxed)
180 }
181
182 #[must_use]
184 pub fn failure_rate(&self) -> f64 {
185 let total = self.total_requests.load(Ordering::Relaxed);
186 if total == 0 {
187 return 0.0;
188 }
189 let failures = self.total_failures.load(Ordering::Relaxed);
190 failures as f64 / total as f64
191 }
192
193 pub async fn last_success(&self) -> Option<Instant> {
195 *self.last_success.read().await
196 }
197
198 pub async fn last_health_check(&self) -> Option<Instant> {
200 *self.last_health_check.read().await
201 }
202}
203
204#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
206pub enum LoadBalancer {
207 #[default]
209 RoundRobin,
210 Random,
212 LeastFailures,
214 Failover,
216}
217
218#[derive(Debug, Clone)]
220pub struct ConnectionPoolConfig {
221 pub endpoints: Vec<String>,
223 pub load_balancer: LoadBalancer,
225 pub health_check_interval: Duration,
227 pub failure_threshold: usize,
229 pub recovery_threshold: usize,
231 pub base_config: Option<TalosClientConfig>,
233 pub auto_health_check: bool,
235}
236
237impl ConnectionPoolConfig {
238 #[must_use]
240 pub fn new(endpoints: Vec<String>) -> Self {
241 Self {
242 endpoints,
243 load_balancer: LoadBalancer::RoundRobin,
244 health_check_interval: Duration::from_secs(30),
245 failure_threshold: 3,
246 recovery_threshold: 2,
247 base_config: None,
248 auto_health_check: true,
249 }
250 }
251
252 #[must_use]
254 pub fn with_load_balancer(mut self, lb: LoadBalancer) -> Self {
255 self.load_balancer = lb;
256 self
257 }
258
259 #[must_use]
261 pub fn with_health_check_interval(mut self, interval: Duration) -> Self {
262 self.health_check_interval = interval;
263 self
264 }
265
266 #[must_use]
268 pub fn with_failure_threshold(mut self, threshold: usize) -> Self {
269 self.failure_threshold = threshold;
270 self
271 }
272
273 #[must_use]
275 pub fn with_recovery_threshold(mut self, threshold: usize) -> Self {
276 self.recovery_threshold = threshold;
277 self
278 }
279
280 #[must_use]
282 pub fn with_base_config(mut self, config: TalosClientConfig) -> Self {
283 self.base_config = Some(config);
284 self
285 }
286
287 #[must_use]
289 pub fn disable_auto_health_check(mut self) -> Self {
290 self.auto_health_check = false;
291 self
292 }
293}
294
295pub struct ConnectionPool {
301 config: ConnectionPoolConfig,
302 clients: RwLock<HashMap<String, TalosClient>>,
303 health: HashMap<String, Arc<EndpointHealth>>,
304 round_robin_index: AtomicUsize,
305 shutdown: AtomicBool,
306 #[allow(dead_code)]
307 health_check_handle: Option<tokio::task::JoinHandle<()>>,
308}
309
310impl ConnectionPool {
311 pub async fn new(config: ConnectionPoolConfig) -> Result<Self> {
317 if config.endpoints.is_empty() {
318 return Err(TalosError::Config(
319 "At least one endpoint is required".to_string(),
320 ));
321 }
322
323 let health: HashMap<String, Arc<EndpointHealth>> = config
325 .endpoints
326 .iter()
327 .map(|e| (e.clone(), Arc::new(EndpointHealth::new(e.clone()))))
328 .collect();
329
330 let pool = Self {
331 config,
332 clients: RwLock::new(HashMap::new()),
333 health,
334 round_robin_index: AtomicUsize::new(0),
335 shutdown: AtomicBool::new(false),
336 health_check_handle: None,
337 };
338
339 pool.connect_all().await?;
341
342 Ok(pool)
343 }
344
345 async fn connect_all(&self) -> Result<()> {
347 let mut connected = false;
348 let mut last_error = None;
349
350 for endpoint in &self.config.endpoints {
351 match self.connect_endpoint(endpoint).await {
352 Ok(client) => {
353 self.clients.write().await.insert(endpoint.clone(), client);
354 if let Some(health) = self.health.get(endpoint) {
355 health.record_success().await;
356 }
357 connected = true;
358 }
359 Err(e) => {
360 if let Some(health) = self.health.get(endpoint) {
361 health.record_failure(self.config.failure_threshold).await;
362 }
363 last_error = Some(e);
364 }
365 }
366 }
367
368 if connected {
369 Ok(())
370 } else {
371 Err(last_error.unwrap_or_else(|| {
372 TalosError::Connection("Failed to connect to any endpoint".to_string())
373 }))
374 }
375 }
376
377 async fn connect_endpoint(&self, endpoint: &str) -> Result<TalosClient> {
379 let config = if let Some(base) = &self.config.base_config {
380 TalosClientConfig {
381 endpoint: endpoint.to_string(),
382 crt_path: base.crt_path.clone(),
383 key_path: base.key_path.clone(),
384 ca_path: base.ca_path.clone(),
385 insecure: base.insecure,
386 connect_timeout: base.connect_timeout,
387 request_timeout: base.request_timeout,
388 keepalive_interval: base.keepalive_interval,
389 keepalive_timeout: base.keepalive_timeout,
390 }
391 } else {
392 TalosClientConfig::new(endpoint)
393 };
394
395 TalosClient::new(config).await
396 }
397
398 pub async fn get_client(&self) -> Result<TalosClient> {
404 let healthy_endpoints = self.get_healthy_endpoints();
405
406 if healthy_endpoints.is_empty() {
407 self.connect_all().await?;
409 let healthy = self.get_healthy_endpoints();
410 if healthy.is_empty() {
411 return Err(TalosError::Connection(
412 "No healthy endpoints available".to_string(),
413 ));
414 }
415 }
416
417 let endpoint = self.select_endpoint(&self.get_healthy_endpoints())?;
418 let clients = self.clients.read().await;
419
420 clients.get(&endpoint).cloned().ok_or_else(|| {
421 TalosError::Connection(format!("Client for endpoint {} not found", endpoint))
422 })
423 }
424
425 #[must_use]
427 pub fn get_healthy_endpoints(&self) -> Vec<String> {
428 self.health
429 .iter()
430 .filter(|(_, h)| h.is_healthy())
431 .map(|(e, _)| e.clone())
432 .collect()
433 }
434
435 #[must_use]
437 pub fn get_endpoint_health(&self, endpoint: &str) -> Option<&Arc<EndpointHealth>> {
438 self.health.get(endpoint)
439 }
440
441 #[must_use]
443 pub fn get_all_health(&self) -> &HashMap<String, Arc<EndpointHealth>> {
444 &self.health
445 }
446
447 #[allow(clippy::result_large_err)]
449 fn select_endpoint(&self, healthy: &[String]) -> Result<String> {
450 if healthy.is_empty() {
451 return Err(TalosError::Connection(
452 "No healthy endpoints available".to_string(),
453 ));
454 }
455
456 let endpoint = match self.config.load_balancer {
457 LoadBalancer::RoundRobin => {
458 let idx = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % healthy.len();
459 healthy[idx].clone()
460 }
461 LoadBalancer::Random => {
462 let idx = rand::random::<usize>() % healthy.len();
463 healthy[idx].clone()
464 }
465 LoadBalancer::LeastFailures => {
466 let mut best = healthy[0].clone();
467 let mut best_rate = f64::MAX;
468 for e in healthy {
469 if let Some(health) = self.health.get(e) {
470 let rate = health.failure_rate();
471 if rate < best_rate {
472 best_rate = rate;
473 best = e.clone();
474 }
475 }
476 }
477 best
478 }
479 LoadBalancer::Failover => healthy[0].clone(),
480 };
481
482 Ok(endpoint)
483 }
484
485 pub async fn health_check(&self, endpoint: &str) -> Result<bool> {
491 let client = match self.connect_endpoint(endpoint).await {
492 Ok(c) => c,
493 Err(e) => {
494 if let Some(health) = self.health.get(endpoint) {
495 health
496 .record_health_check(false, self.config.failure_threshold)
497 .await;
498 }
499 return Err(e);
500 }
501 };
502
503 let mut version_client = client.version();
505 let request = crate::api::version::VersionRequest { client: false };
506 match version_client.version(request).await {
507 Ok(_) => {
508 if let Some(health) = self.health.get(endpoint) {
509 health
510 .record_health_check(true, self.config.failure_threshold)
511 .await;
512 }
513 self.clients
515 .write()
516 .await
517 .insert(endpoint.to_string(), client);
518 Ok(true)
519 }
520 Err(e) => {
521 if let Some(health) = self.health.get(endpoint) {
522 health
523 .record_health_check(false, self.config.failure_threshold)
524 .await;
525 }
526 Err(TalosError::Api(e))
527 }
528 }
529 }
530
531 pub async fn health_check_all(&self) {
533 for endpoint in &self.config.endpoints {
534 let _ = self.health_check(endpoint).await;
535 }
536 }
537
538 pub async fn record_success(&self, endpoint: &str) {
540 if let Some(health) = self.health.get(endpoint) {
541 health.record_success().await;
542 }
543 }
544
545 pub async fn record_failure(&self, endpoint: &str) {
547 if let Some(health) = self.health.get(endpoint) {
548 health.record_failure(self.config.failure_threshold).await;
549 }
550 }
551
552 pub fn shutdown(&self) {
554 self.shutdown.store(true, Ordering::Release);
555 }
556
557 #[must_use]
559 pub fn is_shutdown(&self) -> bool {
560 self.shutdown.load(Ordering::Acquire)
561 }
562
563 pub async fn connected_count(&self) -> usize {
565 self.clients.read().await.len()
566 }
567
568 #[must_use]
570 pub fn endpoint_count(&self) -> usize {
571 self.config.endpoints.len()
572 }
573}
574
575impl Drop for ConnectionPool {
576 fn drop(&mut self) {
577 self.shutdown();
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 #[test]
586 fn test_endpoint_health_new() {
587 let health = EndpointHealth::new("https://test:50000".to_string());
588 assert_eq!(health.status(), HealthStatus::Unknown);
589 assert_eq!(health.consecutive_failures(), 0);
590 assert_eq!(health.total_requests(), 0);
591 }
592
593 #[tokio::test]
594 async fn test_endpoint_health_record_success() {
595 let health = EndpointHealth::new("https://test:50000".to_string());
596 health.record_success().await;
597 assert_eq!(health.status(), HealthStatus::Healthy);
598 assert_eq!(health.total_requests(), 1);
599 assert!(health.last_success().await.is_some());
600 }
601
602 #[tokio::test]
603 async fn test_endpoint_health_record_failure() {
604 let health = EndpointHealth::new("https://test:50000".to_string());
605 health.record_failure(3).await;
606 assert_eq!(health.consecutive_failures(), 1);
607 assert_eq!(health.status(), HealthStatus::Unknown);
608
609 health.record_failure(3).await;
610 health.record_failure(3).await;
611 assert_eq!(health.status(), HealthStatus::Unhealthy);
612 }
613
614 #[tokio::test]
615 async fn test_endpoint_health_recovery() {
616 let health = EndpointHealth::new("https://test:50000".to_string());
617 for _ in 0..3 {
619 health.record_failure(3).await;
620 }
621 assert_eq!(health.status(), HealthStatus::Unhealthy);
622
623 health.record_success().await;
625 assert_eq!(health.status(), HealthStatus::Healthy);
626 }
627
628 #[test]
629 fn test_endpoint_health_failure_rate() {
630 let health = EndpointHealth::new("https://test:50000".to_string());
631 assert_eq!(health.failure_rate(), 0.0);
632
633 health.total_requests.store(10, Ordering::Relaxed);
634 health.total_failures.store(2, Ordering::Relaxed);
635 assert!((health.failure_rate() - 0.2).abs() < f64::EPSILON);
636 }
637
638 #[test]
639 fn test_load_balancer_default() {
640 assert_eq!(LoadBalancer::default(), LoadBalancer::RoundRobin);
641 }
642
643 #[test]
644 fn test_connection_pool_config_new() {
645 let config = ConnectionPoolConfig::new(vec![
646 "https://node1:50000".to_string(),
647 "https://node2:50000".to_string(),
648 ]);
649
650 assert_eq!(config.endpoints.len(), 2);
651 assert_eq!(config.load_balancer, LoadBalancer::RoundRobin);
652 assert_eq!(config.failure_threshold, 3);
653 assert!(config.auto_health_check);
654 }
655
656 #[test]
657 fn test_connection_pool_config_builder() {
658 let config = ConnectionPoolConfig::new(vec!["https://node1:50000".to_string()])
659 .with_load_balancer(LoadBalancer::Random)
660 .with_failure_threshold(5)
661 .with_recovery_threshold(3)
662 .with_health_check_interval(Duration::from_secs(60))
663 .disable_auto_health_check();
664
665 assert_eq!(config.load_balancer, LoadBalancer::Random);
666 assert_eq!(config.failure_threshold, 5);
667 assert_eq!(config.recovery_threshold, 3);
668 assert_eq!(config.health_check_interval, Duration::from_secs(60));
669 assert!(!config.auto_health_check);
670 }
671
672 #[tokio::test]
673 async fn test_connection_pool_empty_endpoints() {
674 let config = ConnectionPoolConfig::new(vec![]);
675 let result = ConnectionPool::new(config).await;
676 assert!(result.is_err());
677 }
678
679 #[test]
680 fn test_health_status_conversions() {
681 assert_eq!(
682 EndpointHealth::u64_to_status(EndpointHealth::status_to_u64(HealthStatus::Healthy)),
683 HealthStatus::Healthy
684 );
685 assert_eq!(
686 EndpointHealth::u64_to_status(EndpointHealth::status_to_u64(HealthStatus::Unhealthy)),
687 HealthStatus::Unhealthy
688 );
689 assert_eq!(
690 EndpointHealth::u64_to_status(EndpointHealth::status_to_u64(HealthStatus::Unknown)),
691 HealthStatus::Unknown
692 );
693 }
694
695 #[test]
696 fn test_endpoint_health_reset() {
697 let health = EndpointHealth::new("https://test:50000".to_string());
698 health.status.store(
699 EndpointHealth::status_to_u64(HealthStatus::Unhealthy),
700 Ordering::Relaxed,
701 );
702 health.consecutive_failures.store(5, Ordering::Relaxed);
703
704 health.reset();
705
706 assert_eq!(health.status(), HealthStatus::Unknown);
707 assert_eq!(health.consecutive_failures(), 0);
708 }
709}