1use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, trace, warn};
16
17use sentinel_common::{errors::SentinelResult, types::HealthCheckType};
18use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
19
20pub struct ActiveHealthChecker {
25 config: HealthCheckConfig,
27 checker: Arc<dyn HealthCheckImpl>,
29 health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
31 check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
33 shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct TargetHealthInfo {
40 pub healthy: bool,
42 pub consecutive_successes: u32,
44 pub consecutive_failures: u32,
46 pub last_check: Instant,
48 pub last_success: Option<Instant>,
50 pub last_error: Option<String>,
52 pub total_checks: u64,
54 pub total_successes: u64,
56 pub avg_response_time: f64,
58}
59
60#[async_trait]
62trait HealthCheckImpl: Send + Sync {
63 async fn check(&self, target: &str) -> Result<Duration, String>;
65
66 fn check_type(&self) -> &str;
68}
69
70struct HttpHealthCheck {
72 path: String,
73 expected_status: u16,
74 host: Option<String>,
75 timeout: Duration,
76}
77
78struct TcpHealthCheck {
80 timeout: Duration,
81}
82
83struct GrpcHealthCheck {
96 service: String,
97 timeout: Duration,
98}
99
100struct InferenceHealthCheck {
111 endpoint: String,
112 expected_models: Vec<String>,
113 timeout: Duration,
114}
115
116struct InferenceProbeCheck {
120 config: sentinel_common::InferenceProbeConfig,
121 timeout: Duration,
122}
123
124struct ModelStatusCheck {
128 config: sentinel_common::ModelStatusConfig,
129 timeout: Duration,
130}
131
132struct QueueDepthCheck {
136 config: sentinel_common::QueueDepthConfig,
137 models_endpoint: String,
138 timeout: Duration,
139}
140
141struct CompositeInferenceHealthCheck {
146 base_check: InferenceHealthCheck,
147 inference_probe: Option<InferenceProbeCheck>,
148 model_status: Option<ModelStatusCheck>,
149 queue_depth: Option<QueueDepthCheck>,
150}
151
152impl ActiveHealthChecker {
153 pub fn new(config: HealthCheckConfig) -> Self {
155 debug!(
156 check_type = ?config.check_type,
157 interval_secs = config.interval_secs,
158 timeout_secs = config.timeout_secs,
159 healthy_threshold = config.healthy_threshold,
160 unhealthy_threshold = config.unhealthy_threshold,
161 "Creating active health checker"
162 );
163
164 let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
165 HealthCheckType::Http {
166 path,
167 expected_status,
168 host,
169 } => {
170 trace!(
171 path = %path,
172 expected_status = expected_status,
173 host = host.as_deref().unwrap_or("(default)"),
174 "Configuring HTTP health check"
175 );
176 Arc::new(HttpHealthCheck {
177 path: path.clone(),
178 expected_status: *expected_status,
179 host: host.clone(),
180 timeout: Duration::from_secs(config.timeout_secs),
181 })
182 }
183 HealthCheckType::Tcp => {
184 trace!("Configuring TCP health check");
185 Arc::new(TcpHealthCheck {
186 timeout: Duration::from_secs(config.timeout_secs),
187 })
188 }
189 HealthCheckType::Grpc { service } => {
190 trace!(
191 service = %service,
192 "Configuring gRPC health check"
193 );
194 Arc::new(GrpcHealthCheck {
195 service: service.clone(),
196 timeout: Duration::from_secs(config.timeout_secs),
197 })
198 }
199 HealthCheckType::Inference {
200 endpoint,
201 expected_models,
202 readiness,
203 } => {
204 trace!(
205 endpoint = %endpoint,
206 expected_models = ?expected_models,
207 has_readiness = readiness.is_some(),
208 "Configuring inference health check"
209 );
210
211 let base_timeout = Duration::from_secs(config.timeout_secs);
212 let base_check = InferenceHealthCheck {
213 endpoint: endpoint.clone(),
214 expected_models: expected_models.clone(),
215 timeout: base_timeout,
216 };
217
218 if let Some(ref readiness_config) = readiness {
219 let inference_probe = readiness_config.inference_probe.as_ref().map(|cfg| {
221 InferenceProbeCheck {
222 config: cfg.clone(),
223 timeout: Duration::from_secs(cfg.timeout_secs),
224 }
225 });
226
227 let model_status = readiness_config.model_status.as_ref().map(|cfg| {
228 ModelStatusCheck {
229 config: cfg.clone(),
230 timeout: Duration::from_secs(cfg.timeout_secs),
231 }
232 });
233
234 let queue_depth = readiness_config.queue_depth.as_ref().map(|cfg| {
235 QueueDepthCheck {
236 config: cfg.clone(),
237 models_endpoint: endpoint.clone(),
238 timeout: Duration::from_secs(cfg.timeout_secs),
239 }
240 });
241
242 Arc::new(CompositeInferenceHealthCheck {
243 base_check,
244 inference_probe,
245 model_status,
246 queue_depth,
247 })
248 } else {
249 Arc::new(base_check)
251 }
252 }
253 };
254
255 let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
256
257 Self {
258 config,
259 checker,
260 health_status: Arc::new(RwLock::new(HashMap::new())),
261 check_handles: Arc::new(RwLock::new(Vec::new())),
262 shutdown_tx: Arc::new(shutdown_tx),
263 }
264 }
265
266 pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
268 info!(
269 target_count = targets.len(),
270 interval_secs = self.config.interval_secs,
271 check_type = self.checker.check_type(),
272 "Starting health checking"
273 );
274
275 let mut handles = self.check_handles.write().await;
276
277 for target in targets {
278 let address = target.address.clone();
279
280 trace!(
281 target = %address,
282 "Initializing health status for target"
283 );
284
285 self.health_status
287 .write()
288 .await
289 .insert(address.clone(), TargetHealthInfo::new());
290
291 debug!(
293 target = %address,
294 "Spawning health check task"
295 );
296 let handle = self.spawn_check_task(address);
297 handles.push(handle);
298 }
299
300 info!(
301 target_count = targets.len(),
302 interval_secs = self.config.interval_secs,
303 healthy_threshold = self.config.healthy_threshold,
304 unhealthy_threshold = self.config.unhealthy_threshold,
305 "Health checking started successfully"
306 );
307
308 Ok(())
309 }
310
311 fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
313 let interval = Duration::from_secs(self.config.interval_secs);
314 let checker = Arc::clone(&self.checker);
315 let health_status = Arc::clone(&self.health_status);
316 let healthy_threshold = self.config.healthy_threshold;
317 let unhealthy_threshold = self.config.unhealthy_threshold;
318 let check_type = self.checker.check_type().to_string();
319 let mut shutdown_rx = self.shutdown_tx.subscribe();
320
321 tokio::spawn(async move {
322 let mut interval_timer = time::interval(interval);
323 interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
324
325 trace!(
326 target = %target,
327 check_type = %check_type,
328 interval_ms = interval.as_millis(),
329 "Health check task started"
330 );
331
332 loop {
333 tokio::select! {
334 _ = interval_timer.tick() => {
335 trace!(
337 target = %target,
338 check_type = %check_type,
339 "Performing health check"
340 );
341 let start = Instant::now();
342 let result = checker.check(&target).await;
343 let check_duration = start.elapsed();
344
345 let mut status_map = health_status.write().await;
347 if let Some(status) = status_map.get_mut(&target) {
348 status.last_check = Instant::now();
349 status.total_checks += 1;
350
351 match result {
352 Ok(response_time) => {
353 status.consecutive_successes += 1;
354 status.consecutive_failures = 0;
355 status.last_success = Some(Instant::now());
356 status.last_error = None;
357 status.total_successes += 1;
358
359 let response_ms = response_time.as_millis() as f64;
361 status.avg_response_time =
362 (status.avg_response_time * (status.total_successes - 1) as f64
363 + response_ms) / status.total_successes as f64;
364
365 if !status.healthy && status.consecutive_successes >= healthy_threshold {
367 status.healthy = true;
368 info!(
369 target = %target,
370 consecutive_successes = status.consecutive_successes,
371 avg_response_ms = format!("{:.2}", status.avg_response_time),
372 total_checks = status.total_checks,
373 "Target marked as healthy"
374 );
375 }
376
377 trace!(
378 target = %target,
379 response_time_ms = response_ms,
380 check_duration_ms = check_duration.as_millis(),
381 consecutive_successes = status.consecutive_successes,
382 health_score = format!("{:.2}", status.health_score()),
383 "Health check succeeded"
384 );
385 }
386 Err(error) => {
387 status.consecutive_failures += 1;
388 status.consecutive_successes = 0;
389 status.last_error = Some(error.clone());
390
391 if status.healthy && status.consecutive_failures >= unhealthy_threshold {
393 status.healthy = false;
394 warn!(
395 target = %target,
396 consecutive_failures = status.consecutive_failures,
397 error = %error,
398 total_checks = status.total_checks,
399 health_score = format!("{:.2}", status.health_score()),
400 "Target marked as unhealthy"
401 );
402 } else {
403 debug!(
404 target = %target,
405 error = %error,
406 consecutive_failures = status.consecutive_failures,
407 unhealthy_threshold = unhealthy_threshold,
408 "Health check failed"
409 );
410 }
411 }
412 }
413 }
414 }
415 _ = shutdown_rx.recv() => {
416 info!(target = %target, "Stopping health check task");
417 break;
418 }
419 }
420 }
421
422 debug!(target = %target, "Health check task stopped");
423 })
424 }
425
426 pub async fn stop(&self) {
428 let task_count = self.check_handles.read().await.len();
429 info!(task_count = task_count, "Stopping health checker");
430
431 let _ = self.shutdown_tx.send(());
433
434 let mut handles = self.check_handles.write().await;
436 for handle in handles.drain(..) {
437 let _ = handle.await;
438 }
439
440 info!("Health checker stopped successfully");
441 }
442
443 pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
445 self.health_status.read().await.get(target).cloned()
446 }
447
448 pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
450 self.health_status.read().await.clone()
451 }
452
453 pub async fn is_healthy(&self, target: &str) -> bool {
455 self.health_status
456 .read()
457 .await
458 .get(target)
459 .map(|s| s.healthy)
460 .unwrap_or(false)
461 }
462
463 pub async fn get_healthy_targets(&self) -> Vec<String> {
465 self.health_status
466 .read()
467 .await
468 .iter()
469 .filter_map(|(target, status)| {
470 if status.healthy {
471 Some(target.clone())
472 } else {
473 None
474 }
475 })
476 .collect()
477 }
478
479 pub async fn mark_unhealthy(&self, target: &str, reason: String) {
481 if let Some(status) = self.health_status.write().await.get_mut(target) {
482 if status.healthy {
483 status.healthy = false;
484 status.consecutive_failures = self.config.unhealthy_threshold;
485 status.consecutive_successes = 0;
486 status.last_error = Some(reason.clone());
487 warn!(
488 target = %target,
489 reason = %reason,
490 "Target marked unhealthy by passive check"
491 );
492 }
493 }
494 }
495}
496
497impl Default for TargetHealthInfo {
498 fn default() -> Self {
499 Self::new()
500 }
501}
502
503impl TargetHealthInfo {
504 pub fn new() -> Self {
506 Self {
507 healthy: true,
508 consecutive_successes: 0,
509 consecutive_failures: 0,
510 last_check: Instant::now(),
511 last_success: Some(Instant::now()),
512 last_error: None,
513 total_checks: 0,
514 total_successes: 0,
515 avg_response_time: 0.0,
516 }
517 }
518
519 pub fn health_score(&self) -> f64 {
521 if self.total_checks == 0 {
522 return 1.0;
523 }
524 self.total_successes as f64 / self.total_checks as f64
525 }
526
527 pub fn is_degraded(&self) -> bool {
529 self.healthy && self.consecutive_failures > 0
530 }
531}
532
533#[async_trait]
534impl HealthCheckImpl for HttpHealthCheck {
535 async fn check(&self, target: &str) -> Result<Duration, String> {
536 let start = Instant::now();
537
538 let addr: SocketAddr = target
540 .parse()
541 .map_err(|e| format!("Invalid address: {}", e))?;
542
543 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
545 .await
546 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
547 .map_err(|e| format!("Connection failed: {}", e))?;
548
549 let host = self.host.as_deref().unwrap_or(target);
551 let request = format!(
552 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
553 self.path,
554 host
555 );
556
557 let mut stream = stream;
559 stream
560 .write_all(request.as_bytes())
561 .await
562 .map_err(|e| format!("Failed to send request: {}", e))?;
563
564 let mut response = vec![0u8; 1024];
565 let n = stream
566 .read(&mut response)
567 .await
568 .map_err(|e| format!("Failed to read response: {}", e))?;
569
570 if n == 0 {
571 return Err("Empty response".to_string());
572 }
573
574 let response_str = String::from_utf8_lossy(&response[..n]);
576 let status_code = parse_http_status(&response_str)
577 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
578
579 if status_code == self.expected_status {
580 Ok(start.elapsed())
581 } else {
582 Err(format!(
583 "Unexpected status code: {} (expected {})",
584 status_code, self.expected_status
585 ))
586 }
587 }
588
589 fn check_type(&self) -> &str {
590 "HTTP"
591 }
592}
593
594#[async_trait]
595impl HealthCheckImpl for TcpHealthCheck {
596 async fn check(&self, target: &str) -> Result<Duration, String> {
597 let start = Instant::now();
598
599 let addr: SocketAddr = target
601 .parse()
602 .map_err(|e| format!("Invalid address: {}", e))?;
603
604 time::timeout(self.timeout, TcpStream::connect(addr))
606 .await
607 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
608 .map_err(|e| format!("Connection failed: {}", e))?;
609
610 Ok(start.elapsed())
611 }
612
613 fn check_type(&self) -> &str {
614 "TCP"
615 }
616}
617
618#[async_trait]
619impl HealthCheckImpl for GrpcHealthCheck {
620 async fn check(&self, target: &str) -> Result<Duration, String> {
621 let start = Instant::now();
622
623 let addr: SocketAddr = target
629 .parse()
630 .map_err(|e| format!("Invalid address: {}", e))?;
631
632 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
634 .await
635 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
636 .map_err(|e| format!("Connection failed: {}", e))?;
637
638 stream
640 .writable()
641 .await
642 .map_err(|e| format!("Connection not writable: {}", e))?;
643
644 debug!(
645 target = %target,
646 service = %self.service,
647 "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
648 );
649
650 Ok(start.elapsed())
651 }
652
653 fn check_type(&self) -> &str {
654 "gRPC"
655 }
656}
657
658#[async_trait]
659impl HealthCheckImpl for InferenceHealthCheck {
660 async fn check(&self, target: &str) -> Result<Duration, String> {
661 let start = Instant::now();
662
663 let addr: SocketAddr = target
665 .parse()
666 .map_err(|e| format!("Invalid address: {}", e))?;
667
668 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
670 .await
671 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
672 .map_err(|e| format!("Connection failed: {}", e))?;
673
674 let request = format!(
676 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
677 self.endpoint,
678 target
679 );
680
681 let mut stream = stream;
683 stream
684 .write_all(request.as_bytes())
685 .await
686 .map_err(|e| format!("Failed to send request: {}", e))?;
687
688 let mut response = vec![0u8; 8192];
690 let n = stream
691 .read(&mut response)
692 .await
693 .map_err(|e| format!("Failed to read response: {}", e))?;
694
695 if n == 0 {
696 return Err("Empty response".to_string());
697 }
698
699 let response_str = String::from_utf8_lossy(&response[..n]);
701 let status_code = parse_http_status(&response_str)
702 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
703
704 if status_code != 200 {
705 return Err(format!(
706 "Unexpected status code: {} (expected 200)",
707 status_code
708 ));
709 }
710
711 if !self.expected_models.is_empty() {
713 if let Some(body_start) = response_str.find("\r\n\r\n") {
715 let body = &response_str[body_start + 4..];
716
717 for model in &self.expected_models {
719 if !body.contains(model) {
720 return Err(format!(
721 "Expected model '{}' not found in response",
722 model
723 ));
724 }
725 }
726
727 debug!(
728 target = %target,
729 endpoint = %self.endpoint,
730 expected_models = ?self.expected_models,
731 "All expected models found in inference health check"
732 );
733 } else {
734 return Err("Could not find response body".to_string());
735 }
736 }
737
738 trace!(
739 target = %target,
740 endpoint = %self.endpoint,
741 response_time_ms = start.elapsed().as_millis(),
742 "Inference health check passed"
743 );
744
745 Ok(start.elapsed())
746 }
747
748 fn check_type(&self) -> &str {
749 "Inference"
750 }
751}
752
753#[async_trait]
754impl HealthCheckImpl for InferenceProbeCheck {
755 async fn check(&self, target: &str) -> Result<Duration, String> {
756 let start = Instant::now();
757
758 let addr: SocketAddr = target
760 .parse()
761 .map_err(|e| format!("Invalid address: {}", e))?;
762
763 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
765 .await
766 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
767 .map_err(|e| format!("Connection failed: {}", e))?;
768
769 let body = format!(
771 r#"{{"model":"{}","prompt":"{}","max_tokens":{}}}"#,
772 self.config.model, self.config.prompt, self.config.max_tokens
773 );
774
775 let request = format!(
777 "POST {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
778 self.config.endpoint,
779 target,
780 body.len(),
781 body
782 );
783
784 let mut stream = stream;
786 stream
787 .write_all(request.as_bytes())
788 .await
789 .map_err(|e| format!("Failed to send request: {}", e))?;
790
791 let mut response = vec![0u8; 16384];
793 let n = stream
794 .read(&mut response)
795 .await
796 .map_err(|e| format!("Failed to read response: {}", e))?;
797
798 if n == 0 {
799 return Err("Empty response".to_string());
800 }
801
802 let latency = start.elapsed();
803
804 let response_str = String::from_utf8_lossy(&response[..n]);
806 let status_code = parse_http_status(&response_str)
807 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
808
809 if status_code != 200 {
810 return Err(format!(
811 "Inference probe failed: status {} (expected 200)",
812 status_code
813 ));
814 }
815
816 if let Some(body_start) = response_str.find("\r\n\r\n") {
818 let body = &response_str[body_start + 4..];
819 if !body.contains("\"choices\"") {
820 return Err("Inference probe response missing 'choices' field".to_string());
821 }
822 }
823
824 if let Some(max_ms) = self.config.max_latency_ms {
826 if latency.as_millis() as u64 > max_ms {
827 return Err(format!(
828 "Inference probe latency {}ms exceeds threshold {}ms",
829 latency.as_millis(),
830 max_ms
831 ));
832 }
833 }
834
835 trace!(
836 target = %target,
837 model = %self.config.model,
838 latency_ms = latency.as_millis(),
839 "Inference probe health check passed"
840 );
841
842 Ok(latency)
843 }
844
845 fn check_type(&self) -> &str {
846 "InferenceProbe"
847 }
848}
849
850#[async_trait]
851impl HealthCheckImpl for ModelStatusCheck {
852 async fn check(&self, target: &str) -> Result<Duration, String> {
853 let start = Instant::now();
854
855 let addr: SocketAddr = target
857 .parse()
858 .map_err(|e| format!("Invalid address: {}", e))?;
859
860 for model in &self.config.models {
862 let endpoint = self.config.endpoint_pattern.replace("{model}", model);
863
864 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
866 .await
867 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
868 .map_err(|e| format!("Connection failed: {}", e))?;
869
870 let request = format!(
872 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
873 endpoint,
874 target
875 );
876
877 let mut stream = stream;
879 stream
880 .write_all(request.as_bytes())
881 .await
882 .map_err(|e| format!("Failed to send request: {}", e))?;
883
884 let mut response = vec![0u8; 8192];
886 let n = stream
887 .read(&mut response)
888 .await
889 .map_err(|e| format!("Failed to read response: {}", e))?;
890
891 if n == 0 {
892 return Err(format!("Empty response for model '{}'", model));
893 }
894
895 let response_str = String::from_utf8_lossy(&response[..n]);
896 let status_code = parse_http_status(&response_str)
897 .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
898
899 if status_code != 200 {
900 return Err(format!(
901 "Model '{}' status check failed: HTTP {}",
902 model, status_code
903 ));
904 }
905
906 if let Some(body_start) = response_str.find("\r\n\r\n") {
908 let body = &response_str[body_start + 4..];
909 let status = extract_json_field(body, &self.config.status_field);
910
911 match status {
912 Some(s) if s == self.config.expected_status => {
913 trace!(
914 target = %target,
915 model = %model,
916 status = %s,
917 "Model status check passed"
918 );
919 }
920 Some(s) => {
921 return Err(format!(
922 "Model '{}' status '{}' != expected '{}'",
923 model, s, self.config.expected_status
924 ));
925 }
926 None => {
927 return Err(format!(
928 "Model '{}' status field '{}' not found",
929 model, self.config.status_field
930 ));
931 }
932 }
933 }
934 }
935
936 Ok(start.elapsed())
937 }
938
939 fn check_type(&self) -> &str {
940 "ModelStatus"
941 }
942}
943
944#[async_trait]
945impl HealthCheckImpl for QueueDepthCheck {
946 async fn check(&self, target: &str) -> Result<Duration, String> {
947 let start = Instant::now();
948
949 let addr: SocketAddr = target
951 .parse()
952 .map_err(|e| format!("Invalid address: {}", e))?;
953
954 let endpoint = self.config.endpoint.as_ref().unwrap_or(&self.models_endpoint);
955
956 let stream = time::timeout(self.timeout, TcpStream::connect(addr))
958 .await
959 .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
960 .map_err(|e| format!("Connection failed: {}", e))?;
961
962 let request = format!(
964 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
965 endpoint,
966 target
967 );
968
969 let mut stream = stream;
971 stream
972 .write_all(request.as_bytes())
973 .await
974 .map_err(|e| format!("Failed to send request: {}", e))?;
975
976 let mut response = vec![0u8; 8192];
978 let n = stream
979 .read(&mut response)
980 .await
981 .map_err(|e| format!("Failed to read response: {}", e))?;
982
983 if n == 0 {
984 return Err("Empty response".to_string());
985 }
986
987 let response_str = String::from_utf8_lossy(&response[..n]);
988
989 let queue_depth = if let Some(ref header_name) = self.config.header {
991 extract_header_value(&response_str, header_name)
992 .and_then(|v| v.parse::<u64>().ok())
993 } else if let Some(ref field) = self.config.body_field {
994 if let Some(body_start) = response_str.find("\r\n\r\n") {
995 let body = &response_str[body_start + 4..];
996 extract_json_field(body, field).and_then(|v| v.parse::<u64>().ok())
997 } else {
998 None
999 }
1000 } else {
1001 return Err("No queue depth source configured (header or body_field)".to_string());
1002 };
1003
1004 let depth = queue_depth.ok_or_else(|| "Could not extract queue depth".to_string())?;
1005
1006 if depth >= self.config.unhealthy_threshold {
1008 return Err(format!(
1009 "Queue depth {} exceeds unhealthy threshold {}",
1010 depth, self.config.unhealthy_threshold
1011 ));
1012 }
1013
1014 if depth >= self.config.degraded_threshold {
1015 warn!(
1016 target = %target,
1017 queue_depth = depth,
1018 threshold = self.config.degraded_threshold,
1019 "Queue depth exceeds degraded threshold"
1020 );
1021 }
1022
1023 trace!(
1024 target = %target,
1025 queue_depth = depth,
1026 "Queue depth check passed"
1027 );
1028
1029 Ok(start.elapsed())
1030 }
1031
1032 fn check_type(&self) -> &str {
1033 "QueueDepth"
1034 }
1035}
1036
1037#[async_trait]
1038impl HealthCheckImpl for CompositeInferenceHealthCheck {
1039 async fn check(&self, target: &str) -> Result<Duration, String> {
1040 let start = Instant::now();
1041
1042 self.base_check.check(target).await?;
1044
1045 if let Some(ref probe) = self.inference_probe {
1047 probe.check(target).await?;
1048 }
1049
1050 if let Some(ref status) = self.model_status {
1051 status.check(target).await?;
1052 }
1053
1054 if let Some(ref queue) = self.queue_depth {
1055 queue.check(target).await?;
1056 }
1057
1058 trace!(
1059 target = %target,
1060 total_time_ms = start.elapsed().as_millis(),
1061 "Composite inference health check passed"
1062 );
1063
1064 Ok(start.elapsed())
1065 }
1066
1067 fn check_type(&self) -> &str {
1068 "CompositeInference"
1069 }
1070}
1071
1072fn extract_header_value(response: &str, header_name: &str) -> Option<String> {
1074 let header_lower = header_name.to_lowercase();
1075 for line in response.lines() {
1076 if line.is_empty() || line == "\r" {
1077 break; }
1079 if let Some((name, value)) = line.split_once(':') {
1080 if name.trim().to_lowercase() == header_lower {
1081 return Some(value.trim().to_string());
1082 }
1083 }
1084 }
1085 None
1086}
1087
1088fn extract_json_field(body: &str, field_path: &str) -> Option<String> {
1090 let json: serde_json::Value = serde_json::from_str(body).ok()?;
1091 let parts: Vec<&str> = field_path.split('.').collect();
1092 let mut current = &json;
1093
1094 for part in parts {
1095 current = current.get(part)?;
1096 }
1097
1098 match current {
1099 serde_json::Value::String(s) => Some(s.clone()),
1100 serde_json::Value::Number(n) => Some(n.to_string()),
1101 serde_json::Value::Bool(b) => Some(b.to_string()),
1102 _ => None,
1103 }
1104}
1105
1106fn parse_http_status(response: &str) -> Option<u16> {
1108 response
1109 .lines()
1110 .next()?
1111 .split_whitespace()
1112 .nth(1)?
1113 .parse()
1114 .ok()
1115}
1116
1117pub struct PassiveHealthChecker {
1123 failure_rate_threshold: f64,
1125 window_size: usize,
1127 outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
1129 last_errors: Arc<RwLock<HashMap<String, String>>>,
1131 active_checker: Option<Arc<ActiveHealthChecker>>,
1133}
1134
1135impl PassiveHealthChecker {
1136 pub fn new(
1138 failure_rate_threshold: f64,
1139 window_size: usize,
1140 active_checker: Option<Arc<ActiveHealthChecker>>,
1141 ) -> Self {
1142 debug!(
1143 failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
1144 window_size = window_size,
1145 has_active_checker = active_checker.is_some(),
1146 "Creating passive health checker"
1147 );
1148 Self {
1149 failure_rate_threshold,
1150 window_size,
1151 outcomes: Arc::new(RwLock::new(HashMap::new())),
1152 last_errors: Arc::new(RwLock::new(HashMap::new())),
1153 active_checker,
1154 }
1155 }
1156
1157 pub async fn record_outcome(&self, target: &str, success: bool, error: Option<&str>) {
1159 trace!(
1160 target = %target,
1161 success = success,
1162 error = ?error,
1163 "Recording request outcome"
1164 );
1165
1166 if let Some(err_msg) = error {
1168 self.last_errors
1169 .write()
1170 .await
1171 .insert(target.to_string(), err_msg.to_string());
1172 } else if success {
1173 self.last_errors.write().await.remove(target);
1175 }
1176
1177 let mut outcomes = self.outcomes.write().await;
1178 let target_outcomes = outcomes
1179 .entry(target.to_string())
1180 .or_insert_with(|| Vec::with_capacity(self.window_size));
1181
1182 if target_outcomes.len() >= self.window_size {
1184 target_outcomes.remove(0);
1185 }
1186 target_outcomes.push(success);
1187
1188 let failures = target_outcomes.iter().filter(|&&s| !s).count();
1190 let failure_rate = failures as f64 / target_outcomes.len() as f64;
1191
1192 trace!(
1193 target = %target,
1194 failure_rate = format!("{:.2}", failure_rate),
1195 window_samples = target_outcomes.len(),
1196 failures = failures,
1197 "Updated failure rate"
1198 );
1199
1200 if failure_rate > self.failure_rate_threshold {
1202 warn!(
1203 target = %target,
1204 failure_rate = format!("{:.2}", failure_rate * 100.0),
1205 threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
1206 window_samples = target_outcomes.len(),
1207 "Failure rate exceeds threshold"
1208 );
1209 if let Some(ref checker) = self.active_checker {
1210 checker
1211 .mark_unhealthy(
1212 target,
1213 format!(
1214 "Failure rate {:.2}% exceeds threshold",
1215 failure_rate * 100.0
1216 ),
1217 )
1218 .await;
1219 }
1220 }
1221 }
1222
1223 pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
1225 let outcomes = self.outcomes.read().await;
1226 outcomes.get(target).map(|target_outcomes| {
1227 let failures = target_outcomes.iter().filter(|&&s| !s).count();
1228 failures as f64 / target_outcomes.len() as f64
1229 })
1230 }
1231
1232 pub async fn get_last_error(&self, target: &str) -> Option<String> {
1234 self.last_errors.read().await.get(target).cloned()
1235 }
1236}
1237
1238use dashmap::DashMap;
1243use sentinel_common::{ColdModelAction, WarmthDetectionConfig};
1244use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
1245
1246pub struct WarmthTracker {
1252 config: WarmthDetectionConfig,
1254 targets: DashMap<String, TargetWarmthState>,
1256}
1257
1258struct TargetWarmthState {
1260 baseline_latency_ms: AtomicU64,
1262 sample_count: AtomicU32,
1264 last_request_ms: AtomicU64,
1266 is_cold: AtomicBool,
1268 cold_start_count: AtomicU64,
1270}
1271
1272impl TargetWarmthState {
1273 fn new() -> Self {
1274 Self {
1275 baseline_latency_ms: AtomicU64::new(0),
1276 sample_count: AtomicU32::new(0),
1277 last_request_ms: AtomicU64::new(0),
1278 is_cold: AtomicBool::new(false),
1279 cold_start_count: AtomicU64::new(0),
1280 }
1281 }
1282
1283 fn update_baseline(&self, latency_ms: u64, sample_size: u32) {
1284 let count = self.sample_count.fetch_add(1, Ordering::Relaxed);
1285 let current = self.baseline_latency_ms.load(Ordering::Relaxed);
1286
1287 if count < sample_size {
1288 let new_baseline = if count == 0 {
1290 latency_ms
1291 } else {
1292 (current * count as u64 + latency_ms) / (count as u64 + 1)
1293 };
1294 self.baseline_latency_ms.store(new_baseline, Ordering::Relaxed);
1295 } else {
1296 let alpha = 0.1_f64;
1299 let new_baseline =
1300 (alpha * latency_ms as f64 + (1.0 - alpha) * current as f64) as u64;
1301 self.baseline_latency_ms.store(new_baseline, Ordering::Relaxed);
1302 }
1303 }
1304}
1305
1306impl WarmthTracker {
1307 pub fn new(config: WarmthDetectionConfig) -> Self {
1309 Self {
1310 config,
1311 targets: DashMap::new(),
1312 }
1313 }
1314
1315 pub fn with_defaults() -> Self {
1317 Self::new(WarmthDetectionConfig {
1318 sample_size: 10,
1319 cold_threshold_multiplier: 3.0,
1320 idle_cold_timeout_secs: 300,
1321 cold_action: ColdModelAction::LogOnly,
1322 })
1323 }
1324
1325 pub fn record_request(&self, target: &str, latency: Duration) -> bool {
1329 let now_ms = std::time::SystemTime::now()
1330 .duration_since(std::time::UNIX_EPOCH)
1331 .map(|d| d.as_millis() as u64)
1332 .unwrap_or(0);
1333
1334 let latency_ms = latency.as_millis() as u64;
1335 let idle_threshold_ms = self.config.idle_cold_timeout_secs * 1000;
1336
1337 let state = self
1338 .targets
1339 .entry(target.to_string())
1340 .or_insert_with(TargetWarmthState::new);
1341
1342 let last_request = state.last_request_ms.load(Ordering::Relaxed);
1343 let idle_duration_ms = if last_request > 0 {
1344 now_ms.saturating_sub(last_request)
1345 } else {
1346 0
1347 };
1348
1349 state.last_request_ms.store(now_ms, Ordering::Relaxed);
1351
1352 if idle_duration_ms >= idle_threshold_ms {
1354 let baseline = state.baseline_latency_ms.load(Ordering::Relaxed);
1355
1356 if baseline > 0 {
1358 let threshold = (baseline as f64 * self.config.cold_threshold_multiplier) as u64;
1359
1360 if latency_ms > threshold {
1361 state.is_cold.store(true, Ordering::Release);
1363 state.cold_start_count.fetch_add(1, Ordering::Relaxed);
1364
1365 warn!(
1366 target = %target,
1367 latency_ms = latency_ms,
1368 baseline_ms = baseline,
1369 threshold_ms = threshold,
1370 idle_duration_secs = idle_duration_ms / 1000,
1371 cold_action = ?self.config.cold_action,
1372 "Cold model detected - latency spike after idle period"
1373 );
1374
1375 return true;
1376 }
1377 }
1378 }
1379
1380 state.is_cold.store(false, Ordering::Release);
1382 state.update_baseline(latency_ms, self.config.sample_size);
1383
1384 trace!(
1385 target = %target,
1386 latency_ms = latency_ms,
1387 baseline_ms = state.baseline_latency_ms.load(Ordering::Relaxed),
1388 sample_count = state.sample_count.load(Ordering::Relaxed),
1389 "Recorded request latency for warmth tracking"
1390 );
1391
1392 false
1393 }
1394
1395 pub fn is_cold(&self, target: &str) -> bool {
1397 self.targets
1398 .get(target)
1399 .map(|s| s.is_cold.load(Ordering::Acquire))
1400 .unwrap_or(false)
1401 }
1402
1403 pub fn cold_action(&self) -> ColdModelAction {
1405 self.config.cold_action
1406 }
1407
1408 pub fn baseline_latency_ms(&self, target: &str) -> Option<u64> {
1410 self.targets
1411 .get(target)
1412 .map(|s| s.baseline_latency_ms.load(Ordering::Relaxed))
1413 }
1414
1415 pub fn cold_start_count(&self, target: &str) -> u64 {
1417 self.targets
1418 .get(target)
1419 .map(|s| s.cold_start_count.load(Ordering::Relaxed))
1420 .unwrap_or(0)
1421 }
1422
1423 pub fn should_deprioritize(&self, target: &str) -> bool {
1425 if !self.is_cold(target) {
1426 return false;
1427 }
1428
1429 match self.config.cold_action {
1430 ColdModelAction::LogOnly => false,
1431 ColdModelAction::MarkDegraded | ColdModelAction::MarkUnhealthy => true,
1432 }
1433 }
1434}
1435
1436#[cfg(test)]
1437mod tests {
1438 use super::*;
1439
1440 #[tokio::test]
1441 async fn test_health_status() {
1442 let status = TargetHealthInfo::new();
1443 assert!(status.healthy);
1444 assert_eq!(status.health_score(), 1.0);
1445 assert!(!status.is_degraded());
1446 }
1447
1448 #[tokio::test]
1449 async fn test_passive_health_checker() {
1450 let checker = PassiveHealthChecker::new(0.5, 10, None);
1451
1452 for _ in 0..5 {
1454 checker.record_outcome("target1", true, None).await;
1455 }
1456 for _ in 0..3 {
1457 checker.record_outcome("target1", false, Some("HTTP 503")).await;
1458 }
1459
1460 let failure_rate = checker.get_failure_rate("target1").await.unwrap();
1461 assert!(failure_rate > 0.3 && failure_rate < 0.4);
1462 }
1463
1464 #[test]
1465 fn test_parse_http_status() {
1466 let response = "HTTP/1.1 200 OK\r\n";
1467 assert_eq!(parse_http_status(response), Some(200));
1468
1469 let response = "HTTP/1.1 404 Not Found\r\n";
1470 assert_eq!(parse_http_status(response), Some(404));
1471
1472 let response = "Invalid response";
1473 assert_eq!(parse_http_status(response), None);
1474 }
1475
1476 #[test]
1477 fn test_warmth_tracker_baseline() {
1478 let tracker = WarmthTracker::with_defaults();
1479
1480 for i in 0..10 {
1482 let cold = tracker.record_request("target1", Duration::from_millis(100));
1483 assert!(!cold, "Should not detect cold on request {}", i);
1484 }
1485
1486 let baseline = tracker.baseline_latency_ms("target1");
1488 assert!(baseline.is_some());
1489 assert!(baseline.unwrap() > 0 && baseline.unwrap() <= 100);
1490 }
1491
1492 #[test]
1493 fn test_warmth_tracker_cold_detection() {
1494 let config = WarmthDetectionConfig {
1495 sample_size: 5,
1496 cold_threshold_multiplier: 2.0,
1497 idle_cold_timeout_secs: 0, cold_action: ColdModelAction::MarkDegraded,
1499 };
1500 let tracker = WarmthTracker::new(config);
1501
1502 for _ in 0..5 {
1504 tracker.record_request("target1", Duration::from_millis(100));
1505 }
1506
1507 std::thread::sleep(Duration::from_millis(10));
1509
1510 let cold = tracker.record_request("target1", Duration::from_millis(300));
1512 assert!(cold, "Should detect cold start");
1513 assert!(tracker.is_cold("target1"));
1514 assert_eq!(tracker.cold_start_count("target1"), 1);
1515 }
1516
1517 #[test]
1518 fn test_warmth_tracker_no_cold_on_normal_latency() {
1519 let config = WarmthDetectionConfig {
1520 sample_size: 5,
1521 cold_threshold_multiplier: 3.0,
1522 idle_cold_timeout_secs: 0,
1523 cold_action: ColdModelAction::LogOnly,
1524 };
1525 let tracker = WarmthTracker::new(config);
1526
1527 for _ in 0..5 {
1529 tracker.record_request("target1", Duration::from_millis(100));
1530 }
1531
1532 std::thread::sleep(Duration::from_millis(10));
1533
1534 let cold = tracker.record_request("target1", Duration::from_millis(150));
1536 assert!(!cold, "Should not detect cold for normal variation");
1537 assert!(!tracker.is_cold("target1"));
1538 }
1539
1540 #[test]
1541 fn test_warmth_tracker_deprioritize() {
1542 let config = WarmthDetectionConfig {
1543 sample_size: 2,
1544 cold_threshold_multiplier: 2.0,
1545 idle_cold_timeout_secs: 0,
1546 cold_action: ColdModelAction::MarkDegraded,
1547 };
1548 let tracker = WarmthTracker::new(config);
1549
1550 tracker.record_request("target1", Duration::from_millis(100));
1552 tracker.record_request("target1", Duration::from_millis(100));
1553 std::thread::sleep(Duration::from_millis(10));
1554 tracker.record_request("target1", Duration::from_millis(300));
1555
1556 assert!(tracker.should_deprioritize("target1"));
1558
1559 tracker.record_request("target1", Duration::from_millis(100));
1561 assert!(!tracker.should_deprioritize("target1"));
1562 }
1563}