warpdrive_proxy/middleware/
circuit_breaker.rs1use async_trait::async_trait;
7use parking_lot::Mutex;
8use pingora::http::ResponseHeader;
9use pingora::prelude::*;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tracing::{debug, warn};
13
14use super::{Middleware, MiddlewareContext};
15use crate::metrics::{CIRCUIT_BREAKER_STATE, UPSTREAM_FAILURES};
16
17#[derive(Debug, Clone, Copy, PartialEq)]
19enum CircuitState {
20 Closed,
21 Open,
22 HalfOpen,
23}
24
25#[derive(Debug)]
27struct CircuitBreakerState {
28 state: CircuitState,
29 failure_count: u32,
30 last_failure_time: Option<Instant>,
31 failure_threshold: u32,
32 timeout: Duration,
33}
34
35impl CircuitBreakerState {
36 fn new(failure_threshold: u32, timeout_secs: u64) -> Self {
37 Self {
38 state: CircuitState::Closed,
39 failure_count: 0,
40 last_failure_time: None,
41 failure_threshold,
42 timeout: Duration::from_secs(timeout_secs),
43 }
44 }
45
46 fn record_success(&mut self) {
47 if self.state == CircuitState::HalfOpen {
48 debug!("Circuit breaker: Half-Open → Closed (success)");
50 self.state = CircuitState::Closed;
51 CIRCUIT_BREAKER_STATE.set(0); }
53 self.failure_count = 0;
54 self.last_failure_time = None;
55 }
56
57 fn record_failure(&mut self) {
58 self.failure_count += 1;
59 self.last_failure_time = Some(Instant::now());
60
61 if self.failure_count >= self.failure_threshold && self.state == CircuitState::Closed {
62 warn!(
63 "Circuit breaker: Closed → Open (failures: {})",
64 self.failure_count
65 );
66 self.state = CircuitState::Open;
67 CIRCUIT_BREAKER_STATE.set(1); }
69 }
70
71 fn is_request_allowed(&mut self) -> bool {
72 match self.state {
73 CircuitState::Closed => true,
74 CircuitState::Open => {
75 if let Some(last_failure) = self.last_failure_time {
77 if last_failure.elapsed() >= self.timeout {
78 debug!("Circuit breaker: Open → Half-Open (timeout elapsed)");
79 self.state = CircuitState::HalfOpen;
80 self.failure_count = 0;
81 CIRCUIT_BREAKER_STATE.set(2); return true;
83 }
84 }
85 false
86 }
87 CircuitState::HalfOpen => true,
88 }
89 }
90}
91
92pub struct CircuitBreakerMiddleware {
97 state: Arc<Mutex<CircuitBreakerState>>,
99 enabled: bool,
101}
102
103impl CircuitBreakerMiddleware {
104 pub fn new(enabled: bool, failure_threshold: u32, timeout_secs: u64) -> Self {
112 debug!(
113 "Circuit breaker initialized: enabled={}, threshold={}, timeout={}s",
114 enabled, failure_threshold, timeout_secs
115 );
116
117 if enabled {
119 CIRCUIT_BREAKER_STATE.set(0);
120 }
121
122 Self {
123 state: Arc::new(Mutex::new(CircuitBreakerState::new(
124 failure_threshold,
125 timeout_secs,
126 ))),
127 enabled,
128 }
129 }
130}
131
132#[async_trait]
133impl Middleware for CircuitBreakerMiddleware {
134 async fn request_filter(
136 &self,
137 session: &mut Session,
138 _ctx: &mut MiddlewareContext,
139 ) -> Result<()> {
140 if !self.enabled {
141 return Ok(());
142 }
143
144 let allowed = {
146 let mut state = self.state.lock();
147 state.is_request_allowed()
148 }; if !allowed {
151 warn!("Circuit breaker: request rejected (circuit OPEN)");
152
153 session.respond_error(503).await?;
155
156 return Err(Error::explain(
157 ErrorType::HTTPStatus(503),
158 "Circuit breaker open - upstream unavailable",
159 ));
160 }
161
162 debug!("Circuit breaker: request allowed");
163 Ok(())
164 }
165
166 async fn response_filter(
168 &self,
169 _session: &mut Session,
170 upstream_response: &mut ResponseHeader,
171 _ctx: &mut MiddlewareContext,
172 ) -> Result<()> {
173 if !self.enabled {
174 return Ok(());
175 }
176
177 let status = upstream_response.status.as_u16();
178 let mut state = self.state.lock();
179
180 if status >= 500 {
182 warn!("Circuit breaker: upstream failure (status {})", status);
183
184 UPSTREAM_FAILURES
186 .with_label_values(&[&status.to_string()])
187 .inc();
188
189 state.record_failure();
190 } else {
191 debug!("Circuit breaker: upstream success (status {})", status);
192 state.record_success();
193 }
194
195 Ok(())
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_circuit_breaker_creation() {
205 let cb = CircuitBreakerMiddleware::new(true, 5, 60);
206 assert!(cb.enabled);
207 }
208
209 #[test]
210 fn test_circuit_breaker_disabled() {
211 let cb = CircuitBreakerMiddleware::new(false, 5, 60);
212 assert!(!cb.enabled);
213 }
214
215 #[test]
216 fn test_circuit_breaker_state_transitions() {
217 let cb = CircuitBreakerMiddleware::new(true, 3, 1);
218 let mut state = cb.state.lock();
219
220 assert!(state.is_request_allowed());
222
223 for _ in 0..3 {
225 state.record_failure();
226 }
227
228 assert_eq!(state.state, CircuitState::Open);
230 assert!(!state.is_request_allowed());
231 }
232}