Skip to main content

rust_web_server/circuit_breaker/
mod.rs

1//! Circuit breaker state machine and retry middleware.
2//!
3//! # Circuit breaker
4//!
5//! [`CircuitBreaker`] tracks per-backend failure counts and transitions through
6//! three states:
7//!
8//! * **Closed** — the backend is healthy; failures are counted.  When the count
9//!   reaches `failure_threshold` the breaker moves to **Open**.
10//! * **Open** — the backend is considered unhealthy; all requests are rejected
11//!   immediately (no TCP connection is attempted).  After `recovery` seconds the
12//!   breaker moves to **HalfOpen**.
13//! * **HalfOpen** — one probe request is let through.  On success the breaker
14//!   closes; on failure it re-opens and the recovery timer resets.
15//!
16//! # Retry middleware
17//!
18//! [`RetryLayer`] wraps any [`Application`] and re-dispatches the request when
19//! the inner app returns one of the configured status codes (default: 502, 503,
20//! 504) up to `max_retries` additional times.
21//!
22//! # Example
23//!
24//! ```rust,no_run
25//! use rust_web_server::app::App;
26//! use rust_web_server::core::New;
27//! use rust_web_server::circuit_breaker::RetryLayer;
28//! use rust_web_server::middleware::WithMiddleware;
29//!
30//! let app = WithMiddleware::new(App::new())
31//!     .wrap(RetryLayer::new().max_retries(2));
32//! ```
33
34#[cfg(test)]
35mod tests;
36
37use std::collections::HashMap;
38use std::sync::{Mutex, OnceLock};
39use std::time::{Duration, Instant};
40
41use crate::application::Application;
42use crate::middleware::Middleware;
43use crate::request::Request;
44use crate::response::Response;
45use crate::server::ConnectionInfo;
46
47// ── BreakerState ─────────────────────────────────────────────────────────────
48
49/// Current state of a single backend's circuit breaker.
50#[derive(Debug, Clone, PartialEq)]
51pub enum BreakerState {
52    /// Healthy — requests are forwarded and failures are counted.
53    Closed,
54    /// Unhealthy — requests are rejected until the recovery window expires.
55    Open,
56    /// Probing — one request is let through to test backend health.
57    HalfOpen,
58}
59
60// ── BackendEntry ──────────────────────────────────────────────────────────────
61
62struct BackendEntry {
63    state: BreakerState,
64    failures: u32,
65    opened_at: Option<Instant>,
66}
67
68impl BackendEntry {
69    fn new() -> Self {
70        Self { state: BreakerState::Closed, failures: 0, opened_at: None }
71    }
72}
73
74// ── CircuitBreaker ────────────────────────────────────────────────────────────
75
76/// Per-backend circuit breaker.
77///
78/// # Concurrency
79///
80/// `CircuitBreaker` is not `Sync` on its own — wrap it in a [`Mutex`] for
81/// shared use across threads (see [`global()`]).
82pub struct CircuitBreaker {
83    backends: HashMap<String, BackendEntry>,
84    failure_threshold: u32,
85    recovery: Duration,
86}
87
88impl CircuitBreaker {
89    /// Create a new circuit breaker.
90    ///
91    /// * `failure_threshold` — how many consecutive failures open the circuit.
92    /// * `recovery_secs` — how long the circuit stays Open before testing again.
93    pub fn new(failure_threshold: u32, recovery_secs: u64) -> Self {
94        Self {
95            backends: HashMap::new(),
96            failure_threshold,
97            recovery: Duration::from_secs(recovery_secs),
98        }
99    }
100
101    /// Returns `true` if a request should be forwarded to `backend`.
102    ///
103    /// Transitions `Open → HalfOpen` when the recovery window has elapsed.
104    pub fn is_available(&mut self, backend: &str) -> bool {
105        let entry = self.backends.entry(backend.to_string()).or_insert_with(BackendEntry::new);
106        match entry.state {
107            BreakerState::Closed => true,
108            BreakerState::HalfOpen => true,
109            BreakerState::Open => {
110                if let Some(opened_at) = entry.opened_at {
111                    if opened_at.elapsed() >= self.recovery {
112                        entry.state = BreakerState::HalfOpen;
113                        entry.opened_at = None;
114                        return true;
115                    }
116                }
117                false
118            }
119        }
120    }
121
122    /// Record a successful response for `backend`.
123    ///
124    /// Transitions `HalfOpen → Closed` and resets the failure counter.
125    pub fn record_success(&mut self, backend: &str) {
126        let entry = self.backends.entry(backend.to_string()).or_insert_with(BackendEntry::new);
127        entry.state = BreakerState::Closed;
128        entry.failures = 0;
129        entry.opened_at = None;
130    }
131
132    /// Record a failed response for `backend`.
133    ///
134    /// In `Closed` state, increments the counter and opens the circuit when
135    /// `failure_threshold` is reached.  In `HalfOpen` state, immediately
136    /// re-opens the circuit and resets the recovery timer.
137    pub fn record_failure(&mut self, backend: &str) {
138        let threshold = self.failure_threshold;
139        let entry = self.backends.entry(backend.to_string()).or_insert_with(BackendEntry::new);
140        match entry.state {
141            BreakerState::Closed => {
142                entry.failures += 1;
143                if entry.failures >= threshold {
144                    entry.state = BreakerState::Open;
145                    entry.opened_at = Some(Instant::now());
146                }
147            }
148            BreakerState::HalfOpen => {
149                entry.state = BreakerState::Open;
150                entry.opened_at = Some(Instant::now());
151            }
152            BreakerState::Open => {
153                // Already open; refresh the timer.
154                entry.opened_at = Some(Instant::now());
155            }
156        }
157    }
158
159    /// Reset `backend` to `Closed` with zero failures.
160    pub fn reset(&mut self, backend: &str) {
161        let entry = self.backends.entry(backend.to_string()).or_insert_with(BackendEntry::new);
162        entry.state = BreakerState::Closed;
163        entry.failures = 0;
164        entry.opened_at = None;
165    }
166
167    /// Return the current state for `backend` (defaults to `Closed` if unseen).
168    pub fn state(&self, backend: &str) -> BreakerState {
169        self.backends
170            .get(backend)
171            .map(|e| e.state.clone())
172            .unwrap_or(BreakerState::Closed)
173    }
174}
175
176// ── global() ─────────────────────────────────────────────────────────────────
177
178static GLOBAL_BREAKER: OnceLock<Mutex<CircuitBreaker>> = OnceLock::new();
179
180/// Return the process-wide default circuit breaker (threshold=5, recovery=30 s).
181///
182/// Acquire the mutex before calling any `CircuitBreaker` method:
183///
184/// ```rust
185/// use rust_web_server::circuit_breaker;
186///
187/// let available = circuit_breaker::global().lock().unwrap().is_available("backend-a:8080");
188/// ```
189pub fn global() -> &'static Mutex<CircuitBreaker> {
190    GLOBAL_BREAKER.get_or_init(|| Mutex::new(CircuitBreaker::new(5, 30)))
191}
192
193// ── RetryLayer ────────────────────────────────────────────────────────────────
194
195/// Retry middleware.
196///
197/// When the inner application returns a response whose status code is in the
198/// configured list, the request is re-dispatched up to `max_retries` additional
199/// times.  If all attempts return a retryable status the last response is
200/// returned as-is.
201pub struct RetryLayer {
202    max_retries: u32,
203    retry_on: Vec<i16>,
204}
205
206impl RetryLayer {
207    /// Create a `RetryLayer` with defaults: retry on 502, 503, 504 up to 3 times.
208    pub fn new() -> Self {
209        Self { max_retries: 3, retry_on: vec![502, 503, 504] }
210    }
211
212    /// Override the maximum number of retry attempts.
213    pub fn max_retries(mut self, n: u32) -> Self {
214        self.max_retries = n;
215        self
216    }
217
218    /// Override the set of status codes that trigger a retry.
219    pub fn retry_on(mut self, codes: Vec<i16>) -> Self {
220        self.retry_on = codes;
221        self
222    }
223}
224
225impl Default for RetryLayer {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231impl Middleware for RetryLayer {
232    fn handle(
233        &self,
234        request: &Request,
235        connection: &ConnectionInfo,
236        next: &dyn Application,
237    ) -> Result<Response, String> {
238        let mut response = next.execute(request, connection)?;
239        let mut attempts = 0u32;
240        while attempts < self.max_retries && self.retry_on.contains(&response.status_code) {
241            response = next.execute(request, connection)?;
242            attempts += 1;
243        }
244        Ok(response)
245    }
246}