1use crate::config::RetryConfig;
4use crate::error::{RetryResult, RetryableError};
5use std::future::Future;
6use std::time::Duration;
7use tokio::time::sleep;
8use tracing::{debug, warn};
9
10#[derive(Debug, Clone)]
12pub struct RetryState {
13 pub attempt: u32,
15 pub last_error: Option<String>,
17 pub total_wait_time: Duration,
19 pub history: Vec<AttemptInfo>,
21}
22
23impl Default for RetryState {
24 fn default() -> Self {
25 Self {
26 attempt: 0,
27 last_error: None,
28 total_wait_time: Duration::ZERO,
29 history: Vec::new(),
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct AttemptInfo {
37 pub attempt: u32,
39 pub success: bool,
41 pub error: Option<String>,
43 pub wait_time: Duration,
45}
46
47pub async fn with_retry<F, Fut, T>(config: &RetryConfig, operation: F) -> RetryResult<T>
61where
62 F: Fn() -> Fut,
63 Fut: Future<Output = RetryResult<T>>,
64{
65 let mut state = RetryState::default();
66 let max_attempts = config.max_retries.saturating_add(1);
67
68 loop {
69 state.attempt += 1;
70
71 debug!(
72 attempt = state.attempt,
73 max_attempts,
74 max_retries = config.max_retries,
75 "Executing retry attempt"
76 );
77
78 match operation().await {
79 Ok(result) => {
80 state.history.push(AttemptInfo {
81 attempt: state.attempt,
82 success: true,
83 error: None,
84 wait_time: Duration::ZERO,
85 });
86 return Ok(result);
87 }
88 Err(error) => {
89 let should_retry =
90 state.attempt < max_attempts && config.retry_on.should_retry(&error);
91
92 if !should_retry {
93 warn!(
94 attempt = state.attempt,
95 error = %error,
96 "Retry exhausted or error not retryable"
97 );
98 return Err(error);
99 }
100
101 let wait = config.wait.calculate(state.attempt, error.retry_after());
102 state.total_wait_time += wait;
103 state.last_error = Some(format!("{}", error));
104
105 state.history.push(AttemptInfo {
106 attempt: state.attempt,
107 success: false,
108 error: Some(format!("{}", error)),
109 wait_time: wait,
110 });
111
112 debug!(
113 attempt = state.attempt,
114 wait_ms = wait.as_millis(),
115 error = %error,
116 "Waiting before retry"
117 );
118
119 sleep(wait).await;
120 }
121 }
122 }
123}
124
125pub async fn with_retry_state<F, Fut, T>(
127 config: &RetryConfig,
128 operation: F,
129) -> (RetryResult<T>, RetryState)
130where
131 F: Fn() -> Fut,
132 Fut: Future<Output = RetryResult<T>>,
133{
134 let mut state = RetryState::default();
135 let max_attempts = config.max_retries.saturating_add(1);
136
137 loop {
138 state.attempt += 1;
139
140 match operation().await {
141 Ok(result) => {
142 state.history.push(AttemptInfo {
143 attempt: state.attempt,
144 success: true,
145 error: None,
146 wait_time: Duration::ZERO,
147 });
148 return (Ok(result), state);
149 }
150 Err(error) => {
151 let should_retry =
152 state.attempt < max_attempts && config.retry_on.should_retry(&error);
153
154 if !should_retry {
155 return (Err(error), state);
156 }
157
158 let wait = config.wait.calculate(state.attempt, error.retry_after());
159 state.total_wait_time += wait;
160 state.last_error = Some(format!("{}", error));
161
162 state.history.push(AttemptInfo {
163 attempt: state.attempt,
164 success: false,
165 error: Some(format!("{}", error)),
166 wait_time: wait,
167 });
168
169 sleep(wait).await;
170 }
171 }
172 }
173}
174
175pub struct Retry<'a> {
177 config: &'a RetryConfig,
178}
179
180impl<'a> Retry<'a> {
181 pub fn new(config: &'a RetryConfig) -> Self {
183 Self { config }
184 }
185
186 pub async fn run<F, Fut, T>(self, operation: F) -> RetryResult<T>
188 where
189 F: Fn() -> Fut,
190 Fut: Future<Output = RetryResult<T>>,
191 {
192 with_retry(self.config, operation).await
193 }
194
195 pub async fn run_with_state<F, Fut, T>(self, operation: F) -> (RetryResult<T>, RetryState)
197 where
198 F: Fn() -> Fut,
199 Fut: Future<Output = RetryResult<T>>,
200 {
201 with_retry_state(self.config, operation).await
202 }
203}
204
205pub trait IntoRetryable<T> {
207 fn into_retryable(self) -> RetryResult<T>;
209}
210
211impl<T, E: Into<RetryableError>> IntoRetryable<T> for Result<T, E> {
212 fn into_retryable(self) -> RetryResult<T> {
213 self.map_err(Into::into)
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220 use std::sync::atomic::{AtomicU32, Ordering};
221 use std::sync::Arc;
222
223 #[tokio::test]
224 async fn test_with_retry_immediate_success() {
225 let config = RetryConfig::new().max_retries(3);
226 let result = with_retry(&config, || async { Ok::<_, RetryableError>(42) }).await;
227 assert_eq!(result.unwrap(), 42);
228 }
229
230 #[tokio::test]
231 async fn test_with_retry_eventual_success() {
232 let config = RetryConfig::new()
233 .max_retries(3)
234 .fixed(Duration::from_millis(1));
235
236 let attempts = Arc::new(AtomicU32::new(0));
237 let attempts_clone = attempts.clone();
238
239 let result = with_retry(&config, || {
240 let attempts = attempts_clone.clone();
241 async move {
242 let n = attempts.fetch_add(1, Ordering::SeqCst);
243 if n < 2 {
244 Err(RetryableError::http(500, "server error"))
245 } else {
246 Ok(42)
247 }
248 }
249 })
250 .await;
251
252 assert_eq!(result.unwrap(), 42);
253 assert_eq!(attempts.load(Ordering::SeqCst), 3);
254 }
255
256 #[tokio::test]
257 async fn test_with_retry_exhausted() {
258 let config = RetryConfig::new()
259 .max_retries(2)
260 .fixed(Duration::from_millis(1));
261
262 let result = with_retry(&config, || async {
263 Err::<i32, _>(RetryableError::http(500, "always fails"))
264 })
265 .await;
266
267 assert!(result.is_err());
268 }
269
270 #[tokio::test]
271 async fn test_with_retry_non_retryable() {
272 let config = RetryConfig::new().max_retries(3);
273
274 let attempts = Arc::new(AtomicU32::new(0));
275 let attempts_clone = attempts.clone();
276
277 let result = with_retry(&config, || {
278 let attempts = attempts_clone.clone();
279 async move {
280 attempts.fetch_add(1, Ordering::SeqCst);
281 Err::<i32, _>(RetryableError::http(400, "bad request"))
282 }
283 })
284 .await;
285
286 assert!(result.is_err());
287 assert_eq!(attempts.load(Ordering::SeqCst), 1);
289 }
290
291 #[tokio::test]
292 async fn test_retry_state() {
293 let config = RetryConfig::new()
294 .max_retries(3)
295 .fixed(Duration::from_millis(1));
296
297 let attempts = Arc::new(AtomicU32::new(0));
298 let attempts_clone = attempts.clone();
299
300 let (result, state) = with_retry_state(&config, || {
301 let attempts = attempts_clone.clone();
302 async move {
303 let n = attempts.fetch_add(1, Ordering::SeqCst);
304 if n < 1 {
305 Err(RetryableError::http(500, "error"))
306 } else {
307 Ok(42)
308 }
309 }
310 })
311 .await;
312
313 assert!(result.is_ok());
314 assert_eq!(state.attempt, 2);
315 assert_eq!(state.history.len(), 2);
316 assert!(!state.history[0].success);
317 assert!(state.history[1].success);
318 }
319
320 #[tokio::test]
321 async fn test_retry_builder() {
322 let config = RetryConfig::new();
323 let result = Retry::new(&config)
324 .run(|| async { Ok::<_, RetryableError>("success") })
325 .await;
326
327 assert_eq!(result.unwrap(), "success");
328 }
329}