1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, trace, warn};
16
17use sentinel_common::{errors::SentinelResult, types::HealthCheckType};
18use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
19
20pub struct ActiveHealthChecker {
25 config: HealthCheckConfig,
27 checker: Arc<dyn HealthCheckImpl>,
29 health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
31 check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
33 shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct TargetHealthInfo {
40 pub healthy: bool,
42 pub consecutive_successes: u32,
44 pub consecutive_failures: u32,
46 pub last_check: Instant,
48 pub last_success: Option<Instant>,
50 pub last_error: Option<String>,
52 pub total_checks: u64,
54 pub total_successes: u64,
56 pub avg_response_time: f64,
58}
59
60#[async_trait]
62trait HealthCheckImpl: Send + Sync {
63 async fn check(&self, target: &str) -> Result<Duration, String>;
65
66 fn check_type(&self) -> &str;
68}
69
70struct HttpHealthCheck {
72 path: String,
73 expected_status: u16,
74 host: Option<String>,
75 timeout: Duration,
76}
77
78struct TcpHealthCheck {
80 timeout: Duration,
81}
82
83struct GrpcHealthCheck {
96 service: String,
97 timeout: Duration,
98}
99
100impl ActiveHealthChecker {
101 pub fn new(config: HealthCheckConfig) -> Self {
103 debug!(
104 check_type = ?config.check_type,
105 interval_secs = config.interval_secs,
106 timeout_secs = config.timeout_secs,
107 healthy_threshold = config.healthy_threshold,
108 unhealthy_threshold = config.unhealthy_threshold,
109 "Creating active health checker"
110 );
111
112 let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
113 HealthCheckType::Http {
114 path,
115 expected_status,
116 host,
117 } => {
118 trace!(
119 path = %path,
120 expected_status = expected_status,
121 host = host.as_deref().unwrap_or("(default)"),
122 "Configuring HTTP health check"
123 );
124 Arc::new(HttpHealthCheck {
125 path: path.clone(),
126 expected_status: *expected_status,
127 host: host.clone(),
128 timeout: Duration::from_secs(config.timeout_secs),
129 })
130 }
131 HealthCheckType::Tcp => {
132 trace!("Configuring TCP health check");
133 Arc::new(TcpHealthCheck {
134 timeout: Duration::from_secs(config.timeout_secs),
135 })
136 }
137 HealthCheckType::Grpc { service } => {
138 trace!(
139 service = %service,
140 "Configuring gRPC health check"
141 );
142 Arc::new(GrpcHealthCheck {
143 service: service.clone(),
144 timeout: Duration::from_secs(config.timeout_secs),
145 })
146 }
147 };
148
149 let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
150
151 Self {
152 config,
153 checker,
154 health_status: Arc::new(RwLock::new(HashMap::new())),
155 check_handles: Arc::new(RwLock::new(Vec::new())),
156 shutdown_tx: Arc::new(shutdown_tx),
157 }
158 }
159
160 pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
162 info!(
163 target_count = targets.len(),
164 interval_secs = self.config.interval_secs,
165 check_type = self.checker.check_type(),
166 "Starting health checking"
167 );
168
169 let mut handles = self.check_handles.write().await;
170
171 for target in targets {
172 let address = target.address.clone();
173
174 trace!(
175 target = %address,
176 "Initializing health status for target"
177 );
178
179 self.health_status
181 .write()
182 .await
183 .insert(address.clone(), TargetHealthInfo::new());
184
185 debug!(
187 target = %address,
188 "Spawning health check task"
189 );
190 let handle = self.spawn_check_task(address);
191 handles.push(handle);
192 }
193
194 info!(
195 target_count = targets.len(),
196 interval_secs = self.config.interval_secs,
197 healthy_threshold = self.config.healthy_threshold,
198 unhealthy_threshold = self.config.unhealthy_threshold,
199 "Health checking started successfully"
200 );
201
202 Ok(())
203 }
204
205 fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
207 let interval = Duration::from_secs(self.config.interval_secs);
208 let checker = Arc::clone(&self.checker);
209 let health_status = Arc::clone(&self.health_status);
210 let healthy_threshold = self.config.healthy_threshold;
211 let unhealthy_threshold = self.config.unhealthy_threshold;
212 let check_type = self.checker.check_type().to_string();
213 let mut shutdown_rx = self.shutdown_tx.subscribe();
214
215 tokio::spawn(async move {
216 let mut interval_timer = time::interval(interval);
217 interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
218
219 trace!(
220 target = %target,
221 check_type = %check_type,
222 interval_ms = interval.as_millis(),
223 "Health check task started"
224 );
225
226 loop {
227 tokio::select! {
228 _ = interval_timer.tick() => {
229 trace!(
231 target = %target,
232 check_type = %check_type,
233 "Performing health check"
234 );
235 let start = Instant::now();
236 let result = checker.check(&target).await;
237 let check_duration = start.elapsed();
238
239 let mut status_map = health_status.write().await;
241 if let Some(status) = status_map.get_mut(&target) {
242 status.last_check = Instant::now();
243 status.total_checks += 1;
244
245 match result {
246 Ok(response_time) => {
247 status.consecutive_successes += 1;
248 status.consecutive_failures = 0;
249 status.last_success = Some(Instant::now());
250 status.last_error = None;
251 status.total_successes += 1;
252
253 let response_ms = response_time.as_millis() as f64;
255 status.avg_response_time =
256 (status.avg_response_time * (status.total_successes - 1) as f64
257 + response_ms) / status.total_successes as f64;
258
259 if !status.healthy && status.consecutive_successes >= healthy_threshold {
261 status.healthy = true;
262 info!(
263 target = %target,
264 consecutive_successes = status.consecutive_successes,
265 avg_response_ms = format!("{:.2}", status.avg_response_time),
266 total_checks = status.total_checks,
267 "Target marked as healthy"
268 );
269 }
270
271 trace!(
272 target = %target,
273 response_time_ms = response_ms,
274 check_duration_ms = check_duration.as_millis(),
275 consecutive_successes = status.consecutive_successes,
276 health_score = format!("{:.2}", status.health_score()),
277 "Health check succeeded"
278 );
279 }
280 Err(error) => {
281 status.consecutive_failures += 1;
282 status.consecutive_successes = 0;
283 status.last_error = Some(error.clone());
284
285 if status.healthy && status.consecutive_failures >= unhealthy_threshold {
287 status.healthy = false;
288 warn!(
289 target = %target,
290 consecutive_failures = status.consecutive_failures,
291 error = %error,
292 total_checks = status.total_checks,
293 health_score = format!("{:.2}", status.health_score()),
294 "Target marked as unhealthy"
295 );
296 } else {
297 debug!(
298 target = %target,
299 error = %error,
300 consecutive_failures = status.consecutive_failures,
301 unhealthy_threshold = unhealthy_threshold,
302 "Health check failed"
303 );
304 }
305 }
306 }
307 }
308 }
309 _ = shutdown_rx.recv() => {
310 info!(target = %target, "Stopping health check task");
311 break;
312 }
313 }
314 }
315
316 debug!(target = %target, "Health check task stopped");
317 })
318 }
319
320 pub async fn stop(&self) {
322 let task_count = self.check_handles.read().await.len();
323 info!(task_count = task_count, "Stopping health checker");
324
325 let _ = self.shutdown_tx.send(());
327
328 let mut handles = self.check_handles.write().await;
330 for handle in handles.drain(..) {
331 let _ = handle.await;
332 }
333
334 info!("Health checker stopped successfully");
335 }
336
337 pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
339 self.health_status.read().await.get(target).cloned()
340 }
341
342 pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
344 self.health_status.read().await.clone()
345 }
346
347 pub async fn is_healthy(&self, target: &str) -> bool {
349 self.health_status
350 .read()
351 .await
352 .get(target)
353 .map(|s| s.healthy)
354 .unwrap_or(false)
355 }
356
357 pub async fn get_healthy_targets(&self) -> Vec<String> {
359 self.health_status
360 .read()
361 .await
362 .iter()
363 .filter_map(|(target, status)| {
364 if status.healthy {
365 Some(target.clone())
366 } else {
367 None
368 }
369 })
370 .collect()
371 }
372
373 pub async fn mark_unhealthy(&self, target: &str, reason: String) {
375 if let Some(status) = self.health_status.write().await.get_mut(target) {
376 if status.healthy {
377 status.healthy = false;
378 status.consecutive_failures = self.config.unhealthy_threshold;
379 status.consecutive_successes = 0;
380 status.last_error = Some(reason.clone());
381 warn!(
382 target = %target,
383 reason = %reason,
384 "Target marked unhealthy by passive check"
385 );
386 }
387 }
388 }
389}
390
391impl Default for TargetHealthInfo {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397impl TargetHealthInfo {
398 pub fn new() -> Self {
400 Self {
401 healthy: true,
402 consecutive_successes: 0,
403 consecutive_failures: 0,
404 last_check: Instant::now(),
405 last_success: Some(Instant::now()),
406 last_error: None,
407 total_checks: 0,
408 total_successes: 0,
409 avg_response_time: 0.0,
410 }
411 }
412
413 pub fn health_score(&self) -> f64 {
415 if self.total_checks == 0 {
416 return 1.0;
417 }
418 self.total_successes as f64 / self.total_checks as f64
419 }
420
421 pub fn is_degraded(&self) -> bool {
423 self.healthy && self.consecutive_failures > 0
424 }
425}
426
427#[async_trait]
428impl HealthCheckImpl for HttpHealthCheck {
429 async fn check(&self, target: &str) -> Result<Duration, String> {
430 let start = Instant::now();
431
432 let addr: SocketAddr = target
434 .parse()
435 .map_err(|e| format!("Invalid address: {}", e))?;
436
437 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
439 .await
440 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
441 .map_err(|e| format!("Connection failed: {}", e))?;
442
443 let host = self.host.as_deref().unwrap_or(target);
445 let request = format!(
446 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
447 self.path,
448 host
449 );
450
451 let mut stream = stream;
453 stream
454 .write_all(request.as_bytes())
455 .await
456 .map_err(|e| format!("Failed to send request: {}", e))?;
457
458 let mut response = vec![0u8; 1024];
459 let n = stream
460 .read(&mut response)
461 .await
462 .map_err(|e| format!("Failed to read response: {}", e))?;
463
464 if n == 0 {
465 return Err("Empty response".to_string());
466 }
467
468 let response_str = String::from_utf8_lossy(&response[..n]);
470 let status_code = parse_http_status(&response_str)
471 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
472
473 if status_code == self.expected_status {
474 Ok(start.elapsed())
475 } else {
476 Err(format!(
477 "Unexpected status code: {} (expected {})",
478 status_code, self.expected_status
479 ))
480 }
481 }
482
483 fn check_type(&self) -> &str {
484 "HTTP"
485 }
486}
487
488#[async_trait]
489impl HealthCheckImpl for TcpHealthCheck {
490 async fn check(&self, target: &str) -> Result<Duration, String> {
491 let start = Instant::now();
492
493 let addr: SocketAddr = target
495 .parse()
496 .map_err(|e| format!("Invalid address: {}", e))?;
497
498 time::timeout(self.timeout, TcpStream::connect(addr))
500 .await
501 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
502 .map_err(|e| format!("Connection failed: {}", e))?;
503
504 Ok(start.elapsed())
505 }
506
507 fn check_type(&self) -> &str {
508 "TCP"
509 }
510}
511
512#[async_trait]
513impl HealthCheckImpl for GrpcHealthCheck {
514 async fn check(&self, target: &str) -> Result<Duration, String> {
515 let start = Instant::now();
516
517 let addr: SocketAddr = target
523 .parse()
524 .map_err(|e| format!("Invalid address: {}", e))?;
525
526 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
528 .await
529 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
530 .map_err(|e| format!("Connection failed: {}", e))?;
531
532 stream
534 .writable()
535 .await
536 .map_err(|e| format!("Connection not writable: {}", e))?;
537
538 debug!(
539 target = %target,
540 service = %self.service,
541 "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
542 );
543
544 Ok(start.elapsed())
545 }
546
547 fn check_type(&self) -> &str {
548 "gRPC"
549 }
550}
551
552fn parse_http_status(response: &str) -> Option<u16> {
554 response
555 .lines()
556 .next()?
557 .split_whitespace()
558 .nth(1)?
559 .parse()
560 .ok()
561}
562
563pub struct PassiveHealthChecker {
569 failure_rate_threshold: f64,
571 window_size: usize,
573 outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
575 active_checker: Option<Arc<ActiveHealthChecker>>,
577}
578
579impl PassiveHealthChecker {
580 pub fn new(
582 failure_rate_threshold: f64,
583 window_size: usize,
584 active_checker: Option<Arc<ActiveHealthChecker>>,
585 ) -> Self {
586 debug!(
587 failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
588 window_size = window_size,
589 has_active_checker = active_checker.is_some(),
590 "Creating passive health checker"
591 );
592 Self {
593 failure_rate_threshold,
594 window_size,
595 outcomes: Arc::new(RwLock::new(HashMap::new())),
596 active_checker,
597 }
598 }
599
600 pub async fn record_outcome(&self, target: &str, success: bool) {
602 trace!(
603 target = %target,
604 success = success,
605 "Recording request outcome"
606 );
607
608 let mut outcomes = self.outcomes.write().await;
609 let target_outcomes = outcomes
610 .entry(target.to_string())
611 .or_insert_with(|| Vec::with_capacity(self.window_size));
612
613 if target_outcomes.len() >= self.window_size {
615 target_outcomes.remove(0);
616 }
617 target_outcomes.push(success);
618
619 let failures = target_outcomes.iter().filter(|&&s| !s).count();
621 let failure_rate = failures as f64 / target_outcomes.len() as f64;
622
623 trace!(
624 target = %target,
625 failure_rate = format!("{:.2}", failure_rate),
626 window_samples = target_outcomes.len(),
627 failures = failures,
628 "Updated failure rate"
629 );
630
631 if failure_rate > self.failure_rate_threshold {
633 warn!(
634 target = %target,
635 failure_rate = format!("{:.2}", failure_rate * 100.0),
636 threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
637 window_samples = target_outcomes.len(),
638 "Failure rate exceeds threshold"
639 );
640 if let Some(ref checker) = self.active_checker {
641 checker
642 .mark_unhealthy(
643 target,
644 format!(
645 "Failure rate {:.2}% exceeds threshold",
646 failure_rate * 100.0
647 ),
648 )
649 .await;
650 }
651 }
652 }
653
654 pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
656 let outcomes = self.outcomes.read().await;
657 outcomes.get(target).map(|target_outcomes| {
658 let failures = target_outcomes.iter().filter(|&&s| !s).count();
659 failures as f64 / target_outcomes.len() as f64
660 })
661 }
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 #[tokio::test]
669 async fn test_health_status() {
670 let status = TargetHealthInfo::new();
671 assert!(status.healthy);
672 assert_eq!(status.health_score(), 1.0);
673 assert!(!status.is_degraded());
674 }
675
676 #[tokio::test]
677 async fn test_passive_health_checker() {
678 let checker = PassiveHealthChecker::new(0.5, 10, None);
679
680 for _ in 0..5 {
682 checker.record_outcome("target1", true).await;
683 }
684 for _ in 0..3 {
685 checker.record_outcome("target1", false).await;
686 }
687
688 let failure_rate = checker.get_failure_rate("target1").await.unwrap();
689 assert!(failure_rate > 0.3 && failure_rate < 0.4);
690 }
691
692 #[test]
693 fn test_parse_http_status() {
694 let response = "HTTP/1.1 200 OK\r\n";
695 assert_eq!(parse_http_status(response), Some(200));
696
697 let response = "HTTP/1.1 404 Not Found\r\n";
698 assert_eq!(parse_http_status(response), Some(404));
699
700 let response = "Invalid response";
701 assert_eq!(parse_http_status(response), None);
702 }
703}