1use recloser::{Recloser, AsyncRecloser, Error as RecloserError};
16use std::future::Future;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::time::Duration;
19use tracing::{debug, warn};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum CircuitState {
24 Closed = 0,
25 HalfOpen = 1,
26 Open = 2,
27}
28
29impl std::fmt::Display for CircuitState {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 Self::Closed => write!(f, "closed"),
33 Self::HalfOpen => write!(f, "half_open"),
34 Self::Open => write!(f, "open"),
35 }
36 }
37}
38
39#[derive(Debug, thiserror::Error)]
41pub enum CircuitError<E> {
42 #[error("circuit breaker open, request rejected")]
44 Rejected,
45
46 #[error("operation failed: {0}")]
48 Inner(#[source] E),
49}
50
51impl<E> From<RecloserError<E>> for CircuitError<E> {
52 fn from(err: RecloserError<E>) -> Self {
53 match err {
54 RecloserError::Rejected => CircuitError::Rejected,
55 RecloserError::Inner(e) => CircuitError::Inner(e),
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct CircuitConfig {
63 pub failure_threshold: u32,
65 pub success_threshold: u32,
67 pub recovery_timeout: Duration,
69}
70
71impl Default for CircuitConfig {
72 fn default() -> Self {
73 Self {
74 failure_threshold: 5,
75 success_threshold: 2,
76 recovery_timeout: Duration::from_secs(30),
77 }
78 }
79}
80
81impl CircuitConfig {
82 #[must_use]
84 pub fn aggressive() -> Self {
85 Self {
86 failure_threshold: 3,
87 success_threshold: 3,
88 recovery_timeout: Duration::from_secs(60),
89 }
90 }
91
92 #[must_use]
94 pub fn lenient() -> Self {
95 Self {
96 failure_threshold: 10,
97 success_threshold: 1,
98 recovery_timeout: Duration::from_secs(15),
99 }
100 }
101
102 #[cfg(test)]
104 pub fn test() -> Self {
105 Self {
106 failure_threshold: 2,
107 success_threshold: 1,
108 recovery_timeout: Duration::from_millis(50),
109 }
110 }
111}
112
113pub struct CircuitBreaker {
115 name: String,
116 inner: AsyncRecloser,
117
118 calls_total: AtomicU64,
120 successes: AtomicU64,
121 failures: AtomicU64,
122 rejections: AtomicU64,
123}
124
125impl CircuitBreaker {
126 pub fn new(name: impl Into<String>, config: CircuitConfig) -> Self {
128 let recloser = Recloser::custom()
129 .error_rate(config.failure_threshold as f32 / 100.0)
130 .closed_len(config.failure_threshold as usize)
131 .half_open_len(config.success_threshold as usize)
132 .open_wait(config.recovery_timeout)
133 .build();
134
135 Self {
136 name: name.into(),
137 inner: recloser.into(),
138 calls_total: AtomicU64::new(0),
139 successes: AtomicU64::new(0),
140 failures: AtomicU64::new(0),
141 rejections: AtomicU64::new(0),
142 }
143 }
144
145 pub fn with_defaults(name: impl Into<String>) -> Self {
147 Self::new(name, CircuitConfig::default())
148 }
149
150 #[must_use]
152 pub fn name(&self) -> &str {
153 &self.name
154 }
155
156 #[must_use]
158 pub fn state(&self) -> CircuitState {
159 CircuitState::Closed }
163
164 pub async fn call<F, Fut, T, E>(&self, f: F) -> Result<T, CircuitError<E>>
168 where
169 F: FnOnce() -> Fut,
170 Fut: Future<Output = Result<T, E>>,
171 {
172 self.calls_total.fetch_add(1, Ordering::Relaxed);
173
174 match self.inner.call(f()).await {
176 Ok(result) => {
177 self.successes.fetch_add(1, Ordering::Relaxed);
178 debug!(circuit = %self.name, "Circuit call succeeded");
179 crate::metrics::record_circuit_breaker_call(&self.name, "success");
180 Ok(result)
181 }
182 Err(RecloserError::Rejected) => {
183 self.rejections.fetch_add(1, Ordering::Relaxed);
184 warn!(circuit = %self.name, "Circuit breaker rejected call (open)");
185 crate::metrics::record_circuit_breaker_call(&self.name, "rejected");
186 Err(CircuitError::Rejected)
187 }
188 Err(RecloserError::Inner(e)) => {
189 self.failures.fetch_add(1, Ordering::Relaxed);
190 debug!(circuit = %self.name, "Circuit call failed");
191 crate::metrics::record_circuit_breaker_call(&self.name, "failure");
192 Err(CircuitError::Inner(e))
193 }
194 }
195 }
196
197 #[must_use]
199 pub fn calls_total(&self) -> u64 {
200 self.calls_total.load(Ordering::Relaxed)
201 }
202
203 #[must_use]
205 pub fn successes(&self) -> u64 {
206 self.successes.load(Ordering::Relaxed)
207 }
208
209 #[must_use]
211 pub fn failures(&self) -> u64 {
212 self.failures.load(Ordering::Relaxed)
213 }
214
215 #[must_use]
217 pub fn rejections(&self) -> u64 {
218 self.rejections.load(Ordering::Relaxed)
219 }
220
221 #[must_use]
223 pub fn failure_rate(&self) -> f64 {
224 let total = self.calls_total();
225 if total == 0 {
226 return 0.0;
227 }
228 self.failures() as f64 / total as f64
229 }
230
231 pub fn reset_metrics(&self) {
233 self.calls_total.store(0, Ordering::Relaxed);
234 self.successes.store(0, Ordering::Relaxed);
235 self.failures.store(0, Ordering::Relaxed);
236 self.rejections.store(0, Ordering::Relaxed);
237 }
238}
239
240pub struct BackendCircuits {
242 pub redis: CircuitBreaker,
244 pub mysql: CircuitBreaker,
246}
247
248impl Default for BackendCircuits {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl BackendCircuits {
255 pub fn new() -> Self {
257 Self {
258 redis: CircuitBreaker::new("redis_l2", CircuitConfig::lenient()),
260 mysql: CircuitBreaker::new("mysql_l3", CircuitConfig::aggressive()),
262 }
263 }
264
265 pub fn metrics(&self) -> BackendCircuitMetrics {
267 BackendCircuitMetrics {
268 redis_calls: self.redis.calls_total(),
269 redis_successes: self.redis.successes(),
270 redis_failures: self.redis.failures(),
271 redis_rejections: self.redis.rejections(),
272 mysql_calls: self.mysql.calls_total(),
273 mysql_successes: self.mysql.successes(),
274 mysql_failures: self.mysql.failures(),
275 mysql_rejections: self.mysql.rejections(),
276 }
277 }
278}
279
280#[derive(Debug, Clone)]
282pub struct BackendCircuitMetrics {
283 pub redis_calls: u64,
284 pub redis_successes: u64,
285 pub redis_failures: u64,
286 pub redis_rejections: u64,
287 pub mysql_calls: u64,
288 pub mysql_successes: u64,
289 pub mysql_failures: u64,
290 pub mysql_rejections: u64,
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use std::sync::atomic::AtomicUsize;
297
298 #[tokio::test]
299 async fn test_circuit_passes_successful_calls() {
300 let cb = CircuitBreaker::new("test", CircuitConfig::test());
301
302 let result: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(42) }).await;
303
304 assert!(result.is_ok());
305 assert_eq!(result.unwrap(), 42);
306 assert_eq!(cb.successes(), 1);
307 assert_eq!(cb.failures(), 0);
308 }
309
310 #[tokio::test]
311 async fn test_circuit_tracks_failures() {
312 let cb = CircuitBreaker::new("test", CircuitConfig::test());
313
314 let result: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("boom") }).await;
315
316 assert!(matches!(result, Err(CircuitError::Inner("boom"))));
317 assert_eq!(cb.successes(), 0);
318 assert_eq!(cb.failures(), 1);
319 }
320
321 #[tokio::test]
322 async fn test_circuit_opens_after_threshold() {
323 let config = CircuitConfig {
324 failure_threshold: 2,
325 success_threshold: 1,
326 recovery_timeout: Duration::from_secs(60), };
328 let cb = CircuitBreaker::new("test", config);
329
330 for _ in 0..3 {
332 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("fail") }).await;
333 }
334
335 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(42) }).await;
337
338 assert!(cb.failures() >= 2 || cb.rejections() >= 1);
341 }
342
343 #[tokio::test]
344 async fn test_circuit_metrics_accumulate() {
345 let cb = CircuitBreaker::new("test", CircuitConfig::test());
346
347 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
349 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(2) }).await;
350 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(3) }).await;
351 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(4) }).await;
352
353 assert_eq!(cb.calls_total(), 4);
354 assert_eq!(cb.successes(), 4);
355 assert_eq!(cb.failures(), 0);
356 }
357
358 #[tokio::test]
359 async fn test_failure_rate_calculation() {
360 let config = CircuitConfig {
362 failure_threshold: 100, success_threshold: 1,
364 recovery_timeout: Duration::from_secs(60),
365 };
366 let cb = CircuitBreaker::new("test", config);
367
368 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
370 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("x") }).await;
371 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(2) }).await;
372 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("y") }).await;
373
374 assert!((cb.failure_rate() - 0.5).abs() < 0.01);
375 }
376
377 #[tokio::test]
378 async fn test_reset_metrics() {
379 let cb = CircuitBreaker::new("test", CircuitConfig::test());
380
381 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Ok(1) }).await;
382 let _: Result<i32, CircuitError<&str>> = cb.call(|| async { Err("x") }).await;
383
384 assert!(cb.calls_total() > 0);
385
386 cb.reset_metrics();
387
388 assert_eq!(cb.calls_total(), 0);
389 assert_eq!(cb.successes(), 0);
390 assert_eq!(cb.failures(), 0);
391 assert_eq!(cb.rejections(), 0);
392 }
393
394 #[tokio::test]
395 async fn test_backend_circuits_configs() {
396 let circuits = BackendCircuits::new();
397
398 assert_eq!(circuits.redis.name(), "redis_l2");
400 assert_eq!(circuits.mysql.name(), "mysql_l3");
401 }
402
403 #[tokio::test]
404 async fn test_circuit_with_async_state() {
405 let cb = CircuitBreaker::new("test", CircuitConfig::test());
406 let counter = std::sync::Arc::new(AtomicUsize::new(0));
407
408 let counter_clone = counter.clone();
410 let result: Result<usize, CircuitError<&str>> = cb.call(|| async move {
411 counter_clone.fetch_add(1, Ordering::SeqCst);
412 Ok(counter_clone.load(Ordering::SeqCst))
413 }).await;
414
415 assert_eq!(result.unwrap(), 1);
416 assert_eq!(counter.load(Ordering::SeqCst), 1);
417 }
418
419 #[tokio::test]
420 async fn test_backend_circuit_metrics() {
421 let circuits = BackendCircuits::new();
422
423 let _: Result<i32, CircuitError<&str>> = circuits.redis.call(|| async { Ok(1) }).await;
424 let _: Result<i32, CircuitError<&str>> = circuits.mysql.call(|| async { Err("down") }).await;
425
426 let metrics = circuits.metrics();
427
428 assert_eq!(metrics.redis_calls, 1);
429 assert_eq!(metrics.redis_successes, 1);
430 assert_eq!(metrics.mysql_calls, 1);
431 assert_eq!(metrics.mysql_failures, 1);
432 }
433
434 #[test]
435 fn test_circuit_config_presets() {
436 let default = CircuitConfig::default();
437 let aggressive = CircuitConfig::aggressive();
438 let lenient = CircuitConfig::lenient();
439
440 assert!(aggressive.failure_threshold < default.failure_threshold);
442 assert!(lenient.failure_threshold > default.failure_threshold);
444 assert!(aggressive.recovery_timeout > lenient.recovery_timeout);
446 }
447}