1use crate::error::{Result, TalosError};
32use std::future::Future;
33use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
34use std::time::{Duration, Instant};
35use tokio::sync::RwLock;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum CircuitState {
40 Closed,
42 Open,
44 HalfOpen,
46}
47
48#[derive(Debug, Clone)]
50pub struct CircuitBreakerConfig {
51 pub failure_threshold: usize,
53 pub success_threshold: usize,
55 pub reset_timeout: Duration,
57 pub half_open_max_requests: usize,
59}
60
61impl Default for CircuitBreakerConfig {
62 fn default() -> Self {
63 Self {
64 failure_threshold: 5,
65 success_threshold: 2,
66 reset_timeout: Duration::from_secs(30),
67 half_open_max_requests: 3,
68 }
69 }
70}
71
72impl CircuitBreakerConfig {
73 #[must_use]
75 pub fn new() -> Self {
76 Self::default()
77 }
78
79 #[must_use]
81 pub fn with_failure_threshold(mut self, threshold: usize) -> Self {
82 self.failure_threshold = threshold;
83 self
84 }
85
86 #[must_use]
88 pub fn with_success_threshold(mut self, threshold: usize) -> Self {
89 self.success_threshold = threshold;
90 self
91 }
92
93 #[must_use]
95 pub fn with_reset_timeout(mut self, timeout: Duration) -> Self {
96 self.reset_timeout = timeout;
97 self
98 }
99
100 #[must_use]
102 pub fn with_half_open_max_requests(mut self, max: usize) -> Self {
103 self.half_open_max_requests = max;
104 self
105 }
106}
107
108pub struct CircuitBreaker {
114 config: CircuitBreakerConfig,
115 state: RwLock<CircuitState>,
116 failure_count: AtomicUsize,
117 success_count: AtomicUsize,
118 half_open_requests: AtomicUsize,
119 last_failure_time: RwLock<Option<Instant>>,
120 opened_at: RwLock<Option<Instant>>,
121 total_calls: AtomicU64,
122 total_failures: AtomicU64,
123 total_rejections: AtomicU64,
124}
125
126impl CircuitBreaker {
127 #[must_use]
129 pub fn new(config: CircuitBreakerConfig) -> Self {
130 Self {
131 config,
132 state: RwLock::new(CircuitState::Closed),
133 failure_count: AtomicUsize::new(0),
134 success_count: AtomicUsize::new(0),
135 half_open_requests: AtomicUsize::new(0),
136 last_failure_time: RwLock::new(None),
137 opened_at: RwLock::new(None),
138 total_calls: AtomicU64::new(0),
139 total_failures: AtomicU64::new(0),
140 total_rejections: AtomicU64::new(0),
141 }
142 }
143
144 #[must_use]
146 pub fn with_defaults() -> Self {
147 Self::new(CircuitBreakerConfig::default())
148 }
149
150 pub async fn state(&self) -> CircuitState {
152 let current_state = *self.state.read().await;
154 if current_state == CircuitState::Open {
155 if let Some(opened_at) = *self.opened_at.read().await {
156 if opened_at.elapsed() >= self.config.reset_timeout {
157 let mut state = self.state.write().await;
159 if *state == CircuitState::Open {
160 *state = CircuitState::HalfOpen;
161 self.half_open_requests.store(0, Ordering::Relaxed);
162 self.success_count.store(0, Ordering::Relaxed);
163 }
164 return CircuitState::HalfOpen;
165 }
166 }
167 }
168 current_state
169 }
170
171 pub async fn can_execute(&self) -> bool {
173 match self.state().await {
174 CircuitState::Closed => true,
175 CircuitState::Open => false,
176 CircuitState::HalfOpen => {
177 let current = self.half_open_requests.load(Ordering::Relaxed);
178 current < self.config.half_open_max_requests
179 }
180 }
181 }
182
183 pub async fn call<F, Fut, T>(&self, operation: F) -> Result<T>
190 where
191 F: FnOnce() -> Fut,
192 Fut: Future<Output = Result<T>>,
193 {
194 self.total_calls.fetch_add(1, Ordering::Relaxed);
195
196 if !self.can_execute().await {
198 self.total_rejections.fetch_add(1, Ordering::Relaxed);
199 return Err(TalosError::CircuitOpen(format!(
200 "Circuit breaker is open, will retry after {:?}",
201 self.time_until_retry().await
202 )));
203 }
204
205 let current_state = self.state().await;
207 if current_state == CircuitState::HalfOpen {
208 self.half_open_requests.fetch_add(1, Ordering::Relaxed);
209 }
210
211 match operation().await {
213 Ok(result) => {
214 self.on_success().await;
215 Ok(result)
216 }
217 Err(e) => {
218 self.on_failure().await;
219 Err(e)
220 }
221 }
222 }
223
224 async fn on_success(&self) {
226 let state = *self.state.read().await;
227 match state {
228 CircuitState::Closed => {
229 self.failure_count.store(0, Ordering::Relaxed);
231 }
232 CircuitState::HalfOpen => {
233 let successes = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
234 if successes >= self.config.success_threshold {
235 let mut state = self.state.write().await;
237 *state = CircuitState::Closed;
238 self.failure_count.store(0, Ordering::Relaxed);
239 self.success_count.store(0, Ordering::Relaxed);
240 }
241 }
242 CircuitState::Open => {
243 self.failure_count.store(0, Ordering::Relaxed);
245 }
246 }
247 }
248
249 async fn on_failure(&self) {
251 self.total_failures.fetch_add(1, Ordering::Relaxed);
252 *self.last_failure_time.write().await = Some(Instant::now());
253
254 let state = *self.state.read().await;
255 match state {
256 CircuitState::Closed => {
257 let failures = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
258 if failures >= self.config.failure_threshold {
259 self.open_circuit().await;
261 }
262 }
263 CircuitState::HalfOpen => {
264 self.open_circuit().await;
266 }
267 CircuitState::Open => {
268 }
270 }
271 }
272
273 async fn open_circuit(&self) {
275 let mut state = self.state.write().await;
276 *state = CircuitState::Open;
277 *self.opened_at.write().await = Some(Instant::now());
278 }
279
280 pub async fn reset(&self) {
282 let mut state = self.state.write().await;
283 *state = CircuitState::Closed;
284 self.failure_count.store(0, Ordering::Relaxed);
285 self.success_count.store(0, Ordering::Relaxed);
286 self.half_open_requests.store(0, Ordering::Relaxed);
287 *self.opened_at.write().await = None;
288 }
289
290 pub async fn time_until_retry(&self) -> Option<Duration> {
292 if *self.state.read().await != CircuitState::Open {
293 return None;
294 }
295
296 self.opened_at.read().await.map(|opened| {
297 let elapsed = opened.elapsed();
298 if elapsed >= self.config.reset_timeout {
299 Duration::ZERO
300 } else {
301 self.config.reset_timeout - elapsed
302 }
303 })
304 }
305
306 #[must_use]
308 pub fn failure_count(&self) -> usize {
309 self.failure_count.load(Ordering::Relaxed)
310 }
311
312 #[must_use]
314 pub fn total_calls(&self) -> u64 {
315 self.total_calls.load(Ordering::Relaxed)
316 }
317
318 #[must_use]
320 pub fn total_failures(&self) -> u64 {
321 self.total_failures.load(Ordering::Relaxed)
322 }
323
324 #[must_use]
326 pub fn total_rejections(&self) -> u64 {
327 self.total_rejections.load(Ordering::Relaxed)
328 }
329
330 #[must_use]
332 pub fn failure_rate(&self) -> f64 {
333 let total = self.total_calls.load(Ordering::Relaxed);
334 if total == 0 {
335 return 0.0;
336 }
337 let failures = self.total_failures.load(Ordering::Relaxed);
338 failures as f64 / total as f64
339 }
340
341 #[must_use]
343 pub fn config(&self) -> &CircuitBreakerConfig {
344 &self.config
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_circuit_breaker_config_default() {
354 let config = CircuitBreakerConfig::default();
355 assert_eq!(config.failure_threshold, 5);
356 assert_eq!(config.success_threshold, 2);
357 assert_eq!(config.reset_timeout, Duration::from_secs(30));
358 assert_eq!(config.half_open_max_requests, 3);
359 }
360
361 #[test]
362 fn test_circuit_breaker_config_builder() {
363 let config = CircuitBreakerConfig::new()
364 .with_failure_threshold(10)
365 .with_success_threshold(5)
366 .with_reset_timeout(Duration::from_secs(60))
367 .with_half_open_max_requests(5);
368
369 assert_eq!(config.failure_threshold, 10);
370 assert_eq!(config.success_threshold, 5);
371 assert_eq!(config.reset_timeout, Duration::from_secs(60));
372 assert_eq!(config.half_open_max_requests, 5);
373 }
374
375 #[tokio::test]
376 async fn test_circuit_breaker_initial_state() {
377 let breaker = CircuitBreaker::with_defaults();
378 assert_eq!(breaker.state().await, CircuitState::Closed);
379 assert!(breaker.can_execute().await);
380 }
381
382 #[tokio::test]
383 async fn test_circuit_breaker_opens_on_failures() {
384 let config = CircuitBreakerConfig::new().with_failure_threshold(3);
385 let breaker = CircuitBreaker::new(config);
386
387 for _ in 0..3 {
389 let _ = breaker
390 .call(|| async { Err::<(), _>(TalosError::Connection("test".to_string())) })
391 .await;
392 }
393
394 assert_eq!(breaker.state().await, CircuitState::Open);
395 assert!(!breaker.can_execute().await);
396 }
397
398 #[tokio::test]
399 async fn test_circuit_breaker_rejects_when_open() {
400 let config = CircuitBreakerConfig::new()
401 .with_failure_threshold(2)
402 .with_reset_timeout(Duration::from_secs(60));
403 let breaker = CircuitBreaker::new(config);
404
405 for _ in 0..2 {
407 let _ = breaker
408 .call(|| async { Err::<(), _>(TalosError::Connection("test".to_string())) })
409 .await;
410 }
411
412 let result = breaker
414 .call(|| async { Ok::<_, TalosError>("success") })
415 .await;
416
417 assert!(matches!(result, Err(TalosError::CircuitOpen(_))));
418 assert_eq!(breaker.total_rejections(), 1);
419 }
420
421 #[tokio::test]
422 async fn test_circuit_breaker_success_resets_failures() {
423 let config = CircuitBreakerConfig::new().with_failure_threshold(3);
424 let breaker = CircuitBreaker::new(config);
425
426 for _ in 0..2 {
428 let _ = breaker
429 .call(|| async { Err::<(), _>(TalosError::Connection("test".to_string())) })
430 .await;
431 }
432 assert_eq!(breaker.failure_count(), 2);
433
434 let _ = breaker.call(|| async { Ok::<_, TalosError>("ok") }).await;
436 assert_eq!(breaker.failure_count(), 0);
437 }
438
439 #[tokio::test]
440 async fn test_circuit_breaker_reset() {
441 let config = CircuitBreakerConfig::new().with_failure_threshold(2);
442 let breaker = CircuitBreaker::new(config);
443
444 for _ in 0..2 {
446 let _ = breaker
447 .call(|| async { Err::<(), _>(TalosError::Connection("test".to_string())) })
448 .await;
449 }
450 assert_eq!(breaker.state().await, CircuitState::Open);
451
452 breaker.reset().await;
454 assert_eq!(breaker.state().await, CircuitState::Closed);
455 assert!(breaker.can_execute().await);
456 }
457
458 #[tokio::test]
459 async fn test_circuit_breaker_half_open_transition() {
460 let config = CircuitBreakerConfig::new()
461 .with_failure_threshold(2)
462 .with_reset_timeout(Duration::from_millis(50));
463 let breaker = CircuitBreaker::new(config);
464
465 for _ in 0..2 {
467 let _ = breaker
468 .call(|| async { Err::<(), _>(TalosError::Connection("test".to_string())) })
469 .await;
470 }
471 assert_eq!(breaker.state().await, CircuitState::Open);
472
473 tokio::time::sleep(Duration::from_millis(60)).await;
475
476 assert_eq!(breaker.state().await, CircuitState::HalfOpen);
478 }
479
480 #[tokio::test]
481 async fn test_circuit_breaker_closes_after_success_in_half_open() {
482 let config = CircuitBreakerConfig::new()
483 .with_failure_threshold(2)
484 .with_success_threshold(2)
485 .with_reset_timeout(Duration::from_millis(10));
486 let breaker = CircuitBreaker::new(config);
487
488 for _ in 0..2 {
490 let _ = breaker
491 .call(|| async { Err::<(), _>(TalosError::Connection("test".to_string())) })
492 .await;
493 }
494
495 tokio::time::sleep(Duration::from_millis(20)).await;
497 assert_eq!(breaker.state().await, CircuitState::HalfOpen);
498
499 for _ in 0..2 {
501 let _ = breaker.call(|| async { Ok::<_, TalosError>("ok") }).await;
502 }
503
504 assert_eq!(breaker.state().await, CircuitState::Closed);
505 }
506
507 #[tokio::test]
508 async fn test_circuit_breaker_failure_rate() {
509 let breaker = CircuitBreaker::with_defaults();
510
511 assert_eq!(breaker.failure_rate(), 0.0);
512
513 for _ in 0..4 {
515 let _ = breaker.call(|| async { Ok::<_, TalosError>("ok") }).await;
516 }
517 let _ = breaker
518 .call(|| async { Err::<(), _>(TalosError::Connection("test".to_string())) })
519 .await;
520
521 assert!((breaker.failure_rate() - 0.2).abs() < f64::EPSILON);
522 }
523
524 #[tokio::test]
525 async fn test_circuit_breaker_time_until_retry() {
526 let config = CircuitBreakerConfig::new()
527 .with_failure_threshold(2)
528 .with_reset_timeout(Duration::from_secs(30));
529 let breaker = CircuitBreaker::new(config);
530
531 assert!(breaker.time_until_retry().await.is_none());
533
534 for _ in 0..2 {
536 let _ = breaker
537 .call(|| async { Err::<(), _>(TalosError::Connection("test".to_string())) })
538 .await;
539 }
540
541 let retry_time = breaker.time_until_retry().await;
543 assert!(retry_time.is_some());
544 assert!(retry_time.unwrap() > Duration::ZERO);
545 }
546
547 #[test]
548 fn test_circuit_state_equality() {
549 assert_eq!(CircuitState::Closed, CircuitState::Closed);
550 assert_ne!(CircuitState::Closed, CircuitState::Open);
551 assert_ne!(CircuitState::Open, CircuitState::HalfOpen);
552 }
553}