sentinel_proxy/
health.rs

1//! Health checking module for Sentinel proxy
2//!
3//! This module implements active and passive health checking for upstream servers,
4//! supporting HTTP, TCP, and gRPC health checks with configurable thresholds.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, trace, warn};
16
17use sentinel_common::{errors::SentinelResult, types::HealthCheckType};
18use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
19
20/// Active health checker for upstream targets
21///
22/// Performs periodic health checks on upstream targets using HTTP, TCP, or gRPC
23/// protocols to determine their availability for load balancing.
24pub struct ActiveHealthChecker {
25    /// Check configuration
26    config: HealthCheckConfig,
27    /// Health checker implementation
28    checker: Arc<dyn HealthCheckImpl>,
29    /// Health status per target
30    health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
31    /// Check task handles
32    check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
33    /// Shutdown signal
34    shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
35}
36
37/// Health status information for a target
38#[derive(Debug, Clone)]
39pub struct TargetHealthInfo {
40    /// Is target healthy
41    pub healthy: bool,
42    /// Consecutive successes
43    pub consecutive_successes: u32,
44    /// Consecutive failures
45    pub consecutive_failures: u32,
46    /// Last check time
47    pub last_check: Instant,
48    /// Last successful check
49    pub last_success: Option<Instant>,
50    /// Last error message
51    pub last_error: Option<String>,
52    /// Total checks performed
53    pub total_checks: u64,
54    /// Total successful checks
55    pub total_successes: u64,
56    /// Average response time (ms)
57    pub avg_response_time: f64,
58}
59
60/// Health check implementation trait
61#[async_trait]
62trait HealthCheckImpl: Send + Sync {
63    /// Perform health check on a target
64    async fn check(&self, target: &str) -> Result<Duration, String>;
65
66    /// Get check type name
67    fn check_type(&self) -> &str;
68}
69
70/// HTTP health check implementation
71struct HttpHealthCheck {
72    path: String,
73    expected_status: u16,
74    host: Option<String>,
75    timeout: Duration,
76}
77
78/// TCP health check implementation
79struct TcpHealthCheck {
80    timeout: Duration,
81}
82
83/// gRPC health check implementation.
84///
85/// Currently uses TCP connectivity check as a fallback since full gRPC
86/// health checking protocol (grpc.health.v1.Health) requires the `tonic`
87/// crate for HTTP/2 and Protocol Buffers support.
88///
89/// Full implementation would:
90/// 1. Establish HTTP/2 connection
91/// 2. Call `grpc.health.v1.Health/Check` with service name
92/// 3. Parse `HealthCheckResponse` for SERVING/NOT_SERVING status
93///
94/// See: https://github.com/grpc/grpc/blob/master/doc/health-checking.md
95struct GrpcHealthCheck {
96    service: String,
97    timeout: Duration,
98}
99
100impl ActiveHealthChecker {
101    /// Create new active health checker
102    pub fn new(config: HealthCheckConfig) -> Self {
103        debug!(
104            check_type = ?config.check_type,
105            interval_secs = config.interval_secs,
106            timeout_secs = config.timeout_secs,
107            healthy_threshold = config.healthy_threshold,
108            unhealthy_threshold = config.unhealthy_threshold,
109            "Creating active health checker"
110        );
111
112        let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
113            HealthCheckType::Http {
114                path,
115                expected_status,
116                host,
117            } => {
118                trace!(
119                    path = %path,
120                    expected_status = expected_status,
121                    host = host.as_deref().unwrap_or("(default)"),
122                    "Configuring HTTP health check"
123                );
124                Arc::new(HttpHealthCheck {
125                    path: path.clone(),
126                    expected_status: *expected_status,
127                    host: host.clone(),
128                    timeout: Duration::from_secs(config.timeout_secs),
129                })
130            }
131            HealthCheckType::Tcp => {
132                trace!("Configuring TCP health check");
133                Arc::new(TcpHealthCheck {
134                    timeout: Duration::from_secs(config.timeout_secs),
135                })
136            }
137            HealthCheckType::Grpc { service } => {
138                trace!(
139                    service = %service,
140                    "Configuring gRPC health check"
141                );
142                Arc::new(GrpcHealthCheck {
143                    service: service.clone(),
144                    timeout: Duration::from_secs(config.timeout_secs),
145                })
146            }
147        };
148
149        let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
150
151        Self {
152            config,
153            checker,
154            health_status: Arc::new(RwLock::new(HashMap::new())),
155            check_handles: Arc::new(RwLock::new(Vec::new())),
156            shutdown_tx: Arc::new(shutdown_tx),
157        }
158    }
159
160    /// Start health checking for targets
161    pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
162        info!(
163            target_count = targets.len(),
164            interval_secs = self.config.interval_secs,
165            check_type = self.checker.check_type(),
166            "Starting health checking"
167        );
168
169        let mut handles = self.check_handles.write().await;
170
171        for target in targets {
172            let address = target.address.clone();
173
174            trace!(
175                target = %address,
176                "Initializing health status for target"
177            );
178
179            // Initialize health status
180            self.health_status
181                .write()
182                .await
183                .insert(address.clone(), TargetHealthInfo::new());
184
185            // Spawn health check task
186            debug!(
187                target = %address,
188                "Spawning health check task"
189            );
190            let handle = self.spawn_check_task(address);
191            handles.push(handle);
192        }
193
194        info!(
195            target_count = targets.len(),
196            interval_secs = self.config.interval_secs,
197            healthy_threshold = self.config.healthy_threshold,
198            unhealthy_threshold = self.config.unhealthy_threshold,
199            "Health checking started successfully"
200        );
201
202        Ok(())
203    }
204
205    /// Spawn health check task for a target
206    fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
207        let interval = Duration::from_secs(self.config.interval_secs);
208        let checker = Arc::clone(&self.checker);
209        let health_status = Arc::clone(&self.health_status);
210        let healthy_threshold = self.config.healthy_threshold;
211        let unhealthy_threshold = self.config.unhealthy_threshold;
212        let check_type = self.checker.check_type().to_string();
213        let mut shutdown_rx = self.shutdown_tx.subscribe();
214
215        tokio::spawn(async move {
216            let mut interval_timer = time::interval(interval);
217            interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
218
219            trace!(
220                target = %target,
221                check_type = %check_type,
222                interval_ms = interval.as_millis(),
223                "Health check task started"
224            );
225
226            loop {
227                tokio::select! {
228                    _ = interval_timer.tick() => {
229                        // Perform health check
230                        trace!(
231                            target = %target,
232                            check_type = %check_type,
233                            "Performing health check"
234                        );
235                        let start = Instant::now();
236                        let result = checker.check(&target).await;
237                        let check_duration = start.elapsed();
238
239                        // Update health status
240                        let mut status_map = health_status.write().await;
241                        if let Some(status) = status_map.get_mut(&target) {
242                            status.last_check = Instant::now();
243                            status.total_checks += 1;
244
245                            match result {
246                                Ok(response_time) => {
247                                    status.consecutive_successes += 1;
248                                    status.consecutive_failures = 0;
249                                    status.last_success = Some(Instant::now());
250                                    status.last_error = None;
251                                    status.total_successes += 1;
252
253                                    // Update average response time
254                                    let response_ms = response_time.as_millis() as f64;
255                                    status.avg_response_time =
256                                        (status.avg_response_time * (status.total_successes - 1) as f64
257                                        + response_ms) / status.total_successes as f64;
258
259                                    // Check if should mark as healthy
260                                    if !status.healthy && status.consecutive_successes >= healthy_threshold {
261                                        status.healthy = true;
262                                        info!(
263                                            target = %target,
264                                            consecutive_successes = status.consecutive_successes,
265                                            avg_response_ms = format!("{:.2}", status.avg_response_time),
266                                            total_checks = status.total_checks,
267                                            "Target marked as healthy"
268                                        );
269                                    }
270
271                                    trace!(
272                                        target = %target,
273                                        response_time_ms = response_ms,
274                                        check_duration_ms = check_duration.as_millis(),
275                                        consecutive_successes = status.consecutive_successes,
276                                        health_score = format!("{:.2}", status.health_score()),
277                                        "Health check succeeded"
278                                    );
279                                }
280                                Err(error) => {
281                                    status.consecutive_failures += 1;
282                                    status.consecutive_successes = 0;
283                                    status.last_error = Some(error.clone());
284
285                                    // Check if should mark as unhealthy
286                                    if status.healthy && status.consecutive_failures >= unhealthy_threshold {
287                                        status.healthy = false;
288                                        warn!(
289                                            target = %target,
290                                            consecutive_failures = status.consecutive_failures,
291                                            error = %error,
292                                            total_checks = status.total_checks,
293                                            health_score = format!("{:.2}", status.health_score()),
294                                            "Target marked as unhealthy"
295                                        );
296                                    } else {
297                                        debug!(
298                                            target = %target,
299                                            error = %error,
300                                            consecutive_failures = status.consecutive_failures,
301                                            unhealthy_threshold = unhealthy_threshold,
302                                            "Health check failed"
303                                        );
304                                    }
305                                }
306                            }
307                        }
308                    }
309                    _ = shutdown_rx.recv() => {
310                        info!(target = %target, "Stopping health check task");
311                        break;
312                    }
313                }
314            }
315
316            debug!(target = %target, "Health check task stopped");
317        })
318    }
319
320    /// Stop health checking
321    pub async fn stop(&self) {
322        let task_count = self.check_handles.read().await.len();
323        info!(task_count = task_count, "Stopping health checker");
324
325        // Send shutdown signal
326        let _ = self.shutdown_tx.send(());
327
328        // Wait for all tasks to complete
329        let mut handles = self.check_handles.write().await;
330        for handle in handles.drain(..) {
331            let _ = handle.await;
332        }
333
334        info!("Health checker stopped successfully");
335    }
336
337    /// Get health status for a target
338    pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
339        self.health_status.read().await.get(target).cloned()
340    }
341
342    /// Get all health statuses
343    pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
344        self.health_status.read().await.clone()
345    }
346
347    /// Check if target is healthy
348    pub async fn is_healthy(&self, target: &str) -> bool {
349        self.health_status
350            .read()
351            .await
352            .get(target)
353            .map(|s| s.healthy)
354            .unwrap_or(false)
355    }
356
357    /// Get healthy targets
358    pub async fn get_healthy_targets(&self) -> Vec<String> {
359        self.health_status
360            .read()
361            .await
362            .iter()
363            .filter_map(|(target, status)| {
364                if status.healthy {
365                    Some(target.clone())
366                } else {
367                    None
368                }
369            })
370            .collect()
371    }
372
373    /// Mark target as unhealthy (for passive health checking)
374    pub async fn mark_unhealthy(&self, target: &str, reason: String) {
375        if let Some(status) = self.health_status.write().await.get_mut(target) {
376            if status.healthy {
377                status.healthy = false;
378                status.consecutive_failures = self.config.unhealthy_threshold;
379                status.consecutive_successes = 0;
380                status.last_error = Some(reason.clone());
381                warn!(
382                    target = %target,
383                    reason = %reason,
384                    "Target marked unhealthy by passive check"
385                );
386            }
387        }
388    }
389}
390
391impl Default for TargetHealthInfo {
392    fn default() -> Self {
393        Self::new()
394    }
395}
396
397impl TargetHealthInfo {
398    /// Create new health status (initially healthy)
399    pub fn new() -> Self {
400        Self {
401            healthy: true,
402            consecutive_successes: 0,
403            consecutive_failures: 0,
404            last_check: Instant::now(),
405            last_success: Some(Instant::now()),
406            last_error: None,
407            total_checks: 0,
408            total_successes: 0,
409            avg_response_time: 0.0,
410        }
411    }
412
413    /// Get health score (0.0 - 1.0)
414    pub fn health_score(&self) -> f64 {
415        if self.total_checks == 0 {
416            return 1.0;
417        }
418        self.total_successes as f64 / self.total_checks as f64
419    }
420
421    /// Check if status is degraded (healthy but with recent failures)
422    pub fn is_degraded(&self) -> bool {
423        self.healthy && self.consecutive_failures > 0
424    }
425}
426
427#[async_trait]
428impl HealthCheckImpl for HttpHealthCheck {
429    async fn check(&self, target: &str) -> Result<Duration, String> {
430        let start = Instant::now();
431
432        // Parse target address
433        let addr: SocketAddr = target
434            .parse()
435            .map_err(|e| format!("Invalid address: {}", e))?;
436
437        // Connect with timeout
438        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
439            .await
440            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
441            .map_err(|e| format!("Connection failed: {}", e))?;
442
443        // Build HTTP request
444        let host = self.host.as_deref().unwrap_or(target);
445        let request = format!(
446            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
447            self.path,
448            host
449        );
450
451        // Send request and read response
452        let mut stream = stream;
453        stream
454            .write_all(request.as_bytes())
455            .await
456            .map_err(|e| format!("Failed to send request: {}", e))?;
457
458        let mut response = vec![0u8; 1024];
459        let n = stream
460            .read(&mut response)
461            .await
462            .map_err(|e| format!("Failed to read response: {}", e))?;
463
464        if n == 0 {
465            return Err("Empty response".to_string());
466        }
467
468        // Parse status code
469        let response_str = String::from_utf8_lossy(&response[..n]);
470        let status_code = parse_http_status(&response_str)
471            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
472
473        if status_code == self.expected_status {
474            Ok(start.elapsed())
475        } else {
476            Err(format!(
477                "Unexpected status code: {} (expected {})",
478                status_code, self.expected_status
479            ))
480        }
481    }
482
483    fn check_type(&self) -> &str {
484        "HTTP"
485    }
486}
487
488#[async_trait]
489impl HealthCheckImpl for TcpHealthCheck {
490    async fn check(&self, target: &str) -> Result<Duration, String> {
491        let start = Instant::now();
492
493        // Parse target address
494        let addr: SocketAddr = target
495            .parse()
496            .map_err(|e| format!("Invalid address: {}", e))?;
497
498        // Connect with timeout
499        time::timeout(self.timeout, TcpStream::connect(addr))
500            .await
501            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
502            .map_err(|e| format!("Connection failed: {}", e))?;
503
504        Ok(start.elapsed())
505    }
506
507    fn check_type(&self) -> &str {
508        "TCP"
509    }
510}
511
512#[async_trait]
513impl HealthCheckImpl for GrpcHealthCheck {
514    async fn check(&self, target: &str) -> Result<Duration, String> {
515        let start = Instant::now();
516
517        // NOTE: Full gRPC health check requires `tonic` crate for HTTP/2 support.
518        // This implementation uses TCP connectivity as a reasonable fallback.
519        // The gRPC health checking protocol (grpc.health.v1.Health/Check) would
520        // return SERVING, NOT_SERVING, or UNKNOWN for the specified service.
521
522        let addr: SocketAddr = target
523            .parse()
524            .map_err(|e| format!("Invalid address: {}", e))?;
525
526        // TCP connectivity check as fallback
527        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
528            .await
529            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
530            .map_err(|e| format!("Connection failed: {}", e))?;
531
532        // Verify connection is writable (basic health indicator)
533        stream
534            .writable()
535            .await
536            .map_err(|e| format!("Connection not writable: {}", e))?;
537
538        debug!(
539            target = %target,
540            service = %self.service,
541            "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
542        );
543
544        Ok(start.elapsed())
545    }
546
547    fn check_type(&self) -> &str {
548        "gRPC"
549    }
550}
551
552/// Parse HTTP status code from response
553fn parse_http_status(response: &str) -> Option<u16> {
554    response
555        .lines()
556        .next()?
557        .split_whitespace()
558        .nth(1)?
559        .parse()
560        .ok()
561}
562
563/// Passive health checker that monitors request outcomes
564///
565/// Observes request success/failure rates to detect unhealthy targets
566/// without performing explicit health checks. Works in combination with
567/// `ActiveHealthChecker` for comprehensive health monitoring.
568pub struct PassiveHealthChecker {
569    /// Failure rate threshold (0.0 - 1.0)
570    failure_rate_threshold: f64,
571    /// Window size for calculating failure rate
572    window_size: usize,
573    /// Request outcomes per target (ring buffer)
574    outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
575    /// Last error per target
576    last_errors: Arc<RwLock<HashMap<String, String>>>,
577    /// Active health checker reference
578    active_checker: Option<Arc<ActiveHealthChecker>>,
579}
580
581impl PassiveHealthChecker {
582    /// Create new passive health checker
583    pub fn new(
584        failure_rate_threshold: f64,
585        window_size: usize,
586        active_checker: Option<Arc<ActiveHealthChecker>>,
587    ) -> Self {
588        debug!(
589            failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
590            window_size = window_size,
591            has_active_checker = active_checker.is_some(),
592            "Creating passive health checker"
593        );
594        Self {
595            failure_rate_threshold,
596            window_size,
597            outcomes: Arc::new(RwLock::new(HashMap::new())),
598            last_errors: Arc::new(RwLock::new(HashMap::new())),
599            active_checker,
600        }
601    }
602
603    /// Record request outcome with optional error message
604    pub async fn record_outcome(&self, target: &str, success: bool, error: Option<&str>) {
605        trace!(
606            target = %target,
607            success = success,
608            error = ?error,
609            "Recording request outcome"
610        );
611
612        // Track last error
613        if let Some(err_msg) = error {
614            self.last_errors
615                .write()
616                .await
617                .insert(target.to_string(), err_msg.to_string());
618        } else if success {
619            // Clear last error on success
620            self.last_errors.write().await.remove(target);
621        }
622
623        let mut outcomes = self.outcomes.write().await;
624        let target_outcomes = outcomes
625            .entry(target.to_string())
626            .or_insert_with(|| Vec::with_capacity(self.window_size));
627
628        // Add outcome to ring buffer
629        if target_outcomes.len() >= self.window_size {
630            target_outcomes.remove(0);
631        }
632        target_outcomes.push(success);
633
634        // Calculate failure rate
635        let failures = target_outcomes.iter().filter(|&&s| !s).count();
636        let failure_rate = failures as f64 / target_outcomes.len() as f64;
637
638        trace!(
639            target = %target,
640            failure_rate = format!("{:.2}", failure_rate),
641            window_samples = target_outcomes.len(),
642            failures = failures,
643            "Updated failure rate"
644        );
645
646        // Mark unhealthy if failure rate exceeds threshold
647        if failure_rate > self.failure_rate_threshold {
648            warn!(
649                target = %target,
650                failure_rate = format!("{:.2}", failure_rate * 100.0),
651                threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
652                window_samples = target_outcomes.len(),
653                "Failure rate exceeds threshold"
654            );
655            if let Some(ref checker) = self.active_checker {
656                checker
657                    .mark_unhealthy(
658                        target,
659                        format!(
660                            "Failure rate {:.2}% exceeds threshold",
661                            failure_rate * 100.0
662                        ),
663                    )
664                    .await;
665            }
666        }
667    }
668
669    /// Get failure rate for a target
670    pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
671        let outcomes = self.outcomes.read().await;
672        outcomes.get(target).map(|target_outcomes| {
673            let failures = target_outcomes.iter().filter(|&&s| !s).count();
674            failures as f64 / target_outcomes.len() as f64
675        })
676    }
677
678    /// Get last error for a target
679    pub async fn get_last_error(&self, target: &str) -> Option<String> {
680        self.last_errors.read().await.get(target).cloned()
681    }
682}
683
684#[cfg(test)]
685mod tests {
686    use super::*;
687
688    #[tokio::test]
689    async fn test_health_status() {
690        let status = TargetHealthInfo::new();
691        assert!(status.healthy);
692        assert_eq!(status.health_score(), 1.0);
693        assert!(!status.is_degraded());
694    }
695
696    #[tokio::test]
697    async fn test_passive_health_checker() {
698        let checker = PassiveHealthChecker::new(0.5, 10, None);
699
700        // Record some outcomes
701        for _ in 0..5 {
702            checker.record_outcome("target1", true, None).await;
703        }
704        for _ in 0..3 {
705            checker.record_outcome("target1", false, Some("HTTP 503")).await;
706        }
707
708        let failure_rate = checker.get_failure_rate("target1").await.unwrap();
709        assert!(failure_rate > 0.3 && failure_rate < 0.4);
710    }
711
712    #[test]
713    fn test_parse_http_status() {
714        let response = "HTTP/1.1 200 OK\r\n";
715        assert_eq!(parse_http_status(response), Some(200));
716
717        let response = "HTTP/1.1 404 Not Found\r\n";
718        assert_eq!(parse_http_status(response), Some(404));
719
720        let response = "Invalid response";
721        assert_eq!(parse_http_status(response), None);
722    }
723}