Skip to main content

serdes_ai_retries/
executor.rs

1//! Retry executor for running operations with retries.
2
3use crate::config::RetryConfig;
4use crate::error::{RetryResult, RetryableError};
5use std::future::Future;
6use std::time::Duration;
7use tokio::time::sleep;
8use tracing::{debug, warn};
9
10/// State of a retry attempt.
11#[derive(Debug, Clone)]
12pub struct RetryState {
13    /// Current attempt number (1-indexed).
14    pub attempt: u32,
15    /// Last error message.
16    pub last_error: Option<String>,
17    /// Total time spent waiting.
18    pub total_wait_time: Duration,
19    /// History of attempts.
20    pub history: Vec<AttemptInfo>,
21}
22
23impl Default for RetryState {
24    fn default() -> Self {
25        Self {
26            attempt: 0,
27            last_error: None,
28            total_wait_time: Duration::ZERO,
29            history: Vec::new(),
30        }
31    }
32}
33
34/// Information about a single attempt.
35#[derive(Debug, Clone)]
36pub struct AttemptInfo {
37    /// Attempt number.
38    pub attempt: u32,
39    /// Whether it succeeded.
40    pub success: bool,
41    /// Error message if failed.
42    pub error: Option<String>,
43    /// Time waited before this attempt.
44    pub wait_time: Duration,
45}
46
47/// Execute an operation with retries.
48///
49/// # Example
50///
51/// ```ignore
52/// use serdes_ai_retries::{with_retry, RetryConfig};
53///
54/// let config = RetryConfig::for_api();
55/// let result = with_retry(&config, || async {
56///     // Your async operation here
57///     Ok("success")
58/// }).await?;
59/// ```
60pub async fn with_retry<F, Fut, T>(config: &RetryConfig, operation: F) -> RetryResult<T>
61where
62    F: Fn() -> Fut,
63    Fut: Future<Output = RetryResult<T>>,
64{
65    let mut state = RetryState::default();
66    let max_attempts = config.max_retries.saturating_add(1);
67
68    loop {
69        state.attempt += 1;
70
71        debug!(
72            attempt = state.attempt,
73            max_attempts,
74            max_retries = config.max_retries,
75            "Executing retry attempt"
76        );
77
78        match operation().await {
79            Ok(result) => {
80                state.history.push(AttemptInfo {
81                    attempt: state.attempt,
82                    success: true,
83                    error: None,
84                    wait_time: Duration::ZERO,
85                });
86                return Ok(result);
87            }
88            Err(error) => {
89                let should_retry =
90                    state.attempt < max_attempts && config.retry_on.should_retry(&error);
91
92                if !should_retry {
93                    warn!(
94                        attempt = state.attempt,
95                        error = %error,
96                        "Retry exhausted or error not retryable"
97                    );
98                    return Err(error);
99                }
100
101                let wait = config.wait.calculate(state.attempt, error.retry_after());
102                state.total_wait_time += wait;
103                state.last_error = Some(format!("{}", error));
104
105                state.history.push(AttemptInfo {
106                    attempt: state.attempt,
107                    success: false,
108                    error: Some(format!("{}", error)),
109                    wait_time: wait,
110                });
111
112                debug!(
113                    attempt = state.attempt,
114                    wait_ms = wait.as_millis(),
115                    error = %error,
116                    "Waiting before retry"
117                );
118
119                sleep(wait).await;
120            }
121        }
122    }
123}
124
125/// Execute with retries and get state information.
126pub async fn with_retry_state<F, Fut, T>(
127    config: &RetryConfig,
128    operation: F,
129) -> (RetryResult<T>, RetryState)
130where
131    F: Fn() -> Fut,
132    Fut: Future<Output = RetryResult<T>>,
133{
134    let mut state = RetryState::default();
135    let max_attempts = config.max_retries.saturating_add(1);
136
137    loop {
138        state.attempt += 1;
139
140        match operation().await {
141            Ok(result) => {
142                state.history.push(AttemptInfo {
143                    attempt: state.attempt,
144                    success: true,
145                    error: None,
146                    wait_time: Duration::ZERO,
147                });
148                return (Ok(result), state);
149            }
150            Err(error) => {
151                let should_retry =
152                    state.attempt < max_attempts && config.retry_on.should_retry(&error);
153
154                if !should_retry {
155                    return (Err(error), state);
156                }
157
158                let wait = config.wait.calculate(state.attempt, error.retry_after());
159                state.total_wait_time += wait;
160                state.last_error = Some(format!("{}", error));
161
162                state.history.push(AttemptInfo {
163                    attempt: state.attempt,
164                    success: false,
165                    error: Some(format!("{}", error)),
166                    wait_time: wait,
167                });
168
169                sleep(wait).await;
170            }
171        }
172    }
173}
174
175/// Builder for retry operations.
176pub struct Retry<'a> {
177    config: &'a RetryConfig,
178}
179
180impl<'a> Retry<'a> {
181    /// Create a new retry builder.
182    pub fn new(config: &'a RetryConfig) -> Self {
183        Self { config }
184    }
185
186    /// Run the operation with retries.
187    pub async fn run<F, Fut, T>(self, operation: F) -> RetryResult<T>
188    where
189        F: Fn() -> Fut,
190        Fut: Future<Output = RetryResult<T>>,
191    {
192        with_retry(self.config, operation).await
193    }
194
195    /// Run and get state.
196    pub async fn run_with_state<F, Fut, T>(self, operation: F) -> (RetryResult<T>, RetryState)
197    where
198        F: Fn() -> Fut,
199        Fut: Future<Output = RetryResult<T>>,
200    {
201        with_retry_state(self.config, operation).await
202    }
203}
204
205/// Wrap a result type for retry compatibility.
206pub trait IntoRetryable<T> {
207    /// Convert into a retryable result.
208    fn into_retryable(self) -> RetryResult<T>;
209}
210
211impl<T, E: Into<RetryableError>> IntoRetryable<T> for Result<T, E> {
212    fn into_retryable(self) -> RetryResult<T> {
213        self.map_err(Into::into)
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use std::sync::atomic::{AtomicU32, Ordering};
221    use std::sync::Arc;
222
223    #[tokio::test]
224    async fn test_with_retry_immediate_success() {
225        let config = RetryConfig::new().max_retries(3);
226        let result = with_retry(&config, || async { Ok::<_, RetryableError>(42) }).await;
227        assert_eq!(result.unwrap(), 42);
228    }
229
230    #[tokio::test]
231    async fn test_with_retry_eventual_success() {
232        let config = RetryConfig::new()
233            .max_retries(3)
234            .fixed(Duration::from_millis(1));
235
236        let attempts = Arc::new(AtomicU32::new(0));
237        let attempts_clone = attempts.clone();
238
239        let result = with_retry(&config, || {
240            let attempts = attempts_clone.clone();
241            async move {
242                let n = attempts.fetch_add(1, Ordering::SeqCst);
243                if n < 2 {
244                    Err(RetryableError::http(500, "server error"))
245                } else {
246                    Ok(42)
247                }
248            }
249        })
250        .await;
251
252        assert_eq!(result.unwrap(), 42);
253        assert_eq!(attempts.load(Ordering::SeqCst), 3);
254    }
255
256    #[tokio::test]
257    async fn test_with_retry_exhausted() {
258        let config = RetryConfig::new()
259            .max_retries(2)
260            .fixed(Duration::from_millis(1));
261
262        let result = with_retry(&config, || async {
263            Err::<i32, _>(RetryableError::http(500, "always fails"))
264        })
265        .await;
266
267        assert!(result.is_err());
268    }
269
270    #[tokio::test]
271    async fn test_with_retry_non_retryable() {
272        let config = RetryConfig::new().max_retries(3);
273
274        let attempts = Arc::new(AtomicU32::new(0));
275        let attempts_clone = attempts.clone();
276
277        let result = with_retry(&config, || {
278            let attempts = attempts_clone.clone();
279            async move {
280                attempts.fetch_add(1, Ordering::SeqCst);
281                Err::<i32, _>(RetryableError::http(400, "bad request"))
282            }
283        })
284        .await;
285
286        assert!(result.is_err());
287        // Should only try once since 400 is not retryable
288        assert_eq!(attempts.load(Ordering::SeqCst), 1);
289    }
290
291    #[tokio::test]
292    async fn test_retry_state() {
293        let config = RetryConfig::new()
294            .max_retries(3)
295            .fixed(Duration::from_millis(1));
296
297        let attempts = Arc::new(AtomicU32::new(0));
298        let attempts_clone = attempts.clone();
299
300        let (result, state) = with_retry_state(&config, || {
301            let attempts = attempts_clone.clone();
302            async move {
303                let n = attempts.fetch_add(1, Ordering::SeqCst);
304                if n < 1 {
305                    Err(RetryableError::http(500, "error"))
306                } else {
307                    Ok(42)
308                }
309            }
310        })
311        .await;
312
313        assert!(result.is_ok());
314        assert_eq!(state.attempt, 2);
315        assert_eq!(state.history.len(), 2);
316        assert!(!state.history[0].success);
317        assert!(state.history[1].success);
318    }
319
320    #[tokio::test]
321    async fn test_retry_builder() {
322        let config = RetryConfig::new();
323        let result = Retry::new(&config)
324            .run(|| async { Ok::<_, RetryableError>("success") })
325            .await;
326
327        assert_eq!(result.unwrap(), "success");
328    }
329}