1use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
4use std::time::{Duration, Instant};
5use tokio::sync::RwLock;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum CircuitState {
10 Closed,
12 Open,
14 HalfOpen,
16}
17
18#[derive(Debug, Clone)]
20pub struct CircuitConfig {
21 pub failure_threshold: u32,
23 pub success_threshold: u32,
25 pub half_open_failure_threshold: u32,
27 pub reset_timeout: Duration,
29 pub window_duration: Duration,
31}
32
33impl Default for CircuitConfig {
34 fn default() -> Self {
35 Self {
36 failure_threshold: 5,
37 success_threshold: 3,
38 half_open_failure_threshold: 2, reset_timeout: Duration::from_secs(30),
40 window_duration: Duration::from_secs(60),
41 }
42 }
43}
44
45impl CircuitConfig {
46 pub fn conservative() -> Self {
48 Self {
49 failure_threshold: 3,
50 success_threshold: 5,
51 half_open_failure_threshold: 1, reset_timeout: Duration::from_secs(60),
53 window_duration: Duration::from_secs(120),
54 }
55 }
56
57 pub fn aggressive() -> Self {
59 Self {
60 failure_threshold: 10,
61 success_threshold: 2,
62 half_open_failure_threshold: 3, reset_timeout: Duration::from_secs(10),
64 window_duration: Duration::from_secs(30),
65 }
66 }
67}
68
69#[derive(Debug)]
71pub struct CircuitBreaker {
72 name: String,
73 config: CircuitConfig,
74 state: RwLock<CircuitState>,
75 failure_count: AtomicU32,
76 success_count: AtomicU32,
77 half_open_failure_count: AtomicU32,
78 last_failure_time: RwLock<Option<Instant>>,
79 last_state_change: RwLock<Instant>,
80 total_requests: AtomicU64,
81 total_failures: AtomicU64,
82 total_rejections: AtomicU64,
83}
84
85impl CircuitBreaker {
86 pub fn new(name: &str, config: CircuitConfig) -> Self {
88 Self {
89 name: name.to_string(),
90 config,
91 state: RwLock::new(CircuitState::Closed),
92 failure_count: AtomicU32::new(0),
93 success_count: AtomicU32::new(0),
94 half_open_failure_count: AtomicU32::new(0),
95 last_failure_time: RwLock::new(None),
96 last_state_change: RwLock::new(Instant::now()),
97 total_requests: AtomicU64::new(0),
98 total_failures: AtomicU64::new(0),
99 total_rejections: AtomicU64::new(0),
100 }
101 }
102
103 pub async fn allow(&self) -> bool {
105 self.total_requests.fetch_add(1, Ordering::Relaxed);
106
107 let mut state = self.state.write().await;
108
109 match *state {
110 CircuitState::Closed => true,
111 CircuitState::Open => {
112 let last_change = *self.last_state_change.read().await;
114 if last_change.elapsed() >= self.config.reset_timeout {
115 *state = CircuitState::HalfOpen;
116 *self.last_state_change.write().await = Instant::now();
117 self.success_count.store(0, Ordering::Relaxed);
118 self.half_open_failure_count.store(0, Ordering::Relaxed);
119 tracing::info!(
120 circuit = %self.name,
121 "Circuit transitioned to HalfOpen"
122 );
123 true
124 } else {
125 self.total_rejections.fetch_add(1, Ordering::Relaxed);
126 false
127 }
128 }
129 CircuitState::HalfOpen => {
130 true
132 }
133 }
134 }
135
136 pub async fn record_success(&self) {
138 let mut state = self.state.write().await;
139
140 match *state {
141 CircuitState::HalfOpen => {
142 let count = self.success_count.fetch_add(1, Ordering::Relaxed) + 1;
143 if count >= self.config.success_threshold {
144 *state = CircuitState::Closed;
145 self.failure_count.store(0, Ordering::Relaxed);
146 self.success_count.store(0, Ordering::Relaxed);
147 *self.last_state_change.write().await = Instant::now();
148 tracing::info!(
149 circuit = %self.name,
150 "Circuit recovered - now Closed"
151 );
152 }
153 }
154 CircuitState::Closed => {
155 self.failure_count.store(0, Ordering::Relaxed);
157 }
158 _ => {}
159 }
160 }
161
162 pub async fn record_failure(&self) {
164 self.total_failures.fetch_add(1, Ordering::Relaxed);
165 *self.last_failure_time.write().await = Some(Instant::now());
166
167 let mut state = self.state.write().await;
168
169 match *state {
170 CircuitState::Closed => {
171 let count = self.failure_count.fetch_add(1, Ordering::Relaxed) + 1;
172 if count >= self.config.failure_threshold {
173 *state = CircuitState::Open;
174 *self.last_state_change.write().await = Instant::now();
175 tracing::warn!(
176 circuit = %self.name,
177 failures = count,
178 "Circuit tripped - now Open"
179 );
180 }
181 }
182 CircuitState::HalfOpen => {
183 let half_open_failures =
185 self.half_open_failure_count.fetch_add(1, Ordering::Relaxed) + 1;
186
187 if half_open_failures >= self.config.half_open_failure_threshold {
188 *state = CircuitState::Open;
190 self.success_count.store(0, Ordering::Relaxed);
191 self.half_open_failure_count.store(0, Ordering::Relaxed);
192 *self.last_state_change.write().await = Instant::now();
193 tracing::warn!(
194 circuit = %self.name,
195 half_open_failures = half_open_failures,
196 "Circuit tripped from HalfOpen - back to Open"
197 );
198 } else {
199 tracing::debug!(
200 circuit = %self.name,
201 half_open_failures = half_open_failures,
202 threshold = self.config.half_open_failure_threshold,
203 "HalfOpen failure recorded, still testing"
204 );
205 }
206 }
207 _ => {}
208 }
209 }
210
211 pub async fn state(&self) -> CircuitState {
213 *self.state.read().await
214 }
215
216 pub fn stats(&self) -> CircuitStats {
218 CircuitStats {
219 name: self.name.clone(),
220 total_requests: self.total_requests.load(Ordering::Relaxed),
221 total_failures: self.total_failures.load(Ordering::Relaxed),
222 total_rejections: self.total_rejections.load(Ordering::Relaxed),
223 current_failures: self.failure_count.load(Ordering::Relaxed),
224 current_successes: self.success_count.load(Ordering::Relaxed),
225 }
226 }
227
228 pub async fn call<F, T, E>(&self, f: F) -> Result<T, CircuitError<E>>
230 where
231 F: std::future::Future<Output = Result<T, E>>,
232 {
233 if !self.allow().await {
234 return Err(CircuitError::Open);
235 }
236
237 match f.await {
238 Ok(result) => {
239 self.record_success().await;
240 Ok(result)
241 }
242 Err(e) => {
243 self.record_failure().await;
244 Err(CircuitError::Failed(e))
245 }
246 }
247 }
248}
249
250#[derive(Debug, Clone)]
252pub struct CircuitStats {
253 pub name: String,
254 pub total_requests: u64,
255 pub total_failures: u64,
256 pub total_rejections: u64,
257 pub current_failures: u32,
258 pub current_successes: u32,
259}
260
261#[derive(Debug, thiserror::Error)]
263pub enum CircuitError<E> {
264 #[error("Circuit is open - service unavailable")]
265 Open,
266 #[error("Call failed: {0}")]
267 Failed(#[source] E),
268}
269
270pub struct RetryPolicy {
272 pub max_retries: u32,
273 pub initial_delay: Duration,
274 pub max_delay: Duration,
275 pub multiplier: f64,
276}
277
278impl Default for RetryPolicy {
279 fn default() -> Self {
280 Self {
281 max_retries: 3,
282 initial_delay: Duration::from_millis(100),
283 max_delay: Duration::from_secs(10),
284 multiplier: 2.0,
285 }
286 }
287}
288
289impl RetryPolicy {
290 pub async fn execute<F, Fut, T, E>(&self, mut f: F) -> Result<T, E>
292 where
293 F: FnMut() -> Fut,
294 Fut: std::future::Future<Output = Result<T, E>>,
295 E: std::fmt::Debug,
296 {
297 let mut delay = self.initial_delay;
298 let mut attempts = 0;
299
300 loop {
301 match f().await {
302 Ok(result) => return Ok(result),
303 Err(e) => {
304 attempts += 1;
305 if attempts >= self.max_retries {
306 tracing::error!(
307 attempts = attempts,
308 error = ?e,
309 "Retry exhausted"
310 );
311 return Err(e);
312 }
313
314 tracing::warn!(
315 attempt = attempts,
316 delay_ms = delay.as_millis(),
317 error = ?e,
318 "Retrying after failure"
319 );
320
321 let jitter = delay.as_millis() as f64 * 0.1;
323 let jittered =
324 delay.as_millis() as f64 + (rand::random::<f64>() * 2.0 - 1.0) * jitter;
325
326 tokio::time::sleep(Duration::from_millis(jittered as u64)).await;
327
328 delay =
330 Duration::from_millis((delay.as_millis() as f64 * self.multiplier) as u64)
331 .min(self.max_delay);
332 }
333 }
334 }
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[tokio::test]
343 async fn test_circuit_breaker_trips() {
344 let config = CircuitConfig {
345 failure_threshold: 2,
346 success_threshold: 1,
347 half_open_failure_threshold: 1,
348 reset_timeout: Duration::from_millis(100),
349 window_duration: Duration::from_secs(60),
350 };
351 let cb = CircuitBreaker::new("test", config);
352
353 assert_eq!(cb.state().await, CircuitState::Closed);
355 assert!(cb.allow().await);
356
357 cb.record_failure().await;
359 assert_eq!(cb.state().await, CircuitState::Closed);
360 cb.record_failure().await;
361 assert_eq!(cb.state().await, CircuitState::Open);
362
363 assert!(!cb.allow().await);
365
366 tokio::time::sleep(Duration::from_millis(150)).await;
368
369 assert!(cb.allow().await);
371 assert_eq!(cb.state().await, CircuitState::HalfOpen);
372
373 cb.record_success().await;
375 assert_eq!(cb.state().await, CircuitState::Closed);
376 }
377
378 #[tokio::test]
379 async fn test_retry_policy() {
380 let policy = RetryPolicy {
381 max_retries: 3,
382 initial_delay: Duration::from_millis(10),
383 max_delay: Duration::from_millis(100),
384 multiplier: 2.0,
385 };
386
387 let mut attempts = 0;
388 let result: Result<i32, &str> = policy
389 .execute(|| {
390 attempts += 1;
391 async move {
392 if attempts < 3 {
393 Err("failed")
394 } else {
395 Ok(42)
396 }
397 }
398 })
399 .await;
400
401 assert_eq!(result, Ok(42));
402 assert_eq!(attempts, 3);
403 }
404}