1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, error, info, trace, warn};
16
17use sentinel_common::{
18 errors::SentinelResult,
19 types::HealthCheckType,
20};
21use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
22
23pub struct ActiveHealthChecker {
28 config: HealthCheckConfig,
30 checker: Arc<dyn HealthCheckImpl>,
32 health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
34 check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
36 shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
38}
39
40#[derive(Debug, Clone)]
42pub struct TargetHealthInfo {
43 pub healthy: bool,
45 pub consecutive_successes: u32,
47 pub consecutive_failures: u32,
49 pub last_check: Instant,
51 pub last_success: Option<Instant>,
53 pub last_error: Option<String>,
55 pub total_checks: u64,
57 pub total_successes: u64,
59 pub avg_response_time: f64,
61}
62
63#[async_trait]
65trait HealthCheckImpl: Send + Sync {
66 async fn check(&self, target: &str) -> Result<Duration, String>;
68
69 fn check_type(&self) -> &str;
71}
72
73struct HttpHealthCheck {
75 path: String,
76 expected_status: u16,
77 host: Option<String>,
78 timeout: Duration,
79}
80
81struct TcpHealthCheck {
83 timeout: Duration,
84}
85
86struct GrpcHealthCheck {
99 service: String,
100 timeout: Duration,
101}
102
103impl ActiveHealthChecker {
104 pub fn new(config: HealthCheckConfig) -> Self {
106 debug!(
107 check_type = ?config.check_type,
108 interval_secs = config.interval_secs,
109 timeout_secs = config.timeout_secs,
110 healthy_threshold = config.healthy_threshold,
111 unhealthy_threshold = config.unhealthy_threshold,
112 "Creating active health checker"
113 );
114
115 let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
116 HealthCheckType::Http {
117 path,
118 expected_status,
119 host,
120 } => {
121 trace!(
122 path = %path,
123 expected_status = expected_status,
124 host = host.as_deref().unwrap_or("(default)"),
125 "Configuring HTTP health check"
126 );
127 Arc::new(HttpHealthCheck {
128 path: path.clone(),
129 expected_status: *expected_status,
130 host: host.clone(),
131 timeout: Duration::from_secs(config.timeout_secs),
132 })
133 }
134 HealthCheckType::Tcp => {
135 trace!("Configuring TCP health check");
136 Arc::new(TcpHealthCheck {
137 timeout: Duration::from_secs(config.timeout_secs),
138 })
139 }
140 HealthCheckType::Grpc { service } => {
141 trace!(
142 service = %service,
143 "Configuring gRPC health check"
144 );
145 Arc::new(GrpcHealthCheck {
146 service: service.clone(),
147 timeout: Duration::from_secs(config.timeout_secs),
148 })
149 }
150 };
151
152 let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
153
154 Self {
155 config,
156 checker,
157 health_status: Arc::new(RwLock::new(HashMap::new())),
158 check_handles: Arc::new(RwLock::new(Vec::new())),
159 shutdown_tx: Arc::new(shutdown_tx),
160 }
161 }
162
163 pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
165 info!(
166 target_count = targets.len(),
167 interval_secs = self.config.interval_secs,
168 check_type = self.checker.check_type(),
169 "Starting health checking"
170 );
171
172 let mut handles = self.check_handles.write().await;
173
174 for target in targets {
175 let address = target.address.clone();
176
177 trace!(
178 target = %address,
179 "Initializing health status for target"
180 );
181
182 self.health_status
184 .write()
185 .await
186 .insert(address.clone(), TargetHealthInfo::new());
187
188 debug!(
190 target = %address,
191 "Spawning health check task"
192 );
193 let handle = self.spawn_check_task(address);
194 handles.push(handle);
195 }
196
197 info!(
198 target_count = targets.len(),
199 interval_secs = self.config.interval_secs,
200 healthy_threshold = self.config.healthy_threshold,
201 unhealthy_threshold = self.config.unhealthy_threshold,
202 "Health checking started successfully"
203 );
204
205 Ok(())
206 }
207
208 fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
210 let interval = Duration::from_secs(self.config.interval_secs);
211 let checker = Arc::clone(&self.checker);
212 let health_status = Arc::clone(&self.health_status);
213 let healthy_threshold = self.config.healthy_threshold;
214 let unhealthy_threshold = self.config.unhealthy_threshold;
215 let check_type = self.checker.check_type().to_string();
216 let mut shutdown_rx = self.shutdown_tx.subscribe();
217
218 tokio::spawn(async move {
219 let mut interval_timer = time::interval(interval);
220 interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
221
222 trace!(
223 target = %target,
224 check_type = %check_type,
225 interval_ms = interval.as_millis(),
226 "Health check task started"
227 );
228
229 loop {
230 tokio::select! {
231 _ = interval_timer.tick() => {
232 trace!(
234 target = %target,
235 check_type = %check_type,
236 "Performing health check"
237 );
238 let start = Instant::now();
239 let result = checker.check(&target).await;
240 let check_duration = start.elapsed();
241
242 let mut status_map = health_status.write().await;
244 if let Some(status) = status_map.get_mut(&target) {
245 status.last_check = Instant::now();
246 status.total_checks += 1;
247
248 match result {
249 Ok(response_time) => {
250 status.consecutive_successes += 1;
251 status.consecutive_failures = 0;
252 status.last_success = Some(Instant::now());
253 status.last_error = None;
254 status.total_successes += 1;
255
256 let response_ms = response_time.as_millis() as f64;
258 status.avg_response_time =
259 (status.avg_response_time * (status.total_successes - 1) as f64
260 + response_ms) / status.total_successes as f64;
261
262 if !status.healthy && status.consecutive_successes >= healthy_threshold {
264 status.healthy = true;
265 info!(
266 target = %target,
267 consecutive_successes = status.consecutive_successes,
268 avg_response_ms = format!("{:.2}", status.avg_response_time),
269 total_checks = status.total_checks,
270 "Target marked as healthy"
271 );
272 }
273
274 trace!(
275 target = %target,
276 response_time_ms = response_ms,
277 check_duration_ms = check_duration.as_millis(),
278 consecutive_successes = status.consecutive_successes,
279 health_score = format!("{:.2}", status.health_score()),
280 "Health check succeeded"
281 );
282 }
283 Err(error) => {
284 status.consecutive_failures += 1;
285 status.consecutive_successes = 0;
286 status.last_error = Some(error.clone());
287
288 if status.healthy && status.consecutive_failures >= unhealthy_threshold {
290 status.healthy = false;
291 warn!(
292 target = %target,
293 consecutive_failures = status.consecutive_failures,
294 error = %error,
295 total_checks = status.total_checks,
296 health_score = format!("{:.2}", status.health_score()),
297 "Target marked as unhealthy"
298 );
299 } else {
300 debug!(
301 target = %target,
302 error = %error,
303 consecutive_failures = status.consecutive_failures,
304 unhealthy_threshold = unhealthy_threshold,
305 "Health check failed"
306 );
307 }
308 }
309 }
310 }
311 }
312 _ = shutdown_rx.recv() => {
313 info!(target = %target, "Stopping health check task");
314 break;
315 }
316 }
317 }
318
319 debug!(target = %target, "Health check task stopped");
320 })
321 }
322
323 pub async fn stop(&self) {
325 let task_count = self.check_handles.read().await.len();
326 info!(
327 task_count = task_count,
328 "Stopping health checker"
329 );
330
331 let _ = self.shutdown_tx.send(());
333
334 let mut handles = self.check_handles.write().await;
336 for handle in handles.drain(..) {
337 let _ = handle.await;
338 }
339
340 info!("Health checker stopped successfully");
341 }
342
343 pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
345 self.health_status.read().await.get(target).cloned()
346 }
347
348 pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
350 self.health_status.read().await.clone()
351 }
352
353 pub async fn is_healthy(&self, target: &str) -> bool {
355 self.health_status
356 .read()
357 .await
358 .get(target)
359 .map(|s| s.healthy)
360 .unwrap_or(false)
361 }
362
363 pub async fn get_healthy_targets(&self) -> Vec<String> {
365 self.health_status
366 .read()
367 .await
368 .iter()
369 .filter_map(|(target, status)| {
370 if status.healthy {
371 Some(target.clone())
372 } else {
373 None
374 }
375 })
376 .collect()
377 }
378
379 pub async fn mark_unhealthy(&self, target: &str, reason: String) {
381 if let Some(status) = self.health_status.write().await.get_mut(target) {
382 if status.healthy {
383 status.healthy = false;
384 status.consecutive_failures = self.config.unhealthy_threshold;
385 status.consecutive_successes = 0;
386 status.last_error = Some(reason.clone());
387 warn!(
388 target = %target,
389 reason = %reason,
390 "Target marked unhealthy by passive check"
391 );
392 }
393 }
394 }
395}
396
397impl TargetHealthInfo {
398 pub fn new() -> Self {
400 Self {
401 healthy: true,
402 consecutive_successes: 0,
403 consecutive_failures: 0,
404 last_check: Instant::now(),
405 last_success: Some(Instant::now()),
406 last_error: None,
407 total_checks: 0,
408 total_successes: 0,
409 avg_response_time: 0.0,
410 }
411 }
412
413 pub fn health_score(&self) -> f64 {
415 if self.total_checks == 0 {
416 return 1.0;
417 }
418 self.total_successes as f64 / self.total_checks as f64
419 }
420
421 pub fn is_degraded(&self) -> bool {
423 self.healthy && self.consecutive_failures > 0
424 }
425}
426
427#[async_trait]
428impl HealthCheckImpl for HttpHealthCheck {
429 async fn check(&self, target: &str) -> Result<Duration, String> {
430 let start = Instant::now();
431
432 let addr: SocketAddr = target
434 .parse()
435 .map_err(|e| format!("Invalid address: {}", e))?;
436
437 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
439 .await
440 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
441 .map_err(|e| format!("Connection failed: {}", e))?;
442
443 let host = self.host.as_deref().unwrap_or(target);
445 let request = format!(
446 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
447 self.path,
448 host
449 );
450
451 let mut stream = stream;
453 stream
454 .write_all(request.as_bytes())
455 .await
456 .map_err(|e| format!("Failed to send request: {}", e))?;
457
458 let mut response = vec![0u8; 1024];
459 let n = stream
460 .read(&mut response)
461 .await
462 .map_err(|e| format!("Failed to read response: {}", e))?;
463
464 if n == 0 {
465 return Err("Empty response".to_string());
466 }
467
468 let response_str = String::from_utf8_lossy(&response[..n]);
470 let status_code = parse_http_status(&response_str)
471 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
472
473 if status_code == self.expected_status {
474 Ok(start.elapsed())
475 } else {
476 Err(format!(
477 "Unexpected status code: {} (expected {})",
478 status_code, self.expected_status
479 ))
480 }
481 }
482
483 fn check_type(&self) -> &str {
484 "HTTP"
485 }
486}
487
488#[async_trait]
489impl HealthCheckImpl for TcpHealthCheck {
490 async fn check(&self, target: &str) -> Result<Duration, String> {
491 let start = Instant::now();
492
493 let addr: SocketAddr = target
495 .parse()
496 .map_err(|e| format!("Invalid address: {}", e))?;
497
498 time::timeout(self.timeout, TcpStream::connect(addr))
500 .await
501 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
502 .map_err(|e| format!("Connection failed: {}", e))?;
503
504 Ok(start.elapsed())
505 }
506
507 fn check_type(&self) -> &str {
508 "TCP"
509 }
510}
511
512#[async_trait]
513impl HealthCheckImpl for GrpcHealthCheck {
514 async fn check(&self, target: &str) -> Result<Duration, String> {
515 let start = Instant::now();
516
517 let addr: SocketAddr = target
523 .parse()
524 .map_err(|e| format!("Invalid address: {}", e))?;
525
526 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
528 .await
529 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
530 .map_err(|e| format!("Connection failed: {}", e))?;
531
532 stream
534 .writable()
535 .await
536 .map_err(|e| format!("Connection not writable: {}", e))?;
537
538 debug!(
539 target = %target,
540 service = %self.service,
541 "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
542 );
543
544 Ok(start.elapsed())
545 }
546
547 fn check_type(&self) -> &str {
548 "gRPC"
549 }
550}
551
552fn parse_http_status(response: &str) -> Option<u16> {
554 response
555 .lines()
556 .next()?
557 .split_whitespace()
558 .nth(1)?
559 .parse()
560 .ok()
561}
562
563pub struct PassiveHealthChecker {
569 failure_rate_threshold: f64,
571 window_size: usize,
573 outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
575 active_checker: Option<Arc<ActiveHealthChecker>>,
577}
578
579impl PassiveHealthChecker {
580 pub fn new(
582 failure_rate_threshold: f64,
583 window_size: usize,
584 active_checker: Option<Arc<ActiveHealthChecker>>,
585 ) -> Self {
586 debug!(
587 failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
588 window_size = window_size,
589 has_active_checker = active_checker.is_some(),
590 "Creating passive health checker"
591 );
592 Self {
593 failure_rate_threshold,
594 window_size,
595 outcomes: Arc::new(RwLock::new(HashMap::new())),
596 active_checker,
597 }
598 }
599
600 pub async fn record_outcome(&self, target: &str, success: bool) {
602 trace!(
603 target = %target,
604 success = success,
605 "Recording request outcome"
606 );
607
608 let mut outcomes = self.outcomes.write().await;
609 let target_outcomes = outcomes
610 .entry(target.to_string())
611 .or_insert_with(|| Vec::with_capacity(self.window_size));
612
613 if target_outcomes.len() >= self.window_size {
615 target_outcomes.remove(0);
616 }
617 target_outcomes.push(success);
618
619 let failures = target_outcomes.iter().filter(|&&s| !s).count();
621 let failure_rate = failures as f64 / target_outcomes.len() as f64;
622
623 trace!(
624 target = %target,
625 failure_rate = format!("{:.2}", failure_rate),
626 window_samples = target_outcomes.len(),
627 failures = failures,
628 "Updated failure rate"
629 );
630
631 if failure_rate > self.failure_rate_threshold {
633 warn!(
634 target = %target,
635 failure_rate = format!("{:.2}", failure_rate * 100.0),
636 threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
637 window_samples = target_outcomes.len(),
638 "Failure rate exceeds threshold"
639 );
640 if let Some(ref checker) = self.active_checker {
641 checker
642 .mark_unhealthy(
643 target,
644 format!(
645 "Failure rate {:.2}% exceeds threshold",
646 failure_rate * 100.0
647 ),
648 )
649 .await;
650 }
651 }
652 }
653
654 pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
656 let outcomes = self.outcomes.read().await;
657 outcomes.get(target).map(|target_outcomes| {
658 let failures = target_outcomes.iter().filter(|&&s| !s).count();
659 failures as f64 / target_outcomes.len() as f64
660 })
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667 use sentinel_config::HealthCheck as HealthCheckConfig;
668
669 #[tokio::test]
670 async fn test_health_status() {
671 let status = TargetHealthInfo::new();
672 assert!(status.healthy);
673 assert_eq!(status.health_score(), 1.0);
674 assert!(!status.is_degraded());
675 }
676
677 #[tokio::test]
678 async fn test_passive_health_checker() {
679 let checker = PassiveHealthChecker::new(0.5, 10, None);
680
681 for _ in 0..5 {
683 checker.record_outcome("target1", true).await;
684 }
685 for _ in 0..3 {
686 checker.record_outcome("target1", false).await;
687 }
688
689 let failure_rate = checker.get_failure_rate("target1").await.unwrap();
690 assert!(failure_rate > 0.3 && failure_rate < 0.4);
691 }
692
693 #[test]
694 fn test_parse_http_status() {
695 let response = "HTTP/1.1 200 OK\r\n";
696 assert_eq!(parse_http_status(response), Some(200));
697
698 let response = "HTTP/1.1 404 Not Found\r\n";
699 assert_eq!(parse_http_status(response), Some(404));
700
701 let response = "Invalid response";
702 assert_eq!(parse_http_status(response), None);
703 }
704}