1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, trace, warn};
16
17use sentinel_common::{errors::SentinelResult, types::HealthCheckType};
18use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
19
20pub struct ActiveHealthChecker {
25 config: HealthCheckConfig,
27 checker: Arc<dyn HealthCheckImpl>,
29 health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
31 check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
33 shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct TargetHealthInfo {
40 pub healthy: bool,
42 pub consecutive_successes: u32,
44 pub consecutive_failures: u32,
46 pub last_check: Instant,
48 pub last_success: Option<Instant>,
50 pub last_error: Option<String>,
52 pub total_checks: u64,
54 pub total_successes: u64,
56 pub avg_response_time: f64,
58}
59
60#[async_trait]
62trait HealthCheckImpl: Send + Sync {
63 async fn check(&self, target: &str) -> Result<Duration, String>;
65
66 fn check_type(&self) -> &str;
68}
69
70struct HttpHealthCheck {
72 path: String,
73 expected_status: u16,
74 host: Option<String>,
75 timeout: Duration,
76}
77
78struct TcpHealthCheck {
80 timeout: Duration,
81}
82
83struct GrpcHealthCheck {
96 service: String,
97 timeout: Duration,
98}
99
100impl ActiveHealthChecker {
101 pub fn new(config: HealthCheckConfig) -> Self {
103 debug!(
104 check_type = ?config.check_type,
105 interval_secs = config.interval_secs,
106 timeout_secs = config.timeout_secs,
107 healthy_threshold = config.healthy_threshold,
108 unhealthy_threshold = config.unhealthy_threshold,
109 "Creating active health checker"
110 );
111
112 let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
113 HealthCheckType::Http {
114 path,
115 expected_status,
116 host,
117 } => {
118 trace!(
119 path = %path,
120 expected_status = expected_status,
121 host = host.as_deref().unwrap_or("(default)"),
122 "Configuring HTTP health check"
123 );
124 Arc::new(HttpHealthCheck {
125 path: path.clone(),
126 expected_status: *expected_status,
127 host: host.clone(),
128 timeout: Duration::from_secs(config.timeout_secs),
129 })
130 }
131 HealthCheckType::Tcp => {
132 trace!("Configuring TCP health check");
133 Arc::new(TcpHealthCheck {
134 timeout: Duration::from_secs(config.timeout_secs),
135 })
136 }
137 HealthCheckType::Grpc { service } => {
138 trace!(
139 service = %service,
140 "Configuring gRPC health check"
141 );
142 Arc::new(GrpcHealthCheck {
143 service: service.clone(),
144 timeout: Duration::from_secs(config.timeout_secs),
145 })
146 }
147 };
148
149 let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
150
151 Self {
152 config,
153 checker,
154 health_status: Arc::new(RwLock::new(HashMap::new())),
155 check_handles: Arc::new(RwLock::new(Vec::new())),
156 shutdown_tx: Arc::new(shutdown_tx),
157 }
158 }
159
160 pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
162 info!(
163 target_count = targets.len(),
164 interval_secs = self.config.interval_secs,
165 check_type = self.checker.check_type(),
166 "Starting health checking"
167 );
168
169 let mut handles = self.check_handles.write().await;
170
171 for target in targets {
172 let address = target.address.clone();
173
174 trace!(
175 target = %address,
176 "Initializing health status for target"
177 );
178
179 self.health_status
181 .write()
182 .await
183 .insert(address.clone(), TargetHealthInfo::new());
184
185 debug!(
187 target = %address,
188 "Spawning health check task"
189 );
190 let handle = self.spawn_check_task(address);
191 handles.push(handle);
192 }
193
194 info!(
195 target_count = targets.len(),
196 interval_secs = self.config.interval_secs,
197 healthy_threshold = self.config.healthy_threshold,
198 unhealthy_threshold = self.config.unhealthy_threshold,
199 "Health checking started successfully"
200 );
201
202 Ok(())
203 }
204
205 fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
207 let interval = Duration::from_secs(self.config.interval_secs);
208 let checker = Arc::clone(&self.checker);
209 let health_status = Arc::clone(&self.health_status);
210 let healthy_threshold = self.config.healthy_threshold;
211 let unhealthy_threshold = self.config.unhealthy_threshold;
212 let check_type = self.checker.check_type().to_string();
213 let mut shutdown_rx = self.shutdown_tx.subscribe();
214
215 tokio::spawn(async move {
216 let mut interval_timer = time::interval(interval);
217 interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
218
219 trace!(
220 target = %target,
221 check_type = %check_type,
222 interval_ms = interval.as_millis(),
223 "Health check task started"
224 );
225
226 loop {
227 tokio::select! {
228 _ = interval_timer.tick() => {
229 trace!(
231 target = %target,
232 check_type = %check_type,
233 "Performing health check"
234 );
235 let start = Instant::now();
236 let result = checker.check(&target).await;
237 let check_duration = start.elapsed();
238
239 let mut status_map = health_status.write().await;
241 if let Some(status) = status_map.get_mut(&target) {
242 status.last_check = Instant::now();
243 status.total_checks += 1;
244
245 match result {
246 Ok(response_time) => {
247 status.consecutive_successes += 1;
248 status.consecutive_failures = 0;
249 status.last_success = Some(Instant::now());
250 status.last_error = None;
251 status.total_successes += 1;
252
253 let response_ms = response_time.as_millis() as f64;
255 status.avg_response_time =
256 (status.avg_response_time * (status.total_successes - 1) as f64
257 + response_ms) / status.total_successes as f64;
258
259 if !status.healthy && status.consecutive_successes >= healthy_threshold {
261 status.healthy = true;
262 info!(
263 target = %target,
264 consecutive_successes = status.consecutive_successes,
265 avg_response_ms = format!("{:.2}", status.avg_response_time),
266 total_checks = status.total_checks,
267 "Target marked as healthy"
268 );
269 }
270
271 trace!(
272 target = %target,
273 response_time_ms = response_ms,
274 check_duration_ms = check_duration.as_millis(),
275 consecutive_successes = status.consecutive_successes,
276 health_score = format!("{:.2}", status.health_score()),
277 "Health check succeeded"
278 );
279 }
280 Err(error) => {
281 status.consecutive_failures += 1;
282 status.consecutive_successes = 0;
283 status.last_error = Some(error.clone());
284
285 if status.healthy && status.consecutive_failures >= unhealthy_threshold {
287 status.healthy = false;
288 warn!(
289 target = %target,
290 consecutive_failures = status.consecutive_failures,
291 error = %error,
292 total_checks = status.total_checks,
293 health_score = format!("{:.2}", status.health_score()),
294 "Target marked as unhealthy"
295 );
296 } else {
297 debug!(
298 target = %target,
299 error = %error,
300 consecutive_failures = status.consecutive_failures,
301 unhealthy_threshold = unhealthy_threshold,
302 "Health check failed"
303 );
304 }
305 }
306 }
307 }
308 }
309 _ = shutdown_rx.recv() => {
310 info!(target = %target, "Stopping health check task");
311 break;
312 }
313 }
314 }
315
316 debug!(target = %target, "Health check task stopped");
317 })
318 }
319
320 pub async fn stop(&self) {
322 let task_count = self.check_handles.read().await.len();
323 info!(task_count = task_count, "Stopping health checker");
324
325 let _ = self.shutdown_tx.send(());
327
328 let mut handles = self.check_handles.write().await;
330 for handle in handles.drain(..) {
331 let _ = handle.await;
332 }
333
334 info!("Health checker stopped successfully");
335 }
336
337 pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
339 self.health_status.read().await.get(target).cloned()
340 }
341
342 pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
344 self.health_status.read().await.clone()
345 }
346
347 pub async fn is_healthy(&self, target: &str) -> bool {
349 self.health_status
350 .read()
351 .await
352 .get(target)
353 .map(|s| s.healthy)
354 .unwrap_or(false)
355 }
356
357 pub async fn get_healthy_targets(&self) -> Vec<String> {
359 self.health_status
360 .read()
361 .await
362 .iter()
363 .filter_map(|(target, status)| {
364 if status.healthy {
365 Some(target.clone())
366 } else {
367 None
368 }
369 })
370 .collect()
371 }
372
373 pub async fn mark_unhealthy(&self, target: &str, reason: String) {
375 if let Some(status) = self.health_status.write().await.get_mut(target) {
376 if status.healthy {
377 status.healthy = false;
378 status.consecutive_failures = self.config.unhealthy_threshold;
379 status.consecutive_successes = 0;
380 status.last_error = Some(reason.clone());
381 warn!(
382 target = %target,
383 reason = %reason,
384 "Target marked unhealthy by passive check"
385 );
386 }
387 }
388 }
389}
390
391impl Default for TargetHealthInfo {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397impl TargetHealthInfo {
398 pub fn new() -> Self {
400 Self {
401 healthy: true,
402 consecutive_successes: 0,
403 consecutive_failures: 0,
404 last_check: Instant::now(),
405 last_success: Some(Instant::now()),
406 last_error: None,
407 total_checks: 0,
408 total_successes: 0,
409 avg_response_time: 0.0,
410 }
411 }
412
413 pub fn health_score(&self) -> f64 {
415 if self.total_checks == 0 {
416 return 1.0;
417 }
418 self.total_successes as f64 / self.total_checks as f64
419 }
420
421 pub fn is_degraded(&self) -> bool {
423 self.healthy && self.consecutive_failures > 0
424 }
425}
426
427#[async_trait]
428impl HealthCheckImpl for HttpHealthCheck {
429 async fn check(&self, target: &str) -> Result<Duration, String> {
430 let start = Instant::now();
431
432 let addr: SocketAddr = target
434 .parse()
435 .map_err(|e| format!("Invalid address: {}", e))?;
436
437 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
439 .await
440 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
441 .map_err(|e| format!("Connection failed: {}", e))?;
442
443 let host = self.host.as_deref().unwrap_or(target);
445 let request = format!(
446 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
447 self.path,
448 host
449 );
450
451 let mut stream = stream;
453 stream
454 .write_all(request.as_bytes())
455 .await
456 .map_err(|e| format!("Failed to send request: {}", e))?;
457
458 let mut response = vec![0u8; 1024];
459 let n = stream
460 .read(&mut response)
461 .await
462 .map_err(|e| format!("Failed to read response: {}", e))?;
463
464 if n == 0 {
465 return Err("Empty response".to_string());
466 }
467
468 let response_str = String::from_utf8_lossy(&response[..n]);
470 let status_code = parse_http_status(&response_str)
471 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
472
473 if status_code == self.expected_status {
474 Ok(start.elapsed())
475 } else {
476 Err(format!(
477 "Unexpected status code: {} (expected {})",
478 status_code, self.expected_status
479 ))
480 }
481 }
482
483 fn check_type(&self) -> &str {
484 "HTTP"
485 }
486}
487
488#[async_trait]
489impl HealthCheckImpl for TcpHealthCheck {
490 async fn check(&self, target: &str) -> Result<Duration, String> {
491 let start = Instant::now();
492
493 let addr: SocketAddr = target
495 .parse()
496 .map_err(|e| format!("Invalid address: {}", e))?;
497
498 time::timeout(self.timeout, TcpStream::connect(addr))
500 .await
501 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
502 .map_err(|e| format!("Connection failed: {}", e))?;
503
504 Ok(start.elapsed())
505 }
506
507 fn check_type(&self) -> &str {
508 "TCP"
509 }
510}
511
512#[async_trait]
513impl HealthCheckImpl for GrpcHealthCheck {
514 async fn check(&self, target: &str) -> Result<Duration, String> {
515 let start = Instant::now();
516
517 let addr: SocketAddr = target
523 .parse()
524 .map_err(|e| format!("Invalid address: {}", e))?;
525
526 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
528 .await
529 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
530 .map_err(|e| format!("Connection failed: {}", e))?;
531
532 stream
534 .writable()
535 .await
536 .map_err(|e| format!("Connection not writable: {}", e))?;
537
538 debug!(
539 target = %target,
540 service = %self.service,
541 "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
542 );
543
544 Ok(start.elapsed())
545 }
546
547 fn check_type(&self) -> &str {
548 "gRPC"
549 }
550}
551
552fn parse_http_status(response: &str) -> Option<u16> {
554 response
555 .lines()
556 .next()?
557 .split_whitespace()
558 .nth(1)?
559 .parse()
560 .ok()
561}
562
563pub struct PassiveHealthChecker {
569 failure_rate_threshold: f64,
571 window_size: usize,
573 outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
575 last_errors: Arc<RwLock<HashMap<String, String>>>,
577 active_checker: Option<Arc<ActiveHealthChecker>>,
579}
580
581impl PassiveHealthChecker {
582 pub fn new(
584 failure_rate_threshold: f64,
585 window_size: usize,
586 active_checker: Option<Arc<ActiveHealthChecker>>,
587 ) -> Self {
588 debug!(
589 failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
590 window_size = window_size,
591 has_active_checker = active_checker.is_some(),
592 "Creating passive health checker"
593 );
594 Self {
595 failure_rate_threshold,
596 window_size,
597 outcomes: Arc::new(RwLock::new(HashMap::new())),
598 last_errors: Arc::new(RwLock::new(HashMap::new())),
599 active_checker,
600 }
601 }
602
603 pub async fn record_outcome(&self, target: &str, success: bool, error: Option<&str>) {
605 trace!(
606 target = %target,
607 success = success,
608 error = ?error,
609 "Recording request outcome"
610 );
611
612 if let Some(err_msg) = error {
614 self.last_errors
615 .write()
616 .await
617 .insert(target.to_string(), err_msg.to_string());
618 } else if success {
619 self.last_errors.write().await.remove(target);
621 }
622
623 let mut outcomes = self.outcomes.write().await;
624 let target_outcomes = outcomes
625 .entry(target.to_string())
626 .or_insert_with(|| Vec::with_capacity(self.window_size));
627
628 if target_outcomes.len() >= self.window_size {
630 target_outcomes.remove(0);
631 }
632 target_outcomes.push(success);
633
634 let failures = target_outcomes.iter().filter(|&&s| !s).count();
636 let failure_rate = failures as f64 / target_outcomes.len() as f64;
637
638 trace!(
639 target = %target,
640 failure_rate = format!("{:.2}", failure_rate),
641 window_samples = target_outcomes.len(),
642 failures = failures,
643 "Updated failure rate"
644 );
645
646 if failure_rate > self.failure_rate_threshold {
648 warn!(
649 target = %target,
650 failure_rate = format!("{:.2}", failure_rate * 100.0),
651 threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
652 window_samples = target_outcomes.len(),
653 "Failure rate exceeds threshold"
654 );
655 if let Some(ref checker) = self.active_checker {
656 checker
657 .mark_unhealthy(
658 target,
659 format!(
660 "Failure rate {:.2}% exceeds threshold",
661 failure_rate * 100.0
662 ),
663 )
664 .await;
665 }
666 }
667 }
668
669 pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
671 let outcomes = self.outcomes.read().await;
672 outcomes.get(target).map(|target_outcomes| {
673 let failures = target_outcomes.iter().filter(|&&s| !s).count();
674 failures as f64 / target_outcomes.len() as f64
675 })
676 }
677
678 pub async fn get_last_error(&self, target: &str) -> Option<String> {
680 self.last_errors.read().await.get(target).cloned()
681 }
682}
683
684#[cfg(test)]
685mod tests {
686 use super::*;
687
688 #[tokio::test]
689 async fn test_health_status() {
690 let status = TargetHealthInfo::new();
691 assert!(status.healthy);
692 assert_eq!(status.health_score(), 1.0);
693 assert!(!status.is_degraded());
694 }
695
696 #[tokio::test]
697 async fn test_passive_health_checker() {
698 let checker = PassiveHealthChecker::new(0.5, 10, None);
699
700 for _ in 0..5 {
702 checker.record_outcome("target1", true, None).await;
703 }
704 for _ in 0..3 {
705 checker.record_outcome("target1", false, Some("HTTP 503")).await;
706 }
707
708 let failure_rate = checker.get_failure_rate("target1").await.unwrap();
709 assert!(failure_rate > 0.3 && failure_rate < 0.4);
710 }
711
712 #[test]
713 fn test_parse_http_status() {
714 let response = "HTTP/1.1 200 OK\r\n";
715 assert_eq!(parse_http_status(response), Some(200));
716
717 let response = "HTTP/1.1 404 Not Found\r\n";
718 assert_eq!(parse_http_status(response), Some(404));
719
720 let response = "Invalid response";
721 assert_eq!(parse_http_status(response), None);
722 }
723}