1use std::future::Future;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use std::time::Duration;
7
8use pin_project_lite::pin_project;
9use tokio::time::{Sleep, sleep};
10
11pin_project! {
12 pub struct Timeout<F> {
14 #[pin]
15 future: F,
16 #[pin]
17 delay: Sleep,
18 }
19}
20
21impl<F> Timeout<F> {
22 pub fn new(future: F, duration: Duration) -> Self {
24 Self {
25 future,
26 delay: sleep(duration),
27 }
28 }
29}
30
31impl<F> Future for Timeout<F>
32where
33 F: Future,
34{
35 type Output = Result<F::Output, TimeoutError>;
36
37 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
38 let this = self.project();
39
40 if let Poll::Ready(output) = this.future.poll(cx) {
42 return Poll::Ready(Ok(output));
43 }
44
45 match this.delay.poll(cx) {
47 Poll::Ready(()) => Poll::Ready(Err(TimeoutError)),
48 Poll::Pending => Poll::Pending,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub struct TimeoutError;
56
57impl std::fmt::Display for TimeoutError {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 write!(f, "Operation timed out")
60 }
61}
62
63impl std::error::Error for TimeoutError {}
64
65pub fn timeout<F>(duration: Duration, future: F) -> Timeout<F>
67where
68 F: Future,
69{
70 Timeout::new(future, duration)
71}
72
73#[derive(Debug, Clone)]
75pub struct RetryConfig {
76 pub max_attempts: usize,
78 pub base_delay: Duration,
80 pub max_delay: Duration,
82 pub backoff_multiplier: f64,
84 pub jitter: bool,
86}
87
88impl Default for RetryConfig {
89 fn default() -> Self {
90 Self {
91 max_attempts: 3,
92 base_delay: Duration::from_millis(100),
93 max_delay: Duration::from_secs(30),
94 backoff_multiplier: 2.0,
95 jitter: true,
96 }
97 }
98}
99
100impl RetryConfig {
101 #[must_use]
103 pub fn new() -> Self {
104 Self::default()
105 }
106
107 #[must_use]
109 pub const fn with_max_attempts(mut self, max_attempts: usize) -> Self {
110 self.max_attempts = max_attempts;
111 self
112 }
113
114 #[must_use]
116 pub const fn with_base_delay(mut self, delay: Duration) -> Self {
117 self.base_delay = delay;
118 self
119 }
120
121 #[must_use]
123 pub const fn with_max_delay(mut self, delay: Duration) -> Self {
124 self.max_delay = delay;
125 self
126 }
127
128 #[must_use]
130 pub const fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
131 self.backoff_multiplier = multiplier;
132 self
133 }
134
135 #[must_use]
137 pub const fn with_jitter(mut self, jitter: bool) -> Self {
138 self.jitter = jitter;
139 self
140 }
141
142 #[must_use]
144 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
145 if attempt == 0 {
146 return Duration::ZERO;
147 }
148
149 let base_delay_ms = self.base_delay.as_millis() as f64;
150 let multiplier = self.backoff_multiplier.powi((attempt - 1) as i32);
151 let delay_ms = base_delay_ms * multiplier;
152
153 let delay = Duration::from_millis(delay_ms as u64).min(self.max_delay);
154
155 if self.jitter {
156 let jitter_factor = (rand::random::<f64>() - 0.5).mul_add(0.1, 1.0); let jittered_delay = delay.mul_f64(jitter_factor);
158 jittered_delay.min(self.max_delay)
159 } else {
160 delay
161 }
162 }
163}
164
165pub async fn retry_with_backoff<F, Fut, T, E>(
171 mut operation: F,
172 config: RetryConfig,
173 should_retry: impl Fn(&E) -> bool,
174) -> Result<T, E>
175where
176 F: FnMut() -> Fut,
177 Fut: Future<Output = Result<T, E>>,
178{
179 let mut last_error = None;
180
181 for attempt in 0..config.max_attempts {
182 match operation().await {
183 Ok(result) => return Ok(result),
184 Err(error) => {
185 if !should_retry(&error) || attempt + 1 >= config.max_attempts {
186 return Err(error);
187 }
188
189 let delay = config.delay_for_attempt(attempt + 1);
190 sleep(delay).await;
191 last_error = Some(error);
192 }
193 }
194 }
195
196 Err(last_error.expect("Retry loop ended without attempts - this is a bug in retry logic"))
199}
200
201#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub enum CircuitState {
204 Closed,
206 Open,
208 HalfOpen,
210}
211
212#[derive(Debug)]
214pub struct CircuitBreaker {
215 state: parking_lot::Mutex<CircuitBreakerState>,
216 failure_threshold: usize,
217 recovery_timeout: Duration,
218 success_threshold: usize,
219}
220
221#[derive(Debug)]
222struct CircuitBreakerState {
223 state: CircuitState,
224 failure_count: usize,
225 success_count: usize,
226 last_failure_time: Option<std::time::Instant>,
227}
228
229impl CircuitBreaker {
230 #[must_use]
232 pub const fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
233 Self {
234 state: parking_lot::Mutex::new(CircuitBreakerState {
235 state: CircuitState::Closed,
236 failure_count: 0,
237 success_count: 0,
238 last_failure_time: None,
239 }),
240 failure_threshold,
241 recovery_timeout,
242 success_threshold: 3,
243 }
244 }
245
246 pub async fn call<F, Fut, T, E>(&self, operation: F) -> Result<T, CircuitBreakerError<E>>
248 where
249 F: FnOnce() -> Fut,
250 Fut: Future<Output = Result<T, E>>,
251 {
252 if self.is_open() {
254 return Err(CircuitBreakerError::Open);
255 }
256
257 match operation().await {
259 Ok(result) => {
260 self.record_success();
261 Ok(result)
262 }
263 Err(error) => {
264 self.record_failure();
265 Err(CircuitBreakerError::Operation(error))
266 }
267 }
268 }
269
270 pub fn state(&self) -> CircuitState {
272 self.state.lock().state
273 }
274
275 fn is_open(&self) -> bool {
276 let mut state = self.state.lock();
277
278 match state.state {
279 CircuitState::Open => {
280 state.last_failure_time.is_none_or(|last_failure| {
282 if last_failure.elapsed() >= self.recovery_timeout {
283 state.state = CircuitState::HalfOpen;
284 state.success_count = 0;
285 false
286 } else {
287 true
288 }
289 })
290 }
291 _ => false,
292 }
293 }
294
295 fn record_success(&self) {
296 let mut state = self.state.lock();
297
298 match state.state {
299 CircuitState::Closed => {
300 state.failure_count = 0;
301 }
302 CircuitState::HalfOpen => {
303 state.success_count += 1;
304 if state.success_count >= self.success_threshold {
305 state.state = CircuitState::Closed;
306 state.failure_count = 0;
307 state.success_count = 0;
308 }
309 }
310 CircuitState::Open => {
311 }
313 }
314 }
315
316 fn record_failure(&self) {
317 let mut state = self.state.lock();
318
319 state.failure_count += 1;
320 state.last_failure_time = Some(std::time::Instant::now());
321
322 match state.state {
323 CircuitState::Closed => {
324 if state.failure_count >= self.failure_threshold {
325 state.state = CircuitState::Open;
326 }
327 }
328 CircuitState::HalfOpen => {
329 state.state = CircuitState::Open;
330 state.success_count = 0;
331 }
332 CircuitState::Open => {
333 }
335 }
336 }
337}
338
339#[derive(Debug)]
341pub enum CircuitBreakerError<E> {
342 Open,
344 Operation(E),
346}
347
348impl<E> std::fmt::Display for CircuitBreakerError<E>
349where
350 E: std::fmt::Display,
351{
352 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
353 match self {
354 Self::Open => write!(f, "Circuit breaker is open"),
355 Self::Operation(e) => write!(f, "Operation failed: {e}"),
356 }
357 }
358}
359
360impl<E> std::error::Error for CircuitBreakerError<E>
361where
362 E: std::error::Error + 'static,
363{
364 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
365 match self {
366 Self::Open => None,
367 Self::Operation(e) => Some(e),
368 }
369 }
370}
371
372#[macro_export]
374macro_rules! measure_time {
375 ($name:expr, $block:block) => {{
376 let _start = std::time::Instant::now();
377 let result = $block;
378 let _elapsed = _start.elapsed();
379
380 #[cfg(feature = "tracing")]
381 tracing::debug!("{} took {:?}", $name, _elapsed);
382
383 result
384 }};
385}
386
387#[macro_export]
389macro_rules! feature_gate {
390 ($feature:expr, $block:block) => {
391 #[cfg(feature = $feature)]
392 $block
393 };
394 ($feature:expr, $if_block:block, $else_block:block) => {
395 #[cfg(feature = $feature)]
396 $if_block
397 #[cfg(not(feature = $feature))]
398 $else_block
399 };
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use std::sync::Arc;
406 use std::sync::atomic::{AtomicU32, Ordering};
407
408 #[tokio::test]
409 async fn test_timeout() {
410 let result = timeout(Duration::from_millis(100), async { 42 }).await;
412 assert_eq!(result.unwrap(), 42);
413
414 let result = timeout(Duration::from_millis(10), async {
416 sleep(Duration::from_millis(50)).await;
417 42
418 })
419 .await;
420 assert!(result.is_err());
421 }
422
423 #[test]
424 fn test_retry_config() {
425 let config = RetryConfig::new()
426 .with_max_attempts(5)
427 .with_base_delay(Duration::from_millis(50))
428 .with_jitter(false);
429
430 assert_eq!(config.max_attempts, 5);
431 assert_eq!(config.base_delay, Duration::from_millis(50));
432 assert!(!config.jitter);
433
434 assert_eq!(config.delay_for_attempt(0), Duration::ZERO);
436 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(50));
437 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(100));
438 }
439
440 #[tokio::test]
441 async fn test_retry_with_backoff() {
442 let counter = Arc::new(AtomicU32::new(0));
443 let counter_clone = counter.clone();
444
445 let config = RetryConfig::new()
446 .with_max_attempts(3)
447 .with_base_delay(Duration::from_millis(1))
448 .with_jitter(false);
449
450 let result = retry_with_backoff(
451 move || {
452 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
453 async move {
454 if count < 2 {
455 Err("fail")
456 } else {
457 Ok("success")
458 }
459 }
460 },
461 config,
462 |_| true,
463 )
464 .await;
465
466 assert_eq!(result.unwrap(), "success");
467 assert_eq!(counter.load(Ordering::SeqCst), 3);
468 }
469
470 #[tokio::test]
471 async fn test_circuit_breaker() {
472 let cb = CircuitBreaker::new(2, Duration::from_millis(10));
473 let counter = Arc::new(AtomicU32::new(0));
474
475 let result = cb
477 .call({
478 let counter = counter.clone();
479 move || async move {
480 counter.fetch_add(1, Ordering::SeqCst);
481 Err::<(), _>("error")
482 }
483 })
484 .await;
485 assert!(matches!(result, Err(CircuitBreakerError::Operation(_))));
486 assert_eq!(cb.state(), CircuitState::Closed);
487
488 let result = cb
490 .call({
491 let counter = counter.clone();
492 move || async move {
493 counter.fetch_add(1, Ordering::SeqCst);
494 Err::<(), _>("error")
495 }
496 })
497 .await;
498 assert!(matches!(result, Err(CircuitBreakerError::Operation(_))));
499 assert_eq!(cb.state(), CircuitState::Open);
500
501 let result: Result<(), CircuitBreakerError<&str>> = cb
503 .call({
504 let counter = counter.clone();
505 move || async move {
506 counter.fetch_add(1, Ordering::SeqCst);
507 Ok(())
508 }
509 })
510 .await;
511 assert!(matches!(result, Err(CircuitBreakerError::Open)));
512
513 assert_eq!(counter.load(Ordering::SeqCst), 2);
515 }
516
517 #[test]
518 fn test_measure_time_macro() {
519 let result = measure_time!("test_operation", {
520 std::thread::sleep(Duration::from_millis(1));
521 42
522 });
523 assert_eq!(result, 42);
524 }
525}