sentinel_proxy/
health.rs

1//! Health checking module for Sentinel proxy
2//!
3//! This module implements active and passive health checking for upstream servers,
4//! supporting HTTP, TCP, and gRPC health checks with configurable thresholds.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, error, info, trace, warn};
16
17use sentinel_common::{
18    errors::SentinelResult,
19    types::HealthCheckType,
20};
21use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
22
23/// Active health checker for upstream targets
24///
25/// Performs periodic health checks on upstream targets using HTTP, TCP, or gRPC
26/// protocols to determine their availability for load balancing.
27pub struct ActiveHealthChecker {
28    /// Check configuration
29    config: HealthCheckConfig,
30    /// Health checker implementation
31    checker: Arc<dyn HealthCheckImpl>,
32    /// Health status per target
33    health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
34    /// Check task handles
35    check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
36    /// Shutdown signal
37    shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
38}
39
40/// Health status information for a target
41#[derive(Debug, Clone)]
42pub struct TargetHealthInfo {
43    /// Is target healthy
44    pub healthy: bool,
45    /// Consecutive successes
46    pub consecutive_successes: u32,
47    /// Consecutive failures
48    pub consecutive_failures: u32,
49    /// Last check time
50    pub last_check: Instant,
51    /// Last successful check
52    pub last_success: Option<Instant>,
53    /// Last error message
54    pub last_error: Option<String>,
55    /// Total checks performed
56    pub total_checks: u64,
57    /// Total successful checks
58    pub total_successes: u64,
59    /// Average response time (ms)
60    pub avg_response_time: f64,
61}
62
63/// Health check implementation trait
64#[async_trait]
65trait HealthCheckImpl: Send + Sync {
66    /// Perform health check on a target
67    async fn check(&self, target: &str) -> Result<Duration, String>;
68
69    /// Get check type name
70    fn check_type(&self) -> &str;
71}
72
73/// HTTP health check implementation
74struct HttpHealthCheck {
75    path: String,
76    expected_status: u16,
77    host: Option<String>,
78    timeout: Duration,
79}
80
81/// TCP health check implementation
82struct TcpHealthCheck {
83    timeout: Duration,
84}
85
86/// gRPC health check implementation.
87///
88/// Currently uses TCP connectivity check as a fallback since full gRPC
89/// health checking protocol (grpc.health.v1.Health) requires the `tonic`
90/// crate for HTTP/2 and Protocol Buffers support.
91///
92/// Full implementation would:
93/// 1. Establish HTTP/2 connection
94/// 2. Call `grpc.health.v1.Health/Check` with service name
95/// 3. Parse `HealthCheckResponse` for SERVING/NOT_SERVING status
96///
97/// See: https://github.com/grpc/grpc/blob/master/doc/health-checking.md
98struct GrpcHealthCheck {
99    service: String,
100    timeout: Duration,
101}
102
103impl ActiveHealthChecker {
104    /// Create new active health checker
105    pub fn new(config: HealthCheckConfig) -> Self {
106        debug!(
107            check_type = ?config.check_type,
108            interval_secs = config.interval_secs,
109            timeout_secs = config.timeout_secs,
110            healthy_threshold = config.healthy_threshold,
111            unhealthy_threshold = config.unhealthy_threshold,
112            "Creating active health checker"
113        );
114
115        let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
116            HealthCheckType::Http {
117                path,
118                expected_status,
119                host,
120            } => {
121                trace!(
122                    path = %path,
123                    expected_status = expected_status,
124                    host = host.as_deref().unwrap_or("(default)"),
125                    "Configuring HTTP health check"
126                );
127                Arc::new(HttpHealthCheck {
128                    path: path.clone(),
129                    expected_status: *expected_status,
130                    host: host.clone(),
131                    timeout: Duration::from_secs(config.timeout_secs),
132                })
133            }
134            HealthCheckType::Tcp => {
135                trace!("Configuring TCP health check");
136                Arc::new(TcpHealthCheck {
137                    timeout: Duration::from_secs(config.timeout_secs),
138                })
139            }
140            HealthCheckType::Grpc { service } => {
141                trace!(
142                    service = %service,
143                    "Configuring gRPC health check"
144                );
145                Arc::new(GrpcHealthCheck {
146                    service: service.clone(),
147                    timeout: Duration::from_secs(config.timeout_secs),
148                })
149            }
150        };
151
152        let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
153
154        Self {
155            config,
156            checker,
157            health_status: Arc::new(RwLock::new(HashMap::new())),
158            check_handles: Arc::new(RwLock::new(Vec::new())),
159            shutdown_tx: Arc::new(shutdown_tx),
160        }
161    }
162
163    /// Start health checking for targets
164    pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
165        info!(
166            target_count = targets.len(),
167            interval_secs = self.config.interval_secs,
168            check_type = self.checker.check_type(),
169            "Starting health checking"
170        );
171
172        let mut handles = self.check_handles.write().await;
173
174        for target in targets {
175            let address = target.address.clone();
176
177            trace!(
178                target = %address,
179                "Initializing health status for target"
180            );
181
182            // Initialize health status
183            self.health_status
184                .write()
185                .await
186                .insert(address.clone(), TargetHealthInfo::new());
187
188            // Spawn health check task
189            debug!(
190                target = %address,
191                "Spawning health check task"
192            );
193            let handle = self.spawn_check_task(address);
194            handles.push(handle);
195        }
196
197        info!(
198            target_count = targets.len(),
199            interval_secs = self.config.interval_secs,
200            healthy_threshold = self.config.healthy_threshold,
201            unhealthy_threshold = self.config.unhealthy_threshold,
202            "Health checking started successfully"
203        );
204
205        Ok(())
206    }
207
208    /// Spawn health check task for a target
209    fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
210        let interval = Duration::from_secs(self.config.interval_secs);
211        let checker = Arc::clone(&self.checker);
212        let health_status = Arc::clone(&self.health_status);
213        let healthy_threshold = self.config.healthy_threshold;
214        let unhealthy_threshold = self.config.unhealthy_threshold;
215        let check_type = self.checker.check_type().to_string();
216        let mut shutdown_rx = self.shutdown_tx.subscribe();
217
218        tokio::spawn(async move {
219            let mut interval_timer = time::interval(interval);
220            interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
221
222            trace!(
223                target = %target,
224                check_type = %check_type,
225                interval_ms = interval.as_millis(),
226                "Health check task started"
227            );
228
229            loop {
230                tokio::select! {
231                    _ = interval_timer.tick() => {
232                        // Perform health check
233                        trace!(
234                            target = %target,
235                            check_type = %check_type,
236                            "Performing health check"
237                        );
238                        let start = Instant::now();
239                        let result = checker.check(&target).await;
240                        let check_duration = start.elapsed();
241
242                        // Update health status
243                        let mut status_map = health_status.write().await;
244                        if let Some(status) = status_map.get_mut(&target) {
245                            status.last_check = Instant::now();
246                            status.total_checks += 1;
247
248                            match result {
249                                Ok(response_time) => {
250                                    status.consecutive_successes += 1;
251                                    status.consecutive_failures = 0;
252                                    status.last_success = Some(Instant::now());
253                                    status.last_error = None;
254                                    status.total_successes += 1;
255
256                                    // Update average response time
257                                    let response_ms = response_time.as_millis() as f64;
258                                    status.avg_response_time =
259                                        (status.avg_response_time * (status.total_successes - 1) as f64
260                                        + response_ms) / status.total_successes as f64;
261
262                                    // Check if should mark as healthy
263                                    if !status.healthy && status.consecutive_successes >= healthy_threshold {
264                                        status.healthy = true;
265                                        info!(
266                                            target = %target,
267                                            consecutive_successes = status.consecutive_successes,
268                                            avg_response_ms = format!("{:.2}", status.avg_response_time),
269                                            total_checks = status.total_checks,
270                                            "Target marked as healthy"
271                                        );
272                                    }
273
274                                    trace!(
275                                        target = %target,
276                                        response_time_ms = response_ms,
277                                        check_duration_ms = check_duration.as_millis(),
278                                        consecutive_successes = status.consecutive_successes,
279                                        health_score = format!("{:.2}", status.health_score()),
280                                        "Health check succeeded"
281                                    );
282                                }
283                                Err(error) => {
284                                    status.consecutive_failures += 1;
285                                    status.consecutive_successes = 0;
286                                    status.last_error = Some(error.clone());
287
288                                    // Check if should mark as unhealthy
289                                    if status.healthy && status.consecutive_failures >= unhealthy_threshold {
290                                        status.healthy = false;
291                                        warn!(
292                                            target = %target,
293                                            consecutive_failures = status.consecutive_failures,
294                                            error = %error,
295                                            total_checks = status.total_checks,
296                                            health_score = format!("{:.2}", status.health_score()),
297                                            "Target marked as unhealthy"
298                                        );
299                                    } else {
300                                        debug!(
301                                            target = %target,
302                                            error = %error,
303                                            consecutive_failures = status.consecutive_failures,
304                                            unhealthy_threshold = unhealthy_threshold,
305                                            "Health check failed"
306                                        );
307                                    }
308                                }
309                            }
310                        }
311                    }
312                    _ = shutdown_rx.recv() => {
313                        info!(target = %target, "Stopping health check task");
314                        break;
315                    }
316                }
317            }
318
319            debug!(target = %target, "Health check task stopped");
320        })
321    }
322
323    /// Stop health checking
324    pub async fn stop(&self) {
325        let task_count = self.check_handles.read().await.len();
326        info!(
327            task_count = task_count,
328            "Stopping health checker"
329        );
330
331        // Send shutdown signal
332        let _ = self.shutdown_tx.send(());
333
334        // Wait for all tasks to complete
335        let mut handles = self.check_handles.write().await;
336        for handle in handles.drain(..) {
337            let _ = handle.await;
338        }
339
340        info!("Health checker stopped successfully");
341    }
342
343    /// Get health status for a target
344    pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
345        self.health_status.read().await.get(target).cloned()
346    }
347
348    /// Get all health statuses
349    pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
350        self.health_status.read().await.clone()
351    }
352
353    /// Check if target is healthy
354    pub async fn is_healthy(&self, target: &str) -> bool {
355        self.health_status
356            .read()
357            .await
358            .get(target)
359            .map(|s| s.healthy)
360            .unwrap_or(false)
361    }
362
363    /// Get healthy targets
364    pub async fn get_healthy_targets(&self) -> Vec<String> {
365        self.health_status
366            .read()
367            .await
368            .iter()
369            .filter_map(|(target, status)| {
370                if status.healthy {
371                    Some(target.clone())
372                } else {
373                    None
374                }
375            })
376            .collect()
377    }
378
379    /// Mark target as unhealthy (for passive health checking)
380    pub async fn mark_unhealthy(&self, target: &str, reason: String) {
381        if let Some(status) = self.health_status.write().await.get_mut(target) {
382            if status.healthy {
383                status.healthy = false;
384                status.consecutive_failures = self.config.unhealthy_threshold;
385                status.consecutive_successes = 0;
386                status.last_error = Some(reason.clone());
387                warn!(
388                    target = %target,
389                    reason = %reason,
390                    "Target marked unhealthy by passive check"
391                );
392            }
393        }
394    }
395}
396
397impl TargetHealthInfo {
398    /// Create new health status (initially healthy)
399    pub fn new() -> Self {
400        Self {
401            healthy: true,
402            consecutive_successes: 0,
403            consecutive_failures: 0,
404            last_check: Instant::now(),
405            last_success: Some(Instant::now()),
406            last_error: None,
407            total_checks: 0,
408            total_successes: 0,
409            avg_response_time: 0.0,
410        }
411    }
412
413    /// Get health score (0.0 - 1.0)
414    pub fn health_score(&self) -> f64 {
415        if self.total_checks == 0 {
416            return 1.0;
417        }
418        self.total_successes as f64 / self.total_checks as f64
419    }
420
421    /// Check if status is degraded (healthy but with recent failures)
422    pub fn is_degraded(&self) -> bool {
423        self.healthy && self.consecutive_failures > 0
424    }
425}
426
427#[async_trait]
428impl HealthCheckImpl for HttpHealthCheck {
429    async fn check(&self, target: &str) -> Result<Duration, String> {
430        let start = Instant::now();
431
432        // Parse target address
433        let addr: SocketAddr = target
434            .parse()
435            .map_err(|e| format!("Invalid address: {}", e))?;
436
437        // Connect with timeout
438        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
439            .await
440            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
441            .map_err(|e| format!("Connection failed: {}", e))?;
442
443        // Build HTTP request
444        let host = self.host.as_deref().unwrap_or(target);
445        let request = format!(
446            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
447            self.path,
448            host
449        );
450
451        // Send request and read response
452        let mut stream = stream;
453        stream
454            .write_all(request.as_bytes())
455            .await
456            .map_err(|e| format!("Failed to send request: {}", e))?;
457
458        let mut response = vec![0u8; 1024];
459        let n = stream
460            .read(&mut response)
461            .await
462            .map_err(|e| format!("Failed to read response: {}", e))?;
463
464        if n == 0 {
465            return Err("Empty response".to_string());
466        }
467
468        // Parse status code
469        let response_str = String::from_utf8_lossy(&response[..n]);
470        let status_code = parse_http_status(&response_str)
471            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
472
473        if status_code == self.expected_status {
474            Ok(start.elapsed())
475        } else {
476            Err(format!(
477                "Unexpected status code: {} (expected {})",
478                status_code, self.expected_status
479            ))
480        }
481    }
482
483    fn check_type(&self) -> &str {
484        "HTTP"
485    }
486}
487
488#[async_trait]
489impl HealthCheckImpl for TcpHealthCheck {
490    async fn check(&self, target: &str) -> Result<Duration, String> {
491        let start = Instant::now();
492
493        // Parse target address
494        let addr: SocketAddr = target
495            .parse()
496            .map_err(|e| format!("Invalid address: {}", e))?;
497
498        // Connect with timeout
499        time::timeout(self.timeout, TcpStream::connect(addr))
500            .await
501            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
502            .map_err(|e| format!("Connection failed: {}", e))?;
503
504        Ok(start.elapsed())
505    }
506
507    fn check_type(&self) -> &str {
508        "TCP"
509    }
510}
511
512#[async_trait]
513impl HealthCheckImpl for GrpcHealthCheck {
514    async fn check(&self, target: &str) -> Result<Duration, String> {
515        let start = Instant::now();
516
517        // NOTE: Full gRPC health check requires `tonic` crate for HTTP/2 support.
518        // This implementation uses TCP connectivity as a reasonable fallback.
519        // The gRPC health checking protocol (grpc.health.v1.Health/Check) would
520        // return SERVING, NOT_SERVING, or UNKNOWN for the specified service.
521
522        let addr: SocketAddr = target
523            .parse()
524            .map_err(|e| format!("Invalid address: {}", e))?;
525
526        // TCP connectivity check as fallback
527        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
528            .await
529            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
530            .map_err(|e| format!("Connection failed: {}", e))?;
531
532        // Verify connection is writable (basic health indicator)
533        stream
534            .writable()
535            .await
536            .map_err(|e| format!("Connection not writable: {}", e))?;
537
538        debug!(
539            target = %target,
540            service = %self.service,
541            "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
542        );
543
544        Ok(start.elapsed())
545    }
546
547    fn check_type(&self) -> &str {
548        "gRPC"
549    }
550}
551
552/// Parse HTTP status code from response
553fn parse_http_status(response: &str) -> Option<u16> {
554    response
555        .lines()
556        .next()?
557        .split_whitespace()
558        .nth(1)?
559        .parse()
560        .ok()
561}
562
563/// Passive health checker that monitors request outcomes
564///
565/// Observes request success/failure rates to detect unhealthy targets
566/// without performing explicit health checks. Works in combination with
567/// `ActiveHealthChecker` for comprehensive health monitoring.
568pub struct PassiveHealthChecker {
569    /// Failure rate threshold (0.0 - 1.0)
570    failure_rate_threshold: f64,
571    /// Window size for calculating failure rate
572    window_size: usize,
573    /// Request outcomes per target (ring buffer)
574    outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
575    /// Active health checker reference
576    active_checker: Option<Arc<ActiveHealthChecker>>,
577}
578
579impl PassiveHealthChecker {
580    /// Create new passive health checker
581    pub fn new(
582        failure_rate_threshold: f64,
583        window_size: usize,
584        active_checker: Option<Arc<ActiveHealthChecker>>,
585    ) -> Self {
586        debug!(
587            failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
588            window_size = window_size,
589            has_active_checker = active_checker.is_some(),
590            "Creating passive health checker"
591        );
592        Self {
593            failure_rate_threshold,
594            window_size,
595            outcomes: Arc::new(RwLock::new(HashMap::new())),
596            active_checker,
597        }
598    }
599
600    /// Record request outcome
601    pub async fn record_outcome(&self, target: &str, success: bool) {
602        trace!(
603            target = %target,
604            success = success,
605            "Recording request outcome"
606        );
607
608        let mut outcomes = self.outcomes.write().await;
609        let target_outcomes = outcomes
610            .entry(target.to_string())
611            .or_insert_with(|| Vec::with_capacity(self.window_size));
612
613        // Add outcome to ring buffer
614        if target_outcomes.len() >= self.window_size {
615            target_outcomes.remove(0);
616        }
617        target_outcomes.push(success);
618
619        // Calculate failure rate
620        let failures = target_outcomes.iter().filter(|&&s| !s).count();
621        let failure_rate = failures as f64 / target_outcomes.len() as f64;
622
623        trace!(
624            target = %target,
625            failure_rate = format!("{:.2}", failure_rate),
626            window_samples = target_outcomes.len(),
627            failures = failures,
628            "Updated failure rate"
629        );
630
631        // Mark unhealthy if failure rate exceeds threshold
632        if failure_rate > self.failure_rate_threshold {
633            warn!(
634                target = %target,
635                failure_rate = format!("{:.2}", failure_rate * 100.0),
636                threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
637                window_samples = target_outcomes.len(),
638                "Failure rate exceeds threshold"
639            );
640            if let Some(ref checker) = self.active_checker {
641                checker
642                    .mark_unhealthy(
643                        target,
644                        format!(
645                            "Failure rate {:.2}% exceeds threshold",
646                            failure_rate * 100.0
647                        ),
648                    )
649                    .await;
650            }
651        }
652    }
653
654    /// Get failure rate for a target
655    pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
656        let outcomes = self.outcomes.read().await;
657        outcomes.get(target).map(|target_outcomes| {
658            let failures = target_outcomes.iter().filter(|&&s| !s).count();
659            failures as f64 / target_outcomes.len() as f64
660        })
661    }
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667    use sentinel_config::HealthCheck as HealthCheckConfig;
668
669    #[tokio::test]
670    async fn test_health_status() {
671        let status = TargetHealthInfo::new();
672        assert!(status.healthy);
673        assert_eq!(status.health_score(), 1.0);
674        assert!(!status.is_degraded());
675    }
676
677    #[tokio::test]
678    async fn test_passive_health_checker() {
679        let checker = PassiveHealthChecker::new(0.5, 10, None);
680
681        // Record some outcomes
682        for _ in 0..5 {
683            checker.record_outcome("target1", true).await;
684        }
685        for _ in 0..3 {
686            checker.record_outcome("target1", false).await;
687        }
688
689        let failure_rate = checker.get_failure_rate("target1").await.unwrap();
690        assert!(failure_rate > 0.3 && failure_rate < 0.4);
691    }
692
693    #[test]
694    fn test_parse_http_status() {
695        let response = "HTTP/1.1 200 OK\r\n";
696        assert_eq!(parse_http_status(response), Some(200));
697
698        let response = "HTTP/1.1 404 Not Found\r\n";
699        assert_eq!(parse_http_status(response), Some(404));
700
701        let response = "Invalid response";
702        assert_eq!(parse_http_status(response), None);
703    }
704}