sentinel_proxy/
health.rs

1//! Health checking module for Sentinel proxy
2//!
3//! This module implements active and passive health checking for upstream servers,
4//! supporting HTTP, TCP, and gRPC health checks with configurable thresholds.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, warn};
16
17use sentinel_common::{
18    errors::SentinelResult,
19    types::HealthCheckType,
20};
21use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
22
23/// Active health checker for upstream targets
24///
25/// Performs periodic health checks on upstream targets using HTTP, TCP, or gRPC
26/// protocols to determine their availability for load balancing.
27pub struct ActiveHealthChecker {
28    /// Check configuration
29    config: HealthCheckConfig,
30    /// Health checker implementation
31    checker: Arc<dyn HealthCheckImpl>,
32    /// Health status per target
33    health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
34    /// Check task handles
35    check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
36    /// Shutdown signal
37    shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
38}
39
40/// Health status information for a target
41#[derive(Debug, Clone)]
42pub struct TargetHealthInfo {
43    /// Is target healthy
44    pub healthy: bool,
45    /// Consecutive successes
46    pub consecutive_successes: u32,
47    /// Consecutive failures
48    pub consecutive_failures: u32,
49    /// Last check time
50    pub last_check: Instant,
51    /// Last successful check
52    pub last_success: Option<Instant>,
53    /// Last error message
54    pub last_error: Option<String>,
55    /// Total checks performed
56    pub total_checks: u64,
57    /// Total successful checks
58    pub total_successes: u64,
59    /// Average response time (ms)
60    pub avg_response_time: f64,
61}
62
63/// Health check implementation trait
64#[async_trait]
65trait HealthCheckImpl: Send + Sync {
66    /// Perform health check on a target
67    async fn check(&self, target: &str) -> Result<Duration, String>;
68
69    /// Get check type name
70    fn check_type(&self) -> &str;
71}
72
73/// HTTP health check implementation
74struct HttpHealthCheck {
75    path: String,
76    expected_status: u16,
77    host: Option<String>,
78    timeout: Duration,
79}
80
81/// TCP health check implementation
82struct TcpHealthCheck {
83    timeout: Duration,
84}
85
86/// gRPC health check implementation.
87///
88/// Currently uses TCP connectivity check as a fallback since full gRPC
89/// health checking protocol (grpc.health.v1.Health) requires the `tonic`
90/// crate for HTTP/2 and Protocol Buffers support.
91///
92/// Full implementation would:
93/// 1. Establish HTTP/2 connection
94/// 2. Call `grpc.health.v1.Health/Check` with service name
95/// 3. Parse `HealthCheckResponse` for SERVING/NOT_SERVING status
96///
97/// See: https://github.com/grpc/grpc/blob/master/doc/health-checking.md
98struct GrpcHealthCheck {
99    service: String,
100    timeout: Duration,
101}
102
103impl ActiveHealthChecker {
104    /// Create new active health checker
105    pub fn new(config: HealthCheckConfig) -> Self {
106        let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
107            HealthCheckType::Http {
108                path,
109                expected_status,
110                host,
111            } => Arc::new(HttpHealthCheck {
112                path: path.clone(),
113                expected_status: *expected_status,
114                host: host.clone(),
115                timeout: Duration::from_secs(config.timeout_secs),
116            }),
117            HealthCheckType::Tcp => Arc::new(TcpHealthCheck {
118                timeout: Duration::from_secs(config.timeout_secs),
119            }),
120            HealthCheckType::Grpc { service } => Arc::new(GrpcHealthCheck {
121                service: service.clone(),
122                timeout: Duration::from_secs(config.timeout_secs),
123            }),
124        };
125
126        let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
127
128        Self {
129            config,
130            checker,
131            health_status: Arc::new(RwLock::new(HashMap::new())),
132            check_handles: Arc::new(RwLock::new(Vec::new())),
133            shutdown_tx: Arc::new(shutdown_tx),
134        }
135    }
136
137    /// Start health checking for targets
138    pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
139        let mut handles = self.check_handles.write().await;
140
141        for target in targets {
142            let address = target.address.clone();
143
144            // Initialize health status
145            self.health_status
146                .write()
147                .await
148                .insert(address.clone(), TargetHealthInfo::new());
149
150            // Spawn health check task
151            let handle = self.spawn_check_task(address);
152            handles.push(handle);
153        }
154
155        info!(
156            "Started health checking for {} targets, interval: {}s",
157            targets.len(),
158            self.config.interval_secs
159        );
160
161        Ok(())
162    }
163
164    /// Spawn health check task for a target
165    fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
166        let interval = Duration::from_secs(self.config.interval_secs);
167        let checker = Arc::clone(&self.checker);
168        let health_status = Arc::clone(&self.health_status);
169        let healthy_threshold = self.config.healthy_threshold;
170        let unhealthy_threshold = self.config.unhealthy_threshold;
171        let mut shutdown_rx = self.shutdown_tx.subscribe();
172
173        tokio::spawn(async move {
174            let mut interval_timer = time::interval(interval);
175            interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
176
177            loop {
178                tokio::select! {
179                    _ = interval_timer.tick() => {
180                        // Perform health check
181                        let start = Instant::now();
182                        let result = checker.check(&target).await;
183                        let _duration = start.elapsed();
184
185                        // Update health status
186                        let mut status_map = health_status.write().await;
187                        if let Some(status) = status_map.get_mut(&target) {
188                            status.last_check = Instant::now();
189                            status.total_checks += 1;
190
191                            match result {
192                                Ok(response_time) => {
193                                    status.consecutive_successes += 1;
194                                    status.consecutive_failures = 0;
195                                    status.last_success = Some(Instant::now());
196                                    status.last_error = None;
197                                    status.total_successes += 1;
198
199                                    // Update average response time
200                                    let response_ms = response_time.as_millis() as f64;
201                                    status.avg_response_time =
202                                        (status.avg_response_time * (status.total_successes - 1) as f64
203                                        + response_ms) / status.total_successes as f64;
204
205                                    // Check if should mark as healthy
206                                    if !status.healthy && status.consecutive_successes >= healthy_threshold {
207                                        status.healthy = true;
208                                        info!(
209                                            target = %target,
210                                            consecutive_successes = status.consecutive_successes,
211                                            "Target marked as healthy"
212                                        );
213                                    }
214
215                                    debug!(
216                                        target = %target,
217                                        response_time_ms = response_ms,
218                                        "Health check succeeded"
219                                    );
220                                }
221                                Err(error) => {
222                                    status.consecutive_failures += 1;
223                                    status.consecutive_successes = 0;
224                                    status.last_error = Some(error.clone());
225
226                                    // Check if should mark as unhealthy
227                                    if status.healthy && status.consecutive_failures >= unhealthy_threshold {
228                                        status.healthy = false;
229                                        warn!(
230                                            target = %target,
231                                            consecutive_failures = status.consecutive_failures,
232                                            error = %error,
233                                            "Target marked as unhealthy"
234                                        );
235                                    }
236
237                                    debug!(
238                                        target = %target,
239                                        error = %error,
240                                        "Health check failed"
241                                    );
242                                }
243                            }
244                        }
245                    }
246                    _ = shutdown_rx.recv() => {
247                        info!(target = %target, "Stopping health check task");
248                        break;
249                    }
250                }
251            }
252        })
253    }
254
255    /// Stop health checking
256    pub async fn stop(&self) {
257        info!("Stopping health checker");
258
259        // Send shutdown signal
260        let _ = self.shutdown_tx.send(());
261
262        // Wait for all tasks to complete
263        let mut handles = self.check_handles.write().await;
264        for handle in handles.drain(..) {
265            let _ = handle.await;
266        }
267    }
268
269    /// Get health status for a target
270    pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
271        self.health_status.read().await.get(target).cloned()
272    }
273
274    /// Get all health statuses
275    pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
276        self.health_status.read().await.clone()
277    }
278
279    /// Check if target is healthy
280    pub async fn is_healthy(&self, target: &str) -> bool {
281        self.health_status
282            .read()
283            .await
284            .get(target)
285            .map(|s| s.healthy)
286            .unwrap_or(false)
287    }
288
289    /// Get healthy targets
290    pub async fn get_healthy_targets(&self) -> Vec<String> {
291        self.health_status
292            .read()
293            .await
294            .iter()
295            .filter_map(|(target, status)| {
296                if status.healthy {
297                    Some(target.clone())
298                } else {
299                    None
300                }
301            })
302            .collect()
303    }
304
305    /// Mark target as unhealthy (for passive health checking)
306    pub async fn mark_unhealthy(&self, target: &str, reason: String) {
307        if let Some(status) = self.health_status.write().await.get_mut(target) {
308            if status.healthy {
309                status.healthy = false;
310                status.consecutive_failures = self.config.unhealthy_threshold;
311                status.consecutive_successes = 0;
312                status.last_error = Some(reason.clone());
313                warn!(
314                    target = %target,
315                    reason = %reason,
316                    "Target marked unhealthy by passive check"
317                );
318            }
319        }
320    }
321}
322
323impl TargetHealthInfo {
324    /// Create new health status (initially healthy)
325    pub fn new() -> Self {
326        Self {
327            healthy: true,
328            consecutive_successes: 0,
329            consecutive_failures: 0,
330            last_check: Instant::now(),
331            last_success: Some(Instant::now()),
332            last_error: None,
333            total_checks: 0,
334            total_successes: 0,
335            avg_response_time: 0.0,
336        }
337    }
338
339    /// Get health score (0.0 - 1.0)
340    pub fn health_score(&self) -> f64 {
341        if self.total_checks == 0 {
342            return 1.0;
343        }
344        self.total_successes as f64 / self.total_checks as f64
345    }
346
347    /// Check if status is degraded (healthy but with recent failures)
348    pub fn is_degraded(&self) -> bool {
349        self.healthy && self.consecutive_failures > 0
350    }
351}
352
353#[async_trait]
354impl HealthCheckImpl for HttpHealthCheck {
355    async fn check(&self, target: &str) -> Result<Duration, String> {
356        let start = Instant::now();
357
358        // Parse target address
359        let addr: SocketAddr = target
360            .parse()
361            .map_err(|e| format!("Invalid address: {}", e))?;
362
363        // Connect with timeout
364        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
365            .await
366            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
367            .map_err(|e| format!("Connection failed: {}", e))?;
368
369        // Build HTTP request
370        let host = self.host.as_deref().unwrap_or(target);
371        let request = format!(
372            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
373            self.path,
374            host
375        );
376
377        // Send request and read response
378        let mut stream = stream;
379        stream
380            .write_all(request.as_bytes())
381            .await
382            .map_err(|e| format!("Failed to send request: {}", e))?;
383
384        let mut response = vec![0u8; 1024];
385        let n = stream
386            .read(&mut response)
387            .await
388            .map_err(|e| format!("Failed to read response: {}", e))?;
389
390        if n == 0 {
391            return Err("Empty response".to_string());
392        }
393
394        // Parse status code
395        let response_str = String::from_utf8_lossy(&response[..n]);
396        let status_code = parse_http_status(&response_str)
397            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
398
399        if status_code == self.expected_status {
400            Ok(start.elapsed())
401        } else {
402            Err(format!(
403                "Unexpected status code: {} (expected {})",
404                status_code, self.expected_status
405            ))
406        }
407    }
408
409    fn check_type(&self) -> &str {
410        "HTTP"
411    }
412}
413
414#[async_trait]
415impl HealthCheckImpl for TcpHealthCheck {
416    async fn check(&self, target: &str) -> Result<Duration, String> {
417        let start = Instant::now();
418
419        // Parse target address
420        let addr: SocketAddr = target
421            .parse()
422            .map_err(|e| format!("Invalid address: {}", e))?;
423
424        // Connect with timeout
425        time::timeout(self.timeout, TcpStream::connect(addr))
426            .await
427            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
428            .map_err(|e| format!("Connection failed: {}", e))?;
429
430        Ok(start.elapsed())
431    }
432
433    fn check_type(&self) -> &str {
434        "TCP"
435    }
436}
437
438#[async_trait]
439impl HealthCheckImpl for GrpcHealthCheck {
440    async fn check(&self, target: &str) -> Result<Duration, String> {
441        let start = Instant::now();
442
443        // NOTE: Full gRPC health check requires `tonic` crate for HTTP/2 support.
444        // This implementation uses TCP connectivity as a reasonable fallback.
445        // The gRPC health checking protocol (grpc.health.v1.Health/Check) would
446        // return SERVING, NOT_SERVING, or UNKNOWN for the specified service.
447
448        let addr: SocketAddr = target
449            .parse()
450            .map_err(|e| format!("Invalid address: {}", e))?;
451
452        // TCP connectivity check as fallback
453        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
454            .await
455            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
456            .map_err(|e| format!("Connection failed: {}", e))?;
457
458        // Verify connection is writable (basic health indicator)
459        stream
460            .writable()
461            .await
462            .map_err(|e| format!("Connection not writable: {}", e))?;
463
464        debug!(
465            target = %target,
466            service = %self.service,
467            "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
468        );
469
470        Ok(start.elapsed())
471    }
472
473    fn check_type(&self) -> &str {
474        "gRPC"
475    }
476}
477
478/// Parse HTTP status code from response
479fn parse_http_status(response: &str) -> Option<u16> {
480    response
481        .lines()
482        .next()?
483        .split_whitespace()
484        .nth(1)?
485        .parse()
486        .ok()
487}
488
489/// Passive health checker that monitors request outcomes
490///
491/// Observes request success/failure rates to detect unhealthy targets
492/// without performing explicit health checks. Works in combination with
493/// `ActiveHealthChecker` for comprehensive health monitoring.
494pub struct PassiveHealthChecker {
495    /// Failure rate threshold (0.0 - 1.0)
496    failure_rate_threshold: f64,
497    /// Window size for calculating failure rate
498    window_size: usize,
499    /// Request outcomes per target (ring buffer)
500    outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
501    /// Active health checker reference
502    active_checker: Option<Arc<ActiveHealthChecker>>,
503}
504
505impl PassiveHealthChecker {
506    /// Create new passive health checker
507    pub fn new(
508        failure_rate_threshold: f64,
509        window_size: usize,
510        active_checker: Option<Arc<ActiveHealthChecker>>,
511    ) -> Self {
512        Self {
513            failure_rate_threshold,
514            window_size,
515            outcomes: Arc::new(RwLock::new(HashMap::new())),
516            active_checker,
517        }
518    }
519
520    /// Record request outcome
521    pub async fn record_outcome(&self, target: &str, success: bool) {
522        let mut outcomes = self.outcomes.write().await;
523        let target_outcomes = outcomes
524            .entry(target.to_string())
525            .or_insert_with(|| Vec::with_capacity(self.window_size));
526
527        // Add outcome to ring buffer
528        if target_outcomes.len() >= self.window_size {
529            target_outcomes.remove(0);
530        }
531        target_outcomes.push(success);
532
533        // Calculate failure rate
534        let failures = target_outcomes.iter().filter(|&&s| !s).count();
535        let failure_rate = failures as f64 / target_outcomes.len() as f64;
536
537        // Mark unhealthy if failure rate exceeds threshold
538        if failure_rate > self.failure_rate_threshold {
539            if let Some(ref checker) = self.active_checker {
540                checker
541                    .mark_unhealthy(
542                        target,
543                        format!(
544                            "Failure rate {:.2}% exceeds threshold",
545                            failure_rate * 100.0
546                        ),
547                    )
548                    .await;
549            }
550        }
551    }
552
553    /// Get failure rate for a target
554    pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
555        let outcomes = self.outcomes.read().await;
556        outcomes.get(target).map(|target_outcomes| {
557            let failures = target_outcomes.iter().filter(|&&s| !s).count();
558            failures as f64 / target_outcomes.len() as f64
559        })
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use sentinel_config::HealthCheck as HealthCheckConfig;
567
568    #[tokio::test]
569    async fn test_health_status() {
570        let status = TargetHealthInfo::new();
571        assert!(status.healthy);
572        assert_eq!(status.health_score(), 1.0);
573        assert!(!status.is_degraded());
574    }
575
576    #[tokio::test]
577    async fn test_passive_health_checker() {
578        let checker = PassiveHealthChecker::new(0.5, 10, None);
579
580        // Record some outcomes
581        for _ in 0..5 {
582            checker.record_outcome("target1", true).await;
583        }
584        for _ in 0..3 {
585            checker.record_outcome("target1", false).await;
586        }
587
588        let failure_rate = checker.get_failure_rate("target1").await.unwrap();
589        assert!(failure_rate > 0.3 && failure_rate < 0.4);
590    }
591
592    #[test]
593    fn test_parse_http_status() {
594        let response = "HTTP/1.1 200 OK\r\n";
595        assert_eq!(parse_http_status(response), Some(200));
596
597        let response = "HTTP/1.1 404 Not Found\r\n";
598        assert_eq!(parse_http_status(response), Some(404));
599
600        let response = "Invalid response";
601        assert_eq!(parse_http_status(response), None);
602    }
603}