Skip to main content

simple_agents_router/
circuit_breaker.rs

1//! Circuit breaker implementation for provider resilience.
2//!
3//! Tracks provider failures and opens the circuit after a threshold,
4//! then allows half-open probes after a cooldown.
5
6use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10/// Circuit breaker configuration.
11#[derive(Debug, Clone, Copy)]
12pub struct CircuitBreakerConfig {
13    /// Number of consecutive failures before opening the circuit.
14    pub failure_threshold: u32,
15    /// Cooldown before moving from open to half-open.
16    pub open_duration: Duration,
17    /// Number of consecutive successes to close from half-open.
18    pub success_threshold: u32,
19}
20
21impl Default for CircuitBreakerConfig {
22    fn default() -> Self {
23        Self {
24            failure_threshold: 3,
25            open_duration: Duration::from_secs(10),
26            success_threshold: 1,
27        }
28    }
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32enum InternalState {
33    Closed,
34    Open { opened_at: Instant },
35    HalfOpen,
36}
37
38/// Public circuit breaker state.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum CircuitBreakerState {
41    /// Circuit is closed and requests flow normally.
42    Closed,
43    /// Circuit is open and requests should be rejected.
44    Open,
45    /// Circuit is half-open and probing for recovery.
46    HalfOpen,
47}
48
49#[derive(Debug)]
50struct CircuitBreakerInner {
51    state: InternalState,
52    consecutive_failures: u32,
53    consecutive_successes: u32,
54}
55
56impl CircuitBreakerInner {
57    fn new() -> Self {
58        Self {
59            state: InternalState::Closed,
60            consecutive_failures: 0,
61            consecutive_successes: 0,
62        }
63    }
64}
65
66/// Circuit breaker for a provider.
67#[derive(Debug)]
68pub struct CircuitBreaker {
69    config: CircuitBreakerConfig,
70    inner: Mutex<CircuitBreakerInner>,
71}
72
73impl CircuitBreaker {
74    /// Create a new circuit breaker.
75    pub fn new(config: CircuitBreakerConfig) -> Self {
76        Self {
77            config,
78            inner: Mutex::new(CircuitBreakerInner::new()),
79        }
80    }
81
82    /// Check if the circuit allows a request.
83    pub fn allow_request(&self) -> bool {
84        let mut inner = self.inner.lock().expect("circuit breaker lock poisoned");
85        match inner.state {
86            InternalState::Closed => true,
87            InternalState::Open { opened_at } => {
88                if opened_at.elapsed() >= self.config.open_duration {
89                    inner.state = InternalState::HalfOpen;
90                    inner.consecutive_successes = 0;
91                    inner.consecutive_failures = 0;
92                    true
93                } else {
94                    false
95                }
96            }
97            InternalState::HalfOpen => true,
98        }
99    }
100
101    /// Record a successful request.
102    pub fn record_success(&self) {
103        let mut inner = self.inner.lock().expect("circuit breaker lock poisoned");
104        match inner.state {
105            InternalState::Closed => {
106                inner.consecutive_failures = 0;
107            }
108            InternalState::HalfOpen => {
109                inner.consecutive_successes = inner.consecutive_successes.saturating_add(1);
110                if inner.consecutive_successes >= self.config.success_threshold {
111                    inner.state = InternalState::Closed;
112                    inner.consecutive_failures = 0;
113                    inner.consecutive_successes = 0;
114                }
115            }
116            InternalState::Open { .. } => {}
117        }
118    }
119
120    /// Record a failed request.
121    pub fn record_failure(&self) {
122        let mut inner = self.inner.lock().expect("circuit breaker lock poisoned");
123        match inner.state {
124            InternalState::Closed => {
125                inner.consecutive_failures = inner.consecutive_failures.saturating_add(1);
126                if inner.consecutive_failures >= self.config.failure_threshold {
127                    inner.state = InternalState::Open {
128                        opened_at: Instant::now(),
129                    };
130                }
131            }
132            InternalState::HalfOpen => {
133                inner.state = InternalState::Open {
134                    opened_at: Instant::now(),
135                };
136                inner.consecutive_failures = 1;
137                inner.consecutive_successes = 0;
138            }
139            InternalState::Open { .. } => {}
140        }
141    }
142
143    /// Record a result to update circuit state.
144    pub fn record_result(&self, result: &std::result::Result<(), SimpleAgentsError>) {
145        match result {
146            Ok(_) => self.record_success(),
147            Err(error) => {
148                if matches!(
149                    error,
150                    SimpleAgentsError::Provider(
151                        ProviderError::RateLimit { .. }
152                            | ProviderError::Timeout(_)
153                            | ProviderError::ServerError(_)
154                    ) | SimpleAgentsError::Network(_)
155                ) {
156                    self.record_failure();
157                }
158            }
159        }
160    }
161
162    /// Current circuit state.
163    pub fn state(&self) -> CircuitBreakerState {
164        let inner = self.inner.lock().expect("circuit breaker lock poisoned");
165        match inner.state {
166            InternalState::Closed => CircuitBreakerState::Closed,
167            InternalState::Open { .. } => CircuitBreakerState::Open,
168            InternalState::HalfOpen => CircuitBreakerState::HalfOpen,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn opens_after_failures() {
179        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
180            failure_threshold: 2,
181            open_duration: Duration::from_secs(10),
182            success_threshold: 1,
183        });
184
185        assert!(breaker.allow_request());
186        breaker.record_failure();
187        assert!(breaker.allow_request());
188        breaker.record_failure();
189        assert_eq!(breaker.state(), CircuitBreakerState::Open);
190        assert!(!breaker.allow_request());
191    }
192
193    #[test]
194    fn closes_after_success_in_half_open() {
195        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
196            failure_threshold: 1,
197            open_duration: Duration::from_millis(0),
198            success_threshold: 1,
199        });
200
201        breaker.record_failure();
202        assert_eq!(breaker.state(), CircuitBreakerState::Open);
203        assert!(breaker.allow_request());
204        assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
205        breaker.record_success();
206        assert_eq!(breaker.state(), CircuitBreakerState::Closed);
207    }
208
209    #[test]
210    fn reopens_on_failure_in_half_open() {
211        let breaker = CircuitBreaker::new(CircuitBreakerConfig {
212            failure_threshold: 1,
213            open_duration: Duration::from_millis(0),
214            success_threshold: 2,
215        });
216
217        breaker.record_failure();
218        assert_eq!(breaker.state(), CircuitBreakerState::Open);
219        assert!(breaker.allow_request());
220        assert_eq!(breaker.state(), CircuitBreakerState::HalfOpen);
221        breaker.record_failure();
222        assert_eq!(breaker.state(), CircuitBreakerState::Open);
223    }
224}