1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, warn};
16
17use sentinel_common::{
18 errors::SentinelResult,
19 types::HealthCheckType,
20};
21use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
22
23pub struct ActiveHealthChecker {
28 config: HealthCheckConfig,
30 checker: Arc<dyn HealthCheckImpl>,
32 health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
34 check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
36 shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
38}
39
40#[derive(Debug, Clone)]
42pub struct TargetHealthInfo {
43 pub healthy: bool,
45 pub consecutive_successes: u32,
47 pub consecutive_failures: u32,
49 pub last_check: Instant,
51 pub last_success: Option<Instant>,
53 pub last_error: Option<String>,
55 pub total_checks: u64,
57 pub total_successes: u64,
59 pub avg_response_time: f64,
61}
62
63#[async_trait]
65trait HealthCheckImpl: Send + Sync {
66 async fn check(&self, target: &str) -> Result<Duration, String>;
68
69 fn check_type(&self) -> &str;
71}
72
73struct HttpHealthCheck {
75 path: String,
76 expected_status: u16,
77 host: Option<String>,
78 timeout: Duration,
79}
80
81struct TcpHealthCheck {
83 timeout: Duration,
84}
85
86struct GrpcHealthCheck {
99 service: String,
100 timeout: Duration,
101}
102
103impl ActiveHealthChecker {
104 pub fn new(config: HealthCheckConfig) -> Self {
106 let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
107 HealthCheckType::Http {
108 path,
109 expected_status,
110 host,
111 } => Arc::new(HttpHealthCheck {
112 path: path.clone(),
113 expected_status: *expected_status,
114 host: host.clone(),
115 timeout: Duration::from_secs(config.timeout_secs),
116 }),
117 HealthCheckType::Tcp => Arc::new(TcpHealthCheck {
118 timeout: Duration::from_secs(config.timeout_secs),
119 }),
120 HealthCheckType::Grpc { service } => Arc::new(GrpcHealthCheck {
121 service: service.clone(),
122 timeout: Duration::from_secs(config.timeout_secs),
123 }),
124 };
125
126 let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
127
128 Self {
129 config,
130 checker,
131 health_status: Arc::new(RwLock::new(HashMap::new())),
132 check_handles: Arc::new(RwLock::new(Vec::new())),
133 shutdown_tx: Arc::new(shutdown_tx),
134 }
135 }
136
137 pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
139 let mut handles = self.check_handles.write().await;
140
141 for target in targets {
142 let address = target.address.clone();
143
144 self.health_status
146 .write()
147 .await
148 .insert(address.clone(), TargetHealthInfo::new());
149
150 let handle = self.spawn_check_task(address);
152 handles.push(handle);
153 }
154
155 info!(
156 "Started health checking for {} targets, interval: {}s",
157 targets.len(),
158 self.config.interval_secs
159 );
160
161 Ok(())
162 }
163
164 fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
166 let interval = Duration::from_secs(self.config.interval_secs);
167 let checker = Arc::clone(&self.checker);
168 let health_status = Arc::clone(&self.health_status);
169 let healthy_threshold = self.config.healthy_threshold;
170 let unhealthy_threshold = self.config.unhealthy_threshold;
171 let mut shutdown_rx = self.shutdown_tx.subscribe();
172
173 tokio::spawn(async move {
174 let mut interval_timer = time::interval(interval);
175 interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
176
177 loop {
178 tokio::select! {
179 _ = interval_timer.tick() => {
180 let start = Instant::now();
182 let result = checker.check(&target).await;
183 let _duration = start.elapsed();
184
185 let mut status_map = health_status.write().await;
187 if let Some(status) = status_map.get_mut(&target) {
188 status.last_check = Instant::now();
189 status.total_checks += 1;
190
191 match result {
192 Ok(response_time) => {
193 status.consecutive_successes += 1;
194 status.consecutive_failures = 0;
195 status.last_success = Some(Instant::now());
196 status.last_error = None;
197 status.total_successes += 1;
198
199 let response_ms = response_time.as_millis() as f64;
201 status.avg_response_time =
202 (status.avg_response_time * (status.total_successes - 1) as f64
203 + response_ms) / status.total_successes as f64;
204
205 if !status.healthy && status.consecutive_successes >= healthy_threshold {
207 status.healthy = true;
208 info!(
209 target = %target,
210 consecutive_successes = status.consecutive_successes,
211 "Target marked as healthy"
212 );
213 }
214
215 debug!(
216 target = %target,
217 response_time_ms = response_ms,
218 "Health check succeeded"
219 );
220 }
221 Err(error) => {
222 status.consecutive_failures += 1;
223 status.consecutive_successes = 0;
224 status.last_error = Some(error.clone());
225
226 if status.healthy && status.consecutive_failures >= unhealthy_threshold {
228 status.healthy = false;
229 warn!(
230 target = %target,
231 consecutive_failures = status.consecutive_failures,
232 error = %error,
233 "Target marked as unhealthy"
234 );
235 }
236
237 debug!(
238 target = %target,
239 error = %error,
240 "Health check failed"
241 );
242 }
243 }
244 }
245 }
246 _ = shutdown_rx.recv() => {
247 info!(target = %target, "Stopping health check task");
248 break;
249 }
250 }
251 }
252 })
253 }
254
255 pub async fn stop(&self) {
257 info!("Stopping health checker");
258
259 let _ = self.shutdown_tx.send(());
261
262 let mut handles = self.check_handles.write().await;
264 for handle in handles.drain(..) {
265 let _ = handle.await;
266 }
267 }
268
269 pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
271 self.health_status.read().await.get(target).cloned()
272 }
273
274 pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
276 self.health_status.read().await.clone()
277 }
278
279 pub async fn is_healthy(&self, target: &str) -> bool {
281 self.health_status
282 .read()
283 .await
284 .get(target)
285 .map(|s| s.healthy)
286 .unwrap_or(false)
287 }
288
289 pub async fn get_healthy_targets(&self) -> Vec<String> {
291 self.health_status
292 .read()
293 .await
294 .iter()
295 .filter_map(|(target, status)| {
296 if status.healthy {
297 Some(target.clone())
298 } else {
299 None
300 }
301 })
302 .collect()
303 }
304
305 pub async fn mark_unhealthy(&self, target: &str, reason: String) {
307 if let Some(status) = self.health_status.write().await.get_mut(target) {
308 if status.healthy {
309 status.healthy = false;
310 status.consecutive_failures = self.config.unhealthy_threshold;
311 status.consecutive_successes = 0;
312 status.last_error = Some(reason.clone());
313 warn!(
314 target = %target,
315 reason = %reason,
316 "Target marked unhealthy by passive check"
317 );
318 }
319 }
320 }
321}
322
323impl TargetHealthInfo {
324 pub fn new() -> Self {
326 Self {
327 healthy: true,
328 consecutive_successes: 0,
329 consecutive_failures: 0,
330 last_check: Instant::now(),
331 last_success: Some(Instant::now()),
332 last_error: None,
333 total_checks: 0,
334 total_successes: 0,
335 avg_response_time: 0.0,
336 }
337 }
338
339 pub fn health_score(&self) -> f64 {
341 if self.total_checks == 0 {
342 return 1.0;
343 }
344 self.total_successes as f64 / self.total_checks as f64
345 }
346
347 pub fn is_degraded(&self) -> bool {
349 self.healthy && self.consecutive_failures > 0
350 }
351}
352
353#[async_trait]
354impl HealthCheckImpl for HttpHealthCheck {
355 async fn check(&self, target: &str) -> Result<Duration, String> {
356 let start = Instant::now();
357
358 let addr: SocketAddr = target
360 .parse()
361 .map_err(|e| format!("Invalid address: {}", e))?;
362
363 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
365 .await
366 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
367 .map_err(|e| format!("Connection failed: {}", e))?;
368
369 let host = self.host.as_deref().unwrap_or(target);
371 let request = format!(
372 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
373 self.path,
374 host
375 );
376
377 let mut stream = stream;
379 stream
380 .write_all(request.as_bytes())
381 .await
382 .map_err(|e| format!("Failed to send request: {}", e))?;
383
384 let mut response = vec![0u8; 1024];
385 let n = stream
386 .read(&mut response)
387 .await
388 .map_err(|e| format!("Failed to read response: {}", e))?;
389
390 if n == 0 {
391 return Err("Empty response".to_string());
392 }
393
394 let response_str = String::from_utf8_lossy(&response[..n]);
396 let status_code = parse_http_status(&response_str)
397 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
398
399 if status_code == self.expected_status {
400 Ok(start.elapsed())
401 } else {
402 Err(format!(
403 "Unexpected status code: {} (expected {})",
404 status_code, self.expected_status
405 ))
406 }
407 }
408
409 fn check_type(&self) -> &str {
410 "HTTP"
411 }
412}
413
414#[async_trait]
415impl HealthCheckImpl for TcpHealthCheck {
416 async fn check(&self, target: &str) -> Result<Duration, String> {
417 let start = Instant::now();
418
419 let addr: SocketAddr = target
421 .parse()
422 .map_err(|e| format!("Invalid address: {}", e))?;
423
424 time::timeout(self.timeout, TcpStream::connect(addr))
426 .await
427 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
428 .map_err(|e| format!("Connection failed: {}", e))?;
429
430 Ok(start.elapsed())
431 }
432
433 fn check_type(&self) -> &str {
434 "TCP"
435 }
436}
437
438#[async_trait]
439impl HealthCheckImpl for GrpcHealthCheck {
440 async fn check(&self, target: &str) -> Result<Duration, String> {
441 let start = Instant::now();
442
443 let addr: SocketAddr = target
449 .parse()
450 .map_err(|e| format!("Invalid address: {}", e))?;
451
452 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
454 .await
455 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
456 .map_err(|e| format!("Connection failed: {}", e))?;
457
458 stream
460 .writable()
461 .await
462 .map_err(|e| format!("Connection not writable: {}", e))?;
463
464 debug!(
465 target = %target,
466 service = %self.service,
467 "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
468 );
469
470 Ok(start.elapsed())
471 }
472
473 fn check_type(&self) -> &str {
474 "gRPC"
475 }
476}
477
478fn parse_http_status(response: &str) -> Option<u16> {
480 response
481 .lines()
482 .next()?
483 .split_whitespace()
484 .nth(1)?
485 .parse()
486 .ok()
487}
488
489pub struct PassiveHealthChecker {
495 failure_rate_threshold: f64,
497 window_size: usize,
499 outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
501 active_checker: Option<Arc<ActiveHealthChecker>>,
503}
504
505impl PassiveHealthChecker {
506 pub fn new(
508 failure_rate_threshold: f64,
509 window_size: usize,
510 active_checker: Option<Arc<ActiveHealthChecker>>,
511 ) -> Self {
512 Self {
513 failure_rate_threshold,
514 window_size,
515 outcomes: Arc::new(RwLock::new(HashMap::new())),
516 active_checker,
517 }
518 }
519
520 pub async fn record_outcome(&self, target: &str, success: bool) {
522 let mut outcomes = self.outcomes.write().await;
523 let target_outcomes = outcomes
524 .entry(target.to_string())
525 .or_insert_with(|| Vec::with_capacity(self.window_size));
526
527 if target_outcomes.len() >= self.window_size {
529 target_outcomes.remove(0);
530 }
531 target_outcomes.push(success);
532
533 let failures = target_outcomes.iter().filter(|&&s| !s).count();
535 let failure_rate = failures as f64 / target_outcomes.len() as f64;
536
537 if failure_rate > self.failure_rate_threshold {
539 if let Some(ref checker) = self.active_checker {
540 checker
541 .mark_unhealthy(
542 target,
543 format!(
544 "Failure rate {:.2}% exceeds threshold",
545 failure_rate * 100.0
546 ),
547 )
548 .await;
549 }
550 }
551 }
552
553 pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
555 let outcomes = self.outcomes.read().await;
556 outcomes.get(target).map(|target_outcomes| {
557 let failures = target_outcomes.iter().filter(|&&s| !s).count();
558 failures as f64 / target_outcomes.len() as f64
559 })
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use sentinel_config::HealthCheck as HealthCheckConfig;
567
568 #[tokio::test]
569 async fn test_health_status() {
570 let status = TargetHealthInfo::new();
571 assert!(status.healthy);
572 assert_eq!(status.health_score(), 1.0);
573 assert!(!status.is_degraded());
574 }
575
576 #[tokio::test]
577 async fn test_passive_health_checker() {
578 let checker = PassiveHealthChecker::new(0.5, 10, None);
579
580 for _ in 0..5 {
582 checker.record_outcome("target1", true).await;
583 }
584 for _ in 0..3 {
585 checker.record_outcome("target1", false).await;
586 }
587
588 let failure_rate = checker.get_failure_rate("target1").await.unwrap();
589 assert!(failure_rate > 0.3 && failure_rate < 0.4);
590 }
591
592 #[test]
593 fn test_parse_http_status() {
594 let response = "HTTP/1.1 200 OK\r\n";
595 assert_eq!(parse_http_status(response), Some(200));
596
597 let response = "HTTP/1.1 404 Not Found\r\n";
598 assert_eq!(parse_http_status(response), Some(404));
599
600 let response = "Invalid response";
601 assert_eq!(parse_http_status(response), None);
602 }
603}