Skip to main content

simple_agents_router/
health.rs

1//! Provider health tracking for routing decisions.
2//!
3//! Maintains per-provider metrics and health state.
4
5use simple_agent_type::prelude::{ProviderHealth, ProviderMetrics};
6use std::sync::Mutex;
7use std::time::Duration;
8
9/// Configuration for health tracking.
10#[derive(Debug, Clone, Copy)]
11pub struct HealthTrackerConfig {
12    /// Failure rate above which providers are degraded.
13    pub degrade_threshold: f32,
14    /// Failure rate above which providers are marked unavailable.
15    pub unavailable_threshold: f32,
16    /// Exponential moving average factor for latency.
17    pub latency_alpha: f64,
18}
19
20impl Default for HealthTrackerConfig {
21    fn default() -> Self {
22        Self {
23            degrade_threshold: 0.2,
24            unavailable_threshold: 0.5,
25            latency_alpha: 0.2,
26        }
27    }
28}
29
30/// Health tracker for providers.
31#[derive(Debug)]
32pub struct HealthTracker {
33    metrics: Mutex<Vec<ProviderMetrics>>,
34    config: HealthTrackerConfig,
35}
36
37impl HealthTracker {
38    /// Create a tracker for the given number of providers.
39    pub fn new(provider_count: usize, config: HealthTrackerConfig) -> Self {
40        let metrics = vec![ProviderMetrics::default(); provider_count];
41        Self {
42            metrics: Mutex::new(metrics),
43            config,
44        }
45    }
46
47    /// Record a successful request.
48    pub fn record_success(&self, provider_index: usize, latency: Duration) {
49        let mut metrics = self
50            .metrics
51            .lock()
52            .unwrap_or_else(|poisoned| poisoned.into_inner());
53        if let Some(entry) = metrics.get_mut(provider_index) {
54            entry.total_requests = entry.total_requests.saturating_add(1);
55            entry.successful_requests = entry.successful_requests.saturating_add(1);
56            entry.avg_latency =
57                update_latency(entry.avg_latency, latency, self.config.latency_alpha);
58            entry.health = compute_health_with_config(entry, self.config);
59        }
60    }
61
62    /// Record a failed request.
63    pub fn record_failure(&self, provider_index: usize, latency: Option<Duration>) {
64        let mut metrics = self
65            .metrics
66            .lock()
67            .unwrap_or_else(|poisoned| poisoned.into_inner());
68        if let Some(entry) = metrics.get_mut(provider_index) {
69            entry.total_requests = entry.total_requests.saturating_add(1);
70            entry.failed_requests = entry.failed_requests.saturating_add(1);
71            if let Some(value) = latency {
72                entry.avg_latency =
73                    update_latency(entry.avg_latency, value, self.config.latency_alpha);
74            }
75            entry.health = compute_health_with_config(entry, self.config);
76        }
77    }
78
79    /// Get metrics for a provider.
80    pub fn metrics(&self, provider_index: usize) -> Option<ProviderMetrics> {
81        let metrics = self
82            .metrics
83            .lock()
84            .unwrap_or_else(|poisoned| poisoned.into_inner());
85        metrics.get(provider_index).copied()
86    }
87
88    /// Get health for a provider.
89    pub fn health(&self, provider_index: usize) -> Option<ProviderHealth> {
90        self.metrics(provider_index).map(|entry| entry.health)
91    }
92}
93
94fn update_latency(current: Duration, new_value: Duration, alpha: f64) -> Duration {
95    if current.as_millis() == 0 {
96        return new_value;
97    }
98    let current_ms = current.as_secs_f64() * 1000.0;
99    let new_ms = new_value.as_secs_f64() * 1000.0;
100    let ema = (alpha * new_ms) + ((1.0 - alpha) * current_ms);
101    Duration::from_millis(ema.max(0.0) as u64)
102}
103
104fn compute_health_with_config(
105    metrics: &ProviderMetrics,
106    config: HealthTrackerConfig,
107) -> ProviderHealth {
108    let failure_rate = metrics.failure_rate();
109    if failure_rate >= config.unavailable_threshold {
110        ProviderHealth::Unavailable
111    } else if failure_rate >= config.degrade_threshold {
112        ProviderHealth::Degraded
113    } else {
114        ProviderHealth::Healthy
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn success_updates_metrics() {
124        let tracker = HealthTracker::new(1, HealthTrackerConfig::default());
125        tracker.record_success(0, Duration::from_millis(100));
126        let metrics = tracker.metrics(0).unwrap();
127        assert_eq!(metrics.total_requests, 1);
128        assert_eq!(metrics.successful_requests, 1);
129        assert_eq!(metrics.failed_requests, 0);
130        assert_eq!(metrics.health, ProviderHealth::Healthy);
131    }
132
133    #[test]
134    fn failures_degrade_health() {
135        let config = HealthTrackerConfig {
136            degrade_threshold: 0.2,
137            unavailable_threshold: 0.5,
138            latency_alpha: 0.2,
139        };
140        let tracker = HealthTracker::new(1, config);
141
142        tracker.record_failure(0, Some(Duration::from_millis(50)));
143        tracker.record_failure(0, Some(Duration::from_millis(50)));
144        tracker.record_success(0, Duration::from_millis(50));
145        tracker.record_failure(0, Some(Duration::from_millis(50)));
146
147        let metrics = tracker.metrics(0).unwrap();
148        assert_eq!(metrics.total_requests, 4);
149        assert_eq!(metrics.failed_requests, 3);
150        assert_eq!(metrics.health, ProviderHealth::Unavailable);
151    }
152
153    #[test]
154    fn metrics_out_of_range_returns_none() {
155        let tracker = HealthTracker::new(1, HealthTrackerConfig::default());
156        assert!(tracker.metrics(5).is_none());
157        assert!(tracker.health(2).is_none());
158    }
159}