Skip to main content

rustant_core/providers/
failover.rs

1//! Multi-provider failover with circuit breaker and auth profile rotation.
2//!
3//! Provides resilient LLM access by:
4//! - Trying providers in priority order
5//! - Skipping providers with open circuit breakers
6//! - Rotating auth profiles on rate limit errors
7//! - Automatic recovery via half-open circuit state
8
9use crate::brain::LlmProvider;
10use crate::error::LlmError;
11use crate::types::{CompletionRequest, CompletionResponse, Message, StreamEvent};
12use async_trait::async_trait;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::sync::{mpsc, Mutex};
16use tracing::{debug, info, warn};
17
18// ---------------------------------------------------------------------------
19// Circuit Breaker
20// ---------------------------------------------------------------------------
21
22/// State of a circuit breaker.
23#[derive(Debug, Clone, Copy, PartialEq)]
24pub enum CircuitState {
25    /// Normal operation — calls are permitted.
26    Closed,
27    /// Too many failures — calls are blocked.
28    Open { since: Instant },
29    /// Recovery probe — one call is permitted to test the provider.
30    HalfOpen,
31}
32
33/// A circuit breaker that trips after consecutive failures and recovers
34/// after a timeout.
35#[derive(Debug)]
36pub struct CircuitBreaker {
37    state: CircuitState,
38    failure_count: usize,
39    failure_threshold: usize,
40    recovery_timeout: Duration,
41}
42
43impl CircuitBreaker {
44    pub fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
45        Self {
46            state: CircuitState::Closed,
47            failure_count: 0,
48            failure_threshold,
49            recovery_timeout,
50        }
51    }
52
53    /// Whether a call is currently permitted.
54    pub fn is_call_permitted(&mut self) -> bool {
55        match self.state {
56            CircuitState::Closed => true,
57            CircuitState::Open { since } => {
58                if since.elapsed() >= self.recovery_timeout {
59                    debug!("Circuit breaker transitioning to half-open");
60                    self.state = CircuitState::HalfOpen;
61                    true
62                } else {
63                    false
64                }
65            }
66            CircuitState::HalfOpen => true,
67        }
68    }
69
70    /// Record a successful call.
71    pub fn record_success(&mut self) {
72        self.failure_count = 0;
73        if self.state == CircuitState::HalfOpen {
74            debug!("Circuit breaker closing after successful probe");
75        }
76        self.state = CircuitState::Closed;
77    }
78
79    /// Record a failed call.
80    pub fn record_failure(&mut self) {
81        self.failure_count += 1;
82        if self.failure_count >= self.failure_threshold {
83            let now = Instant::now();
84            warn!(
85                failures = self.failure_count,
86                threshold = self.failure_threshold,
87                "Circuit breaker opening"
88            );
89            self.state = CircuitState::Open { since: now };
90        }
91    }
92
93    /// Get the current state.
94    pub fn state(&self) -> CircuitState {
95        self.state
96    }
97}
98
99// ---------------------------------------------------------------------------
100// Auth Profile
101// ---------------------------------------------------------------------------
102
103/// A single set of credentials for a provider.
104#[derive(Debug, Clone)]
105pub struct AuthProfile {
106    /// Environment variable name containing the API key.
107    pub api_key_env: String,
108    /// When the profile entered cooldown (rate-limited).
109    cooldown_until: Option<Instant>,
110    /// Cooldown duration when rate-limited.
111    cooldown_duration: Duration,
112}
113
114impl AuthProfile {
115    pub fn new(api_key_env: impl Into<String>) -> Self {
116        Self {
117            api_key_env: api_key_env.into(),
118            cooldown_until: None,
119            cooldown_duration: Duration::from_secs(60),
120        }
121    }
122
123    pub fn with_cooldown_duration(mut self, duration: Duration) -> Self {
124        self.cooldown_duration = duration;
125        self
126    }
127
128    /// Whether this profile can be used right now.
129    pub fn is_available(&self) -> bool {
130        match self.cooldown_until {
131            None => true,
132            Some(until) => Instant::now() >= until,
133        }
134    }
135
136    /// Put this profile into cooldown.
137    pub fn trigger_cooldown(&mut self) {
138        info!(
139            env_var = %self.api_key_env,
140            cooldown_secs = self.cooldown_duration.as_secs(),
141            "Auth profile entering cooldown"
142        );
143        self.cooldown_until = Some(Instant::now() + self.cooldown_duration);
144    }
145}
146
147// ---------------------------------------------------------------------------
148// Provider Entry
149// ---------------------------------------------------------------------------
150
151/// A provider with its circuit breaker and priority.
152struct ProviderEntry {
153    provider: Arc<dyn LlmProvider>,
154    circuit_breaker: Mutex<CircuitBreaker>,
155    #[allow(dead_code)]
156    priority: u8,
157}
158
159// ---------------------------------------------------------------------------
160// FailoverProvider
161// ---------------------------------------------------------------------------
162
163/// An LLM provider that tries multiple backends in priority order,
164/// skipping providers with open circuit breakers.
165pub struct FailoverProvider {
166    providers: Vec<ProviderEntry>,
167}
168
169impl FailoverProvider {
170    /// Create a new FailoverProvider.
171    ///
172    /// Providers are tried in the order given. The first provider is the primary.
173    pub fn new(
174        providers: Vec<Arc<dyn LlmProvider>>,
175        failure_threshold: usize,
176        recovery_timeout: Duration,
177    ) -> Self {
178        let entries = providers
179            .into_iter()
180            .enumerate()
181            .map(|(i, provider)| ProviderEntry {
182                provider,
183                circuit_breaker: Mutex::new(CircuitBreaker::new(
184                    failure_threshold,
185                    recovery_timeout,
186                )),
187                priority: i as u8,
188            })
189            .collect();
190
191        Self { providers: entries }
192    }
193
194    /// Get the primary (first) provider.
195    fn primary(&self) -> &dyn LlmProvider {
196        &*self.providers[0].provider
197    }
198}
199
200#[async_trait]
201impl LlmProvider for FailoverProvider {
202    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
203        let mut last_error = None;
204
205        for (i, entry) in self.providers.iter().enumerate() {
206            let mut cb = entry.circuit_breaker.lock().await;
207            if !cb.is_call_permitted() {
208                debug!(provider_index = i, "Skipping provider — circuit open");
209                continue;
210            }
211            drop(cb); // release lock before making the call
212
213            match entry.provider.complete(request.clone()).await {
214                Ok(response) => {
215                    let mut cb = entry.circuit_breaker.lock().await;
216                    cb.record_success();
217                    return Ok(response);
218                }
219                Err(e) => {
220                    warn!(
221                        provider_index = i,
222                        model = entry.provider.model_name(),
223                        error = %e,
224                        "Provider failed, trying next"
225                    );
226                    let mut cb = entry.circuit_breaker.lock().await;
227                    cb.record_failure();
228                    last_error = Some(e);
229                }
230            }
231        }
232
233        Err(last_error.unwrap_or(LlmError::Connection {
234            message: "All providers failed or circuits open".into(),
235        }))
236    }
237
238    async fn complete_streaming(
239        &self,
240        request: CompletionRequest,
241        tx: mpsc::Sender<StreamEvent>,
242    ) -> Result<(), LlmError> {
243        let mut last_error = None;
244
245        for (i, entry) in self.providers.iter().enumerate() {
246            let mut cb = entry.circuit_breaker.lock().await;
247            if !cb.is_call_permitted() {
248                debug!(provider_index = i, "Skipping provider — circuit open");
249                continue;
250            }
251            drop(cb);
252
253            match entry
254                .provider
255                .complete_streaming(request.clone(), tx.clone())
256                .await
257            {
258                Ok(()) => {
259                    let mut cb = entry.circuit_breaker.lock().await;
260                    cb.record_success();
261                    return Ok(());
262                }
263                Err(e) => {
264                    warn!(
265                        provider_index = i,
266                        error = %e,
267                        "Provider streaming failed, trying next"
268                    );
269                    let mut cb = entry.circuit_breaker.lock().await;
270                    cb.record_failure();
271                    last_error = Some(e);
272                }
273            }
274        }
275
276        Err(last_error.unwrap_or(LlmError::Connection {
277            message: "All providers failed or circuits open".into(),
278        }))
279    }
280
281    fn estimate_tokens(&self, messages: &[Message]) -> usize {
282        self.primary().estimate_tokens(messages)
283    }
284
285    fn context_window(&self) -> usize {
286        self.primary().context_window()
287    }
288
289    fn supports_tools(&self) -> bool {
290        self.primary().supports_tools()
291    }
292
293    fn cost_per_token(&self) -> (f64, f64) {
294        self.primary().cost_per_token()
295    }
296
297    fn model_name(&self) -> &str {
298        self.primary().model_name()
299    }
300}
301
302// ---------------------------------------------------------------------------
303// Tests
304// ---------------------------------------------------------------------------
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::brain::MockLlmProvider;
310    use crate::types::{CompletionResponse, Message};
311
312    /// A provider that always fails with a given error type.
313    struct AlwaysFailProvider {
314        model: String,
315        error: String,
316    }
317
318    impl AlwaysFailProvider {
319        fn new(model: &str, error: &str) -> Self {
320            Self {
321                model: model.to_string(),
322                error: error.to_string(),
323            }
324        }
325    }
326
327    #[async_trait]
328    impl LlmProvider for AlwaysFailProvider {
329        async fn complete(
330            &self,
331            _request: CompletionRequest,
332        ) -> Result<CompletionResponse, LlmError> {
333            match self.error.as_str() {
334                "rate_limited" => Err(LlmError::RateLimited {
335                    retry_after_secs: 5,
336                }),
337                "timeout" => Err(LlmError::Timeout { timeout_secs: 30 }),
338                _ => Err(LlmError::Connection {
339                    message: format!("Always fail: {}", self.error),
340                }),
341            }
342        }
343
344        async fn complete_streaming(
345            &self,
346            _request: CompletionRequest,
347            _tx: mpsc::Sender<StreamEvent>,
348        ) -> Result<(), LlmError> {
349            Err(LlmError::Connection {
350                message: "Always fail streaming".into(),
351            })
352        }
353
354        fn estimate_tokens(&self, _messages: &[Message]) -> usize {
355            100
356        }
357        fn context_window(&self) -> usize {
358            128_000
359        }
360        fn supports_tools(&self) -> bool {
361            true
362        }
363        fn cost_per_token(&self) -> (f64, f64) {
364            (0.0, 0.0)
365        }
366        fn model_name(&self) -> &str {
367            &self.model
368        }
369    }
370
371    /// A provider that fails N times then succeeds.
372    #[allow(dead_code)]
373    struct FailNThenSucceedProvider {
374        model: String,
375        failures_remaining: std::sync::Mutex<usize>,
376    }
377
378    impl FailNThenSucceedProvider {
379        #[allow(dead_code)]
380        fn new(model: &str, failures: usize) -> Self {
381            Self {
382                model: model.to_string(),
383                failures_remaining: std::sync::Mutex::new(failures),
384            }
385        }
386    }
387
388    #[async_trait]
389    impl LlmProvider for FailNThenSucceedProvider {
390        async fn complete(
391            &self,
392            _request: CompletionRequest,
393        ) -> Result<CompletionResponse, LlmError> {
394            let mut remaining = self.failures_remaining.lock().unwrap();
395            if *remaining > 0 {
396                *remaining -= 1;
397                Err(LlmError::Connection {
398                    message: "temporary failure".into(),
399                })
400            } else {
401                Ok(MockLlmProvider::text_response("recovered"))
402            }
403        }
404
405        async fn complete_streaming(
406            &self,
407            _request: CompletionRequest,
408            _tx: mpsc::Sender<StreamEvent>,
409        ) -> Result<(), LlmError> {
410            Ok(())
411        }
412
413        fn estimate_tokens(&self, _messages: &[Message]) -> usize {
414            100
415        }
416        fn context_window(&self) -> usize {
417            128_000
418        }
419        fn supports_tools(&self) -> bool {
420            true
421        }
422        fn cost_per_token(&self) -> (f64, f64) {
423            (0.0, 0.0)
424        }
425        fn model_name(&self) -> &str {
426            &self.model
427        }
428    }
429
430    // --- Circuit Breaker Tests ---
431
432    #[test]
433    fn test_circuit_breaker_starts_closed() {
434        let cb = CircuitBreaker::new(3, Duration::from_secs(60));
435        assert_eq!(cb.state(), CircuitState::Closed);
436    }
437
438    #[test]
439    fn test_circuit_breaker_opens_after_threshold() {
440        let mut cb = CircuitBreaker::new(3, Duration::from_secs(60));
441        cb.record_failure();
442        cb.record_failure();
443        assert_eq!(cb.state(), CircuitState::Closed); // not yet
444        cb.record_failure();
445        assert!(matches!(cb.state(), CircuitState::Open { .. }));
446    }
447
448    #[test]
449    fn test_circuit_breaker_blocks_calls_when_open() {
450        let mut cb = CircuitBreaker::new(2, Duration::from_secs(600));
451        cb.record_failure();
452        cb.record_failure();
453        assert!(!cb.is_call_permitted());
454    }
455
456    #[test]
457    fn test_circuit_breaker_half_open_after_timeout() {
458        let mut cb = CircuitBreaker::new(1, Duration::from_millis(1));
459        cb.record_failure();
460        assert!(matches!(cb.state(), CircuitState::Open { .. }));
461
462        // Wait for recovery timeout
463        std::thread::sleep(Duration::from_millis(5));
464        assert!(cb.is_call_permitted()); // transitions to HalfOpen
465        assert_eq!(cb.state(), CircuitState::HalfOpen);
466    }
467
468    #[test]
469    fn test_circuit_breaker_closes_on_success_in_half_open() {
470        let mut cb = CircuitBreaker::new(1, Duration::from_millis(1));
471        cb.record_failure();
472        std::thread::sleep(Duration::from_millis(5));
473        cb.is_call_permitted(); // transitions to HalfOpen
474        cb.record_success();
475        assert_eq!(cb.state(), CircuitState::Closed);
476        assert_eq!(cb.failure_count, 0);
477    }
478
479    #[test]
480    fn test_circuit_breaker_success_resets_count() {
481        let mut cb = CircuitBreaker::new(3, Duration::from_secs(60));
482        cb.record_failure();
483        cb.record_failure();
484        cb.record_success();
485        assert_eq!(cb.failure_count, 0);
486        assert_eq!(cb.state(), CircuitState::Closed);
487    }
488
489    // --- Auth Profile Tests ---
490
491    #[test]
492    fn test_auth_profile_initially_available() {
493        let profile = AuthProfile::new("TEST_KEY");
494        assert!(profile.is_available());
495    }
496
497    #[test]
498    fn test_auth_profile_cooldown() {
499        let mut profile =
500            AuthProfile::new("TEST_KEY").with_cooldown_duration(Duration::from_millis(10));
501        profile.trigger_cooldown();
502        assert!(!profile.is_available());
503
504        std::thread::sleep(Duration::from_millis(15));
505        assert!(profile.is_available());
506    }
507
508    // --- FailoverProvider Tests ---
509
510    #[tokio::test]
511    async fn test_failover_primary_succeeds() {
512        let primary = Arc::new(MockLlmProvider::new());
513        primary.queue_response(MockLlmProvider::text_response("primary response"));
514
515        let fallback = Arc::new(MockLlmProvider::new());
516        fallback.queue_response(MockLlmProvider::text_response("fallback response"));
517
518        let provider = FailoverProvider::new(vec![primary, fallback], 3, Duration::from_secs(60));
519
520        let response = provider
521            .complete(CompletionRequest::default())
522            .await
523            .unwrap();
524        assert_eq!(response.message.content.as_text(), Some("primary response"));
525    }
526
527    #[tokio::test]
528    async fn test_failover_to_secondary() {
529        let primary: Arc<dyn LlmProvider> =
530            Arc::new(AlwaysFailProvider::new("primary", "connection"));
531        let fallback = Arc::new(MockLlmProvider::new());
532        fallback.queue_response(MockLlmProvider::text_response("fallback response"));
533        let fallback: Arc<dyn LlmProvider> = fallback;
534
535        let provider = FailoverProvider::new(vec![primary, fallback], 3, Duration::from_secs(60));
536
537        let response = provider
538            .complete(CompletionRequest::default())
539            .await
540            .unwrap();
541        assert_eq!(
542            response.message.content.as_text(),
543            Some("fallback response")
544        );
545    }
546
547    #[tokio::test]
548    async fn test_all_providers_fail() {
549        let p1: Arc<dyn LlmProvider> = Arc::new(AlwaysFailProvider::new("p1", "connection"));
550        let p2: Arc<dyn LlmProvider> = Arc::new(AlwaysFailProvider::new("p2", "timeout"));
551
552        let provider = FailoverProvider::new(vec![p1, p2], 3, Duration::from_secs(60));
553
554        let result = provider.complete(CompletionRequest::default()).await;
555        assert!(result.is_err());
556    }
557
558    #[tokio::test]
559    async fn test_circuit_breaker_opens_and_skips_provider() {
560        // Primary fails with threshold=1 so circuit opens immediately
561        let primary: Arc<dyn LlmProvider> =
562            Arc::new(AlwaysFailProvider::new("primary", "connection"));
563        let fallback = Arc::new(MockLlmProvider::new());
564        // Queue enough responses for multiple calls
565        for _ in 0..5 {
566            fallback.queue_response(MockLlmProvider::text_response("fallback"));
567        }
568        let fallback: Arc<dyn LlmProvider> = fallback;
569
570        let provider = FailoverProvider::new(
571            vec![primary, fallback],
572            1,                        // open after 1 failure
573            Duration::from_secs(600), // long recovery so it stays open
574        );
575
576        // First call: primary fails, circuit opens, fallback succeeds
577        let r1 = provider
578            .complete(CompletionRequest::default())
579            .await
580            .unwrap();
581        assert_eq!(r1.message.content.as_text(), Some("fallback"));
582
583        // Second call: primary skipped (circuit open), fallback used directly
584        let r2 = provider
585            .complete(CompletionRequest::default())
586            .await
587            .unwrap();
588        assert_eq!(r2.message.content.as_text(), Some("fallback"));
589    }
590
591    #[tokio::test]
592    async fn test_failover_provider_delegates_properties() {
593        let primary = Arc::new(MockLlmProvider::new());
594        let provider = FailoverProvider::new(
595            vec![primary as Arc<dyn LlmProvider>],
596            3,
597            Duration::from_secs(60),
598        );
599
600        assert_eq!(provider.model_name(), "mock-model");
601        assert_eq!(provider.context_window(), 128_000);
602        assert!(provider.supports_tools());
603        assert_eq!(provider.cost_per_token(), (0.0, 0.0));
604    }
605
606    #[tokio::test]
607    async fn test_failover_streaming() {
608        let primary: Arc<dyn LlmProvider> =
609            Arc::new(AlwaysFailProvider::new("primary", "connection"));
610        let fallback = Arc::new(MockLlmProvider::new());
611        fallback.queue_response(MockLlmProvider::text_response("streamed"));
612        let fallback: Arc<dyn LlmProvider> = fallback;
613
614        let provider = FailoverProvider::new(vec![primary, fallback], 3, Duration::from_secs(60));
615
616        let (tx, mut rx) = mpsc::channel(32);
617        provider
618            .complete_streaming(CompletionRequest::default(), tx)
619            .await
620            .unwrap();
621
622        let mut tokens = Vec::new();
623        while let Some(event) = rx.recv().await {
624            match event {
625                StreamEvent::Token(t) => tokens.push(t),
626                StreamEvent::Done { .. } => break,
627                _ => {}
628            }
629        }
630        assert!(!tokens.is_empty());
631    }
632}