1use std::sync::atomic::{AtomicU32, Ordering};
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6use tracing::{debug, info, warn};
7
8pub struct CircuitBreaker {
10 state: AtomicU32, failure_count: AtomicU32,
12 success_count: AtomicU32,
13 config: CircuitBreakerConfig,
14 last_failure_time: RwLock<Option<Instant>>,
15 last_state_change: RwLock<Instant>,
16}
17
18#[derive(Debug, Clone, Copy)]
20pub struct CircuitBreakerConfig {
21 pub failure_threshold: u32,
23 pub success_threshold: u32,
25 pub reset_timeout: Duration,
27 pub half_open_max_requests: u32,
29}
30
31impl Default for CircuitBreakerConfig {
32 fn default() -> Self {
33 Self {
34 failure_threshold: 5,
35 success_threshold: 3,
36 reset_timeout: Duration::from_secs(30),
37 half_open_max_requests: 3,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum CircuitState {
45 Closed, Open, HalfOpen, }
49
50#[derive(Debug, Clone)]
52pub enum CircuitBreakerError<E> {
53 CircuitOpen,
55 OperationFailed(E),
57}
58
59impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::CircuitOpen => write!(f, "Circuit breaker is open"),
63 Self::OperationFailed(e) => write!(f, "Operation failed: {}", e),
64 }
65 }
66}
67
68impl<E: std::fmt::Debug + std::fmt::Display> std::error::Error for CircuitBreakerError<E> {}
69
70impl CircuitBreaker {
71 pub fn new(config: CircuitBreakerConfig) -> Self {
73 Self {
74 state: AtomicU32::new(0),
75 failure_count: AtomicU32::new(0),
76 success_count: AtomicU32::new(0),
77 config,
78 last_failure_time: RwLock::new(None),
79 last_state_change: RwLock::new(Instant::now()),
80 }
81 }
82
83 pub fn current_state(&self) -> CircuitState {
85 match self.state.load(Ordering::Relaxed) {
86 0 => CircuitState::Closed,
87 1 => CircuitState::Open,
88 2 => CircuitState::HalfOpen,
89 _ => CircuitState::Closed,
90 }
91 }
92
93 pub async fn should_attempt_reset(&self) -> bool {
95 if self.current_state() != CircuitState::Open {
96 return false;
97 }
98
99 let last_change = *self.last_state_change.read().await;
100 last_change.elapsed() >= self.config.reset_timeout
101 }
102
103 pub async fn call<F, Fut, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError<E>>
105 where
106 F: FnOnce() -> Fut,
107 Fut: std::future::Future<Output = Result<T, E>>,
108 {
109 match self.current_state() {
111 CircuitState::Open => {
112 if self.should_attempt_reset().await {
113 self.transition_to(CircuitState::HalfOpen).await;
114 } else {
115 warn!("Circuit breaker open, rejecting request");
116 return Err(CircuitBreakerError::CircuitOpen);
117 }
118 }
119 CircuitState::HalfOpen => {
120 let requests = self.success_count.load(Ordering::Relaxed)
121 + self.failure_count.load(Ordering::Relaxed);
122 if requests >= self.config.half_open_max_requests {
123 warn!("Half-open max requests reached");
124 return Err(CircuitBreakerError::CircuitOpen);
125 }
126 }
127 CircuitState::Closed => {}
128 }
129
130 match operation().await {
132 Ok(result) => {
133 self.on_success().await;
134 Ok(result)
135 }
136 Err(e) => {
137 self.on_failure().await;
138 Err(CircuitBreakerError::OperationFailed(e))
139 }
140 }
141 }
142
143 async fn on_success(&self) {
145 let success_count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
146 debug!(success_count = success_count, "Operation succeeded");
147
148 if self.current_state() == CircuitState::HalfOpen {
149 if success_count >= self.config.success_threshold {
150 info!("Circuit breaker closing after successful recovery");
151 self.transition_to(CircuitState::Closed).await;
152 }
153 } else {
154 self.failure_count.store(0, Ordering::SeqCst);
156 }
157 }
158
159 async fn on_failure(&self) {
161 let failure_count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
162 *self.last_failure_time.write().await = Some(Instant::now());
163
164 warn!(failure_count = failure_count, "Operation failed");
165
166 if self.current_state() == CircuitState::HalfOpen {
167 info!("Failure in half-open, reopening circuit");
169 self.transition_to(CircuitState::Open).await;
170 } else if failure_count >= self.config.failure_threshold {
171 info!("Failure threshold reached, opening circuit");
172 self.transition_to(CircuitState::Open).await;
173 }
174 }
175
176 async fn transition_to(&self, new_state: CircuitState) {
178 let state_num = match new_state {
179 CircuitState::Closed => 0,
180 CircuitState::Open => 1,
181 CircuitState::HalfOpen => 2,
182 };
183
184 let old_state = self.state.swap(state_num, Ordering::SeqCst);
185 *self.last_state_change.write().await = Instant::now();
186
187 self.failure_count.store(0, Ordering::SeqCst);
189 self.success_count.store(0, Ordering::SeqCst);
190
191 info!(
192 old_state = ?match old_state {
193 0 => CircuitState::Closed,
194 1 => CircuitState::Open,
195 2 => CircuitState::HalfOpen,
196 _ => CircuitState::Closed,
197 },
198 new_state = ?new_state,
199 "Circuit breaker state changed"
200 );
201 }
202
203 pub fn metrics(&self) -> CircuitBreakerMetrics {
205 CircuitBreakerMetrics {
206 state: self.current_state(),
207 failure_count: self.failure_count.load(Ordering::Relaxed),
208 success_count: self.success_count.load(Ordering::Relaxed),
209 }
210 }
211}
212
213#[derive(Debug, Clone)]
215pub struct CircuitBreakerMetrics {
216 pub state: CircuitState,
217 pub failure_count: u32,
218 pub success_count: u32,
219}
220
221impl Default for CircuitBreaker {
222 fn default() -> Self {
223 Self::new(CircuitBreakerConfig::default())
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use std::time::Duration;
231
232 fn fast_config() -> CircuitBreakerConfig {
233 CircuitBreakerConfig {
234 failure_threshold: 3,
235 success_threshold: 2,
236 reset_timeout: Duration::from_millis(50),
237 half_open_max_requests: 2,
238 }
239 }
240
241 #[test]
242 fn test_initial_state_is_closed() {
243 let cb = CircuitBreaker::default();
244 assert_eq!(cb.current_state(), CircuitState::Closed);
245 }
246
247 #[test]
248 fn test_default_config_values() {
249 let config = CircuitBreakerConfig::default();
250 assert_eq!(config.failure_threshold, 5);
251 assert_eq!(config.success_threshold, 3);
252 assert_eq!(config.reset_timeout, Duration::from_secs(30));
253 assert_eq!(config.half_open_max_requests, 3);
254 }
255
256 #[test]
257 fn test_initial_metrics_are_zero() {
258 let cb = CircuitBreaker::default();
259 let metrics = cb.metrics();
260 assert_eq!(metrics.state, CircuitState::Closed);
261 assert_eq!(metrics.failure_count, 0);
262 assert_eq!(metrics.success_count, 0);
263 }
264
265 #[tokio::test]
266 async fn test_success_keeps_circuit_closed() {
267 let cb = CircuitBreaker::new(fast_config());
268
269 let result: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(42) }).await;
270
271 assert!(result.is_ok());
272 assert_eq!(result.unwrap(), 42);
273 assert_eq!(cb.current_state(), CircuitState::Closed);
274 }
275
276 #[tokio::test]
277 async fn test_failures_below_threshold_stay_closed() {
278 let cb = CircuitBreaker::new(fast_config());
279
280 for _ in 0..2 {
282 let _: Result<i32, _> = cb
283 .call(|| async { Err::<i32, String>("fail".into()) })
284 .await;
285 }
286
287 assert_eq!(cb.current_state(), CircuitState::Closed);
288 assert_eq!(cb.metrics().failure_count, 2);
289 }
290
291 #[tokio::test]
292 async fn test_transition_to_open_after_failure_threshold() {
293 let cb = CircuitBreaker::new(fast_config());
294
295 for _ in 0..3 {
297 let _: Result<i32, _> = cb
298 .call(|| async { Err::<i32, String>("fail".into()) })
299 .await;
300 }
301
302 assert_eq!(cb.current_state(), CircuitState::Open);
303 }
304
305 #[tokio::test]
306 async fn test_open_circuit_rejects_requests() {
307 let cb = CircuitBreaker::new(fast_config());
308
309 for _ in 0..3 {
311 let _: Result<i32, _> = cb
312 .call(|| async { Err::<i32, String>("fail".into()) })
313 .await;
314 }
315 assert_eq!(cb.current_state(), CircuitState::Open);
316
317 let result: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(42) }).await;
319
320 assert!(matches!(result, Err(CircuitBreakerError::CircuitOpen)));
321 }
322
323 #[tokio::test]
324 async fn test_half_open_after_reset_timeout() {
325 let cb = CircuitBreaker::new(fast_config());
326
327 for _ in 0..3 {
329 let _: Result<i32, _> = cb
330 .call(|| async { Err::<i32, String>("fail".into()) })
331 .await;
332 }
333 assert_eq!(cb.current_state(), CircuitState::Open);
334
335 tokio::time::sleep(Duration::from_millis(60)).await;
337
338 assert!(cb.should_attempt_reset().await);
340
341 let result: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(1) }).await;
343 assert!(result.is_ok());
344 assert_eq!(cb.current_state(), CircuitState::HalfOpen);
345 }
346
347 #[tokio::test]
348 async fn test_half_open_to_closed_after_success_threshold() {
349 let cb = CircuitBreaker::new(fast_config());
350
351 for _ in 0..3 {
353 let _: Result<i32, _> = cb
354 .call(|| async { Err::<i32, String>("fail".into()) })
355 .await;
356 }
357
358 tokio::time::sleep(Duration::from_millis(60)).await;
360
361 for _ in 0..2 {
363 let result: Result<i32, CircuitBreakerError<String>> =
364 cb.call(|| async { Ok(1) }).await;
365 assert!(result.is_ok());
366 }
367
368 assert_eq!(cb.current_state(), CircuitState::Closed);
369 }
370
371 #[tokio::test]
372 async fn test_half_open_failure_reopens_circuit() {
373 let cb = CircuitBreaker::new(fast_config());
374
375 for _ in 0..3 {
377 let _: Result<i32, _> = cb
378 .call(|| async { Err::<i32, String>("fail".into()) })
379 .await;
380 }
381
382 tokio::time::sleep(Duration::from_millis(60)).await;
384
385 let _: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(1) }).await;
387 assert_eq!(cb.current_state(), CircuitState::HalfOpen);
388
389 let _: Result<i32, _> = cb
391 .call(|| async { Err::<i32, String>("fail again".into()) })
392 .await;
393 assert_eq!(cb.current_state(), CircuitState::Open);
394 }
395
396 #[tokio::test]
397 async fn test_should_attempt_reset_false_when_closed() {
398 let cb = CircuitBreaker::default();
399 assert!(!cb.should_attempt_reset().await);
400 }
401
402 #[tokio::test]
403 async fn test_should_attempt_reset_false_before_timeout() {
404 let config = CircuitBreakerConfig {
405 failure_threshold: 1,
406 reset_timeout: Duration::from_secs(60),
407 ..CircuitBreakerConfig::default()
408 };
409 let cb = CircuitBreaker::new(config);
410
411 let _: Result<i32, _> = cb
413 .call(|| async { Err::<i32, String>("fail".into()) })
414 .await;
415 assert_eq!(cb.current_state(), CircuitState::Open);
416
417 assert!(!cb.should_attempt_reset().await);
419 }
420
421 #[tokio::test]
422 async fn test_success_resets_failure_count_in_closed() {
423 let cb = CircuitBreaker::new(fast_config());
424
425 for _ in 0..2 {
427 let _: Result<i32, _> = cb
428 .call(|| async { Err::<i32, String>("fail".into()) })
429 .await;
430 }
431 assert_eq!(cb.metrics().failure_count, 2);
432
433 let _: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(1) }).await;
435 assert_eq!(cb.metrics().failure_count, 0);
436 }
437
438 #[tokio::test]
439 async fn test_metrics_track_successes() {
440 let cb = CircuitBreaker::new(fast_config());
441
442 for _ in 0..4 {
443 let _: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(1) }).await;
444 }
445
446 let metrics = cb.metrics();
447 assert_eq!(metrics.state, CircuitState::Closed);
448 assert_eq!(metrics.success_count, 4);
451 }
452
453 #[test]
454 fn test_circuit_breaker_error_display() {
455 let open_err: CircuitBreakerError<String> = CircuitBreakerError::CircuitOpen;
456 assert_eq!(format!("{}", open_err), "Circuit breaker is open");
457
458 let op_err: CircuitBreakerError<String> =
459 CircuitBreakerError::OperationFailed("db timeout".into());
460 assert_eq!(format!("{}", op_err), "Operation failed: db timeout");
461 }
462}