Skip to main content

selfware/supervision/
circuit_breaker.rs

1//! Circuit breaker pattern for fault tolerance
2
3use std::sync::atomic::{AtomicU32, Ordering};
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6use tracing::{debug, info, warn};
7
8/// Circuit breaker for protecting against cascading failures
9pub struct CircuitBreaker {
10    state: AtomicU32, // 0=Closed, 1=Open, 2=HalfOpen
11    failure_count: AtomicU32,
12    success_count: AtomicU32,
13    config: CircuitBreakerConfig,
14    last_failure_time: RwLock<Option<Instant>>,
15    last_state_change: RwLock<Instant>,
16}
17
18/// Circuit breaker configuration
19#[derive(Debug, Clone, Copy)]
20pub struct CircuitBreakerConfig {
21    /// Number of failures before opening circuit
22    pub failure_threshold: u32,
23    /// Number of successes in half-open to close circuit
24    pub success_threshold: u32,
25    /// Time before attempting half-open
26    pub reset_timeout: Duration,
27    /// Half-open max requests
28    pub half_open_max_requests: u32,
29}
30
31impl Default for CircuitBreakerConfig {
32    fn default() -> Self {
33        Self {
34            failure_threshold: 5,
35            success_threshold: 3,
36            reset_timeout: Duration::from_secs(30),
37            half_open_max_requests: 3,
38        }
39    }
40}
41
42/// Circuit state
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum CircuitState {
45    Closed,   // Normal operation
46    Open,     // Failing, rejecting requests
47    HalfOpen, // Testing if service recovered
48}
49
50/// Circuit breaker error
51#[derive(Debug, Clone)]
52pub enum CircuitBreakerError<E> {
53    /// Circuit is open
54    CircuitOpen,
55    /// Operation failed
56    OperationFailed(E),
57}
58
59impl<E: std::fmt::Display> std::fmt::Display for CircuitBreakerError<E> {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Self::CircuitOpen => write!(f, "Circuit breaker is open"),
63            Self::OperationFailed(e) => write!(f, "Operation failed: {}", e),
64        }
65    }
66}
67
68impl<E: std::fmt::Debug + std::fmt::Display> std::error::Error for CircuitBreakerError<E> {}
69
70impl CircuitBreaker {
71    /// Create a new circuit breaker
72    pub fn new(config: CircuitBreakerConfig) -> Self {
73        Self {
74            state: AtomicU32::new(0),
75            failure_count: AtomicU32::new(0),
76            success_count: AtomicU32::new(0),
77            config,
78            last_failure_time: RwLock::new(None),
79            last_state_change: RwLock::new(Instant::now()),
80        }
81    }
82
83    /// Get current circuit state
84    pub fn current_state(&self) -> CircuitState {
85        match self.state.load(Ordering::Relaxed) {
86            0 => CircuitState::Closed,
87            1 => CircuitState::Open,
88            2 => CircuitState::HalfOpen,
89            _ => CircuitState::Closed,
90        }
91    }
92
93    /// Check if we should attempt reset
94    pub async fn should_attempt_reset(&self) -> bool {
95        if self.current_state() != CircuitState::Open {
96            return false;
97        }
98
99        let last_change = *self.last_state_change.read().await;
100        last_change.elapsed() >= self.config.reset_timeout
101    }
102
103    /// Execute operation with circuit breaker protection
104    pub async fn call<F, Fut, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError<E>>
105    where
106        F: FnOnce() -> Fut,
107        Fut: std::future::Future<Output = Result<T, E>>,
108    {
109        // Check current state
110        match self.current_state() {
111            CircuitState::Open => {
112                if self.should_attempt_reset().await {
113                    self.transition_to(CircuitState::HalfOpen).await;
114                } else {
115                    warn!("Circuit breaker open, rejecting request");
116                    return Err(CircuitBreakerError::CircuitOpen);
117                }
118            }
119            CircuitState::HalfOpen => {
120                let requests = self.success_count.load(Ordering::Relaxed)
121                    + self.failure_count.load(Ordering::Relaxed);
122                if requests >= self.config.half_open_max_requests {
123                    warn!("Half-open max requests reached");
124                    return Err(CircuitBreakerError::CircuitOpen);
125                }
126            }
127            CircuitState::Closed => {}
128        }
129
130        // Execute operation
131        match operation().await {
132            Ok(result) => {
133                self.on_success().await;
134                Ok(result)
135            }
136            Err(e) => {
137                self.on_failure().await;
138                Err(CircuitBreakerError::OperationFailed(e))
139            }
140        }
141    }
142
143    /// Handle successful operation
144    async fn on_success(&self) {
145        let success_count = self.success_count.fetch_add(1, Ordering::SeqCst) + 1;
146        debug!(success_count = success_count, "Operation succeeded");
147
148        if self.current_state() == CircuitState::HalfOpen {
149            if success_count >= self.config.success_threshold {
150                info!("Circuit breaker closing after successful recovery");
151                self.transition_to(CircuitState::Closed).await;
152            }
153        } else {
154            // Reset failure count in closed state
155            self.failure_count.store(0, Ordering::SeqCst);
156        }
157    }
158
159    /// Handle failed operation
160    async fn on_failure(&self) {
161        let failure_count = self.failure_count.fetch_add(1, Ordering::SeqCst) + 1;
162        *self.last_failure_time.write().await = Some(Instant::now());
163
164        warn!(failure_count = failure_count, "Operation failed");
165
166        if self.current_state() == CircuitState::HalfOpen {
167            // Any failure in half-open goes back to open
168            info!("Failure in half-open, reopening circuit");
169            self.transition_to(CircuitState::Open).await;
170        } else if failure_count >= self.config.failure_threshold {
171            info!("Failure threshold reached, opening circuit");
172            self.transition_to(CircuitState::Open).await;
173        }
174    }
175
176    /// Transition to a new state
177    async fn transition_to(&self, new_state: CircuitState) {
178        let state_num = match new_state {
179            CircuitState::Closed => 0,
180            CircuitState::Open => 1,
181            CircuitState::HalfOpen => 2,
182        };
183
184        let old_state = self.state.swap(state_num, Ordering::SeqCst);
185        *self.last_state_change.write().await = Instant::now();
186
187        // Reset counters on state change
188        self.failure_count.store(0, Ordering::SeqCst);
189        self.success_count.store(0, Ordering::SeqCst);
190
191        info!(
192            old_state = ?match old_state {
193                0 => CircuitState::Closed,
194                1 => CircuitState::Open,
195                2 => CircuitState::HalfOpen,
196                _ => CircuitState::Closed,
197            },
198            new_state = ?new_state,
199            "Circuit breaker state changed"
200        );
201    }
202
203    /// Get metrics
204    pub fn metrics(&self) -> CircuitBreakerMetrics {
205        CircuitBreakerMetrics {
206            state: self.current_state(),
207            failure_count: self.failure_count.load(Ordering::Relaxed),
208            success_count: self.success_count.load(Ordering::Relaxed),
209        }
210    }
211}
212
213/// Circuit breaker metrics
214#[derive(Debug, Clone)]
215pub struct CircuitBreakerMetrics {
216    pub state: CircuitState,
217    pub failure_count: u32,
218    pub success_count: u32,
219}
220
221impl Default for CircuitBreaker {
222    fn default() -> Self {
223        Self::new(CircuitBreakerConfig::default())
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use std::time::Duration;
231
232    fn fast_config() -> CircuitBreakerConfig {
233        CircuitBreakerConfig {
234            failure_threshold: 3,
235            success_threshold: 2,
236            reset_timeout: Duration::from_millis(50),
237            half_open_max_requests: 2,
238        }
239    }
240
241    #[test]
242    fn test_initial_state_is_closed() {
243        let cb = CircuitBreaker::default();
244        assert_eq!(cb.current_state(), CircuitState::Closed);
245    }
246
247    #[test]
248    fn test_default_config_values() {
249        let config = CircuitBreakerConfig::default();
250        assert_eq!(config.failure_threshold, 5);
251        assert_eq!(config.success_threshold, 3);
252        assert_eq!(config.reset_timeout, Duration::from_secs(30));
253        assert_eq!(config.half_open_max_requests, 3);
254    }
255
256    #[test]
257    fn test_initial_metrics_are_zero() {
258        let cb = CircuitBreaker::default();
259        let metrics = cb.metrics();
260        assert_eq!(metrics.state, CircuitState::Closed);
261        assert_eq!(metrics.failure_count, 0);
262        assert_eq!(metrics.success_count, 0);
263    }
264
265    #[tokio::test]
266    async fn test_success_keeps_circuit_closed() {
267        let cb = CircuitBreaker::new(fast_config());
268
269        let result: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(42) }).await;
270
271        assert!(result.is_ok());
272        assert_eq!(result.unwrap(), 42);
273        assert_eq!(cb.current_state(), CircuitState::Closed);
274    }
275
276    #[tokio::test]
277    async fn test_failures_below_threshold_stay_closed() {
278        let cb = CircuitBreaker::new(fast_config());
279
280        // Cause 2 failures (threshold is 3)
281        for _ in 0..2 {
282            let _: Result<i32, _> = cb
283                .call(|| async { Err::<i32, String>("fail".into()) })
284                .await;
285        }
286
287        assert_eq!(cb.current_state(), CircuitState::Closed);
288        assert_eq!(cb.metrics().failure_count, 2);
289    }
290
291    #[tokio::test]
292    async fn test_transition_to_open_after_failure_threshold() {
293        let cb = CircuitBreaker::new(fast_config());
294
295        // Cause 3 failures (threshold is 3)
296        for _ in 0..3 {
297            let _: Result<i32, _> = cb
298                .call(|| async { Err::<i32, String>("fail".into()) })
299                .await;
300        }
301
302        assert_eq!(cb.current_state(), CircuitState::Open);
303    }
304
305    #[tokio::test]
306    async fn test_open_circuit_rejects_requests() {
307        let cb = CircuitBreaker::new(fast_config());
308
309        // Trip the breaker
310        for _ in 0..3 {
311            let _: Result<i32, _> = cb
312                .call(|| async { Err::<i32, String>("fail".into()) })
313                .await;
314        }
315        assert_eq!(cb.current_state(), CircuitState::Open);
316
317        // Next call should be rejected immediately
318        let result: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(42) }).await;
319
320        assert!(matches!(result, Err(CircuitBreakerError::CircuitOpen)));
321    }
322
323    #[tokio::test]
324    async fn test_half_open_after_reset_timeout() {
325        let cb = CircuitBreaker::new(fast_config());
326
327        // Trip the breaker
328        for _ in 0..3 {
329            let _: Result<i32, _> = cb
330                .call(|| async { Err::<i32, String>("fail".into()) })
331                .await;
332        }
333        assert_eq!(cb.current_state(), CircuitState::Open);
334
335        // Wait for the reset timeout
336        tokio::time::sleep(Duration::from_millis(60)).await;
337
338        // should_attempt_reset should be true
339        assert!(cb.should_attempt_reset().await);
340
341        // Next call should transition to half-open and succeed
342        let result: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(1) }).await;
343        assert!(result.is_ok());
344        assert_eq!(cb.current_state(), CircuitState::HalfOpen);
345    }
346
347    #[tokio::test]
348    async fn test_half_open_to_closed_after_success_threshold() {
349        let cb = CircuitBreaker::new(fast_config());
350
351        // Trip the breaker
352        for _ in 0..3 {
353            let _: Result<i32, _> = cb
354                .call(|| async { Err::<i32, String>("fail".into()) })
355                .await;
356        }
357
358        // Wait for reset timeout
359        tokio::time::sleep(Duration::from_millis(60)).await;
360
361        // Succeed enough times (success_threshold = 2)
362        for _ in 0..2 {
363            let result: Result<i32, CircuitBreakerError<String>> =
364                cb.call(|| async { Ok(1) }).await;
365            assert!(result.is_ok());
366        }
367
368        assert_eq!(cb.current_state(), CircuitState::Closed);
369    }
370
371    #[tokio::test]
372    async fn test_half_open_failure_reopens_circuit() {
373        let cb = CircuitBreaker::new(fast_config());
374
375        // Trip the breaker
376        for _ in 0..3 {
377            let _: Result<i32, _> = cb
378                .call(|| async { Err::<i32, String>("fail".into()) })
379                .await;
380        }
381
382        // Wait for reset timeout
383        tokio::time::sleep(Duration::from_millis(60)).await;
384
385        // One success to get into half-open
386        let _: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(1) }).await;
387        assert_eq!(cb.current_state(), CircuitState::HalfOpen);
388
389        // Fail in half-open -> back to open
390        let _: Result<i32, _> = cb
391            .call(|| async { Err::<i32, String>("fail again".into()) })
392            .await;
393        assert_eq!(cb.current_state(), CircuitState::Open);
394    }
395
396    #[tokio::test]
397    async fn test_should_attempt_reset_false_when_closed() {
398        let cb = CircuitBreaker::default();
399        assert!(!cb.should_attempt_reset().await);
400    }
401
402    #[tokio::test]
403    async fn test_should_attempt_reset_false_before_timeout() {
404        let config = CircuitBreakerConfig {
405            failure_threshold: 1,
406            reset_timeout: Duration::from_secs(60),
407            ..CircuitBreakerConfig::default()
408        };
409        let cb = CircuitBreaker::new(config);
410
411        // Trip breaker
412        let _: Result<i32, _> = cb
413            .call(|| async { Err::<i32, String>("fail".into()) })
414            .await;
415        assert_eq!(cb.current_state(), CircuitState::Open);
416
417        // Should not reset yet (timeout is 60s)
418        assert!(!cb.should_attempt_reset().await);
419    }
420
421    #[tokio::test]
422    async fn test_success_resets_failure_count_in_closed() {
423        let cb = CircuitBreaker::new(fast_config());
424
425        // Cause 2 failures (below threshold of 3)
426        for _ in 0..2 {
427            let _: Result<i32, _> = cb
428                .call(|| async { Err::<i32, String>("fail".into()) })
429                .await;
430        }
431        assert_eq!(cb.metrics().failure_count, 2);
432
433        // A success should reset failure count
434        let _: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(1) }).await;
435        assert_eq!(cb.metrics().failure_count, 0);
436    }
437
438    #[tokio::test]
439    async fn test_metrics_track_successes() {
440        let cb = CircuitBreaker::new(fast_config());
441
442        for _ in 0..4 {
443            let _: Result<i32, CircuitBreakerError<String>> = cb.call(|| async { Ok(1) }).await;
444        }
445
446        let metrics = cb.metrics();
447        assert_eq!(metrics.state, CircuitState::Closed);
448        // success_count resets to 0 on each success in closed state because
449        // failure_count is reset, but success_count still increments
450        assert_eq!(metrics.success_count, 4);
451    }
452
453    #[test]
454    fn test_circuit_breaker_error_display() {
455        let open_err: CircuitBreakerError<String> = CircuitBreakerError::CircuitOpen;
456        assert_eq!(format!("{}", open_err), "Circuit breaker is open");
457
458        let op_err: CircuitBreakerError<String> =
459            CircuitBreakerError::OperationFailed("db timeout".into());
460        assert_eq!(format!("{}", op_err), "Operation failed: db timeout");
461    }
462}