1use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderHealth, Result,
7 SimpleAgentsError,
8};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone)]
15pub struct LatencyRouterConfig {
16 pub alpha: f64,
18 pub slow_threshold: Duration,
20}
21
22impl Default for LatencyRouterConfig {
23 fn default() -> Self {
24 Self {
25 alpha: 0.2,
26 slow_threshold: Duration::from_secs(2),
27 }
28 }
29}
30
31#[derive(Clone, Copy, Debug)]
32struct LatencyStats {
33 avg_latency_ms: f64,
34 samples: u64,
35 health: ProviderHealth,
36}
37
38impl LatencyStats {
39 fn new() -> Self {
40 Self {
41 avg_latency_ms: 0.0,
42 samples: 0,
43 health: ProviderHealth::Healthy,
44 }
45 }
46
47 fn record(&mut self, latency: Duration, alpha: f64, slow_threshold: Duration) {
48 let latency_ms = latency.as_secs_f64() * 1000.0;
49 if self.samples == 0 {
50 self.avg_latency_ms = latency_ms;
51 } else {
52 let previous = self.avg_latency_ms;
53 self.avg_latency_ms = (alpha * latency_ms) + ((1.0 - alpha) * previous);
54 }
55 self.samples = self.samples.saturating_add(1);
56
57 let threshold_ms = slow_threshold.as_secs_f64() * 1000.0;
58 self.health = if self.avg_latency_ms >= threshold_ms {
59 ProviderHealth::Degraded
60 } else {
61 ProviderHealth::Healthy
62 };
63 }
64}
65
66pub struct LatencyRouter {
68 providers: Vec<Arc<dyn Provider>>,
69 stats: Mutex<Vec<LatencyStats>>,
70 counter: AtomicUsize,
71 config: LatencyRouterConfig,
72}
73
74impl LatencyRouter {
75 pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
80 Self::with_config(providers, LatencyRouterConfig::default())
81 }
82
83 pub fn with_config(
88 providers: Vec<Arc<dyn Provider>>,
89 config: LatencyRouterConfig,
90 ) -> Result<Self> {
91 if providers.is_empty() {
92 return Err(SimpleAgentsError::Routing(
93 "no providers configured".to_string(),
94 ));
95 }
96
97 let stats = vec![LatencyStats::new(); providers.len()];
98 Ok(Self {
99 providers,
100 stats: Mutex::new(stats),
101 counter: AtomicUsize::new(0),
102 config,
103 })
104 }
105
106 pub fn provider_count(&self) -> usize {
108 self.providers.len()
109 }
110
111 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
113 let index = self.select_provider_index()?;
114 let provider = &self.providers[index];
115 let start = Instant::now();
116 let provider_request = provider.transform_request(request)?;
117 let provider_response = provider.execute(provider_request).await?;
118 let response = provider.transform_response(provider_response)?;
119 self.record_latency(index, start.elapsed());
120 Ok(response)
121 }
122
123 pub async fn stream(
125 &self,
126 request: &CompletionRequest,
127 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
128 let index = self.select_provider_index()?;
129 let provider = &self.providers[index];
130 let provider_request = provider.transform_request(request)?;
131 provider.execute_stream(provider_request).await
132 }
133
134 fn select_provider_index(&self) -> Result<usize> {
135 let len = self.providers.len();
136 if len == 0 {
137 return Err(SimpleAgentsError::Routing(
138 "no providers configured".to_string(),
139 ));
140 }
141
142 let stats = self
143 .stats
144 .lock()
145 .unwrap_or_else(|poisoned| poisoned.into_inner());
146 let mut best_index: Option<usize> = None;
147 let mut best_latency = f64::MAX;
148 let mut has_samples = false;
149 let mut has_healthy = false;
150
151 for stat in stats.iter() {
152 if stat.samples == 0 {
153 continue;
154 }
155 has_samples = true;
156 if stat.health == ProviderHealth::Healthy {
157 has_healthy = true;
158 }
159 }
160
161 if has_samples {
162 for (index, stat) in stats.iter().enumerate() {
163 if stat.samples == 0 {
164 continue;
165 }
166 if has_healthy && stat.health != ProviderHealth::Healthy {
167 continue;
168 }
169 if stat.avg_latency_ms < best_latency {
170 best_latency = stat.avg_latency_ms;
171 best_index = Some(index);
172 }
173 }
174 }
175
176 if let Some(index) = best_index {
177 return Ok(index);
178 }
179
180 let index = self.counter.fetch_add(1, Ordering::Relaxed);
181 Ok(index % len)
182 }
183
184 fn record_latency(&self, index: usize, latency: Duration) {
185 let mut stats = self
186 .stats
187 .lock()
188 .unwrap_or_else(|poisoned| poisoned.into_inner());
189 if let Some(stat) = stats.get_mut(index) {
190 stat.record(latency, self.config.alpha, self.config.slow_threshold);
191 }
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use async_trait::async_trait;
199 use simple_agent_type::prelude::*;
200
201 struct MockProvider {
202 name: &'static str,
203 }
204
205 impl MockProvider {
206 fn new(name: &'static str) -> Self {
207 Self { name }
208 }
209 }
210
211 #[async_trait]
212 impl Provider for MockProvider {
213 fn name(&self) -> &str {
214 self.name
215 }
216
217 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
218 Ok(ProviderRequest::new("http://example.com"))
219 }
220
221 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
222 Ok(ProviderResponse::new(200, serde_json::Value::Null))
223 }
224
225 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
226 Ok(CompletionResponse {
227 id: "resp_test".to_string(),
228 model: "test-model".to_string(),
229 choices: vec![CompletionChoice {
230 index: 0,
231 message: Message::assistant("ok"),
232 finish_reason: FinishReason::Stop,
233 logprobs: None,
234 }],
235 usage: Usage::new(1, 1),
236 created: None,
237 provider: Some(self.name().to_string()),
238 healing_metadata: None,
239 })
240 }
241 }
242
243 fn build_request() -> CompletionRequest {
244 CompletionRequest::builder()
245 .model("test-model")
246 .message(Message::user("hello"))
247 .build()
248 .unwrap()
249 }
250
251 #[test]
252 fn empty_router_returns_error() {
253 let result = LatencyRouter::new(Vec::new());
254 match result {
255 Ok(_) => panic!("expected error, got Ok"),
256 Err(SimpleAgentsError::Routing(message)) => {
257 assert_eq!(message, "no providers configured");
258 }
259 Err(_) => panic!("unexpected error type"),
260 }
261 }
262
263 #[test]
264 fn selects_lowest_latency_provider() {
265 let router = LatencyRouter::new(vec![
266 Arc::new(MockProvider::new("p1")),
267 Arc::new(MockProvider::new("p2")),
268 ])
269 .unwrap();
270
271 router.record_latency(0, Duration::from_millis(250));
272 router.record_latency(1, Duration::from_millis(50));
273
274 let index = router.select_provider_index().unwrap();
275 assert_eq!(index, 1);
276 }
277
278 #[test]
279 fn prefers_healthy_over_degraded() {
280 let config = LatencyRouterConfig {
281 alpha: 1.0,
282 slow_threshold: Duration::from_millis(100),
283 };
284 let router = LatencyRouter::with_config(
285 vec![
286 Arc::new(MockProvider::new("p1")),
287 Arc::new(MockProvider::new("p2")),
288 ],
289 config,
290 )
291 .unwrap();
292
293 router.record_latency(0, Duration::from_millis(400));
294 router.record_latency(1, Duration::from_millis(80));
295
296 let index = router.select_provider_index().unwrap();
297 assert_eq!(index, 1);
298 }
299
300 #[test]
301 fn round_robin_when_no_metrics() {
302 let router = LatencyRouter::new(vec![
303 Arc::new(MockProvider::new("p1")),
304 Arc::new(MockProvider::new("p2")),
305 ])
306 .unwrap();
307
308 let first = router.select_provider_index().unwrap();
309 let second = router.select_provider_index().unwrap();
310
311 assert_eq!(first, 0);
312 assert_eq!(second, 1);
313 }
314
315 #[tokio::test]
316 async fn records_latency_on_success() {
317 let router = LatencyRouter::new(vec![Arc::new(MockProvider::new("p1"))]).unwrap();
318 let request = build_request();
319
320 let _ = router.complete(&request).await.unwrap();
321 let stats = router.stats.lock().expect("latency stats lock poisoned");
322 assert_eq!(stats[0].samples, 1);
323 }
324}