Skip to main content

simple_agents_router/
latency.rs

1//! Latency-based routing implementation.
2//!
3//! Routes requests to provider with lowest observed latency.
4
5use simple_agent_type::prelude::{
6    CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderHealth, Result,
7    SimpleAgentsError,
8};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13/// Configuration for latency-based routing.
14#[derive(Debug, Clone)]
15pub struct LatencyRouterConfig {
16    /// Exponential moving average factor (0.0-1.0).
17    pub alpha: f64,
18    /// Threshold after which providers are marked degraded.
19    pub slow_threshold: Duration,
20}
21
22impl Default for LatencyRouterConfig {
23    fn default() -> Self {
24        Self {
25            alpha: 0.2,
26            slow_threshold: Duration::from_secs(2),
27        }
28    }
29}
30
31#[derive(Clone, Copy, Debug)]
32struct LatencyStats {
33    avg_latency_ms: f64,
34    samples: u64,
35    health: ProviderHealth,
36}
37
38impl LatencyStats {
39    fn new() -> Self {
40        Self {
41            avg_latency_ms: 0.0,
42            samples: 0,
43            health: ProviderHealth::Healthy,
44        }
45    }
46
47    fn record(&mut self, latency: Duration, alpha: f64, slow_threshold: Duration) {
48        let latency_ms = latency.as_secs_f64() * 1000.0;
49        if self.samples == 0 {
50            self.avg_latency_ms = latency_ms;
51        } else {
52            let previous = self.avg_latency_ms;
53            self.avg_latency_ms = (alpha * latency_ms) + ((1.0 - alpha) * previous);
54        }
55        self.samples = self.samples.saturating_add(1);
56
57        let threshold_ms = slow_threshold.as_secs_f64() * 1000.0;
58        self.health = if self.avg_latency_ms >= threshold_ms {
59            ProviderHealth::Degraded
60        } else {
61            ProviderHealth::Healthy
62        };
63    }
64}
65
66/// Router that selects providers based on observed latency.
67pub struct LatencyRouter {
68    providers: Vec<Arc<dyn Provider>>,
69    stats: Mutex<Vec<LatencyStats>>,
70    counter: AtomicUsize,
71    config: LatencyRouterConfig,
72}
73
74impl LatencyRouter {
75    /// Create a latency router with default configuration.
76    ///
77    /// # Errors
78    /// Returns a routing error if no providers are supplied.
79    pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
80        Self::with_config(providers, LatencyRouterConfig::default())
81    }
82
83    /// Create a latency router with explicit configuration.
84    ///
85    /// # Errors
86    /// Returns a routing error if no providers are supplied.
87    pub fn with_config(
88        providers: Vec<Arc<dyn Provider>>,
89        config: LatencyRouterConfig,
90    ) -> Result<Self> {
91        if providers.is_empty() {
92            return Err(SimpleAgentsError::Routing(
93                "no providers configured".to_string(),
94            ));
95        }
96
97        let stats = vec![LatencyStats::new(); providers.len()];
98        Ok(Self {
99            providers,
100            stats: Mutex::new(stats),
101            counter: AtomicUsize::new(0),
102            config,
103        })
104    }
105
106    /// Return the number of configured providers.
107    pub fn provider_count(&self) -> usize {
108        self.providers.len()
109    }
110
111    /// Execute a completion request using latency-based selection.
112    pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
113        let index = self.select_provider_index()?;
114        let provider = &self.providers[index];
115        let start = Instant::now();
116        let provider_request = provider.transform_request(request)?;
117        let provider_response = provider.execute(provider_request).await?;
118        let response = provider.transform_response(provider_response)?;
119        self.record_latency(index, start.elapsed());
120        Ok(response)
121    }
122
123    /// Execute a streaming request using latency-based selection.
124    pub async fn stream(
125        &self,
126        request: &CompletionRequest,
127    ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
128        let index = self.select_provider_index()?;
129        let provider = &self.providers[index];
130        let provider_request = provider.transform_request(request)?;
131        provider.execute_stream(provider_request).await
132    }
133
134    fn select_provider_index(&self) -> Result<usize> {
135        let len = self.providers.len();
136        if len == 0 {
137            return Err(SimpleAgentsError::Routing(
138                "no providers configured".to_string(),
139            ));
140        }
141
142        let stats = self
143            .stats
144            .lock()
145            .unwrap_or_else(|poisoned| poisoned.into_inner());
146        let mut best_index: Option<usize> = None;
147        let mut best_latency = f64::MAX;
148        let mut has_samples = false;
149        let mut has_healthy = false;
150
151        for stat in stats.iter() {
152            if stat.samples == 0 {
153                continue;
154            }
155            has_samples = true;
156            if stat.health == ProviderHealth::Healthy {
157                has_healthy = true;
158            }
159        }
160
161        if has_samples {
162            for (index, stat) in stats.iter().enumerate() {
163                if stat.samples == 0 {
164                    continue;
165                }
166                if has_healthy && stat.health != ProviderHealth::Healthy {
167                    continue;
168                }
169                if stat.avg_latency_ms < best_latency {
170                    best_latency = stat.avg_latency_ms;
171                    best_index = Some(index);
172                }
173            }
174        }
175
176        if let Some(index) = best_index {
177            return Ok(index);
178        }
179
180        let index = self.counter.fetch_add(1, Ordering::Relaxed);
181        Ok(index % len)
182    }
183
184    fn record_latency(&self, index: usize, latency: Duration) {
185        let mut stats = self
186            .stats
187            .lock()
188            .unwrap_or_else(|poisoned| poisoned.into_inner());
189        if let Some(stat) = stats.get_mut(index) {
190            stat.record(latency, self.config.alpha, self.config.slow_threshold);
191        }
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use async_trait::async_trait;
199    use simple_agent_type::prelude::*;
200
201    struct MockProvider {
202        name: &'static str,
203    }
204
205    impl MockProvider {
206        fn new(name: &'static str) -> Self {
207            Self { name }
208        }
209    }
210
211    #[async_trait]
212    impl Provider for MockProvider {
213        fn name(&self) -> &str {
214            self.name
215        }
216
217        fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
218            Ok(ProviderRequest::new("http://example.com"))
219        }
220
221        async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
222            Ok(ProviderResponse::new(200, serde_json::Value::Null))
223        }
224
225        fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
226            Ok(CompletionResponse {
227                id: "resp_test".to_string(),
228                model: "test-model".to_string(),
229                choices: vec![CompletionChoice {
230                    index: 0,
231                    message: Message::assistant("ok"),
232                    finish_reason: FinishReason::Stop,
233                    logprobs: None,
234                }],
235                usage: Usage::new(1, 1),
236                created: None,
237                provider: Some(self.name().to_string()),
238                healing_metadata: None,
239            })
240        }
241    }
242
243    fn build_request() -> CompletionRequest {
244        CompletionRequest::builder()
245            .model("test-model")
246            .message(Message::user("hello"))
247            .build()
248            .unwrap()
249    }
250
251    #[test]
252    fn empty_router_returns_error() {
253        let result = LatencyRouter::new(Vec::new());
254        match result {
255            Ok(_) => panic!("expected error, got Ok"),
256            Err(SimpleAgentsError::Routing(message)) => {
257                assert_eq!(message, "no providers configured");
258            }
259            Err(_) => panic!("unexpected error type"),
260        }
261    }
262
263    #[test]
264    fn selects_lowest_latency_provider() {
265        let router = LatencyRouter::new(vec![
266            Arc::new(MockProvider::new("p1")),
267            Arc::new(MockProvider::new("p2")),
268        ])
269        .unwrap();
270
271        router.record_latency(0, Duration::from_millis(250));
272        router.record_latency(1, Duration::from_millis(50));
273
274        let index = router.select_provider_index().unwrap();
275        assert_eq!(index, 1);
276    }
277
278    #[test]
279    fn prefers_healthy_over_degraded() {
280        let config = LatencyRouterConfig {
281            alpha: 1.0,
282            slow_threshold: Duration::from_millis(100),
283        };
284        let router = LatencyRouter::with_config(
285            vec![
286                Arc::new(MockProvider::new("p1")),
287                Arc::new(MockProvider::new("p2")),
288            ],
289            config,
290        )
291        .unwrap();
292
293        router.record_latency(0, Duration::from_millis(400));
294        router.record_latency(1, Duration::from_millis(80));
295
296        let index = router.select_provider_index().unwrap();
297        assert_eq!(index, 1);
298    }
299
300    #[test]
301    fn round_robin_when_no_metrics() {
302        let router = LatencyRouter::new(vec![
303            Arc::new(MockProvider::new("p1")),
304            Arc::new(MockProvider::new("p2")),
305        ])
306        .unwrap();
307
308        let first = router.select_provider_index().unwrap();
309        let second = router.select_provider_index().unwrap();
310
311        assert_eq!(first, 0);
312        assert_eq!(second, 1);
313    }
314
315    #[tokio::test]
316    async fn records_latency_on_success() {
317        let router = LatencyRouter::new(vec![Arc::new(MockProvider::new("p1"))]).unwrap();
318        let request = build_request();
319
320        let _ = router.complete(&request).await.unwrap();
321        let stats = router.stats.lock().expect("latency stats lock poisoned");
322        assert_eq!(stats[0].samples, 1);
323    }
324}