Skip to main content

rustyclaw_core/models/
failover.rs

1//! Model failover and auth profile rotation.
2//!
3//! Provides automatic failover when a model provider returns errors, and
4//! rotation through multiple auth profiles (API keys) for load distribution
5//! and rate-limit avoidance.
6
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::time::{Duration, SystemTime, UNIX_EPOCH};
11use tracing::{debug, info, warn};
12
13/// An auth profile — a named set of credentials for a provider.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AuthProfile {
16    /// Profile name (e.g., "primary", "backup", "org-key").
17    pub name: String,
18    /// Provider this profile is for (e.g., "openai", "anthropic").
19    pub provider: String,
20    /// API key or token.
21    pub api_key: String,
22    /// Optional base URL override.
23    pub base_url: Option<String>,
24    /// Whether this profile is enabled.
25    #[serde(default = "default_true")]
26    pub enabled: bool,
27    /// Max requests per minute (0 = unlimited).
28    #[serde(default)]
29    pub rate_limit_rpm: u32,
30}
31
32fn default_true() -> bool {
33    true
34}
35
36/// Tracks the health of a model/profile combination.
37#[derive(Debug)]
38pub struct HealthTracker {
39    /// Consecutive failures.
40    consecutive_failures: AtomicU64,
41    /// Total requests.
42    total_requests: AtomicU64,
43    /// Total failures.
44    total_failures: AtomicU64,
45    /// Unix timestamp when the circuit was opened (0 = closed).
46    circuit_open_since: AtomicU64,
47    /// Duration in seconds before retrying after circuit opens.
48    circuit_break_secs: u64,
49}
50
51impl HealthTracker {
52    pub fn new(circuit_break_secs: u64) -> Self {
53        Self {
54            consecutive_failures: AtomicU64::new(0),
55            total_requests: AtomicU64::new(0),
56            total_failures: AtomicU64::new(0),
57            circuit_open_since: AtomicU64::new(0),
58            circuit_break_secs,
59        }
60    }
61
62    /// Record a successful request.
63    pub fn record_success(&self) {
64        self.total_requests.fetch_add(1, Ordering::Relaxed);
65        self.consecutive_failures.store(0, Ordering::Relaxed);
66        // Close circuit on success
67        self.circuit_open_since.store(0, Ordering::Relaxed);
68    }
69
70    /// Record a failed request.
71    pub fn record_failure(&self) {
72        self.total_requests.fetch_add(1, Ordering::Relaxed);
73        self.total_failures.fetch_add(1, Ordering::Relaxed);
74        let failures = self.consecutive_failures.fetch_add(1, Ordering::Relaxed) + 1;
75
76        // Open circuit after 3 consecutive failures
77        if failures >= 3 {
78            let now = SystemTime::now()
79                .duration_since(UNIX_EPOCH)
80                .unwrap_or_default()
81                .as_secs();
82            self.circuit_open_since.store(now, Ordering::Relaxed);
83            warn!(
84                consecutive_failures = failures,
85                "Circuit breaker opened after {} consecutive failures",
86                failures
87            );
88        }
89    }
90
91    /// Check if the circuit is open (should not send requests).
92    pub fn is_circuit_open(&self) -> bool {
93        let opened_at = self.circuit_open_since.load(Ordering::Relaxed);
94        if opened_at == 0 {
95            return false;
96        }
97
98        let now = SystemTime::now()
99            .duration_since(UNIX_EPOCH)
100            .unwrap_or_default()
101            .as_secs();
102        let elapsed = now.saturating_sub(opened_at);
103
104        // Allow retry after circuit_break_secs (half-open state)
105        if elapsed >= self.circuit_break_secs {
106            debug!(
107                elapsed_secs = elapsed,
108                "Circuit breaker entering half-open state"
109            );
110            return false;
111        }
112
113        true
114    }
115
116    /// Get health statistics.
117    pub fn stats(&self) -> HealthStats {
118        HealthStats {
119            total_requests: self.total_requests.load(Ordering::Relaxed),
120            total_failures: self.total_failures.load(Ordering::Relaxed),
121            consecutive_failures: self.consecutive_failures.load(Ordering::Relaxed),
122            circuit_open: self.is_circuit_open(),
123        }
124    }
125}
126
127/// Health statistics snapshot.
128#[derive(Debug, Clone, Serialize)]
129pub struct HealthStats {
130    pub total_requests: u64,
131    pub total_failures: u64,
132    pub consecutive_failures: u64,
133    pub circuit_open: bool,
134}
135
136/// Failover strategy.
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
138#[serde(rename_all = "snake_case")]
139pub enum FailoverStrategy {
140    /// Try profiles in order, skip to next on failure.
141    Sequential,
142    /// Round-robin across healthy profiles.
143    RoundRobin,
144    /// Random selection from healthy profiles.
145    Random,
146}
147
148impl Default for FailoverStrategy {
149    fn default() -> Self {
150        Self::Sequential
151    }
152}
153
154/// Failover configuration for a provider.
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct FailoverConfig {
157    /// Profiles for this provider.
158    #[serde(default)]
159    pub profiles: Vec<AuthProfile>,
160
161    /// Failover strategy.
162    #[serde(default)]
163    pub strategy: FailoverStrategy,
164
165    /// Fallback model IDs to try if all profiles for the primary model fail.
166    #[serde(default)]
167    pub fallback_models: Vec<String>,
168
169    /// Seconds before retrying a circuit-broken profile.
170    #[serde(default = "default_circuit_break")]
171    pub circuit_break_secs: u64,
172
173    /// Maximum retries before giving up entirely.
174    #[serde(default = "default_max_retries")]
175    pub max_retries: u32,
176}
177
178fn default_circuit_break() -> u64 {
179    60
180}
181
182fn default_max_retries() -> u32 {
183    3
184}
185
186impl Default for FailoverConfig {
187    fn default() -> Self {
188        Self {
189            profiles: Vec::new(),
190            strategy: FailoverStrategy::default(),
191            fallback_models: Vec::new(),
192            circuit_break_secs: default_circuit_break(),
193            max_retries: default_max_retries(),
194        }
195    }
196}
197
198/// Model failover manager.
199///
200/// Manages auth profile rotation and automatic failover for model providers.
201pub struct FailoverManager {
202    /// Per-provider failover configs.
203    configs: HashMap<String, FailoverConfig>,
204    /// Health trackers keyed by "provider/profile_name".
205    health: HashMap<String, HealthTracker>,
206    /// Round-robin index per provider.
207    rr_index: HashMap<String, AtomicUsize>,
208}
209
210impl FailoverManager {
211    /// Create a new failover manager.
212    pub fn new() -> Self {
213        Self {
214            configs: HashMap::new(),
215            health: HashMap::new(),
216            rr_index: HashMap::new(),
217        }
218    }
219
220    /// Register a failover configuration for a provider.
221    pub fn register(&mut self, provider: String, config: FailoverConfig) {
222        let circuit_secs = config.circuit_break_secs;
223
224        // Create health trackers for each profile
225        for profile in &config.profiles {
226            let key = format!("{}/{}", provider, profile.name);
227            self.health
228                .insert(key, HealthTracker::new(circuit_secs));
229        }
230
231        self.rr_index
232            .insert(provider.clone(), AtomicUsize::new(0));
233        self.configs.insert(provider, config);
234
235        info!("Failover config registered");
236    }
237
238    /// Select the next auth profile to use for a provider.
239    ///
240    /// Returns `None` if no healthy profiles are available.
241    pub fn select_profile(&self, provider: &str) -> Option<&AuthProfile> {
242        let config = self.configs.get(provider)?;
243        let healthy: Vec<_> = config
244            .profiles
245            .iter()
246            .filter(|p| {
247                if !p.enabled {
248                    return false;
249                }
250                let key = format!("{}/{}", provider, p.name);
251                if let Some(tracker) = self.health.get(&key) {
252                    !tracker.is_circuit_open()
253                } else {
254                    true
255                }
256            })
257            .collect();
258
259        if healthy.is_empty() {
260            warn!(provider = %provider, "No healthy profiles available");
261            return None;
262        }
263
264        match config.strategy {
265            FailoverStrategy::Sequential => Some(healthy[0]),
266            FailoverStrategy::RoundRobin => {
267                if let Some(idx) = self.rr_index.get(provider) {
268                    let i = idx.fetch_add(1, Ordering::Relaxed) % healthy.len();
269                    Some(healthy[i])
270                } else {
271                    Some(healthy[0])
272                }
273            }
274            FailoverStrategy::Random => {
275                use std::collections::hash_map::DefaultHasher;
276                use std::hash::{Hash, Hasher};
277
278                let now = SystemTime::now()
279                    .duration_since(UNIX_EPOCH)
280                    .unwrap_or(Duration::ZERO)
281                    .as_nanos();
282                let mut hasher = DefaultHasher::new();
283                now.hash(&mut hasher);
284                let i = hasher.finish() as usize % healthy.len();
285                Some(healthy[i])
286            }
287        }
288    }
289
290    /// Record a successful request for a profile.
291    pub fn record_success(&self, provider: &str, profile_name: &str) {
292        let key = format!("{}/{}", provider, profile_name);
293        if let Some(tracker) = self.health.get(&key) {
294            tracker.record_success();
295        }
296    }
297
298    /// Record a failed request for a profile.
299    pub fn record_failure(&self, provider: &str, profile_name: &str) {
300        let key = format!("{}/{}", provider, profile_name);
301        if let Some(tracker) = self.health.get(&key) {
302            tracker.record_failure();
303        }
304    }
305
306    /// Get fallback model IDs for a provider.
307    pub fn fallback_models(&self, provider: &str) -> &[String] {
308        self.configs
309            .get(provider)
310            .map(|c| c.fallback_models.as_slice())
311            .unwrap_or(&[])
312    }
313
314    /// Get health stats for all profiles.
315    pub fn health_report(&self) -> HashMap<String, HealthStats> {
316        self.health
317            .iter()
318            .map(|(key, tracker)| (key.clone(), tracker.stats()))
319            .collect()
320    }
321
322    /// Check if a provider has failover configured.
323    pub fn has_failover(&self, provider: &str) -> bool {
324        self.configs
325            .get(provider)
326            .map(|c| c.profiles.len() > 1 || !c.fallback_models.is_empty())
327            .unwrap_or(false)
328    }
329}
330
331impl Default for FailoverManager {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    fn make_profile(name: &str, provider: &str) -> AuthProfile {
342        AuthProfile {
343            name: name.to_string(),
344            provider: provider.to_string(),
345            api_key: format!("key-{}", name),
346            base_url: None,
347            enabled: true,
348            rate_limit_rpm: 0,
349        }
350    }
351
352    #[test]
353    fn test_health_tracker_success_resets_failures() {
354        let tracker = HealthTracker::new(60);
355        tracker.record_failure();
356        tracker.record_failure();
357        assert_eq!(tracker.stats().consecutive_failures, 2);
358
359        tracker.record_success();
360        assert_eq!(tracker.stats().consecutive_failures, 0);
361    }
362
363    #[test]
364    fn test_circuit_opens_after_3_failures() {
365        let tracker = HealthTracker::new(60);
366        tracker.record_failure();
367        tracker.record_failure();
368        assert!(!tracker.is_circuit_open());
369
370        tracker.record_failure();
371        assert!(tracker.is_circuit_open());
372    }
373
374    #[test]
375    fn test_failover_sequential() {
376        let mut mgr = FailoverManager::new();
377        mgr.register(
378            "openai".to_string(),
379            FailoverConfig {
380                profiles: vec![
381                    make_profile("primary", "openai"),
382                    make_profile("backup", "openai"),
383                ],
384                strategy: FailoverStrategy::Sequential,
385                ..Default::default()
386            },
387        );
388
389        let profile = mgr.select_profile("openai").unwrap();
390        assert_eq!(profile.name, "primary");
391    }
392
393    #[test]
394    fn test_failover_round_robin() {
395        let mut mgr = FailoverManager::new();
396        mgr.register(
397            "openai".to_string(),
398            FailoverConfig {
399                profiles: vec![
400                    make_profile("a", "openai"),
401                    make_profile("b", "openai"),
402                ],
403                strategy: FailoverStrategy::RoundRobin,
404                ..Default::default()
405            },
406        );
407
408        let first = mgr.select_profile("openai").unwrap().name.clone();
409        let second = mgr.select_profile("openai").unwrap().name.clone();
410        assert_ne!(first, second);
411    }
412
413    #[test]
414    fn test_failover_skips_broken_circuit() {
415        let mut mgr = FailoverManager::new();
416        mgr.register(
417            "openai".to_string(),
418            FailoverConfig {
419                profiles: vec![
420                    make_profile("primary", "openai"),
421                    make_profile("backup", "openai"),
422                ],
423                strategy: FailoverStrategy::Sequential,
424                ..Default::default()
425            },
426        );
427
428        // Break primary's circuit
429        for _ in 0..3 {
430            mgr.record_failure("openai", "primary");
431        }
432
433        let profile = mgr.select_profile("openai").unwrap();
434        assert_eq!(profile.name, "backup");
435    }
436
437    #[test]
438    fn test_fallback_models() {
439        let mut mgr = FailoverManager::new();
440        mgr.register(
441            "anthropic".to_string(),
442            FailoverConfig {
443                profiles: vec![make_profile("main", "anthropic")],
444                fallback_models: vec!["openai/gpt-4.1".to_string()],
445                ..Default::default()
446            },
447        );
448
449        assert_eq!(mgr.fallback_models("anthropic"), &["openai/gpt-4.1"]);
450        assert!(mgr.has_failover("anthropic"));
451    }
452
453    #[test]
454    fn test_no_failover() {
455        let mgr = FailoverManager::new();
456        assert!(!mgr.has_failover("missing"));
457        assert!(mgr.select_profile("missing").is_none());
458    }
459
460    #[test]
461    fn test_health_report() {
462        let mut mgr = FailoverManager::new();
463        mgr.register(
464            "openai".to_string(),
465            FailoverConfig {
466                profiles: vec![make_profile("main", "openai")],
467                ..Default::default()
468            },
469        );
470
471        mgr.record_success("openai", "main");
472        mgr.record_success("openai", "main");
473        mgr.record_failure("openai", "main");
474
475        let report = mgr.health_report();
476        let stats = report.get("openai/main").unwrap();
477        assert_eq!(stats.total_requests, 3);
478        assert_eq!(stats.total_failures, 1);
479        assert_eq!(stats.consecutive_failures, 1);
480    }
481}