vex_api/
circuit_breaker.rs

1//! Circuit breaker for resilient service calls
2
3use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6
7/// Circuit breaker state
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum CircuitState {
10    /// Normal operation - requests pass through
11    Closed,
12    /// Circuit tripped - requests fail immediately
13    Open,
14    /// Testing recovery - limited requests allowed
15    HalfOpen,
16}
17
18/// Circuit breaker configuration
19#[derive(Debug, Clone)]
20pub struct CircuitConfig {
21    /// Failure threshold to trip circuit
22    pub failure_threshold: u32,
23    /// Success threshold in half-open to close
24    pub success_threshold: u32,
25    /// Failure threshold in half-open before re-opening (allows a few test failures)
26    pub half_open_failure_threshold: u32,
27    /// Time to wait in open state before testing
28    pub reset_timeout: Duration,
29    /// Rolling window for failures
30    pub window_duration: Duration,
31}
32
33impl Default for CircuitConfig {
34    fn default() -> Self {
35        Self {
36            failure_threshold: 5,
37            success_threshold: 3,
38            half_open_failure_threshold: 2, // Allow 1 test failure before re-opening
39            reset_timeout: Duration::from_secs(30),
40            window_duration: Duration::from_secs(60),
41        }
42    }
43}
44
45impl CircuitConfig {
46    /// Conservative settings for critical services
47    pub fn conservative() -> Self {
48        Self {
49            failure_threshold: 3,
50            success_threshold: 5,
51            half_open_failure_threshold: 1, // Strict: re-open on first failure
52            reset_timeout: Duration::from_secs(60),
53            window_duration: Duration::from_secs(120),
54        }
55    }
56
57    /// Aggressive settings for non-critical services
58    pub fn aggressive() -> Self {
59        Self {
60            failure_threshold: 10,
61            success_threshold: 2,
62            half_open_failure_threshold: 3, // Allow several test failures
63            reset_timeout: Duration::from_secs(10),
64            window_duration: Duration::from_secs(30),
65        }
66    }
67}
68
69/// Thread-safe circuit breaker
70#[derive(Debug)]
71pub struct CircuitBreaker {
72    name: String,
73    config: CircuitConfig,
74    state: RwLock<CircuitState>,
75    failure_count: AtomicU32,
76    success_count: AtomicU32,
77    half_open_failure_count: AtomicU32,
78    last_failure_time: RwLock<Option<Instant>>,
79    last_state_change: RwLock<Instant>,
80    total_requests: AtomicU64,
81    total_failures: AtomicU64,
82    total_rejections: AtomicU64,
83}
84
85impl CircuitBreaker {
86    /// Create a new circuit breaker
87    pub fn new(name: &str, config: CircuitConfig) -> Self {
88        Self {
89            name: name.to_string(),
90            config,
91            state: RwLock::new(CircuitState::Closed),
92            failure_count: AtomicU32::new(0),
93            success_count: AtomicU32::new(0),
94            half_open_failure_count: AtomicU32::new(0),
95            last_failure_time: RwLock::new(None),
96            last_state_change: RwLock::new(Instant::now()),
97            total_requests: AtomicU64::new(0),
98            total_failures: AtomicU64::new(0),
99            total_rejections: AtomicU64::new(0),
100        }
101    }
102
103    /// Check if request is allowed
104    pub async fn allow(&self) -> bool {
105        self.total_requests.fetch_add(1, Ordering::Relaxed);
106
107        let mut state = self.state.write().await;
108
109        match *state {
110            CircuitState::Closed => true,
111            CircuitState::Open => {
112                // Check if reset timeout has passed
113                let last_change = *self.last_state_change.read().await;
114                if last_change.elapsed() >= self.config.reset_timeout {
115                    *state = CircuitState::HalfOpen;
116                    *self.last_state_change.write().await = Instant::now();
117                    self.success_count.store(0, Ordering::Relaxed);
118                    self.half_open_failure_count.store(0, Ordering::Relaxed);
119                    tracing::info!(
120                        circuit = %self.name,
121                        "Circuit transitioned to HalfOpen"
122                    );
123                    true
124                } else {
125                    self.total_rejections.fetch_add(1, Ordering::Relaxed);
126                    false
127                }
128            }
129            CircuitState::HalfOpen => {
130                // Allow limited requests for testing
131                true
132            }
133        }
134    }
135
136    /// Record a successful call
137    pub async fn record_success(&self) {
138        let mut state = self.state.write().await;
139
140        match *state {
141            CircuitState::HalfOpen => {
142                let count = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
143                if count >= self.config.success_threshold {
144                    *state = CircuitState::Closed;
145                    self.failure_count.store(0, Ordering::Relaxed);
146                    self.success_count.store(0, Ordering::Relaxed);
147                    *self.last_state_change.write().await = Instant::now();
148                    tracing::info!(
149                        circuit = %self.name,
150                        "Circuit recovered - now Closed"
151                    );
152                }
153            }
154            CircuitState::Closed => {
155                // Reset failure count on success in closed state
156                self.failure_count.store(0, Ordering::Relaxed);
157            }
158            _ => {}
159        }
160    }
161
162    /// Record a failed call
163    pub async fn record_failure(&self) {
164        self.total_failures.fetch_add(1, Ordering::Relaxed);
165        *self.last_failure_time.write().await = Some(Instant::now());
166
167        let mut state = self.state.write().await;
168
169        match *state {
170            CircuitState::Closed => {
171                let count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
172                if count >= self.config.failure_threshold {
173                    *state = CircuitState::Open;
174                    *self.last_state_change.write().await = Instant::now();
175                    tracing::warn!(
176                        circuit = %self.name,
177                        failures = count,
178                        "Circuit tripped - now Open"
179                    );
180                }
181            }
182            CircuitState::HalfOpen => {
183                // Track failures during recovery testing
184                let half_open_failures =
185                    self.half_open_failure_count.fetch_add(1, Ordering::Relaxed) + 1;
186
187                if half_open_failures >= self.config.half_open_failure_threshold {
188                    // Too many failures during recovery, re-open the circuit
189                    *state = CircuitState::Open;
190                    self.success_count.store(0, Ordering::Relaxed);
191                    self.half_open_failure_count.store(0, Ordering::Relaxed);
192                    *self.last_state_change.write().await = Instant::now();
193                    tracing::warn!(
194                        circuit = %self.name,
195                        half_open_failures = half_open_failures,
196                        "Circuit tripped from HalfOpen - back to Open"
197                    );
198                } else {
199                    tracing::debug!(
200                        circuit = %self.name,
201                        half_open_failures = half_open_failures,
202                        threshold = self.config.half_open_failure_threshold,
203                        "HalfOpen failure recorded, still testing"
204                    );
205                }
206            }
207            _ => {}
208        }
209    }
210
211    /// Get current state
212    pub async fn state(&self) -> CircuitState {
213        *self.state.read().await
214    }
215
216    /// Get statistics
217    pub fn stats(&self) -> CircuitStats {
218        CircuitStats {
219            name: self.name.clone(),
220            total_requests: self.total_requests.load(Ordering::Relaxed),
221            total_failures: self.total_failures.load(Ordering::Relaxed),
222            total_rejections: self.total_rejections.load(Ordering::Relaxed),
223            current_failures: self.failure_count.load(Ordering::Relaxed),
224            current_successes: self.success_count.load(Ordering::Relaxed),
225        }
226    }
227
228    /// Execute with circuit breaker protection
229    pub async fn call<F, T, E>(&self, f: F) -> Result<T, CircuitError<E>>
230    where
231        F: std::future::Future<Output = Result<T, E>>,
232    {
233        if !self.allow().await {
234            return Err(CircuitError::Open);
235        }
236
237        match f.await {
238            Ok(result) => {
239                self.record_success().await;
240                Ok(result)
241            }
242            Err(e) => {
243                self.record_failure().await;
244                Err(CircuitError::Failed(e))
245            }
246        }
247    }
248}
249
250/// Circuit breaker statistics
251#[derive(Debug, Clone)]
252pub struct CircuitStats {
253    pub name: String,
254    pub total_requests: u64,
255    pub total_failures: u64,
256    pub total_rejections: u64,
257    pub current_failures: u32,
258    pub current_successes: u32,
259}
260
261/// Circuit breaker error
262#[derive(Debug, thiserror::Error)]
263pub enum CircuitError<E> {
264    #[error("Circuit is open - service unavailable")]
265    Open,
266    #[error("Call failed: {0}")]
267    Failed(#[source] E),
268}
269
270/// Retry with exponential backoff
271pub struct RetryPolicy {
272    pub max_retries: u32,
273    pub initial_delay: Duration,
274    pub max_delay: Duration,
275    pub multiplier: f64,
276}
277
278impl Default for RetryPolicy {
279    fn default() -> Self {
280        Self {
281            max_retries: 3,
282            initial_delay: Duration::from_millis(100),
283            max_delay: Duration::from_secs(10),
284            multiplier: 2.0,
285        }
286    }
287}
288
289impl RetryPolicy {
290    /// Execute with retry and exponential backoff
291    pub async fn execute<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
292    where
293        F: FnMut() -> Fut,
294        Fut: std::future::Future<Output = Result<T, E>>,
295        E: std::fmt::Debug,
296    {
297        let mut delay = self.initial_delay;
298        let mut attempts = 0;
299
300        loop {
301            match f().await {
302                Ok(result) => return Ok(result),
303                Err(e) => {
304                    attempts += 1;
305                    if attempts >= self.max_retries {
306                        tracing::error!(
307                            attempts = attempts,
308                            error = ?e,
309                            "Retry exhausted"
310                        );
311                        return Err(e);
312                    }
313
314                    tracing::warn!(
315                        attempt = attempts,
316                        delay_ms = delay.as_millis(),
317                        error = ?e,
318                        "Retrying after failure"
319                    );
320
321                    // Add jitter (±10%)
322                    let jitter = delay.as_millis() as f64 * 0.1;
323                    let jittered =
324                        delay.as_millis() as f64 + (rand::random::<f64>() * 2.0 - 1.0) * jitter;
325
326                    tokio::time::sleep(Duration::from_millis(jittered as u64)).await;
327
328                    // Exponential backoff
329                    delay =
330                        Duration::from_millis((delay.as_millis() as f64 * self.multiplier) as u64)
331                            .min(self.max_delay);
332                }
333            }
334        }
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[tokio::test]
343    async fn test_circuit_breaker_trips() {
344        let config = CircuitConfig {
345            failure_threshold: 2,
346            success_threshold: 1,
347            half_open_failure_threshold: 1,
348            reset_timeout: Duration::from_millis(100),
349            window_duration: Duration::from_secs(60),
350        };
351        let cb = CircuitBreaker::new("test", config);
352
353        // Should be closed initially
354        assert_eq!(cb.state().await, CircuitState::Closed);
355        assert!(cb.allow().await);
356
357        // Record failures
358        cb.record_failure().await;
359        assert_eq!(cb.state().await, CircuitState::Closed);
360        cb.record_failure().await;
361        assert_eq!(cb.state().await, CircuitState::Open);
362
363        // Should reject in open state
364        assert!(!cb.allow().await);
365
366        // Wait for reset timeout
367        tokio::time::sleep(Duration::from_millis(150)).await;
368
369        // Should transition to half-open
370        assert!(cb.allow().await);
371        assert_eq!(cb.state().await, CircuitState::HalfOpen);
372
373        // Success should close it
374        cb.record_success().await;
375        assert_eq!(cb.state().await, CircuitState::Closed);
376    }
377
378    #[tokio::test]
379    async fn test_retry_policy() {
380        let policy = RetryPolicy {
381            max_retries: 3,
382            initial_delay: Duration::from_millis(10),
383            max_delay: Duration::from_millis(100),
384            multiplier: 2.0,
385        };
386
387        let mut attempts = 0;
388        let result: Result<i32, &str> = policy
389            .execute(|| {
390                attempts += 1;
391                async move {
392                    if attempts < 3 {
393                        Err("failed")
394                    } else {
395                        Ok(42)
396                    }
397                }
398            })
399            .await;
400
401        assert_eq!(result, Ok(42));
402        assert_eq!(attempts, 3);
403    }
404}