warpdrive_proxy/middleware/
circuit_breaker.rs

1//! Circuit breaker middleware
2//!
3//! Implements the circuit breaker pattern to prevent cascading failures.
4//! States: Closed → Open → Half-Open → Closed
5
6use async_trait::async_trait;
7use parking_lot::Mutex;
8use pingora::http::ResponseHeader;
9use pingora::prelude::*;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tracing::{debug, warn};
13
14use super::{Middleware, MiddlewareContext};
15use crate::metrics::{CIRCUIT_BREAKER_STATE, UPSTREAM_FAILURES};
16
17/// Circuit breaker state
18#[derive(Debug, Clone, Copy, PartialEq)]
19enum CircuitState {
20    Closed,
21    Open,
22    HalfOpen,
23}
24
25/// Internal circuit breaker state
26#[derive(Debug)]
27struct CircuitBreakerState {
28    state: CircuitState,
29    failure_count: u32,
30    last_failure_time: Option<Instant>,
31    failure_threshold: u32,
32    timeout: Duration,
33}
34
35impl CircuitBreakerState {
36    fn new(failure_threshold: u32, timeout_secs: u64) -> Self {
37        Self {
38            state: CircuitState::Closed,
39            failure_count: 0,
40            last_failure_time: None,
41            failure_threshold,
42            timeout: Duration::from_secs(timeout_secs),
43        }
44    }
45
46    fn record_success(&mut self) {
47        if self.state == CircuitState::HalfOpen {
48            // Success in half-open state -> close circuit
49            debug!("Circuit breaker: Half-Open → Closed (success)");
50            self.state = CircuitState::Closed;
51            CIRCUIT_BREAKER_STATE.set(0); // 0 = Closed
52        }
53        self.failure_count = 0;
54        self.last_failure_time = None;
55    }
56
57    fn record_failure(&mut self) {
58        self.failure_count += 1;
59        self.last_failure_time = Some(Instant::now());
60
61        if self.failure_count >= self.failure_threshold && self.state == CircuitState::Closed {
62            warn!(
63                "Circuit breaker: Closed → Open (failures: {})",
64                self.failure_count
65            );
66            self.state = CircuitState::Open;
67            CIRCUIT_BREAKER_STATE.set(1); // 1 = Open
68        }
69    }
70
71    fn is_request_allowed(&mut self) -> bool {
72        match self.state {
73            CircuitState::Closed => true,
74            CircuitState::Open => {
75                // Check if timeout has elapsed
76                if let Some(last_failure) = self.last_failure_time {
77                    if last_failure.elapsed() >= self.timeout {
78                        debug!("Circuit breaker: Open → Half-Open (timeout elapsed)");
79                        self.state = CircuitState::HalfOpen;
80                        self.failure_count = 0;
81                        CIRCUIT_BREAKER_STATE.set(2); // 2 = Half-Open
82                        return true;
83                    }
84                }
85                false
86            }
87            CircuitState::HalfOpen => true,
88        }
89    }
90}
91
92/// Circuit breaker middleware
93///
94/// Tracks upstream failures and opens the circuit when threshold is exceeded.
95/// While open, requests are rejected with 503 to prevent overwhelming the failing upstream.
96pub struct CircuitBreakerMiddleware {
97    /// Internal circuit breaker state
98    state: Arc<Mutex<CircuitBreakerState>>,
99    /// Whether circuit breaker is enabled
100    enabled: bool,
101}
102
103impl CircuitBreakerMiddleware {
104    /// Create a new circuit breaker middleware
105    ///
106    /// # Arguments
107    ///
108    /// * `enabled` - Whether to enable circuit breaking
109    /// * `failure_threshold` - Number of consecutive failures before opening
110    /// * `timeout_secs` - Seconds to wait before trying half-open state
111    pub fn new(enabled: bool, failure_threshold: u32, timeout_secs: u64) -> Self {
112        debug!(
113            "Circuit breaker initialized: enabled={}, threshold={}, timeout={}s",
114            enabled, failure_threshold, timeout_secs
115        );
116
117        // Initialize circuit breaker state metric to Closed (0)
118        if enabled {
119            CIRCUIT_BREAKER_STATE.set(0);
120        }
121
122        Self {
123            state: Arc::new(Mutex::new(CircuitBreakerState::new(
124                failure_threshold,
125                timeout_secs,
126            ))),
127            enabled,
128        }
129    }
130}
131
132#[async_trait]
133impl Middleware for CircuitBreakerMiddleware {
134    /// Check circuit state before forwarding request
135    async fn request_filter(
136        &self,
137        session: &mut Session,
138        _ctx: &mut MiddlewareContext,
139    ) -> Result<()> {
140        if !self.enabled {
141            return Ok(());
142        }
143
144        // Check if request is allowed (lock scope)
145        let allowed = {
146            let mut state = self.state.lock();
147            state.is_request_allowed()
148        }; // Lock released here
149
150        if !allowed {
151            warn!("Circuit breaker: request rejected (circuit OPEN)");
152
153            // Return 503 Service Unavailable
154            session.respond_error(503).await?;
155
156            return Err(Error::explain(
157                ErrorType::HTTPStatus(503),
158                "Circuit breaker open - upstream unavailable",
159            ));
160        }
161
162        debug!("Circuit breaker: request allowed");
163        Ok(())
164    }
165
166    /// Record success/failure based on upstream response
167    async fn response_filter(
168        &self,
169        _session: &mut Session,
170        upstream_response: &mut ResponseHeader,
171        _ctx: &mut MiddlewareContext,
172    ) -> Result<()> {
173        if !self.enabled {
174            return Ok(());
175        }
176
177        let status = upstream_response.status.as_u16();
178        let mut state = self.state.lock();
179
180        // Consider 5xx responses as failures
181        if status >= 500 {
182            warn!("Circuit breaker: upstream failure (status {})", status);
183
184            // Record upstream failure metric
185            UPSTREAM_FAILURES
186                .with_label_values(&[&status.to_string()])
187                .inc();
188
189            state.record_failure();
190        } else {
191            debug!("Circuit breaker: upstream success (status {})", status);
192            state.record_success();
193        }
194
195        Ok(())
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_circuit_breaker_creation() {
205        let cb = CircuitBreakerMiddleware::new(true, 5, 60);
206        assert!(cb.enabled);
207    }
208
209    #[test]
210    fn test_circuit_breaker_disabled() {
211        let cb = CircuitBreakerMiddleware::new(false, 5, 60);
212        assert!(!cb.enabled);
213    }
214
215    #[test]
216    fn test_circuit_breaker_state_transitions() {
217        let cb = CircuitBreakerMiddleware::new(true, 3, 1);
218        let mut state = cb.state.lock();
219
220        // Should start closed
221        assert!(state.is_request_allowed());
222
223        // Record failures
224        for _ in 0..3 {
225            state.record_failure();
226        }
227
228        // Circuit should be open after threshold
229        assert_eq!(state.state, CircuitState::Open);
230        assert!(!state.is_request_allowed());
231    }
232}