1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11use tracing::{debug, info, warn};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AuthProfile {
16 pub name: String,
18 pub provider: String,
20 pub api_key: String,
22 pub base_url: Option<String>,
24 #[serde(default = "default_true")]
26 pub enabled: bool,
27 #[serde(default)]
29 pub rate_limit_rpm: u32,
30}
31
32fn default_true() -> bool {
33 true
34}
35
36#[derive(Debug)]
38pub struct HealthTracker {
39 consecutive_failures: AtomicU64,
41 total_requests: AtomicU64,
43 total_failures: AtomicU64,
45 circuit_open_since: AtomicU64,
47 circuit_break_secs: u64,
49}
50
51impl HealthTracker {
52 pub fn new(circuit_break_secs: u64) -> Self {
53 Self {
54 consecutive_failures: AtomicU64::new(0),
55 total_requests: AtomicU64::new(0),
56 total_failures: AtomicU64::new(0),
57 circuit_open_since: AtomicU64::new(0),
58 circuit_break_secs,
59 }
60 }
61
62 pub fn record_success(&self) {
64 self.total_requests.fetch_add(1, Ordering::Relaxed);
65 self.consecutive_failures.store(0, Ordering::Relaxed);
66 self.circuit_open_since.store(0, Ordering::Relaxed);
68 }
69
70 pub fn record_failure(&self) {
72 self.total_requests.fetch_add(1, Ordering::Relaxed);
73 self.total_failures.fetch_add(1, Ordering::Relaxed);
74 let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
75
76 if failures >= 3 {
78 let now = SystemTime::now()
79 .duration_since(UNIX_EPOCH)
80 .unwrap_or_default()
81 .as_secs();
82 self.circuit_open_since.store(now, Ordering::Relaxed);
83 warn!(
84 consecutive_failures = failures,
85 "Circuit breaker opened after {} consecutive failures",
86 failures
87 );
88 }
89 }
90
91 pub fn is_circuit_open(&self) -> bool {
93 let opened_at = self.circuit_open_since.load(Ordering::Relaxed);
94 if opened_at == 0 {
95 return false;
96 }
97
98 let now = SystemTime::now()
99 .duration_since(UNIX_EPOCH)
100 .unwrap_or_default()
101 .as_secs();
102 let elapsed = now.saturating_sub(opened_at);
103
104 if elapsed >= self.circuit_break_secs {
106 debug!(
107 elapsed_secs = elapsed,
108 "Circuit breaker entering half-open state"
109 );
110 return false;
111 }
112
113 true
114 }
115
116 pub fn stats(&self) -> HealthStats {
118 HealthStats {
119 total_requests: self.total_requests.load(Ordering::Relaxed),
120 total_failures: self.total_failures.load(Ordering::Relaxed),
121 consecutive_failures: self.consecutive_failures.load(Ordering::Relaxed),
122 circuit_open: self.is_circuit_open(),
123 }
124 }
125}
126
127#[derive(Debug, Clone, Serialize)]
129pub struct HealthStats {
130 pub total_requests: u64,
131 pub total_failures: u64,
132 pub consecutive_failures: u64,
133 pub circuit_open: bool,
134}
135
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum FailoverStrategy {
140 Sequential,
142 RoundRobin,
144 Random,
146}
147
148impl Default for FailoverStrategy {
149 fn default() -> Self {
150 Self::Sequential
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct FailoverConfig {
157 #[serde(default)]
159 pub profiles: Vec<AuthProfile>,
160
161 #[serde(default)]
163 pub strategy: FailoverStrategy,
164
165 #[serde(default)]
167 pub fallback_models: Vec<String>,
168
169 #[serde(default = "default_circuit_break")]
171 pub circuit_break_secs: u64,
172
173 #[serde(default = "default_max_retries")]
175 pub max_retries: u32,
176}
177
178fn default_circuit_break() -> u64 {
179 60
180}
181
182fn default_max_retries() -> u32 {
183 3
184}
185
186impl Default for FailoverConfig {
187 fn default() -> Self {
188 Self {
189 profiles: Vec::new(),
190 strategy: FailoverStrategy::default(),
191 fallback_models: Vec::new(),
192 circuit_break_secs: default_circuit_break(),
193 max_retries: default_max_retries(),
194 }
195 }
196}
197
198pub struct FailoverManager {
202 configs: HashMap<String, FailoverConfig>,
204 health: HashMap<String, HealthTracker>,
206 rr_index: HashMap<String, AtomicUsize>,
208}
209
210impl FailoverManager {
211 pub fn new() -> Self {
213 Self {
214 configs: HashMap::new(),
215 health: HashMap::new(),
216 rr_index: HashMap::new(),
217 }
218 }
219
220 pub fn register(&mut self, provider: String, config: FailoverConfig) {
222 let circuit_secs = config.circuit_break_secs;
223
224 for profile in &config.profiles {
226 let key = format!("{}/{}", provider, profile.name);
227 self.health
228 .insert(key, HealthTracker::new(circuit_secs));
229 }
230
231 self.rr_index
232 .insert(provider.clone(), AtomicUsize::new(0));
233 self.configs.insert(provider, config);
234
235 info!("Failover config registered");
236 }
237
238 pub fn select_profile(&self, provider: &str) -> Option<&AuthProfile> {
242 let config = self.configs.get(provider)?;
243 let healthy: Vec<_> = config
244 .profiles
245 .iter()
246 .filter(|p| {
247 if !p.enabled {
248 return false;
249 }
250 let key = format!("{}/{}", provider, p.name);
251 if let Some(tracker) = self.health.get(&key) {
252 !tracker.is_circuit_open()
253 } else {
254 true
255 }
256 })
257 .collect();
258
259 if healthy.is_empty() {
260 warn!(provider = %provider, "No healthy profiles available");
261 return None;
262 }
263
264 match config.strategy {
265 FailoverStrategy::Sequential => Some(healthy[0]),
266 FailoverStrategy::RoundRobin => {
267 if let Some(idx) = self.rr_index.get(provider) {
268 let i = idx.fetch_add(1, Ordering::Relaxed) % healthy.len();
269 Some(healthy[i])
270 } else {
271 Some(healthy[0])
272 }
273 }
274 FailoverStrategy::Random => {
275 use std::collections::hash_map::DefaultHasher;
276 use std::hash::{Hash, Hasher};
277
278 let now = SystemTime::now()
279 .duration_since(UNIX_EPOCH)
280 .unwrap_or(Duration::ZERO)
281 .as_nanos();
282 let mut hasher = DefaultHasher::new();
283 now.hash(&mut hasher);
284 let i = hasher.finish() as usize % healthy.len();
285 Some(healthy[i])
286 }
287 }
288 }
289
290 pub fn record_success(&self, provider: &str, profile_name: &str) {
292 let key = format!("{}/{}", provider, profile_name);
293 if let Some(tracker) = self.health.get(&key) {
294 tracker.record_success();
295 }
296 }
297
298 pub fn record_failure(&self, provider: &str, profile_name: &str) {
300 let key = format!("{}/{}", provider, profile_name);
301 if let Some(tracker) = self.health.get(&key) {
302 tracker.record_failure();
303 }
304 }
305
306 pub fn fallback_models(&self, provider: &str) -> &[String] {
308 self.configs
309 .get(provider)
310 .map(|c| c.fallback_models.as_slice())
311 .unwrap_or(&[])
312 }
313
314 pub fn health_report(&self) -> HashMap<String, HealthStats> {
316 self.health
317 .iter()
318 .map(|(key, tracker)| (key.clone(), tracker.stats()))
319 .collect()
320 }
321
322 pub fn has_failover(&self, provider: &str) -> bool {
324 self.configs
325 .get(provider)
326 .map(|c| c.profiles.len() > 1 || !c.fallback_models.is_empty())
327 .unwrap_or(false)
328 }
329}
330
331impl Default for FailoverManager {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 fn make_profile(name: &str, provider: &str) -> AuthProfile {
342 AuthProfile {
343 name: name.to_string(),
344 provider: provider.to_string(),
345 api_key: format!("key-{}", name),
346 base_url: None,
347 enabled: true,
348 rate_limit_rpm: 0,
349 }
350 }
351
352 #[test]
353 fn test_health_tracker_success_resets_failures() {
354 let tracker = HealthTracker::new(60);
355 tracker.record_failure();
356 tracker.record_failure();
357 assert_eq!(tracker.stats().consecutive_failures, 2);
358
359 tracker.record_success();
360 assert_eq!(tracker.stats().consecutive_failures, 0);
361 }
362
363 #[test]
364 fn test_circuit_opens_after_3_failures() {
365 let tracker = HealthTracker::new(60);
366 tracker.record_failure();
367 tracker.record_failure();
368 assert!(!tracker.is_circuit_open());
369
370 tracker.record_failure();
371 assert!(tracker.is_circuit_open());
372 }
373
374 #[test]
375 fn test_failover_sequential() {
376 let mut mgr = FailoverManager::new();
377 mgr.register(
378 "openai".to_string(),
379 FailoverConfig {
380 profiles: vec![
381 make_profile("primary", "openai"),
382 make_profile("backup", "openai"),
383 ],
384 strategy: FailoverStrategy::Sequential,
385 ..Default::default()
386 },
387 );
388
389 let profile = mgr.select_profile("openai").unwrap();
390 assert_eq!(profile.name, "primary");
391 }
392
393 #[test]
394 fn test_failover_round_robin() {
395 let mut mgr = FailoverManager::new();
396 mgr.register(
397 "openai".to_string(),
398 FailoverConfig {
399 profiles: vec![
400 make_profile("a", "openai"),
401 make_profile("b", "openai"),
402 ],
403 strategy: FailoverStrategy::RoundRobin,
404 ..Default::default()
405 },
406 );
407
408 let first = mgr.select_profile("openai").unwrap().name.clone();
409 let second = mgr.select_profile("openai").unwrap().name.clone();
410 assert_ne!(first, second);
411 }
412
413 #[test]
414 fn test_failover_skips_broken_circuit() {
415 let mut mgr = FailoverManager::new();
416 mgr.register(
417 "openai".to_string(),
418 FailoverConfig {
419 profiles: vec![
420 make_profile("primary", "openai"),
421 make_profile("backup", "openai"),
422 ],
423 strategy: FailoverStrategy::Sequential,
424 ..Default::default()
425 },
426 );
427
428 for _ in 0..3 {
430 mgr.record_failure("openai", "primary");
431 }
432
433 let profile = mgr.select_profile("openai").unwrap();
434 assert_eq!(profile.name, "backup");
435 }
436
437 #[test]
438 fn test_fallback_models() {
439 let mut mgr = FailoverManager::new();
440 mgr.register(
441 "anthropic".to_string(),
442 FailoverConfig {
443 profiles: vec![make_profile("main", "anthropic")],
444 fallback_models: vec!["openai/gpt-4.1".to_string()],
445 ..Default::default()
446 },
447 );
448
449 assert_eq!(mgr.fallback_models("anthropic"), &["openai/gpt-4.1"]);
450 assert!(mgr.has_failover("anthropic"));
451 }
452
453 #[test]
454 fn test_no_failover() {
455 let mgr = FailoverManager::new();
456 assert!(!mgr.has_failover("missing"));
457 assert!(mgr.select_profile("missing").is_none());
458 }
459
460 #[test]
461 fn test_health_report() {
462 let mut mgr = FailoverManager::new();
463 mgr.register(
464 "openai".to_string(),
465 FailoverConfig {
466 profiles: vec![make_profile("main", "openai")],
467 ..Default::default()
468 },
469 );
470
471 mgr.record_success("openai", "main");
472 mgr.record_success("openai", "main");
473 mgr.record_failure("openai", "main");
474
475 let report = mgr.health_report();
476 let stats = report.get("openai/main").unwrap();
477 assert_eq!(stats.total_requests, 3);
478 assert_eq!(stats.total_failures, 1);
479 assert_eq!(stats.consecutive_failures, 1);
480 }
481}