1use crate::brain::LlmProvider;
10use crate::error::LlmError;
11use crate::types::{CompletionRequest, CompletionResponse, Message, StreamEvent};
12use async_trait::async_trait;
13use std::sync::Arc;
14use std::time::{Duration, Instant};
15use tokio::sync::{mpsc, Mutex};
16use tracing::{debug, info, warn};
17
18#[derive(Debug, Clone, Copy, PartialEq)]
24pub enum CircuitState {
25 Closed,
27 Open { since: Instant },
29 HalfOpen,
31}
32
33#[derive(Debug)]
36pub struct CircuitBreaker {
37 state: CircuitState,
38 failure_count: usize,
39 failure_threshold: usize,
40 recovery_timeout: Duration,
41}
42
43impl CircuitBreaker {
44 pub fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
45 Self {
46 state: CircuitState::Closed,
47 failure_count: 0,
48 failure_threshold,
49 recovery_timeout,
50 }
51 }
52
53 pub fn is_call_permitted(&mut self) -> bool {
55 match self.state {
56 CircuitState::Closed => true,
57 CircuitState::Open { since } => {
58 if since.elapsed() >= self.recovery_timeout {
59 debug!("Circuit breaker transitioning to half-open");
60 self.state = CircuitState::HalfOpen;
61 true
62 } else {
63 false
64 }
65 }
66 CircuitState::HalfOpen => true,
67 }
68 }
69
70 pub fn record_success(&mut self) {
72 self.failure_count = 0;
73 if self.state == CircuitState::HalfOpen {
74 debug!("Circuit breaker closing after successful probe");
75 }
76 self.state = CircuitState::Closed;
77 }
78
79 pub fn record_failure(&mut self) {
81 self.failure_count += 1;
82 if self.failure_count >= self.failure_threshold {
83 let now = Instant::now();
84 warn!(
85 failures = self.failure_count,
86 threshold = self.failure_threshold,
87 "Circuit breaker opening"
88 );
89 self.state = CircuitState::Open { since: now };
90 }
91 }
92
93 pub fn state(&self) -> CircuitState {
95 self.state
96 }
97}
98
99#[derive(Debug, Clone)]
105pub struct AuthProfile {
106 pub api_key_env: String,
108 cooldown_until: Option<Instant>,
110 cooldown_duration: Duration,
112}
113
114impl AuthProfile {
115 pub fn new(api_key_env: impl Into<String>) -> Self {
116 Self {
117 api_key_env: api_key_env.into(),
118 cooldown_until: None,
119 cooldown_duration: Duration::from_secs(60),
120 }
121 }
122
123 pub fn with_cooldown_duration(mut self, duration: Duration) -> Self {
124 self.cooldown_duration = duration;
125 self
126 }
127
128 pub fn is_available(&self) -> bool {
130 match self.cooldown_until {
131 None => true,
132 Some(until) => Instant::now() >= until,
133 }
134 }
135
136 pub fn trigger_cooldown(&mut self) {
138 info!(
139 env_var = %self.api_key_env,
140 cooldown_secs = self.cooldown_duration.as_secs(),
141 "Auth profile entering cooldown"
142 );
143 self.cooldown_until = Some(Instant::now() + self.cooldown_duration);
144 }
145}
146
147struct ProviderEntry {
153 provider: Arc<dyn LlmProvider>,
154 circuit_breaker: Mutex<CircuitBreaker>,
155 #[allow(dead_code)]
156 priority: u8,
157}
158
159pub struct FailoverProvider {
166 providers: Vec<ProviderEntry>,
167}
168
169impl FailoverProvider {
170 pub fn new(
174 providers: Vec<Arc<dyn LlmProvider>>,
175 failure_threshold: usize,
176 recovery_timeout: Duration,
177 ) -> Self {
178 let entries = providers
179 .into_iter()
180 .enumerate()
181 .map(|(i, provider)| ProviderEntry {
182 provider,
183 circuit_breaker: Mutex::new(CircuitBreaker::new(
184 failure_threshold,
185 recovery_timeout,
186 )),
187 priority: i as u8,
188 })
189 .collect();
190
191 Self { providers: entries }
192 }
193
194 fn primary(&self) -> &dyn LlmProvider {
196 &*self.providers[0].provider
197 }
198}
199
200#[async_trait]
201impl LlmProvider for FailoverProvider {
202 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
203 let mut last_error = None;
204
205 for (i, entry) in self.providers.iter().enumerate() {
206 let mut cb = entry.circuit_breaker.lock().await;
207 if !cb.is_call_permitted() {
208 debug!(provider_index = i, "Skipping provider — circuit open");
209 continue;
210 }
211 drop(cb); match entry.provider.complete(request.clone()).await {
214 Ok(response) => {
215 let mut cb = entry.circuit_breaker.lock().await;
216 cb.record_success();
217 return Ok(response);
218 }
219 Err(e) => {
220 warn!(
221 provider_index = i,
222 model = entry.provider.model_name(),
223 error = %e,
224 "Provider failed, trying next"
225 );
226 let mut cb = entry.circuit_breaker.lock().await;
227 cb.record_failure();
228 last_error = Some(e);
229 }
230 }
231 }
232
233 Err(last_error.unwrap_or(LlmError::Connection {
234 message: "All providers failed or circuits open".into(),
235 }))
236 }
237
238 async fn complete_streaming(
239 &self,
240 request: CompletionRequest,
241 tx: mpsc::Sender<StreamEvent>,
242 ) -> Result<(), LlmError> {
243 let mut last_error = None;
244
245 for (i, entry) in self.providers.iter().enumerate() {
246 let mut cb = entry.circuit_breaker.lock().await;
247 if !cb.is_call_permitted() {
248 debug!(provider_index = i, "Skipping provider — circuit open");
249 continue;
250 }
251 drop(cb);
252
253 match entry
254 .provider
255 .complete_streaming(request.clone(), tx.clone())
256 .await
257 {
258 Ok(()) => {
259 let mut cb = entry.circuit_breaker.lock().await;
260 cb.record_success();
261 return Ok(());
262 }
263 Err(e) => {
264 warn!(
265 provider_index = i,
266 error = %e,
267 "Provider streaming failed, trying next"
268 );
269 let mut cb = entry.circuit_breaker.lock().await;
270 cb.record_failure();
271 last_error = Some(e);
272 }
273 }
274 }
275
276 Err(last_error.unwrap_or(LlmError::Connection {
277 message: "All providers failed or circuits open".into(),
278 }))
279 }
280
281 fn estimate_tokens(&self, messages: &[Message]) -> usize {
282 self.primary().estimate_tokens(messages)
283 }
284
285 fn context_window(&self) -> usize {
286 self.primary().context_window()
287 }
288
289 fn supports_tools(&self) -> bool {
290 self.primary().supports_tools()
291 }
292
293 fn cost_per_token(&self) -> (f64, f64) {
294 self.primary().cost_per_token()
295 }
296
297 fn model_name(&self) -> &str {
298 self.primary().model_name()
299 }
300}
301
302#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::brain::MockLlmProvider;
310 use crate::types::{CompletionResponse, Message};
311
312 struct AlwaysFailProvider {
314 model: String,
315 error: String,
316 }
317
318 impl AlwaysFailProvider {
319 fn new(model: &str, error: &str) -> Self {
320 Self {
321 model: model.to_string(),
322 error: error.to_string(),
323 }
324 }
325 }
326
327 #[async_trait]
328 impl LlmProvider for AlwaysFailProvider {
329 async fn complete(
330 &self,
331 _request: CompletionRequest,
332 ) -> Result<CompletionResponse, LlmError> {
333 match self.error.as_str() {
334 "rate_limited" => Err(LlmError::RateLimited {
335 retry_after_secs: 5,
336 }),
337 "timeout" => Err(LlmError::Timeout { timeout_secs: 30 }),
338 _ => Err(LlmError::Connection {
339 message: format!("Always fail: {}", self.error),
340 }),
341 }
342 }
343
344 async fn complete_streaming(
345 &self,
346 _request: CompletionRequest,
347 _tx: mpsc::Sender<StreamEvent>,
348 ) -> Result<(), LlmError> {
349 Err(LlmError::Connection {
350 message: "Always fail streaming".into(),
351 })
352 }
353
354 fn estimate_tokens(&self, _messages: &[Message]) -> usize {
355 100
356 }
357 fn context_window(&self) -> usize {
358 128_000
359 }
360 fn supports_tools(&self) -> bool {
361 true
362 }
363 fn cost_per_token(&self) -> (f64, f64) {
364 (0.0, 0.0)
365 }
366 fn model_name(&self) -> &str {
367 &self.model
368 }
369 }
370
371 #[allow(dead_code)]
373 struct FailNThenSucceedProvider {
374 model: String,
375 failures_remaining: std::sync::Mutex<usize>,
376 }
377
378 impl FailNThenSucceedProvider {
379 #[allow(dead_code)]
380 fn new(model: &str, failures: usize) -> Self {
381 Self {
382 model: model.to_string(),
383 failures_remaining: std::sync::Mutex::new(failures),
384 }
385 }
386 }
387
388 #[async_trait]
389 impl LlmProvider for FailNThenSucceedProvider {
390 async fn complete(
391 &self,
392 _request: CompletionRequest,
393 ) -> Result<CompletionResponse, LlmError> {
394 let mut remaining = self.failures_remaining.lock().unwrap();
395 if *remaining > 0 {
396 *remaining -= 1;
397 Err(LlmError::Connection {
398 message: "temporary failure".into(),
399 })
400 } else {
401 Ok(MockLlmProvider::text_response("recovered"))
402 }
403 }
404
405 async fn complete_streaming(
406 &self,
407 _request: CompletionRequest,
408 _tx: mpsc::Sender<StreamEvent>,
409 ) -> Result<(), LlmError> {
410 Ok(())
411 }
412
413 fn estimate_tokens(&self, _messages: &[Message]) -> usize {
414 100
415 }
416 fn context_window(&self) -> usize {
417 128_000
418 }
419 fn supports_tools(&self) -> bool {
420 true
421 }
422 fn cost_per_token(&self) -> (f64, f64) {
423 (0.0, 0.0)
424 }
425 fn model_name(&self) -> &str {
426 &self.model
427 }
428 }
429
430 #[test]
433 fn test_circuit_breaker_starts_closed() {
434 let cb = CircuitBreaker::new(3, Duration::from_secs(60));
435 assert_eq!(cb.state(), CircuitState::Closed);
436 }
437
438 #[test]
439 fn test_circuit_breaker_opens_after_threshold() {
440 let mut cb = CircuitBreaker::new(3, Duration::from_secs(60));
441 cb.record_failure();
442 cb.record_failure();
443 assert_eq!(cb.state(), CircuitState::Closed); cb.record_failure();
445 assert!(matches!(cb.state(), CircuitState::Open { .. }));
446 }
447
448 #[test]
449 fn test_circuit_breaker_blocks_calls_when_open() {
450 let mut cb = CircuitBreaker::new(2, Duration::from_secs(600));
451 cb.record_failure();
452 cb.record_failure();
453 assert!(!cb.is_call_permitted());
454 }
455
456 #[test]
457 fn test_circuit_breaker_half_open_after_timeout() {
458 let mut cb = CircuitBreaker::new(1, Duration::from_millis(1));
459 cb.record_failure();
460 assert!(matches!(cb.state(), CircuitState::Open { .. }));
461
462 std::thread::sleep(Duration::from_millis(5));
464 assert!(cb.is_call_permitted()); assert_eq!(cb.state(), CircuitState::HalfOpen);
466 }
467
468 #[test]
469 fn test_circuit_breaker_closes_on_success_in_half_open() {
470 let mut cb = CircuitBreaker::new(1, Duration::from_millis(1));
471 cb.record_failure();
472 std::thread::sleep(Duration::from_millis(5));
473 cb.is_call_permitted(); cb.record_success();
475 assert_eq!(cb.state(), CircuitState::Closed);
476 assert_eq!(cb.failure_count, 0);
477 }
478
479 #[test]
480 fn test_circuit_breaker_success_resets_count() {
481 let mut cb = CircuitBreaker::new(3, Duration::from_secs(60));
482 cb.record_failure();
483 cb.record_failure();
484 cb.record_success();
485 assert_eq!(cb.failure_count, 0);
486 assert_eq!(cb.state(), CircuitState::Closed);
487 }
488
489 #[test]
492 fn test_auth_profile_initially_available() {
493 let profile = AuthProfile::new("TEST_KEY");
494 assert!(profile.is_available());
495 }
496
497 #[test]
498 fn test_auth_profile_cooldown() {
499 let mut profile =
500 AuthProfile::new("TEST_KEY").with_cooldown_duration(Duration::from_millis(10));
501 profile.trigger_cooldown();
502 assert!(!profile.is_available());
503
504 std::thread::sleep(Duration::from_millis(15));
505 assert!(profile.is_available());
506 }
507
508 #[tokio::test]
511 async fn test_failover_primary_succeeds() {
512 let primary = Arc::new(MockLlmProvider::new());
513 primary.queue_response(MockLlmProvider::text_response("primary response"));
514
515 let fallback = Arc::new(MockLlmProvider::new());
516 fallback.queue_response(MockLlmProvider::text_response("fallback response"));
517
518 let provider = FailoverProvider::new(vec![primary, fallback], 3, Duration::from_secs(60));
519
520 let response = provider
521 .complete(CompletionRequest::default())
522 .await
523 .unwrap();
524 assert_eq!(response.message.content.as_text(), Some("primary response"));
525 }
526
527 #[tokio::test]
528 async fn test_failover_to_secondary() {
529 let primary: Arc<dyn LlmProvider> =
530 Arc::new(AlwaysFailProvider::new("primary", "connection"));
531 let fallback = Arc::new(MockLlmProvider::new());
532 fallback.queue_response(MockLlmProvider::text_response("fallback response"));
533 let fallback: Arc<dyn LlmProvider> = fallback;
534
535 let provider = FailoverProvider::new(vec![primary, fallback], 3, Duration::from_secs(60));
536
537 let response = provider
538 .complete(CompletionRequest::default())
539 .await
540 .unwrap();
541 assert_eq!(
542 response.message.content.as_text(),
543 Some("fallback response")
544 );
545 }
546
547 #[tokio::test]
548 async fn test_all_providers_fail() {
549 let p1: Arc<dyn LlmProvider> = Arc::new(AlwaysFailProvider::new("p1", "connection"));
550 let p2: Arc<dyn LlmProvider> = Arc::new(AlwaysFailProvider::new("p2", "timeout"));
551
552 let provider = FailoverProvider::new(vec![p1, p2], 3, Duration::from_secs(60));
553
554 let result = provider.complete(CompletionRequest::default()).await;
555 assert!(result.is_err());
556 }
557
558 #[tokio::test]
559 async fn test_circuit_breaker_opens_and_skips_provider() {
560 let primary: Arc<dyn LlmProvider> =
562 Arc::new(AlwaysFailProvider::new("primary", "connection"));
563 let fallback = Arc::new(MockLlmProvider::new());
564 for _ in 0..5 {
566 fallback.queue_response(MockLlmProvider::text_response("fallback"));
567 }
568 let fallback: Arc<dyn LlmProvider> = fallback;
569
570 let provider = FailoverProvider::new(
571 vec![primary, fallback],
572 1, Duration::from_secs(600), );
575
576 let r1 = provider
578 .complete(CompletionRequest::default())
579 .await
580 .unwrap();
581 assert_eq!(r1.message.content.as_text(), Some("fallback"));
582
583 let r2 = provider
585 .complete(CompletionRequest::default())
586 .await
587 .unwrap();
588 assert_eq!(r2.message.content.as_text(), Some("fallback"));
589 }
590
591 #[tokio::test]
592 async fn test_failover_provider_delegates_properties() {
593 let primary = Arc::new(MockLlmProvider::new());
594 let provider = FailoverProvider::new(
595 vec![primary as Arc<dyn LlmProvider>],
596 3,
597 Duration::from_secs(60),
598 );
599
600 assert_eq!(provider.model_name(), "mock-model");
601 assert_eq!(provider.context_window(), 128_000);
602 assert!(provider.supports_tools());
603 assert_eq!(provider.cost_per_token(), (0.0, 0.0));
604 }
605
606 #[tokio::test]
607 async fn test_failover_streaming() {
608 let primary: Arc<dyn LlmProvider> =
609 Arc::new(AlwaysFailProvider::new("primary", "connection"));
610 let fallback = Arc::new(MockLlmProvider::new());
611 fallback.queue_response(MockLlmProvider::text_response("streamed"));
612 let fallback: Arc<dyn LlmProvider> = fallback;
613
614 let provider = FailoverProvider::new(vec![primary, fallback], 3, Duration::from_secs(60));
615
616 let (tx, mut rx) = mpsc::channel(32);
617 provider
618 .complete_streaming(CompletionRequest::default(), tx)
619 .await
620 .unwrap();
621
622 let mut tokens = Vec::new();
623 while let Some(event) = rx.recv().await {
624 match event {
625 StreamEvent::Token(t) => tokens.push(t),
626 StreamEvent::Done { .. } => break,
627 _ => {}
628 }
629 }
630 assert!(!tokens.is_empty());
631 }
632}