Skip to main content

sentinel_proxy/
health.rs

1//! Health checking module for Sentinel proxy
2//!
3//! This module implements active and passive health checking for upstream servers,
4//! supporting HTTP, TCP, and gRPC health checks with configurable thresholds.
5
6use async_trait::async_trait;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::TcpStream;
13use tokio::sync::RwLock;
14use tokio::time;
15use tracing::{debug, info, trace, warn};
16
17use sentinel_common::{errors::SentinelResult, types::HealthCheckType};
18use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamTarget};
19
20/// Active health checker for upstream targets
21///
22/// Performs periodic health checks on upstream targets using HTTP, TCP, or gRPC
23/// protocols to determine their availability for load balancing.
24pub struct ActiveHealthChecker {
25    /// Check configuration
26    config: HealthCheckConfig,
27    /// Health checker implementation
28    checker: Arc<dyn HealthCheckImpl>,
29    /// Health status per target
30    health_status: Arc<RwLock<HashMap<String, TargetHealthInfo>>>,
31    /// Check task handles
32    check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
33    /// Shutdown signal
34    shutdown_tx: Arc<tokio::sync::broadcast::Sender<()>>,
35}
36
37/// Health status information for a target
38#[derive(Debug, Clone)]
39pub struct TargetHealthInfo {
40    /// Is target healthy
41    pub healthy: bool,
42    /// Consecutive successes
43    pub consecutive_successes: u32,
44    /// Consecutive failures
45    pub consecutive_failures: u32,
46    /// Last check time
47    pub last_check: Instant,
48    /// Last successful check
49    pub last_success: Option<Instant>,
50    /// Last error message
51    pub last_error: Option<String>,
52    /// Total checks performed
53    pub total_checks: u64,
54    /// Total successful checks
55    pub total_successes: u64,
56    /// Average response time (ms)
57    pub avg_response_time: f64,
58}
59
60/// Health check implementation trait
61#[async_trait]
62trait HealthCheckImpl: Send + Sync {
63    /// Perform health check on a target
64    async fn check(&self, target: &str) -> Result<Duration, String>;
65
66    /// Get check type name
67    fn check_type(&self) -> &str;
68}
69
70/// HTTP health check implementation
71struct HttpHealthCheck {
72    path: String,
73    expected_status: u16,
74    host: Option<String>,
75    timeout: Duration,
76}
77
78/// TCP health check implementation
79struct TcpHealthCheck {
80    timeout: Duration,
81}
82
83/// gRPC health check implementation.
84///
85/// Currently uses TCP connectivity check as a fallback since full gRPC
86/// health checking protocol (grpc.health.v1.Health) requires the `tonic`
87/// crate for HTTP/2 and Protocol Buffers support.
88///
89/// Full implementation would:
90/// 1. Establish HTTP/2 connection
91/// 2. Call `grpc.health.v1.Health/Check` with service name
92/// 3. Parse `HealthCheckResponse` for SERVING/NOT_SERVING status
93///
94/// See: https://github.com/grpc/grpc/blob/master/doc/health-checking.md
95struct GrpcHealthCheck {
96    service: String,
97    timeout: Duration,
98}
99
100/// Inference health check implementation for LLM/AI backends.
101///
102/// Probes the models endpoint to verify the inference server is running
103/// and expected models are available. Typically used with OpenAI-compatible
104/// APIs that expose a `/v1/models` endpoint.
105///
106/// The check:
107/// 1. Sends GET request to the configured endpoint (default: `/v1/models`)
108/// 2. Expects HTTP 200 response
109/// 3. Optionally parses response to verify expected models are available
110struct InferenceHealthCheck {
111    endpoint: String,
112    expected_models: Vec<String>,
113    timeout: Duration,
114}
115
116/// Inference probe health check - sends minimal completion request
117///
118/// Verifies model can actually process requests, not just that server is running.
119struct InferenceProbeCheck {
120    config: sentinel_common::InferenceProbeConfig,
121    timeout: Duration,
122}
123
124/// Model status endpoint health check
125///
126/// Queries provider-specific status endpoints to verify model readiness.
127struct ModelStatusCheck {
128    config: sentinel_common::ModelStatusConfig,
129    timeout: Duration,
130}
131
132/// Queue depth health check
133///
134/// Monitors queue depth from headers or response body to detect overload.
135struct QueueDepthCheck {
136    config: sentinel_common::QueueDepthConfig,
137    models_endpoint: String,
138    timeout: Duration,
139}
140
141/// Composite inference health check that runs multiple sub-checks
142///
143/// Runs base inference check plus any configured readiness checks.
144/// All enabled checks must pass for the target to be considered healthy.
145struct CompositeInferenceHealthCheck {
146    base_check: InferenceHealthCheck,
147    inference_probe: Option<InferenceProbeCheck>,
148    model_status: Option<ModelStatusCheck>,
149    queue_depth: Option<QueueDepthCheck>,
150}
151
152impl ActiveHealthChecker {
153    /// Create new active health checker
154    pub fn new(config: HealthCheckConfig) -> Self {
155        debug!(
156            check_type = ?config.check_type,
157            interval_secs = config.interval_secs,
158            timeout_secs = config.timeout_secs,
159            healthy_threshold = config.healthy_threshold,
160            unhealthy_threshold = config.unhealthy_threshold,
161            "Creating active health checker"
162        );
163
164        let checker: Arc<dyn HealthCheckImpl> = match &config.check_type {
165            HealthCheckType::Http {
166                path,
167                expected_status,
168                host,
169            } => {
170                trace!(
171                    path = %path,
172                    expected_status = expected_status,
173                    host = host.as_deref().unwrap_or("(default)"),
174                    "Configuring HTTP health check"
175                );
176                Arc::new(HttpHealthCheck {
177                    path: path.clone(),
178                    expected_status: *expected_status,
179                    host: host.clone(),
180                    timeout: Duration::from_secs(config.timeout_secs),
181                })
182            }
183            HealthCheckType::Tcp => {
184                trace!("Configuring TCP health check");
185                Arc::new(TcpHealthCheck {
186                    timeout: Duration::from_secs(config.timeout_secs),
187                })
188            }
189            HealthCheckType::Grpc { service } => {
190                trace!(
191                    service = %service,
192                    "Configuring gRPC health check"
193                );
194                Arc::new(GrpcHealthCheck {
195                    service: service.clone(),
196                    timeout: Duration::from_secs(config.timeout_secs),
197                })
198            }
199            HealthCheckType::Inference {
200                endpoint,
201                expected_models,
202                readiness,
203            } => {
204                trace!(
205                    endpoint = %endpoint,
206                    expected_models = ?expected_models,
207                    has_readiness = readiness.is_some(),
208                    "Configuring inference health check"
209                );
210
211                let base_timeout = Duration::from_secs(config.timeout_secs);
212                let base_check = InferenceHealthCheck {
213                    endpoint: endpoint.clone(),
214                    expected_models: expected_models.clone(),
215                    timeout: base_timeout,
216                };
217
218                if let Some(ref readiness_config) = readiness {
219                    // Create composite check with sub-checks
220                    let inference_probe = readiness_config.inference_probe.as_ref().map(|cfg| {
221                        InferenceProbeCheck {
222                            config: cfg.clone(),
223                            timeout: Duration::from_secs(cfg.timeout_secs),
224                        }
225                    });
226
227                    let model_status = readiness_config.model_status.as_ref().map(|cfg| {
228                        ModelStatusCheck {
229                            config: cfg.clone(),
230                            timeout: Duration::from_secs(cfg.timeout_secs),
231                        }
232                    });
233
234                    let queue_depth = readiness_config.queue_depth.as_ref().map(|cfg| {
235                        QueueDepthCheck {
236                            config: cfg.clone(),
237                            models_endpoint: endpoint.clone(),
238                            timeout: Duration::from_secs(cfg.timeout_secs),
239                        }
240                    });
241
242                    Arc::new(CompositeInferenceHealthCheck {
243                        base_check,
244                        inference_probe,
245                        model_status,
246                        queue_depth,
247                    })
248                } else {
249                    // Simple inference check without readiness sub-checks
250                    Arc::new(base_check)
251                }
252            }
253        };
254
255        let (shutdown_tx, _) = tokio::sync::broadcast::channel(1);
256
257        Self {
258            config,
259            checker,
260            health_status: Arc::new(RwLock::new(HashMap::new())),
261            check_handles: Arc::new(RwLock::new(Vec::new())),
262            shutdown_tx: Arc::new(shutdown_tx),
263        }
264    }
265
266    /// Start health checking for targets
267    pub async fn start(&self, targets: &[UpstreamTarget]) -> SentinelResult<()> {
268        info!(
269            target_count = targets.len(),
270            interval_secs = self.config.interval_secs,
271            check_type = self.checker.check_type(),
272            "Starting health checking"
273        );
274
275        let mut handles = self.check_handles.write().await;
276
277        for target in targets {
278            let address = target.address.clone();
279
280            trace!(
281                target = %address,
282                "Initializing health status for target"
283            );
284
285            // Initialize health status
286            self.health_status
287                .write()
288                .await
289                .insert(address.clone(), TargetHealthInfo::new());
290
291            // Spawn health check task
292            debug!(
293                target = %address,
294                "Spawning health check task"
295            );
296            let handle = self.spawn_check_task(address);
297            handles.push(handle);
298        }
299
300        info!(
301            target_count = targets.len(),
302            interval_secs = self.config.interval_secs,
303            healthy_threshold = self.config.healthy_threshold,
304            unhealthy_threshold = self.config.unhealthy_threshold,
305            "Health checking started successfully"
306        );
307
308        Ok(())
309    }
310
311    /// Spawn health check task for a target
312    fn spawn_check_task(&self, target: String) -> tokio::task::JoinHandle<()> {
313        let interval = Duration::from_secs(self.config.interval_secs);
314        let checker = Arc::clone(&self.checker);
315        let health_status = Arc::clone(&self.health_status);
316        let healthy_threshold = self.config.healthy_threshold;
317        let unhealthy_threshold = self.config.unhealthy_threshold;
318        let check_type = self.checker.check_type().to_string();
319        let mut shutdown_rx = self.shutdown_tx.subscribe();
320
321        tokio::spawn(async move {
322            let mut interval_timer = time::interval(interval);
323            interval_timer.set_missed_tick_behavior(time::MissedTickBehavior::Skip);
324
325            trace!(
326                target = %target,
327                check_type = %check_type,
328                interval_ms = interval.as_millis(),
329                "Health check task started"
330            );
331
332            loop {
333                tokio::select! {
334                    _ = interval_timer.tick() => {
335                        // Perform health check
336                        trace!(
337                            target = %target,
338                            check_type = %check_type,
339                            "Performing health check"
340                        );
341                        let start = Instant::now();
342                        let result = checker.check(&target).await;
343                        let check_duration = start.elapsed();
344
345                        // Update health status
346                        let mut status_map = health_status.write().await;
347                        if let Some(status) = status_map.get_mut(&target) {
348                            status.last_check = Instant::now();
349                            status.total_checks += 1;
350
351                            match result {
352                                Ok(response_time) => {
353                                    status.consecutive_successes += 1;
354                                    status.consecutive_failures = 0;
355                                    status.last_success = Some(Instant::now());
356                                    status.last_error = None;
357                                    status.total_successes += 1;
358
359                                    // Update average response time
360                                    let response_ms = response_time.as_millis() as f64;
361                                    status.avg_response_time =
362                                        (status.avg_response_time * (status.total_successes - 1) as f64
363                                        + response_ms) / status.total_successes as f64;
364
365                                    // Check if should mark as healthy
366                                    if !status.healthy && status.consecutive_successes >= healthy_threshold {
367                                        status.healthy = true;
368                                        info!(
369                                            target = %target,
370                                            consecutive_successes = status.consecutive_successes,
371                                            avg_response_ms = format!("{:.2}", status.avg_response_time),
372                                            total_checks = status.total_checks,
373                                            "Target marked as healthy"
374                                        );
375                                    }
376
377                                    trace!(
378                                        target = %target,
379                                        response_time_ms = response_ms,
380                                        check_duration_ms = check_duration.as_millis(),
381                                        consecutive_successes = status.consecutive_successes,
382                                        health_score = format!("{:.2}", status.health_score()),
383                                        "Health check succeeded"
384                                    );
385                                }
386                                Err(error) => {
387                                    status.consecutive_failures += 1;
388                                    status.consecutive_successes = 0;
389                                    status.last_error = Some(error.clone());
390
391                                    // Check if should mark as unhealthy
392                                    if status.healthy && status.consecutive_failures >= unhealthy_threshold {
393                                        status.healthy = false;
394                                        warn!(
395                                            target = %target,
396                                            consecutive_failures = status.consecutive_failures,
397                                            error = %error,
398                                            total_checks = status.total_checks,
399                                            health_score = format!("{:.2}", status.health_score()),
400                                            "Target marked as unhealthy"
401                                        );
402                                    } else {
403                                        debug!(
404                                            target = %target,
405                                            error = %error,
406                                            consecutive_failures = status.consecutive_failures,
407                                            unhealthy_threshold = unhealthy_threshold,
408                                            "Health check failed"
409                                        );
410                                    }
411                                }
412                            }
413                        }
414                    }
415                    _ = shutdown_rx.recv() => {
416                        info!(target = %target, "Stopping health check task");
417                        break;
418                    }
419                }
420            }
421
422            debug!(target = %target, "Health check task stopped");
423        })
424    }
425
426    /// Stop health checking
427    pub async fn stop(&self) {
428        let task_count = self.check_handles.read().await.len();
429        info!(task_count = task_count, "Stopping health checker");
430
431        // Send shutdown signal
432        let _ = self.shutdown_tx.send(());
433
434        // Wait for all tasks to complete
435        let mut handles = self.check_handles.write().await;
436        for handle in handles.drain(..) {
437            let _ = handle.await;
438        }
439
440        info!("Health checker stopped successfully");
441    }
442
443    /// Get health status for a target
444    pub async fn get_status(&self, target: &str) -> Option<TargetHealthInfo> {
445        self.health_status.read().await.get(target).cloned()
446    }
447
448    /// Get all health statuses
449    pub async fn get_all_statuses(&self) -> HashMap<String, TargetHealthInfo> {
450        self.health_status.read().await.clone()
451    }
452
453    /// Check if target is healthy
454    pub async fn is_healthy(&self, target: &str) -> bool {
455        self.health_status
456            .read()
457            .await
458            .get(target)
459            .map(|s| s.healthy)
460            .unwrap_or(false)
461    }
462
463    /// Get healthy targets
464    pub async fn get_healthy_targets(&self) -> Vec<String> {
465        self.health_status
466            .read()
467            .await
468            .iter()
469            .filter_map(|(target, status)| {
470                if status.healthy {
471                    Some(target.clone())
472                } else {
473                    None
474                }
475            })
476            .collect()
477    }
478
479    /// Mark target as unhealthy (for passive health checking)
480    pub async fn mark_unhealthy(&self, target: &str, reason: String) {
481        if let Some(status) = self.health_status.write().await.get_mut(target) {
482            if status.healthy {
483                status.healthy = false;
484                status.consecutive_failures = self.config.unhealthy_threshold;
485                status.consecutive_successes = 0;
486                status.last_error = Some(reason.clone());
487                warn!(
488                    target = %target,
489                    reason = %reason,
490                    "Target marked unhealthy by passive check"
491                );
492            }
493        }
494    }
495}
496
497impl Default for TargetHealthInfo {
498    fn default() -> Self {
499        Self::new()
500    }
501}
502
503impl TargetHealthInfo {
504    /// Create new health status (initially healthy)
505    pub fn new() -> Self {
506        Self {
507            healthy: true,
508            consecutive_successes: 0,
509            consecutive_failures: 0,
510            last_check: Instant::now(),
511            last_success: Some(Instant::now()),
512            last_error: None,
513            total_checks: 0,
514            total_successes: 0,
515            avg_response_time: 0.0,
516        }
517    }
518
519    /// Get health score (0.0 - 1.0)
520    pub fn health_score(&self) -> f64 {
521        if self.total_checks == 0 {
522            return 1.0;
523        }
524        self.total_successes as f64 / self.total_checks as f64
525    }
526
527    /// Check if status is degraded (healthy but with recent failures)
528    pub fn is_degraded(&self) -> bool {
529        self.healthy && self.consecutive_failures > 0
530    }
531}
532
533#[async_trait]
534impl HealthCheckImpl for HttpHealthCheck {
535    async fn check(&self, target: &str) -> Result<Duration, String> {
536        let start = Instant::now();
537
538        // Parse target address
539        let addr: SocketAddr = target
540            .parse()
541            .map_err(|e| format!("Invalid address: {}", e))?;
542
543        // Connect with timeout
544        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
545            .await
546            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
547            .map_err(|e| format!("Connection failed: {}", e))?;
548
549        // Build HTTP request
550        let host = self.host.as_deref().unwrap_or(target);
551        let request = format!(
552            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nConnection: close\r\n\r\n",
553            self.path,
554            host
555        );
556
557        // Send request and read response
558        let mut stream = stream;
559        stream
560            .write_all(request.as_bytes())
561            .await
562            .map_err(|e| format!("Failed to send request: {}", e))?;
563
564        let mut response = vec![0u8; 1024];
565        let n = stream
566            .read(&mut response)
567            .await
568            .map_err(|e| format!("Failed to read response: {}", e))?;
569
570        if n == 0 {
571            return Err("Empty response".to_string());
572        }
573
574        // Parse status code
575        let response_str = String::from_utf8_lossy(&response[..n]);
576        let status_code = parse_http_status(&response_str)
577            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
578
579        if status_code == self.expected_status {
580            Ok(start.elapsed())
581        } else {
582            Err(format!(
583                "Unexpected status code: {} (expected {})",
584                status_code, self.expected_status
585            ))
586        }
587    }
588
589    fn check_type(&self) -> &str {
590        "HTTP"
591    }
592}
593
594#[async_trait]
595impl HealthCheckImpl for TcpHealthCheck {
596    async fn check(&self, target: &str) -> Result<Duration, String> {
597        let start = Instant::now();
598
599        // Parse target address
600        let addr: SocketAddr = target
601            .parse()
602            .map_err(|e| format!("Invalid address: {}", e))?;
603
604        // Connect with timeout
605        time::timeout(self.timeout, TcpStream::connect(addr))
606            .await
607            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
608            .map_err(|e| format!("Connection failed: {}", e))?;
609
610        Ok(start.elapsed())
611    }
612
613    fn check_type(&self) -> &str {
614        "TCP"
615    }
616}
617
618#[async_trait]
619impl HealthCheckImpl for GrpcHealthCheck {
620    async fn check(&self, target: &str) -> Result<Duration, String> {
621        let start = Instant::now();
622
623        // NOTE: Full gRPC health check requires `tonic` crate for HTTP/2 support.
624        // This implementation uses TCP connectivity as a reasonable fallback.
625        // The gRPC health checking protocol (grpc.health.v1.Health/Check) would
626        // return SERVING, NOT_SERVING, or UNKNOWN for the specified service.
627
628        let addr: SocketAddr = target
629            .parse()
630            .map_err(|e| format!("Invalid address: {}", e))?;
631
632        // TCP connectivity check as fallback
633        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
634            .await
635            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
636            .map_err(|e| format!("Connection failed: {}", e))?;
637
638        // Verify connection is writable (basic health indicator)
639        stream
640            .writable()
641            .await
642            .map_err(|e| format!("Connection not writable: {}", e))?;
643
644        debug!(
645            target = %target,
646            service = %self.service,
647            "gRPC health check using TCP fallback (full gRPC protocol requires tonic)"
648        );
649
650        Ok(start.elapsed())
651    }
652
653    fn check_type(&self) -> &str {
654        "gRPC"
655    }
656}
657
658#[async_trait]
659impl HealthCheckImpl for InferenceHealthCheck {
660    async fn check(&self, target: &str) -> Result<Duration, String> {
661        let start = Instant::now();
662
663        // Parse target address
664        let addr: SocketAddr = target
665            .parse()
666            .map_err(|e| format!("Invalid address: {}", e))?;
667
668        // Connect with timeout
669        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
670            .await
671            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
672            .map_err(|e| format!("Connection failed: {}", e))?;
673
674        // Build HTTP request for the models endpoint
675        let request = format!(
676            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
677            self.endpoint,
678            target
679        );
680
681        // Send request and read response
682        let mut stream = stream;
683        stream
684            .write_all(request.as_bytes())
685            .await
686            .map_err(|e| format!("Failed to send request: {}", e))?;
687
688        // Read response (larger buffer for JSON response)
689        let mut response = vec![0u8; 8192];
690        let n = stream
691            .read(&mut response)
692            .await
693            .map_err(|e| format!("Failed to read response: {}", e))?;
694
695        if n == 0 {
696            return Err("Empty response".to_string());
697        }
698
699        // Parse status code
700        let response_str = String::from_utf8_lossy(&response[..n]);
701        let status_code = parse_http_status(&response_str)
702            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
703
704        if status_code != 200 {
705            return Err(format!(
706                "Unexpected status code: {} (expected 200)",
707                status_code
708            ));
709        }
710
711        // If expected models are specified, verify they're in the response
712        if !self.expected_models.is_empty() {
713            // Find the JSON body (after headers)
714            if let Some(body_start) = response_str.find("\r\n\r\n") {
715                let body = &response_str[body_start + 4..];
716
717                // Check if each expected model is mentioned in the response
718                for model in &self.expected_models {
719                    if !body.contains(model) {
720                        return Err(format!(
721                            "Expected model '{}' not found in response",
722                            model
723                        ));
724                    }
725                }
726
727                debug!(
728                    target = %target,
729                    endpoint = %self.endpoint,
730                    expected_models = ?self.expected_models,
731                    "All expected models found in inference health check"
732                );
733            } else {
734                return Err("Could not find response body".to_string());
735            }
736        }
737
738        trace!(
739            target = %target,
740            endpoint = %self.endpoint,
741            response_time_ms = start.elapsed().as_millis(),
742            "Inference health check passed"
743        );
744
745        Ok(start.elapsed())
746    }
747
748    fn check_type(&self) -> &str {
749        "Inference"
750    }
751}
752
753#[async_trait]
754impl HealthCheckImpl for InferenceProbeCheck {
755    async fn check(&self, target: &str) -> Result<Duration, String> {
756        let start = Instant::now();
757
758        // Parse target address
759        let addr: SocketAddr = target
760            .parse()
761            .map_err(|e| format!("Invalid address: {}", e))?;
762
763        // Connect with timeout
764        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
765            .await
766            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
767            .map_err(|e| format!("Connection failed: {}", e))?;
768
769        // Build completion request body
770        let body = format!(
771            r#"{{"model":"{}","prompt":"{}","max_tokens":{}}}"#,
772            self.config.model, self.config.prompt, self.config.max_tokens
773        );
774
775        // Build HTTP request
776        let request = format!(
777            "POST {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
778            self.config.endpoint,
779            target,
780            body.len(),
781            body
782        );
783
784        // Send request
785        let mut stream = stream;
786        stream
787            .write_all(request.as_bytes())
788            .await
789            .map_err(|e| format!("Failed to send request: {}", e))?;
790
791        // Read response
792        let mut response = vec![0u8; 16384];
793        let n = stream
794            .read(&mut response)
795            .await
796            .map_err(|e| format!("Failed to read response: {}", e))?;
797
798        if n == 0 {
799            return Err("Empty response".to_string());
800        }
801
802        let latency = start.elapsed();
803
804        // Parse status code
805        let response_str = String::from_utf8_lossy(&response[..n]);
806        let status_code = parse_http_status(&response_str)
807            .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
808
809        if status_code != 200 {
810            return Err(format!(
811                "Inference probe failed: status {} (expected 200)",
812                status_code
813            ));
814        }
815
816        // Verify response contains choices array
817        if let Some(body_start) = response_str.find("\r\n\r\n") {
818            let body = &response_str[body_start + 4..];
819            if !body.contains("\"choices\"") {
820                return Err("Inference probe response missing 'choices' field".to_string());
821            }
822        }
823
824        // Check latency threshold if configured
825        if let Some(max_ms) = self.config.max_latency_ms {
826            if latency.as_millis() as u64 > max_ms {
827                return Err(format!(
828                    "Inference probe latency {}ms exceeds threshold {}ms",
829                    latency.as_millis(),
830                    max_ms
831                ));
832            }
833        }
834
835        trace!(
836            target = %target,
837            model = %self.config.model,
838            latency_ms = latency.as_millis(),
839            "Inference probe health check passed"
840        );
841
842        Ok(latency)
843    }
844
845    fn check_type(&self) -> &str {
846        "InferenceProbe"
847    }
848}
849
850#[async_trait]
851impl HealthCheckImpl for ModelStatusCheck {
852    async fn check(&self, target: &str) -> Result<Duration, String> {
853        let start = Instant::now();
854
855        // Parse target address
856        let addr: SocketAddr = target
857            .parse()
858            .map_err(|e| format!("Invalid address: {}", e))?;
859
860        // Check each model's status
861        for model in &self.config.models {
862            let endpoint = self.config.endpoint_pattern.replace("{model}", model);
863
864            // Connect with timeout
865            let stream = time::timeout(self.timeout, TcpStream::connect(addr))
866                .await
867                .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
868                .map_err(|e| format!("Connection failed: {}", e))?;
869
870            // Build HTTP request
871            let request = format!(
872                "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
873                endpoint,
874                target
875            );
876
877            // Send request
878            let mut stream = stream;
879            stream
880                .write_all(request.as_bytes())
881                .await
882                .map_err(|e| format!("Failed to send request: {}", e))?;
883
884            // Read response
885            let mut response = vec![0u8; 8192];
886            let n = stream
887                .read(&mut response)
888                .await
889                .map_err(|e| format!("Failed to read response: {}", e))?;
890
891            if n == 0 {
892                return Err(format!("Empty response for model '{}'", model));
893            }
894
895            let response_str = String::from_utf8_lossy(&response[..n]);
896            let status_code = parse_http_status(&response_str)
897                .ok_or_else(|| "Failed to parse HTTP status".to_string())?;
898
899            if status_code != 200 {
900                return Err(format!(
901                    "Model '{}' status check failed: HTTP {}",
902                    model, status_code
903                ));
904            }
905
906            // Extract status field from JSON body
907            if let Some(body_start) = response_str.find("\r\n\r\n") {
908                let body = &response_str[body_start + 4..];
909                let status = extract_json_field(body, &self.config.status_field);
910
911                match status {
912                    Some(s) if s == self.config.expected_status => {
913                        trace!(
914                            target = %target,
915                            model = %model,
916                            status = %s,
917                            "Model status check passed"
918                        );
919                    }
920                    Some(s) => {
921                        return Err(format!(
922                            "Model '{}' status '{}' != expected '{}'",
923                            model, s, self.config.expected_status
924                        ));
925                    }
926                    None => {
927                        return Err(format!(
928                            "Model '{}' status field '{}' not found",
929                            model, self.config.status_field
930                        ));
931                    }
932                }
933            }
934        }
935
936        Ok(start.elapsed())
937    }
938
939    fn check_type(&self) -> &str {
940        "ModelStatus"
941    }
942}
943
944#[async_trait]
945impl HealthCheckImpl for QueueDepthCheck {
946    async fn check(&self, target: &str) -> Result<Duration, String> {
947        let start = Instant::now();
948
949        // Parse target address
950        let addr: SocketAddr = target
951            .parse()
952            .map_err(|e| format!("Invalid address: {}", e))?;
953
954        let endpoint = self.config.endpoint.as_ref().unwrap_or(&self.models_endpoint);
955
956        // Connect with timeout
957        let stream = time::timeout(self.timeout, TcpStream::connect(addr))
958            .await
959            .map_err(|_| format!("Connection timeout after {:?}", self.timeout))?
960            .map_err(|e| format!("Connection failed: {}", e))?;
961
962        // Build HTTP request
963        let request = format!(
964            "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: Sentinel-HealthCheck/1.0\r\nAccept: application/json\r\nConnection: close\r\n\r\n",
965            endpoint,
966            target
967        );
968
969        // Send request
970        let mut stream = stream;
971        stream
972            .write_all(request.as_bytes())
973            .await
974            .map_err(|e| format!("Failed to send request: {}", e))?;
975
976        // Read response
977        let mut response = vec![0u8; 8192];
978        let n = stream
979            .read(&mut response)
980            .await
981            .map_err(|e| format!("Failed to read response: {}", e))?;
982
983        if n == 0 {
984            return Err("Empty response".to_string());
985        }
986
987        let response_str = String::from_utf8_lossy(&response[..n]);
988
989        // Extract queue depth from header or body
990        let queue_depth = if let Some(ref header_name) = self.config.header {
991            extract_header_value(&response_str, header_name)
992                .and_then(|v| v.parse::<u64>().ok())
993        } else if let Some(ref field) = self.config.body_field {
994            if let Some(body_start) = response_str.find("\r\n\r\n") {
995                let body = &response_str[body_start + 4..];
996                extract_json_field(body, field).and_then(|v| v.parse::<u64>().ok())
997            } else {
998                None
999            }
1000        } else {
1001            return Err("No queue depth source configured (header or body_field)".to_string());
1002        };
1003
1004        let depth = queue_depth.ok_or_else(|| "Could not extract queue depth".to_string())?;
1005
1006        // Check thresholds
1007        if depth >= self.config.unhealthy_threshold {
1008            return Err(format!(
1009                "Queue depth {} exceeds unhealthy threshold {}",
1010                depth, self.config.unhealthy_threshold
1011            ));
1012        }
1013
1014        if depth >= self.config.degraded_threshold {
1015            warn!(
1016                target = %target,
1017                queue_depth = depth,
1018                threshold = self.config.degraded_threshold,
1019                "Queue depth exceeds degraded threshold"
1020            );
1021        }
1022
1023        trace!(
1024            target = %target,
1025            queue_depth = depth,
1026            "Queue depth check passed"
1027        );
1028
1029        Ok(start.elapsed())
1030    }
1031
1032    fn check_type(&self) -> &str {
1033        "QueueDepth"
1034    }
1035}
1036
1037#[async_trait]
1038impl HealthCheckImpl for CompositeInferenceHealthCheck {
1039    async fn check(&self, target: &str) -> Result<Duration, String> {
1040        let start = Instant::now();
1041
1042        // Run base inference check first (always required)
1043        self.base_check.check(target).await?;
1044
1045        // Run optional sub-checks (all must pass)
1046        if let Some(ref probe) = self.inference_probe {
1047            probe.check(target).await?;
1048        }
1049
1050        if let Some(ref status) = self.model_status {
1051            status.check(target).await?;
1052        }
1053
1054        if let Some(ref queue) = self.queue_depth {
1055            queue.check(target).await?;
1056        }
1057
1058        trace!(
1059            target = %target,
1060            total_time_ms = start.elapsed().as_millis(),
1061            "Composite inference health check passed"
1062        );
1063
1064        Ok(start.elapsed())
1065    }
1066
1067    fn check_type(&self) -> &str {
1068        "CompositeInference"
1069    }
1070}
1071
1072/// Extract a header value from HTTP response
1073fn extract_header_value(response: &str, header_name: &str) -> Option<String> {
1074    let header_lower = header_name.to_lowercase();
1075    for line in response.lines() {
1076        if line.is_empty() || line == "\r" {
1077            break; // End of headers
1078        }
1079        if let Some((name, value)) = line.split_once(':') {
1080            if name.trim().to_lowercase() == header_lower {
1081                return Some(value.trim().to_string());
1082            }
1083        }
1084    }
1085    None
1086}
1087
1088/// Extract a field from JSON body using dot notation (e.g., "status" or "state.loaded")
1089fn extract_json_field(body: &str, field_path: &str) -> Option<String> {
1090    let json: serde_json::Value = serde_json::from_str(body).ok()?;
1091    let parts: Vec<&str> = field_path.split('.').collect();
1092    let mut current = &json;
1093
1094    for part in parts {
1095        current = current.get(part)?;
1096    }
1097
1098    match current {
1099        serde_json::Value::String(s) => Some(s.clone()),
1100        serde_json::Value::Number(n) => Some(n.to_string()),
1101        serde_json::Value::Bool(b) => Some(b.to_string()),
1102        _ => None,
1103    }
1104}
1105
1106/// Parse HTTP status code from response
1107fn parse_http_status(response: &str) -> Option<u16> {
1108    response
1109        .lines()
1110        .next()?
1111        .split_whitespace()
1112        .nth(1)?
1113        .parse()
1114        .ok()
1115}
1116
1117/// Passive health checker that monitors request outcomes
1118///
1119/// Observes request success/failure rates to detect unhealthy targets
1120/// without performing explicit health checks. Works in combination with
1121/// `ActiveHealthChecker` for comprehensive health monitoring.
1122pub struct PassiveHealthChecker {
1123    /// Failure rate threshold (0.0 - 1.0)
1124    failure_rate_threshold: f64,
1125    /// Window size for calculating failure rate
1126    window_size: usize,
1127    /// Request outcomes per target (ring buffer)
1128    outcomes: Arc<RwLock<HashMap<String, Vec<bool>>>>,
1129    /// Last error per target
1130    last_errors: Arc<RwLock<HashMap<String, String>>>,
1131    /// Active health checker reference
1132    active_checker: Option<Arc<ActiveHealthChecker>>,
1133}
1134
1135impl PassiveHealthChecker {
1136    /// Create new passive health checker
1137    pub fn new(
1138        failure_rate_threshold: f64,
1139        window_size: usize,
1140        active_checker: Option<Arc<ActiveHealthChecker>>,
1141    ) -> Self {
1142        debug!(
1143            failure_rate_threshold = format!("{:.2}", failure_rate_threshold),
1144            window_size = window_size,
1145            has_active_checker = active_checker.is_some(),
1146            "Creating passive health checker"
1147        );
1148        Self {
1149            failure_rate_threshold,
1150            window_size,
1151            outcomes: Arc::new(RwLock::new(HashMap::new())),
1152            last_errors: Arc::new(RwLock::new(HashMap::new())),
1153            active_checker,
1154        }
1155    }
1156
1157    /// Record request outcome with optional error message
1158    pub async fn record_outcome(&self, target: &str, success: bool, error: Option<&str>) {
1159        trace!(
1160            target = %target,
1161            success = success,
1162            error = ?error,
1163            "Recording request outcome"
1164        );
1165
1166        // Track last error
1167        if let Some(err_msg) = error {
1168            self.last_errors
1169                .write()
1170                .await
1171                .insert(target.to_string(), err_msg.to_string());
1172        } else if success {
1173            // Clear last error on success
1174            self.last_errors.write().await.remove(target);
1175        }
1176
1177        let mut outcomes = self.outcomes.write().await;
1178        let target_outcomes = outcomes
1179            .entry(target.to_string())
1180            .or_insert_with(|| Vec::with_capacity(self.window_size));
1181
1182        // Add outcome to ring buffer
1183        if target_outcomes.len() >= self.window_size {
1184            target_outcomes.remove(0);
1185        }
1186        target_outcomes.push(success);
1187
1188        // Calculate failure rate
1189        let failures = target_outcomes.iter().filter(|&&s| !s).count();
1190        let failure_rate = failures as f64 / target_outcomes.len() as f64;
1191
1192        trace!(
1193            target = %target,
1194            failure_rate = format!("{:.2}", failure_rate),
1195            window_samples = target_outcomes.len(),
1196            failures = failures,
1197            "Updated failure rate"
1198        );
1199
1200        // Mark unhealthy if failure rate exceeds threshold
1201        if failure_rate > self.failure_rate_threshold {
1202            warn!(
1203                target = %target,
1204                failure_rate = format!("{:.2}", failure_rate * 100.0),
1205                threshold = format!("{:.2}", self.failure_rate_threshold * 100.0),
1206                window_samples = target_outcomes.len(),
1207                "Failure rate exceeds threshold"
1208            );
1209            if let Some(ref checker) = self.active_checker {
1210                checker
1211                    .mark_unhealthy(
1212                        target,
1213                        format!(
1214                            "Failure rate {:.2}% exceeds threshold",
1215                            failure_rate * 100.0
1216                        ),
1217                    )
1218                    .await;
1219            }
1220        }
1221    }
1222
1223    /// Get failure rate for a target
1224    pub async fn get_failure_rate(&self, target: &str) -> Option<f64> {
1225        let outcomes = self.outcomes.read().await;
1226        outcomes.get(target).map(|target_outcomes| {
1227            let failures = target_outcomes.iter().filter(|&&s| !s).count();
1228            failures as f64 / target_outcomes.len() as f64
1229        })
1230    }
1231
1232    /// Get last error for a target
1233    pub async fn get_last_error(&self, target: &str) -> Option<String> {
1234        self.last_errors.read().await.get(target).cloned()
1235    }
1236}
1237
1238// ============================================================================
1239// Warmth Tracker (Passive Cold Model Detection)
1240// ============================================================================
1241
1242use dashmap::DashMap;
1243use sentinel_common::{ColdModelAction, WarmthDetectionConfig};
1244use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
1245
1246/// Warmth tracker for detecting cold models after idle periods
1247///
1248/// This is a passive tracker that observes actual request latency rather than
1249/// sending active probes. It tracks baseline latency per target and detects
1250/// when first-request latency after an idle period indicates a cold model.
1251pub struct WarmthTracker {
1252    /// Configuration for warmth detection
1253    config: WarmthDetectionConfig,
1254    /// Per-target warmth state
1255    targets: DashMap<String, TargetWarmthState>,
1256}
1257
1258/// Per-target warmth tracking state
1259struct TargetWarmthState {
1260    /// Baseline latency in milliseconds (EWMA)
1261    baseline_latency_ms: AtomicU64,
1262    /// Number of samples collected for baseline
1263    sample_count: AtomicU32,
1264    /// Last request timestamp (millis since epoch)
1265    last_request_ms: AtomicU64,
1266    /// Currently considered cold
1267    is_cold: AtomicBool,
1268    /// Total cold starts detected (for metrics)
1269    cold_start_count: AtomicU64,
1270}
1271
1272impl TargetWarmthState {
1273    fn new() -> Self {
1274        Self {
1275            baseline_latency_ms: AtomicU64::new(0),
1276            sample_count: AtomicU32::new(0),
1277            last_request_ms: AtomicU64::new(0),
1278            is_cold: AtomicBool::new(false),
1279            cold_start_count: AtomicU64::new(0),
1280        }
1281    }
1282
1283    fn update_baseline(&self, latency_ms: u64, sample_size: u32) {
1284        let count = self.sample_count.fetch_add(1, Ordering::Relaxed);
1285        let current = self.baseline_latency_ms.load(Ordering::Relaxed);
1286
1287        if count < sample_size {
1288            // Building initial baseline - simple average
1289            let new_baseline = if count == 0 {
1290                latency_ms
1291            } else {
1292                (current * count as u64 + latency_ms) / (count as u64 + 1)
1293            };
1294            self.baseline_latency_ms.store(new_baseline, Ordering::Relaxed);
1295        } else {
1296            // EWMA update: new = alpha * sample + (1 - alpha) * old
1297            // Using alpha = 0.1 for smooth updates
1298            let alpha = 0.1_f64;
1299            let new_baseline =
1300                (alpha * latency_ms as f64 + (1.0 - alpha) * current as f64) as u64;
1301            self.baseline_latency_ms.store(new_baseline, Ordering::Relaxed);
1302        }
1303    }
1304}
1305
1306impl WarmthTracker {
1307    /// Create a new warmth tracker with the given configuration
1308    pub fn new(config: WarmthDetectionConfig) -> Self {
1309        Self {
1310            config,
1311            targets: DashMap::new(),
1312        }
1313    }
1314
1315    /// Create a warmth tracker with default configuration
1316    pub fn with_defaults() -> Self {
1317        Self::new(WarmthDetectionConfig {
1318            sample_size: 10,
1319            cold_threshold_multiplier: 3.0,
1320            idle_cold_timeout_secs: 300,
1321            cold_action: ColdModelAction::LogOnly,
1322        })
1323    }
1324
1325    /// Record a completed request and detect cold starts
1326    ///
1327    /// Returns true if a cold start was detected
1328    pub fn record_request(&self, target: &str, latency: Duration) -> bool {
1329        let now_ms = std::time::SystemTime::now()
1330            .duration_since(std::time::UNIX_EPOCH)
1331            .map(|d| d.as_millis() as u64)
1332            .unwrap_or(0);
1333
1334        let latency_ms = latency.as_millis() as u64;
1335        let idle_threshold_ms = self.config.idle_cold_timeout_secs * 1000;
1336
1337        let state = self
1338            .targets
1339            .entry(target.to_string())
1340            .or_insert_with(TargetWarmthState::new);
1341
1342        let last_request = state.last_request_ms.load(Ordering::Relaxed);
1343        let idle_duration_ms = if last_request > 0 {
1344            now_ms.saturating_sub(last_request)
1345        } else {
1346            0
1347        };
1348
1349        // Update last request time
1350        state.last_request_ms.store(now_ms, Ordering::Relaxed);
1351
1352        // Check if this might be a cold start (first request after idle period)
1353        if idle_duration_ms >= idle_threshold_ms {
1354            let baseline = state.baseline_latency_ms.load(Ordering::Relaxed);
1355
1356            // Only check if we have a baseline
1357            if baseline > 0 {
1358                let threshold = (baseline as f64 * self.config.cold_threshold_multiplier) as u64;
1359
1360                if latency_ms > threshold {
1361                    // Cold start detected!
1362                    state.is_cold.store(true, Ordering::Release);
1363                    state.cold_start_count.fetch_add(1, Ordering::Relaxed);
1364
1365                    warn!(
1366                        target = %target,
1367                        latency_ms = latency_ms,
1368                        baseline_ms = baseline,
1369                        threshold_ms = threshold,
1370                        idle_duration_secs = idle_duration_ms / 1000,
1371                        cold_action = ?self.config.cold_action,
1372                        "Cold model detected - latency spike after idle period"
1373                    );
1374
1375                    return true;
1376                }
1377            }
1378        }
1379
1380        // Normal request - update baseline and clear cold flag
1381        state.is_cold.store(false, Ordering::Release);
1382        state.update_baseline(latency_ms, self.config.sample_size);
1383
1384        trace!(
1385            target = %target,
1386            latency_ms = latency_ms,
1387            baseline_ms = state.baseline_latency_ms.load(Ordering::Relaxed),
1388            sample_count = state.sample_count.load(Ordering::Relaxed),
1389            "Recorded request latency for warmth tracking"
1390        );
1391
1392        false
1393    }
1394
1395    /// Check if a target is currently considered cold
1396    pub fn is_cold(&self, target: &str) -> bool {
1397        self.targets
1398            .get(target)
1399            .map(|s| s.is_cold.load(Ordering::Acquire))
1400            .unwrap_or(false)
1401    }
1402
1403    /// Get the configured action for cold models
1404    pub fn cold_action(&self) -> ColdModelAction {
1405        self.config.cold_action
1406    }
1407
1408    /// Get baseline latency for a target (in ms)
1409    pub fn baseline_latency_ms(&self, target: &str) -> Option<u64> {
1410        self.targets
1411            .get(target)
1412            .map(|s| s.baseline_latency_ms.load(Ordering::Relaxed))
1413    }
1414
1415    /// Get cold start count for a target
1416    pub fn cold_start_count(&self, target: &str) -> u64 {
1417        self.targets
1418            .get(target)
1419            .map(|s| s.cold_start_count.load(Ordering::Relaxed))
1420            .unwrap_or(0)
1421    }
1422
1423    /// Check if warmth tracking should affect load balancing for this target
1424    pub fn should_deprioritize(&self, target: &str) -> bool {
1425        if !self.is_cold(target) {
1426            return false;
1427        }
1428
1429        match self.config.cold_action {
1430            ColdModelAction::LogOnly => false,
1431            ColdModelAction::MarkDegraded | ColdModelAction::MarkUnhealthy => true,
1432        }
1433    }
1434}
1435
1436#[cfg(test)]
1437mod tests {
1438    use super::*;
1439
1440    #[tokio::test]
1441    async fn test_health_status() {
1442        let status = TargetHealthInfo::new();
1443        assert!(status.healthy);
1444        assert_eq!(status.health_score(), 1.0);
1445        assert!(!status.is_degraded());
1446    }
1447
1448    #[tokio::test]
1449    async fn test_passive_health_checker() {
1450        let checker = PassiveHealthChecker::new(0.5, 10, None);
1451
1452        // Record some outcomes
1453        for _ in 0..5 {
1454            checker.record_outcome("target1", true, None).await;
1455        }
1456        for _ in 0..3 {
1457            checker.record_outcome("target1", false, Some("HTTP 503")).await;
1458        }
1459
1460        let failure_rate = checker.get_failure_rate("target1").await.unwrap();
1461        assert!(failure_rate > 0.3 && failure_rate < 0.4);
1462    }
1463
1464    #[test]
1465    fn test_parse_http_status() {
1466        let response = "HTTP/1.1 200 OK\r\n";
1467        assert_eq!(parse_http_status(response), Some(200));
1468
1469        let response = "HTTP/1.1 404 Not Found\r\n";
1470        assert_eq!(parse_http_status(response), Some(404));
1471
1472        let response = "Invalid response";
1473        assert_eq!(parse_http_status(response), None);
1474    }
1475
1476    #[test]
1477    fn test_warmth_tracker_baseline() {
1478        let tracker = WarmthTracker::with_defaults();
1479
1480        // First few requests should build baseline
1481        for i in 0..10 {
1482            let cold = tracker.record_request("target1", Duration::from_millis(100));
1483            assert!(!cold, "Should not detect cold on request {}", i);
1484        }
1485
1486        // Check baseline was established
1487        let baseline = tracker.baseline_latency_ms("target1");
1488        assert!(baseline.is_some());
1489        assert!(baseline.unwrap() > 0 && baseline.unwrap() <= 100);
1490    }
1491
1492    #[test]
1493    fn test_warmth_tracker_cold_detection() {
1494        let config = WarmthDetectionConfig {
1495            sample_size: 5,
1496            cold_threshold_multiplier: 2.0,
1497            idle_cold_timeout_secs: 0, // Immediate idle for testing
1498            cold_action: ColdModelAction::MarkDegraded,
1499        };
1500        let tracker = WarmthTracker::new(config);
1501
1502        // Build baseline with 100ms latency
1503        for _ in 0..5 {
1504            tracker.record_request("target1", Duration::from_millis(100));
1505        }
1506
1507        // Wait a tiny bit to simulate idle
1508        std::thread::sleep(Duration::from_millis(10));
1509
1510        // Next request with 3x latency (> 2x threshold) should detect cold
1511        let cold = tracker.record_request("target1", Duration::from_millis(300));
1512        assert!(cold, "Should detect cold start");
1513        assert!(tracker.is_cold("target1"));
1514        assert_eq!(tracker.cold_start_count("target1"), 1);
1515    }
1516
1517    #[test]
1518    fn test_warmth_tracker_no_cold_on_normal_latency() {
1519        let config = WarmthDetectionConfig {
1520            sample_size: 5,
1521            cold_threshold_multiplier: 3.0,
1522            idle_cold_timeout_secs: 0,
1523            cold_action: ColdModelAction::LogOnly,
1524        };
1525        let tracker = WarmthTracker::new(config);
1526
1527        // Build baseline
1528        for _ in 0..5 {
1529            tracker.record_request("target1", Duration::from_millis(100));
1530        }
1531
1532        std::thread::sleep(Duration::from_millis(10));
1533
1534        // Request with only 1.5x latency (< 3x threshold) should not detect cold
1535        let cold = tracker.record_request("target1", Duration::from_millis(150));
1536        assert!(!cold, "Should not detect cold for normal variation");
1537        assert!(!tracker.is_cold("target1"));
1538    }
1539
1540    #[test]
1541    fn test_warmth_tracker_deprioritize() {
1542        let config = WarmthDetectionConfig {
1543            sample_size: 2,
1544            cold_threshold_multiplier: 2.0,
1545            idle_cold_timeout_secs: 0,
1546            cold_action: ColdModelAction::MarkDegraded,
1547        };
1548        let tracker = WarmthTracker::new(config);
1549
1550        // Build baseline and trigger cold
1551        tracker.record_request("target1", Duration::from_millis(100));
1552        tracker.record_request("target1", Duration::from_millis(100));
1553        std::thread::sleep(Duration::from_millis(10));
1554        tracker.record_request("target1", Duration::from_millis(300));
1555
1556        // Should deprioritize when cold and action is MarkDegraded
1557        assert!(tracker.should_deprioritize("target1"));
1558
1559        // New normal request clears cold flag
1560        tracker.record_request("target1", Duration::from_millis(100));
1561        assert!(!tracker.should_deprioritize("target1"));
1562    }
1563}