Skip to main content

simple_agents_router/
circuit_breaker.rs

1//! Circuit breaker implementation for provider resilience.
2//!
3//! Tracks provider failures and opens the circuit after a threshold,
4//! then allows half-open probes after a cooldown.
5
6use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10/// Circuit breaker configuration.
11#[derive(Debug, Clone, Copy)]
12pub struct CircuitBreakerConfig {
13    /// Number of consecutive failures before opening the circuit.
14    pub failure_threshold: u32,
15    /// Cooldown before moving from open to half-open.
16    pub open_duration: Duration,
17    /// Number of consecutive successes to close from half-open.
18    pub success_threshold: u32,
19}
20
21impl Default for CircuitBreakerConfig {
22    fn default() -> Self {
23        Self {
24            failure_threshold: 3,
25            open_duration: Duration::from_secs(10),
26            success_threshold: 1,
27        }
28    }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32enum InternalState {
33    Closed,
34    Open { opened_at: Instant },
35    HalfOpen,
36}
37
38/// Public circuit breaker state.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum CircuitBreakerState {
41    /// Circuit is closed and requests flow normally.
42    Closed,
43    /// Circuit is open and requests should be rejected.
44    Open,
45    /// Circuit is half-open and probing for recovery.
46    HalfOpen,
47}
48
49#[derive(Debug)]
50struct CircuitBreakerInner {
51    state: InternalState,
52    consecutive_failures: u32,
53    consecutive_successes: u32,
54}
55
56impl CircuitBreakerInner {
57    fn new() -> Self {
58        Self {
59            state: InternalState::Closed,
60            consecutive_failures: 0,
61            consecutive_successes: 0,
62        }
63    }
64}
65
66/// Circuit breaker for a provider.
67#[derive(Debug)]
68pub struct CircuitBreaker {
69    config: CircuitBreakerConfig,
70    inner: Mutex<CircuitBreakerInner>,
71}
72
73impl CircuitBreaker {
74    /// Create a new circuit breaker.
75    pub fn new(config: CircuitBreakerConfig) -> Self {
76        Self {
77            config,
78            inner: Mutex::new(CircuitBreakerInner::new()),
79        }
80    }
81
82    /// Check if the circuit allows a request.
83    pub fn allow_request(&self) -> bool {
84        let mut inner = self
85            .inner
86            .lock()
87            .unwrap_or_else(|poisoned| poisoned.into_inner());
88        match inner.state {
89            InternalState::Closed => true,
90            InternalState::Open { opened_at } => {
91                if opened_at.elapsed() >= self.config.open_duration {
92                    inner.state = InternalState::HalfOpen;
93                    inner.consecutive_successes = 0;
94                    inner.consecutive_failures = 0;
95                    true
96                } else {
97                    false
98                }
99            }
100            InternalState::HalfOpen => true,
101        }
102    }
103
104    /// Record a successful request.
105    pub fn record_success(&self) {
106        let mut inner = self
107            .inner
108            .lock()
109            .unwrap_or_else(|poisoned| poisoned.into_inner());
110        match inner.state {
111            InternalState::Closed => {
112                inner.consecutive_failures = 0;
113            }
114            InternalState::HalfOpen => {
115                inner.consecutive_successes = inner.consecutive_successes.saturating_add(1);
116                if inner.consecutive_successes >= self.config.success_threshold {
117                    inner.state = InternalState::Closed;
118                    inner.consecutive_failures = 0;
119                    inner.consecutive_successes = 0;
120                }
121            }
122            InternalState::Open { .. } => {}
123        }
124    }
125
126    /// Record a failed request.
127    pub fn record_failure(&self) {
128        let mut inner = self
129            .inner
130            .lock()
131            .unwrap_or_else(|poisoned| poisoned.into_inner());
132        match inner.state {
133            InternalState::Closed => {
134                inner.consecutive_failures = inner.consecutive_failures.saturating_add(1);
135                if inner.consecutive_failures >= self.config.failure_threshold {
136                    inner.state = InternalState::Open {
137                        opened_at: Instant::now(),
138                    };
139                }
140            }
141            InternalState::HalfOpen => {
142                inner.state = InternalState::Open {
143                    opened_at: Instant::now(),
144                };
145                inner.consecutive_failures = 1;
146                inner.consecutive_successes = 0;
147            }
148            InternalState::Open { .. } => {}
149        }
150    }
151
152    /// Record a result to update circuit state.
153    pub fn record_result(&self, result: &std::result::Result<(), SimpleAgentsError>) {
154        match result {
155            Ok(_) => self.record_success(),
156            Err(error) => {
157                if matches!(
158                    error,
159                    SimpleAgentsError::Provider(
160                        ProviderError::RateLimit { .. }
161                            | ProviderError::Timeout(_)
162                            | ProviderError::ServerError(_)
163                    ) | SimpleAgentsError::Network(_)
164                ) {
165                    self.record_failure();
166                }
167            }
168        }
169    }
170
171    /// Current circuit state.
172    pub fn state(&self) -> CircuitBreakerState {
173        let inner = self
174            .inner
175            .lock()
176            .unwrap_or_else(|poisoned| poisoned.into_inner());
177        match inner.state {
178            InternalState::Closed => CircuitBreakerState::Closed,
179            InternalState::Open { .. } => CircuitBreakerState::Open,
180            InternalState::HalfOpen => CircuitBreakerState::HalfOpen,
181        }
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    #[test]
190    fn opens_after_failures() {
191        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
192            failure_threshold: 2,
193            open_duration: Duration::from_secs(10),
194            success_threshold: 1,
195        });
196
197        assert!(breaker.allow_request());
198        breaker.record_failure();
199        assert!(breaker.allow_request());
200        breaker.record_failure();
201        assert_eq!(breaker.state(), CircuitBreakerState::Open);
202        assert!(!breaker.allow_request());
203    }
204
205    #[test]
206    fn closes_after_success_in_half_open() {
207        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
208            failure_threshold: 1,
209            open_duration: Duration::from_millis(0),
210            success_threshold: 1,
211        });
212
213        breaker.record_failure();
214        assert_eq!(breaker.state(), CircuitBreakerState::Open);
215        assert!(breaker.allow_request());
216        assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
217        breaker.record_success();
218        assert_eq!(breaker.state(), CircuitBreakerState::Closed);
219    }
220
221    #[test]
222    fn reopens_on_failure_in_half_open() {
223        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
224            failure_threshold: 1,
225            open_duration: Duration::from_millis(0),
226            success_threshold: 2,
227        });
228
229        breaker.record_failure();
230        assert_eq!(breaker.state(), CircuitBreakerState::Open);
231        assert!(breaker.allow_request());
232        assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
233        breaker.record_failure();
234        assert_eq!(breaker.state(), CircuitBreakerState::Open);
235    }
236}