1use std::future::Future;
10use std::time::Duration;
11
12#[derive(Clone, Debug)]
14pub struct RetryConfig {
15 pub max_retries: u32,
17 pub initial_delay: Duration,
19 pub max_delay: Duration,
21 pub backoff_multiplier: f64,
23 pub jitter: bool,
25}
26
27impl Default for RetryConfig {
28 fn default() -> Self {
29 Self {
30 max_retries: 3,
31 initial_delay: Duration::from_millis(100),
32 max_delay: Duration::from_secs(10),
33 backoff_multiplier: 2.0,
34 jitter: true,
35 }
36 }
37}
38
39impl RetryConfig {
40 #[must_use]
42 pub fn new(max_retries: u32) -> Self {
43 Self {
44 max_retries,
45 ..Default::default()
46 }
47 }
48
49 #[must_use]
51 pub fn none() -> Self {
52 Self {
53 max_retries: 0,
54 ..Default::default()
55 }
56 }
57
58 #[must_use]
60 pub fn quick() -> Self {
61 Self {
62 max_retries: 2,
63 initial_delay: Duration::from_millis(50),
64 max_delay: Duration::from_secs(1),
65 backoff_multiplier: 2.0,
66 jitter: true,
67 }
68 }
69
70 #[must_use]
72 pub fn batch() -> Self {
73 Self {
74 max_retries: 5,
75 initial_delay: Duration::from_millis(200),
76 max_delay: Duration::from_secs(30),
77 backoff_multiplier: 2.0,
78 jitter: true,
79 }
80 }
81
82 #[must_use]
84 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
85 self.max_retries = max_retries;
86 self
87 }
88
89 #[must_use]
91 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
92 self.initial_delay = delay;
93 self
94 }
95
96 #[must_use]
98 pub fn with_max_delay(mut self, delay: Duration) -> Self {
99 self.max_delay = delay;
100 self
101 }
102
103 #[must_use]
105 pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
106 self.backoff_multiplier = multiplier;
107 self
108 }
109
110 #[must_use]
112 pub fn with_jitter(mut self, jitter: bool) -> Self {
113 self.jitter = jitter;
114 self
115 }
116
117 fn delay_for_attempt(&self, attempt: u32) -> Duration {
119 let base_delay =
120 self.initial_delay.as_millis() as f64 * self.backoff_multiplier.powi(attempt as i32);
121 let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
122
123 let final_delay = if self.jitter {
124 let jitter_factor = 1.0 + (random_f64() * 0.25);
126 capped_delay * jitter_factor
127 } else {
128 capped_delay
129 };
130
131 Duration::from_millis(final_delay as u64)
132 }
133}
134
135fn random_f64() -> f64 {
137 fastrand::f64()
138}
139
140pub trait RetryableError {
142 fn is_retryable(&self) -> bool;
144
145 fn retry_after(&self) -> Option<Duration> {
147 None
148 }
149}
150
151#[derive(Debug)]
153pub struct RetryError<E> {
154 pub error: E,
156 pub attempts: u32,
158}
159
160impl<E: std::fmt::Display> std::fmt::Display for RetryError<E> {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 write!(
163 f,
164 "{} (after {} attempt{})",
165 self.error,
166 self.attempts,
167 if self.attempts == 1 { "" } else { "s" }
168 )
169 }
170}
171
172impl<E: std::error::Error + 'static> std::error::Error for RetryError<E> {
173 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
174 Some(&self.error)
175 }
176}
177
178impl<E> RetryError<E> {
179 pub fn into_inner(self) -> E {
181 self.error
182 }
183}
184
185pub async fn with_retry<T, E, F, Fut>(
220 config: &RetryConfig,
221 mut operation: F,
222) -> Result<T, RetryError<E>>
223where
224 F: FnMut() -> Fut,
225 Fut: Future<Output = Result<T, E>>,
226 E: RetryableError,
227{
228 let mut attempts = 0;
229 let max_attempts = config.max_retries + 1;
230
231 loop {
232 attempts += 1;
233
234 match operation().await {
235 Ok(result) => return Ok(result),
236 Err(e) => {
237 if attempts >= max_attempts || !e.is_retryable() {
238 return Err(RetryError { error: e, attempts });
239 }
240
241 let delay = e
243 .retry_after()
244 .unwrap_or_else(|| config.delay_for_attempt(attempts - 1));
245 tokio::time::sleep(delay).await;
246 }
247 }
248 }
249}
250
251pub async fn with_simple_retry<T, E, F, Fut>(max_retries: u32, mut operation: F) -> Result<T, E>
270where
271 F: FnMut() -> Fut,
272 Fut: Future<Output = Result<T, E>>,
273{
274 let config = RetryConfig::new(max_retries);
275 let mut attempts = 0;
276 let max_attempts = config.max_retries + 1;
277
278 loop {
279 attempts += 1;
280
281 match operation().await {
282 Ok(result) => return Ok(result),
283 Err(e) => {
284 if attempts >= max_attempts {
285 return Err(e);
286 }
287
288 let delay = config.delay_for_attempt(attempts - 1);
289 tokio::time::sleep(delay).await;
290 }
291 }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_delay_calculation() {
301 let config = RetryConfig {
302 max_retries: 5,
303 initial_delay: Duration::from_millis(100),
304 max_delay: Duration::from_secs(10),
305 backoff_multiplier: 2.0,
306 jitter: false,
307 };
308
309 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
310 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
311 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
312 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
313 }
314
315 #[test]
316 fn test_delay_cap() {
317 let config = RetryConfig {
318 max_retries: 10,
319 initial_delay: Duration::from_secs(1),
320 max_delay: Duration::from_secs(5),
321 backoff_multiplier: 2.0,
322 jitter: false,
323 };
324
325 assert_eq!(config.delay_for_attempt(5), Duration::from_secs(5));
327 assert_eq!(config.delay_for_attempt(10), Duration::from_secs(5));
328 }
329
330 #[test]
331 fn test_presets() {
332 let quick = RetryConfig::quick();
333 assert_eq!(quick.max_retries, 2);
334 assert_eq!(quick.initial_delay, Duration::from_millis(50));
335
336 let batch = RetryConfig::batch();
337 assert_eq!(batch.max_retries, 5);
338 assert_eq!(batch.initial_delay, Duration::from_millis(200));
339
340 let none = RetryConfig::none();
341 assert_eq!(none.max_retries, 0);
342 }
343}