1use simple_agent_type::prelude::{
6 CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderHealth, Result,
7 SimpleAgentsError,
8};
9use std::sync::atomic::{AtomicUsize, Ordering};
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone)]
15pub struct LatencyRouterConfig {
16 pub alpha: f64,
18 pub slow_threshold: Duration,
20}
21
22impl Default for LatencyRouterConfig {
23 fn default() -> Self {
24 Self {
25 alpha: 0.2,
26 slow_threshold: Duration::from_secs(2),
27 }
28 }
29}
30
31#[derive(Clone, Copy, Debug)]
32struct LatencyStats {
33 avg_latency_ms: f64,
34 samples: u64,
35 health: ProviderHealth,
36}
37
38impl LatencyStats {
39 fn new() -> Self {
40 Self {
41 avg_latency_ms: 0.0,
42 samples: 0,
43 health: ProviderHealth::Healthy,
44 }
45 }
46
47 fn record(&mut self, latency: Duration, alpha: f64, slow_threshold: Duration) {
48 let latency_ms = latency.as_secs_f64() * 1000.0;
49 if self.samples == 0 {
50 self.avg_latency_ms = latency_ms;
51 } else {
52 let previous = self.avg_latency_ms;
53 self.avg_latency_ms = (alpha * latency_ms) + ((1.0 - alpha) * previous);
54 }
55 self.samples = self.samples.saturating_add(1);
56
57 let threshold_ms = slow_threshold.as_secs_f64() * 1000.0;
58 self.health = if self.avg_latency_ms >= threshold_ms {
59 ProviderHealth::Degraded
60 } else {
61 ProviderHealth::Healthy
62 };
63 }
64}
65
66pub struct LatencyRouter {
68 providers: Vec<Arc<dyn Provider>>,
69 stats: Mutex<Vec<LatencyStats>>,
70 counter: AtomicUsize,
71 config: LatencyRouterConfig,
72}
73
74impl LatencyRouter {
75 pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
80 Self::with_config(providers, LatencyRouterConfig::default())
81 }
82
83 pub fn with_config(
88 providers: Vec<Arc<dyn Provider>>,
89 config: LatencyRouterConfig,
90 ) -> Result<Self> {
91 if providers.is_empty() {
92 return Err(SimpleAgentsError::Routing(
93 "no providers configured".to_string(),
94 ));
95 }
96
97 let stats = vec![LatencyStats::new(); providers.len()];
98 Ok(Self {
99 providers,
100 stats: Mutex::new(stats),
101 counter: AtomicUsize::new(0),
102 config,
103 })
104 }
105
106 pub fn provider_count(&self) -> usize {
108 self.providers.len()
109 }
110
111 pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
113 let index = self.select_provider_index()?;
114 let provider = &self.providers[index];
115 let start = Instant::now();
116 let provider_request = provider.transform_request(request)?;
117 let provider_response = provider.execute(provider_request).await?;
118 let response = provider.transform_response(provider_response)?;
119 self.record_latency(index, start.elapsed());
120 Ok(response)
121 }
122
123 pub async fn stream(
125 &self,
126 request: &CompletionRequest,
127 ) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
128 let index = self.select_provider_index()?;
129 let provider = &self.providers[index];
130 let provider_request = provider.transform_request(request)?;
131 provider.execute_stream(provider_request).await
132 }
133
134 fn select_provider_index(&self) -> Result<usize> {
135 let len = self.providers.len();
136 if len == 0 {
137 return Err(SimpleAgentsError::Routing(
138 "no providers configured".to_string(),
139 ));
140 }
141
142 let stats = self.stats.lock().expect("latency stats lock poisoned");
143 let mut best_index: Option<usize> = None;
144 let mut best_latency = f64::MAX;
145 let mut has_samples = false;
146 let mut has_healthy = false;
147
148 for stat in stats.iter() {
149 if stat.samples == 0 {
150 continue;
151 }
152 has_samples = true;
153 if stat.health == ProviderHealth::Healthy {
154 has_healthy = true;
155 }
156 }
157
158 if has_samples {
159 for (index, stat) in stats.iter().enumerate() {
160 if stat.samples == 0 {
161 continue;
162 }
163 if has_healthy && stat.health != ProviderHealth::Healthy {
164 continue;
165 }
166 if stat.avg_latency_ms < best_latency {
167 best_latency = stat.avg_latency_ms;
168 best_index = Some(index);
169 }
170 }
171 }
172
173 if let Some(index) = best_index {
174 return Ok(index);
175 }
176
177 let index = self.counter.fetch_add(1, Ordering::Relaxed);
178 Ok(index % len)
179 }
180
181 fn record_latency(&self, index: usize, latency: Duration) {
182 let mut stats = self.stats.lock().expect("latency stats lock poisoned");
183 if let Some(stat) = stats.get_mut(index) {
184 stat.record(latency, self.config.alpha, self.config.slow_threshold);
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use async_trait::async_trait;
193 use simple_agent_type::prelude::*;
194
195 struct MockProvider {
196 name: &'static str,
197 }
198
199 impl MockProvider {
200 fn new(name: &'static str) -> Self {
201 Self { name }
202 }
203 }
204
205 #[async_trait]
206 impl Provider for MockProvider {
207 fn name(&self) -> &str {
208 self.name
209 }
210
211 fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
212 Ok(ProviderRequest::new("http://example.com"))
213 }
214
215 async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
216 Ok(ProviderResponse::new(200, serde_json::Value::Null))
217 }
218
219 fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
220 Ok(CompletionResponse {
221 id: "resp_test".to_string(),
222 model: "test-model".to_string(),
223 choices: vec![CompletionChoice {
224 index: 0,
225 message: Message::assistant("ok"),
226 finish_reason: FinishReason::Stop,
227 logprobs: None,
228 }],
229 usage: Usage::new(1, 1),
230 created: None,
231 provider: Some(self.name().to_string()),
232 healing_metadata: None,
233 })
234 }
235 }
236
237 fn build_request() -> CompletionRequest {
238 CompletionRequest::builder()
239 .model("test-model")
240 .message(Message::user("hello"))
241 .build()
242 .unwrap()
243 }
244
245 #[test]
246 fn empty_router_returns_error() {
247 let result = LatencyRouter::new(Vec::new());
248 match result {
249 Ok(_) => panic!("expected error, got Ok"),
250 Err(SimpleAgentsError::Routing(message)) => {
251 assert_eq!(message, "no providers configured");
252 }
253 Err(_) => panic!("unexpected error type"),
254 }
255 }
256
257 #[test]
258 fn selects_lowest_latency_provider() {
259 let router = LatencyRouter::new(vec![
260 Arc::new(MockProvider::new("p1")),
261 Arc::new(MockProvider::new("p2")),
262 ])
263 .unwrap();
264
265 router.record_latency(0, Duration::from_millis(250));
266 router.record_latency(1, Duration::from_millis(50));
267
268 let index = router.select_provider_index().unwrap();
269 assert_eq!(index, 1);
270 }
271
272 #[test]
273 fn prefers_healthy_over_degraded() {
274 let config = LatencyRouterConfig {
275 alpha: 1.0,
276 slow_threshold: Duration::from_millis(100),
277 };
278 let router = LatencyRouter::with_config(
279 vec![
280 Arc::new(MockProvider::new("p1")),
281 Arc::new(MockProvider::new("p2")),
282 ],
283 config,
284 )
285 .unwrap();
286
287 router.record_latency(0, Duration::from_millis(400));
288 router.record_latency(1, Duration::from_millis(80));
289
290 let index = router.select_provider_index().unwrap();
291 assert_eq!(index, 1);
292 }
293
294 #[test]
295 fn round_robin_when_no_metrics() {
296 let router = LatencyRouter::new(vec![
297 Arc::new(MockProvider::new("p1")),
298 Arc::new(MockProvider::new("p2")),
299 ])
300 .unwrap();
301
302 let first = router.select_provider_index().unwrap();
303 let second = router.select_provider_index().unwrap();
304
305 assert_eq!(first, 0);
306 assert_eq!(second, 1);
307 }
308
309 #[tokio::test]
310 async fn records_latency_on_success() {
311 let router = LatencyRouter::new(vec![Arc::new(MockProvider::new("p1"))]).unwrap();
312 let request = build_request();
313
314 let _ = router.complete(&request).await.unwrap();
315 let stats = router.stats.lock().expect("latency stats lock poisoned");
316 assert_eq!(stats[0].samples, 1);
317 }
318}