1use recloser::{Recloser, AsyncRecloser, Error as RecloserError};
13use std::future::Future;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::Duration;
16use tracing::{debug, warn};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum CircuitState {
21 Closed = 0,
22 HalfOpen = 1,
23 Open = 2,
24}
25
26impl std::fmt::Display for CircuitState {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 Self::Closed => write!(f, "closed"),
30 Self::HalfOpen => write!(f, "half_open"),
31 Self::Open => write!(f, "open"),
32 }
33 }
34}
35
36#[derive(Debug, thiserror::Error)]
38pub enum CircuitError<E> {
39 #[error("circuit breaker open, request rejected")]
41 Rejected,
42
43 #[error("operation failed: {0}")]
45 Inner(#[source] E),
46}
47
48impl<E> From<RecloserError<E>> for CircuitError<E> {
49 fn from(err: RecloserError<E>) -> Self {
50 match err {
51 RecloserError::Rejected => CircuitError::Rejected,
52 RecloserError::Inner(e) => CircuitError::Inner(e),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct CircuitConfig {
60 pub failure_threshold: u32,
62 pub success_threshold: u32,
64 pub recovery_timeout: Duration,
66}
67
68impl Default for CircuitConfig {
69 fn default() -> Self {
70 Self {
71 failure_threshold: 5,
72 success_threshold: 2,
73 recovery_timeout: Duration::from_secs(30),
74 }
75 }
76}
77
78impl CircuitConfig {
79 #[must_use]
81 pub fn aggressive() -> Self {
82 Self {
83 failure_threshold: 3,
84 success_threshold: 3,
85 recovery_timeout: Duration::from_secs(60),
86 }
87 }
88
89 #[must_use]
91 pub fn lenient() -> Self {
92 Self {
93 failure_threshold: 10,
94 success_threshold: 1,
95 recovery_timeout: Duration::from_secs(15),
96 }
97 }
98
99 #[cfg(test)]
101 pub fn test() -> Self {
102 Self {
103 failure_threshold: 2,
104 success_threshold: 1,
105 recovery_timeout: Duration::from_millis(50),
106 }
107 }
108}
109
110pub struct CircuitBreaker {
112 name: String,
113 inner: AsyncRecloser,
114
115 calls_total: AtomicU64,
117 successes: AtomicU64,
118 failures: AtomicU64,
119 rejections: AtomicU64,
120}
121
122impl CircuitBreaker {
123 pub fn new(name: impl Into<String>, config: CircuitConfig) -> Self {
125 let recloser = Recloser::custom()
126 .error_rate(config.failure_threshold as f32 / 100.0)
127 .closed_len(config.failure_threshold as usize)
128 .half_open_len(config.success_threshold as usize)
129 .open_wait(config.recovery_timeout)
130 .build();
131
132 Self {
133 name: name.into(),
134 inner: recloser.into(),
135 calls_total: AtomicU64::new(0),
136 successes: AtomicU64::new(0),
137 failures: AtomicU64::new(0),
138 rejections: AtomicU64::new(0),
139 }
140 }
141
142 pub fn with_defaults(name: impl Into<String>) -> Self {
144 Self::new(name, CircuitConfig::default())
145 }
146
147 #[must_use]
149 pub fn name(&self) -> &str {
150 &self.name
151 }
152
153 #[must_use]
155 pub fn state(&self) -> CircuitState {
156 CircuitState::Closed }
160
161 pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitError<E>>
165 where
166 F: FnOnce() -> Fut,
167 Fut: Future<Output = Result<T, E>>,
168 {
169 self.calls_total.fetch_add(1, Ordering::Relaxed);
170
171 match self.inner.call(f()).await {
173 Ok(result) => {
174 self.successes.fetch_add(1, Ordering::Relaxed);
175 debug!(circuit = %self.name, "Circuit call succeeded");
176 crate::metrics::record_circuit_breaker_call(&self.name, "success");
177 Ok(result)
178 }
179 Err(RecloserError::Rejected) => {
180 self.rejections.fetch_add(1, Ordering::Relaxed);
181 warn!(circuit = %self.name, "Circuit breaker rejected call (open)");
182 crate::metrics::record_circuit_breaker_call(&self.name, "rejected");
183 Err(CircuitError::Rejected)
184 }
185 Err(RecloserError::Inner(e)) => {
186 self.failures.fetch_add(1, Ordering::Relaxed);
187 debug!(circuit = %self.name, "Circuit call failed");
188 crate::metrics::record_circuit_breaker_call(&self.name, "failure");
189 Err(CircuitError::Inner(e))
190 }
191 }
192 }
193
194 #[must_use]
196 pub fn calls_total(&self) -> u64 {
197 self.calls_total.load(Ordering::Relaxed)
198 }
199
200 #[must_use]
202 pub fn successes(&self) -> u64 {
203 self.successes.load(Ordering::Relaxed)
204 }
205
206 #[must_use]
208 pub fn failures(&self) -> u64 {
209 self.failures.load(Ordering::Relaxed)
210 }
211
212 #[must_use]
214 pub fn rejections(&self) -> u64 {
215 self.rejections.load(Ordering::Relaxed)
216 }
217
218 #[must_use]
220 pub fn failure_rate(&self) -> f64 {
221 let total = self.calls_total();
222 if total == 0 {
223 return 0.0;
224 }
225 self.failures() as f64 / total as f64
226 }
227
228 pub fn reset_metrics(&self) {
230 self.calls_total.store(0, Ordering::Relaxed);
231 self.successes.store(0, Ordering::Relaxed);
232 self.failures.store(0, Ordering::Relaxed);
233 self.rejections.store(0, Ordering::Relaxed);
234 }
235}
236
237pub struct BackendCircuits {
239 pub redis: CircuitBreaker,
241 pub mysql: CircuitBreaker,
243}
244
245impl Default for BackendCircuits {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251impl BackendCircuits {
252 pub fn new() -> Self {
254 Self {
255 redis: CircuitBreaker::new("redis_l2", CircuitConfig::lenient()),
257 mysql: CircuitBreaker::new("mysql_l3", CircuitConfig::aggressive()),
259 }
260 }
261
262 pub fn metrics(&self) -> BackendCircuitMetrics {
264 BackendCircuitMetrics {
265 redis_calls: self.redis.calls_total(),
266 redis_successes: self.redis.successes(),
267 redis_failures: self.redis.failures(),
268 redis_rejections: self.redis.rejections(),
269 mysql_calls: self.mysql.calls_total(),
270 mysql_successes: self.mysql.successes(),
271 mysql_failures: self.mysql.failures(),
272 mysql_rejections: self.mysql.rejections(),
273 }
274 }
275}
276
277#[derive(Debug, Clone)]
279pub struct BackendCircuitMetrics {
280 pub redis_calls: u64,
281 pub redis_successes: u64,
282 pub redis_failures: u64,
283 pub redis_rejections: u64,
284 pub mysql_calls: u64,
285 pub mysql_successes: u64,
286 pub mysql_failures: u64,
287 pub mysql_rejections: u64,
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use std::sync::atomic::AtomicUsize;
294
295 #[tokio::test]
296 async fn test_circuit_passes_successful_calls() {
297 let cb = CircuitBreaker::new("test", CircuitConfig::test());
298
299 let result: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(42) }).await;
300
301 assert!(result.is_ok());
302 assert_eq!(result.unwrap(), 42);
303 assert_eq!(cb.successes(), 1);
304 assert_eq!(cb.failures(), 0);
305 }
306
307 #[tokio::test]
308 async fn test_circuit_tracks_failures() {
309 let cb = CircuitBreaker::new("test", CircuitConfig::test());
310
311 let result: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("boom") }).await;
312
313 assert!(matches!(result, Err(CircuitError::Inner("boom"))));
314 assert_eq!(cb.successes(), 0);
315 assert_eq!(cb.failures(), 1);
316 }
317
318 #[tokio::test]
319 async fn test_circuit_opens_after_threshold() {
320 let config = CircuitConfig {
321 failure_threshold: 2,
322 success_threshold: 1,
323 recovery_timeout: Duration::from_secs(60), };
325 let cb = CircuitBreaker::new("test", config);
326
327 for _ in 0..3 {
329 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("fail") }).await;
330 }
331
332 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(42) }).await;
334
335 assert!(cb.failures() >= 2 || cb.rejections() >= 1);
338 }
339
340 #[tokio::test]
341 async fn test_circuit_metrics_accumulate() {
342 let cb = CircuitBreaker::new("test", CircuitConfig::test());
343
344 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
346 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(2) }).await;
347 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(3) }).await;
348 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(4) }).await;
349
350 assert_eq!(cb.calls_total(), 4);
351 assert_eq!(cb.successes(), 4);
352 assert_eq!(cb.failures(), 0);
353 }
354
355 #[tokio::test]
356 async fn test_failure_rate_calculation() {
357 let config = CircuitConfig {
359 failure_threshold: 100, success_threshold: 1,
361 recovery_timeout: Duration::from_secs(60),
362 };
363 let cb = CircuitBreaker::new("test", config);
364
365 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
367 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("x") }).await;
368 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(2) }).await;
369 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("y") }).await;
370
371 assert!((cb.failure_rate() - 0.5).abs() < 0.01);
372 }
373
374 #[tokio::test]
375 async fn test_reset_metrics() {
376 let cb = CircuitBreaker::new("test", CircuitConfig::test());
377
378 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
379 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("x") }).await;
380
381 assert!(cb.calls_total() > 0);
382
383 cb.reset_metrics();
384
385 assert_eq!(cb.calls_total(), 0);
386 assert_eq!(cb.successes(), 0);
387 assert_eq!(cb.failures(), 0);
388 assert_eq!(cb.rejections(), 0);
389 }
390
391 #[tokio::test]
392 async fn test_backend_circuits_configs() {
393 let circuits = BackendCircuits::new();
394
395 assert_eq!(circuits.redis.name(), "redis_l2");
397 assert_eq!(circuits.mysql.name(), "mysql_l3");
398 }
399
400 #[tokio::test]
401 async fn test_circuit_with_async_state() {
402 let cb = CircuitBreaker::new("test", CircuitConfig::test());
403 let counter = std::sync::Arc::new(AtomicUsize::new(0));
404
405 let counter_clone = counter.clone();
407 let result: Result<usize, CircuitError<&str>> = cb.call(|| async move {
408 counter_clone.fetch_add(1, Ordering::SeqCst);
409 Ok(counter_clone.load(Ordering::SeqCst))
410 }).await;
411
412 assert_eq!(result.unwrap(), 1);
413 assert_eq!(counter.load(Ordering::SeqCst), 1);
414 }
415
416 #[tokio::test]
417 async fn test_backend_circuit_metrics() {
418 let circuits = BackendCircuits::new();
419
420 let _: Result<i32, CircuitError<&str>> = circuits.redis.call(|| async { Ok(1) }).await;
421 let _: Result<i32, CircuitError<&str>> = circuits.mysql.call(|| async { Err("down") }).await;
422
423 let metrics = circuits.metrics();
424
425 assert_eq!(metrics.redis_calls, 1);
426 assert_eq!(metrics.redis_successes, 1);
427 assert_eq!(metrics.mysql_calls, 1);
428 assert_eq!(metrics.mysql_failures, 1);
429 }
430
431 #[test]
432 fn test_circuit_config_presets() {
433 let default = CircuitConfig::default();
434 let aggressive = CircuitConfig::aggressive();
435 let lenient = CircuitConfig::lenient();
436
437 assert!(aggressive.failure_threshold < default.failure_threshold);
439 assert!(lenient.failure_threshold > default.failure_threshold);
441 assert!(aggressive.recovery_timeout > lenient.recovery_timeout);
443 }
444}