Skip to main content

sediment/
retry.rs

1//! Retry utilities with exponential backoff
2//!
3//! Provides generic retry logic for transient failures in async operations.
4
5use std::fmt::Display;
6use std::future::Future;
7use std::time::Duration;
8
9use tokio::time::sleep;
10use tracing::{debug, warn};
11
12/// Default maximum number of retry attempts
13const DEFAULT_MAX_ATTEMPTS: u32 = 3;
14
15/// Default initial delay between retries (in milliseconds)
16const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
17
18/// Default maximum delay between retries (in milliseconds)
19const DEFAULT_MAX_DELAY_MS: u64 = 2000;
20
21/// Configuration for retry behavior
22#[derive(Debug, Clone)]
23pub struct RetryConfig {
24    /// Maximum number of attempts (including the initial attempt)
25    pub max_attempts: u32,
26    /// Initial delay between retries in milliseconds
27    pub initial_delay_ms: u64,
28    /// Maximum delay between retries in milliseconds (caps exponential growth)
29    pub max_delay_ms: u64,
30}
31
32impl Default for RetryConfig {
33    fn default() -> Self {
34        Self {
35            max_attempts: DEFAULT_MAX_ATTEMPTS,
36            initial_delay_ms: DEFAULT_INITIAL_DELAY_MS,
37            max_delay_ms: DEFAULT_MAX_DELAY_MS,
38        }
39    }
40}
41
42impl RetryConfig {
43    /// Create a new retry configuration with custom values
44    pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
45        Self {
46            max_attempts,
47            initial_delay_ms,
48            max_delay_ms,
49        }
50    }
51
52    /// Calculate the delay for a given attempt number (0-indexed)
53    fn delay_for_attempt(&self, attempt: u32) -> Duration {
54        let delay_ms = self
55            .initial_delay_ms
56            .saturating_mul(1u64.checked_shl(attempt).unwrap_or(u64::MAX));
57        let capped_delay_ms = delay_ms.min(self.max_delay_ms);
58        Duration::from_millis(capped_delay_ms)
59    }
60}
61
62/// Execute an async operation with exponential backoff retry.
63///
64/// The operation is retried up to `config.max_attempts` times on failure.
65/// The delay between retries grows exponentially, starting at `initial_delay_ms`
66/// and capped at `max_delay_ms`.
67///
68/// # Arguments
69///
70/// * `config` - Retry configuration
71/// * `operation` - A closure that returns a Future yielding Result<T, E>
72///
73/// # Returns
74///
75/// The result of the successful operation, or the last error if all attempts fail.
76///
77/// # Example
78///
79/// ```ignore
80/// use sediment::retry::{with_retry, RetryConfig};
81///
82/// let result = with_retry(&RetryConfig::default(), || async {
83///     // Your fallible async operation here
84///     Ok::<_, String>("success")
85/// }).await;
86/// ```
87pub async fn with_retry<T, E, F, Fut>(config: &RetryConfig, operation: F) -> Result<T, E>
88where
89    F: Fn() -> Fut,
90    Fut: Future<Output = Result<T, E>>,
91    E: Display,
92{
93    let mut last_error: Option<E> = None;
94
95    for attempt in 0..config.max_attempts {
96        match operation().await {
97            Ok(result) => {
98                if attempt > 0 {
99                    debug!("Operation succeeded on attempt {}", attempt + 1);
100                }
101                return Ok(result);
102            }
103            Err(e) => {
104                let is_last_attempt = attempt + 1 >= config.max_attempts;
105
106                if is_last_attempt {
107                    warn!(
108                        "Operation failed after {} attempts: {}",
109                        config.max_attempts, e
110                    );
111                    last_error = Some(e);
112                } else {
113                    let delay = config.delay_for_attempt(attempt);
114                    warn!(
115                        "Operation failed (attempt {}/{}): {}. Retrying in {:?}...",
116                        attempt + 1,
117                        config.max_attempts,
118                        e,
119                        delay
120                    );
121                    sleep(delay).await;
122                    last_error = Some(e);
123                }
124            }
125        }
126    }
127
128    // Return the last error
129    Err(last_error.expect("at least one attempt should have been made"))
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use std::sync::Arc;
136    use std::sync::atomic::{AtomicU32, Ordering};
137
138    #[tokio::test]
139    async fn test_retry_success_first_attempt() {
140        let config = RetryConfig::default();
141        let result: Result<&str, &str> = with_retry(&config, || async { Ok("success") }).await;
142        assert_eq!(result, Ok("success"));
143    }
144
145    #[tokio::test]
146    async fn test_retry_success_after_failures() {
147        let config = RetryConfig::new(3, 10, 100); // Short delays for testing
148        let attempt_count = Arc::new(AtomicU32::new(0));
149        let attempt_count_clone = attempt_count.clone();
150
151        let result: Result<&str, &str> = with_retry(&config, || {
152            let count = attempt_count_clone.clone();
153            async move {
154                let current = count.fetch_add(1, Ordering::SeqCst);
155                if current < 2 {
156                    Err("transient error")
157                } else {
158                    Ok("success")
159                }
160            }
161        })
162        .await;
163
164        assert_eq!(result, Ok("success"));
165        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
166    }
167
168    #[tokio::test]
169    async fn test_retry_all_failures() {
170        let config = RetryConfig::new(3, 10, 100); // Short delays for testing
171        let attempt_count = Arc::new(AtomicU32::new(0));
172        let attempt_count_clone = attempt_count.clone();
173
174        let result: Result<&str, &str> = with_retry(&config, || {
175            let count = attempt_count_clone.clone();
176            async move {
177                count.fetch_add(1, Ordering::SeqCst);
178                Err("persistent error")
179            }
180        })
181        .await;
182
183        assert_eq!(result, Err("persistent error"));
184        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
185    }
186
187    #[test]
188    fn test_delay_for_attempt_no_overflow() {
189        // Bug #10: large attempt numbers should not panic due to overflow
190        let config = RetryConfig::new(100, 100, 2000);
191        // These should not panic
192        let d64 = config.delay_for_attempt(64);
193        let d100 = config.delay_for_attempt(99);
194        // Should be capped at max_delay_ms
195        assert_eq!(d64, Duration::from_millis(2000));
196        assert_eq!(d100, Duration::from_millis(2000));
197    }
198
199    #[test]
200    fn test_delay_calculation() {
201        let config = RetryConfig::new(5, 100, 1000);
202
203        assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
204        assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
205        assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
206        assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
207        assert_eq!(config.delay_for_attempt(4), Duration::from_millis(1000)); // Capped
208    }
209}