rustkernel_core/resilience/
circuit_breaker.rs

1//! Circuit Breaker Pattern
2//!
3//! Prevents cascade failures by detecting and isolating unhealthy kernels.
4//!
5//! # States
6//!
7//! - **Closed**: Normal operation, requests pass through
8//! - **Open**: Failures exceeded threshold, requests fail fast
9//! - **HalfOpen**: Testing if service has recovered
10//!
11//! # Example
12//!
13//! ```rust,ignore
14//! use rustkernel_core::resilience::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
15//!
16//! let config = CircuitBreakerConfig::default()
17//!     .failure_threshold(5)
18//!     .reset_timeout(Duration::from_secs(30));
19//!
20//! let cb = CircuitBreaker::new("graph/pagerank", config);
21//!
22//! match cb.execute(|| async { /* kernel execution */ }).await {
23//!     Ok(result) => println!("Success: {:?}", result),
24//!     Err(ResilienceError::CircuitOpen { .. }) => println!("Circuit is open"),
25//!     Err(e) => println!("Error: {:?}", e),
26//! }
27//! ```
28
29use super::{ResilienceError, ResilienceResult};
30use serde::{Deserialize, Serialize};
31use std::sync::Arc;
32use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
33use std::time::{Duration, Instant};
34use tokio::sync::RwLock;
35
36/// Circuit breaker state
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum CircuitState {
40    /// Circuit is closed, requests pass through
41    #[default]
42    Closed,
43    /// Circuit is open, requests fail fast
44    Open,
45    /// Circuit is half-open, testing recovery
46    HalfOpen,
47}
48
49impl std::fmt::Display for CircuitState {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            Self::Closed => write!(f, "closed"),
53            Self::Open => write!(f, "open"),
54            Self::HalfOpen => write!(f, "half-open"),
55        }
56    }
57}
58
59/// Circuit breaker configuration
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct CircuitBreakerConfig {
62    /// Number of failures before opening circuit
63    pub failure_threshold: u32,
64    /// Number of successes to close circuit from half-open
65    pub success_threshold: u32,
66    /// Time to wait before transitioning from open to half-open
67    pub reset_timeout: Duration,
68    /// Sliding window size for tracking failures
69    pub window_size: Duration,
70    /// Maximum concurrent requests in half-open state
71    pub half_open_max_requests: u32,
72}
73
74impl Default for CircuitBreakerConfig {
75    fn default() -> Self {
76        Self {
77            failure_threshold: 5,
78            success_threshold: 2,
79            reset_timeout: Duration::from_secs(30),
80            window_size: Duration::from_secs(60),
81            half_open_max_requests: 3,
82        }
83    }
84}
85
86impl CircuitBreakerConfig {
87    /// Production configuration with conservative settings
88    pub fn production() -> Self {
89        Self {
90            failure_threshold: 5,
91            success_threshold: 3,
92            reset_timeout: Duration::from_secs(60),
93            window_size: Duration::from_secs(120),
94            half_open_max_requests: 5,
95        }
96    }
97
98    /// Set failure threshold
99    pub fn failure_threshold(mut self, threshold: u32) -> Self {
100        self.failure_threshold = threshold;
101        self
102    }
103
104    /// Set success threshold for closing
105    pub fn success_threshold(mut self, threshold: u32) -> Self {
106        self.success_threshold = threshold;
107        self
108    }
109
110    /// Set reset timeout
111    pub fn reset_timeout(mut self, timeout: Duration) -> Self {
112        self.reset_timeout = timeout;
113        self
114    }
115
116    /// Set sliding window size
117    pub fn window_size(mut self, size: Duration) -> Self {
118        self.window_size = size;
119        self
120    }
121
122    /// Set max requests in half-open state
123    pub fn half_open_max_requests(mut self, max: u32) -> Self {
124        self.half_open_max_requests = max;
125        self
126    }
127}
128
129/// Circuit breaker for a kernel
130pub struct CircuitBreaker {
131    /// Kernel ID this circuit breaker protects
132    kernel_id: String,
133    /// Configuration
134    config: CircuitBreakerConfig,
135    /// Inner state
136    inner: Arc<CircuitBreakerInner>,
137}
138
139struct CircuitBreakerInner {
140    state: RwLock<CircuitState>,
141    failure_count: AtomicU32,
142    success_count: AtomicU32,
143    last_failure_time: RwLock<Option<Instant>>,
144    half_open_requests: AtomicU32,
145    total_requests: AtomicU64,
146    total_failures: AtomicU64,
147}
148
149impl CircuitBreaker {
150    /// Create a new circuit breaker
151    pub fn new(kernel_id: impl Into<String>, config: CircuitBreakerConfig) -> Self {
152        Self {
153            kernel_id: kernel_id.into(),
154            config,
155            inner: Arc::new(CircuitBreakerInner {
156                state: RwLock::new(CircuitState::Closed),
157                failure_count: AtomicU32::new(0),
158                success_count: AtomicU32::new(0),
159                last_failure_time: RwLock::new(None),
160                half_open_requests: AtomicU32::new(0),
161                total_requests: AtomicU64::new(0),
162                total_failures: AtomicU64::new(0),
163            }),
164        }
165    }
166
167    /// Get current state
168    pub async fn state(&self) -> CircuitState {
169        let state = *self.inner.state.read().await;
170
171        // Check if we should transition from Open to HalfOpen
172        if state == CircuitState::Open {
173            if let Some(last_failure) = *self.inner.last_failure_time.read().await {
174                if last_failure.elapsed() >= self.config.reset_timeout {
175                    return self.try_transition_to_half_open().await;
176                }
177            }
178        }
179
180        state
181    }
182
183    /// Get the kernel ID
184    pub fn kernel_id(&self) -> &str {
185        &self.kernel_id
186    }
187
188    /// Check if requests are allowed
189    pub async fn is_allowed(&self) -> bool {
190        match self.state().await {
191            CircuitState::Closed => true,
192            CircuitState::Open => false,
193            CircuitState::HalfOpen => {
194                self.inner.half_open_requests.load(Ordering::Relaxed)
195                    < self.config.half_open_max_requests
196            }
197        }
198    }
199
200    /// Execute a function with circuit breaker protection
201    pub async fn execute<F, Fut, T, E>(&self, f: F) -> ResilienceResult<T>
202    where
203        F: FnOnce() -> Fut,
204        Fut: std::future::Future<Output = Result<T, E>>,
205        E: Into<crate::error::KernelError>,
206    {
207        self.inner.total_requests.fetch_add(1, Ordering::Relaxed);
208
209        // Check if circuit allows the request
210        let state = self.state().await;
211        match state {
212            CircuitState::Open => {
213                return Err(ResilienceError::CircuitOpen {
214                    kernel_id: self.kernel_id.clone(),
215                });
216            }
217            CircuitState::HalfOpen => {
218                // Limit concurrent requests in half-open state
219                let current = self
220                    .inner
221                    .half_open_requests
222                    .fetch_add(1, Ordering::Relaxed);
223                if current >= self.config.half_open_max_requests {
224                    self.inner
225                        .half_open_requests
226                        .fetch_sub(1, Ordering::Relaxed);
227                    return Err(ResilienceError::CircuitOpen {
228                        kernel_id: self.kernel_id.clone(),
229                    });
230                }
231            }
232            CircuitState::Closed => {}
233        }
234
235        // Execute the function
236        let result = f().await;
237
238        // Record the result
239        match &result {
240            Ok(_) => self.record_success().await,
241            Err(_) => self.record_failure().await,
242        }
243
244        // If we were in half-open, decrement the counter
245        if state == CircuitState::HalfOpen {
246            self.inner
247                .half_open_requests
248                .fetch_sub(1, Ordering::Relaxed);
249        }
250
251        result.map_err(|e| ResilienceError::KernelError(e.into()))
252    }
253
254    /// Manually record a success
255    pub async fn record_success(&self) {
256        let state = *self.inner.state.read().await;
257
258        match state {
259            CircuitState::Closed => {
260                // Reset failure count on success
261                self.inner.failure_count.store(0, Ordering::Relaxed);
262            }
263            CircuitState::HalfOpen => {
264                let successes = self.inner.success_count.fetch_add(1, Ordering::Relaxed) + 1;
265                if successes >= self.config.success_threshold {
266                    self.transition_to_closed().await;
267                }
268            }
269            CircuitState::Open => {}
270        }
271    }
272
273    /// Manually record a failure
274    pub async fn record_failure(&self) {
275        self.inner.total_failures.fetch_add(1, Ordering::Relaxed);
276        *self.inner.last_failure_time.write().await = Some(Instant::now());
277
278        let state = *self.inner.state.read().await;
279
280        match state {
281            CircuitState::Closed => {
282                let failures = self.inner.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
283                if failures >= self.config.failure_threshold {
284                    self.transition_to_open().await;
285                }
286            }
287            CircuitState::HalfOpen => {
288                // Any failure in half-open goes back to open
289                self.transition_to_open().await;
290            }
291            CircuitState::Open => {}
292        }
293    }
294
295    /// Manually reset the circuit breaker
296    pub async fn reset(&self) {
297        *self.inner.state.write().await = CircuitState::Closed;
298        self.inner.failure_count.store(0, Ordering::Relaxed);
299        self.inner.success_count.store(0, Ordering::Relaxed);
300        self.inner.half_open_requests.store(0, Ordering::Relaxed);
301        *self.inner.last_failure_time.write().await = None;
302    }
303
304    /// Get statistics
305    pub fn stats(&self) -> CircuitBreakerStats {
306        CircuitBreakerStats {
307            total_requests: self.inner.total_requests.load(Ordering::Relaxed),
308            total_failures: self.inner.total_failures.load(Ordering::Relaxed),
309            current_failures: self.inner.failure_count.load(Ordering::Relaxed),
310        }
311    }
312
313    // Private transition methods
314
315    async fn transition_to_open(&self) {
316        *self.inner.state.write().await = CircuitState::Open;
317        self.inner.success_count.store(0, Ordering::Relaxed);
318        tracing::warn!(
319            kernel_id = %self.kernel_id,
320            "Circuit breaker opened"
321        );
322    }
323
324    async fn transition_to_closed(&self) {
325        *self.inner.state.write().await = CircuitState::Closed;
326        self.inner.failure_count.store(0, Ordering::Relaxed);
327        self.inner.success_count.store(0, Ordering::Relaxed);
328        tracing::info!(
329            kernel_id = %self.kernel_id,
330            "Circuit breaker closed"
331        );
332    }
333
334    async fn try_transition_to_half_open(&self) -> CircuitState {
335        let mut state = self.inner.state.write().await;
336        if *state == CircuitState::Open {
337            *state = CircuitState::HalfOpen;
338            self.inner.success_count.store(0, Ordering::Relaxed);
339            self.inner.half_open_requests.store(0, Ordering::Relaxed);
340            tracing::info!(
341                kernel_id = %self.kernel_id,
342                "Circuit breaker half-open"
343            );
344        }
345        *state
346    }
347}
348
349impl Clone for CircuitBreaker {
350    fn clone(&self) -> Self {
351        Self {
352            kernel_id: self.kernel_id.clone(),
353            config: self.config.clone(),
354            inner: self.inner.clone(),
355        }
356    }
357}
358
359/// Circuit breaker statistics
360#[derive(Debug, Clone)]
361pub struct CircuitBreakerStats {
362    /// Total requests through this breaker
363    pub total_requests: u64,
364    /// Total failures recorded
365    pub total_failures: u64,
366    /// Current failure count in window
367    pub current_failures: u32,
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[tokio::test]
375    async fn test_circuit_breaker_starts_closed() {
376        let cb = CircuitBreaker::new("test", CircuitBreakerConfig::default());
377        assert_eq!(cb.state().await, CircuitState::Closed);
378        assert!(cb.is_allowed().await);
379    }
380
381    #[tokio::test]
382    async fn test_circuit_opens_after_failures() {
383        let config = CircuitBreakerConfig::default().failure_threshold(3);
384        let cb = CircuitBreaker::new("test", config);
385
386        // Record failures
387        for _ in 0..3 {
388            cb.record_failure().await;
389        }
390
391        assert_eq!(cb.state().await, CircuitState::Open);
392        assert!(!cb.is_allowed().await);
393    }
394
395    #[tokio::test]
396    async fn test_circuit_resets_on_success() {
397        let config = CircuitBreakerConfig::default().failure_threshold(3);
398        let cb = CircuitBreaker::new("test", config);
399
400        // Record some failures
401        cb.record_failure().await;
402        cb.record_failure().await;
403
404        // Success should reset
405        cb.record_success().await;
406
407        assert_eq!(cb.inner.failure_count.load(Ordering::Relaxed), 0);
408    }
409
410    #[tokio::test]
411    async fn test_manual_reset() {
412        let config = CircuitBreakerConfig::default().failure_threshold(3);
413        let cb = CircuitBreaker::new("test", config);
414
415        // Open the circuit
416        for _ in 0..3 {
417            cb.record_failure().await;
418        }
419        assert_eq!(cb.state().await, CircuitState::Open);
420
421        // Manual reset
422        cb.reset().await;
423        assert_eq!(cb.state().await, CircuitState::Closed);
424    }
425
426    #[test]
427    fn test_config_builder() {
428        let config = CircuitBreakerConfig::default()
429            .failure_threshold(10)
430            .reset_timeout(Duration::from_secs(60))
431            .success_threshold(5);
432
433        assert_eq!(config.failure_threshold, 10);
434        assert_eq!(config.reset_timeout, Duration::from_secs(60));
435        assert_eq!(config.success_threshold, 5);
436    }
437}