1use crate::error::RetryableError;
4use std::time::Duration;
5
6#[derive(Debug, Clone)]
8pub struct RetryConfig {
9 pub max_retries: u32,
11 pub wait: WaitStrategy,
13 pub retry_on: RetryCondition,
15 pub reraise: bool,
17}
18
19impl Default for RetryConfig {
20 fn default() -> Self {
21 Self {
22 max_retries: 3,
23 wait: WaitStrategy::ExponentialBackoff {
24 initial: Duration::from_millis(500),
25 max: Duration::from_secs(60),
26 multiplier: 2.0,
27 },
28 retry_on: RetryCondition::default(),
29 reraise: true,
30 }
31 }
32}
33
34impl RetryConfig {
35 pub fn new() -> Self {
37 Self::default()
38 }
39
40 pub fn max_retries(mut self, n: u32) -> Self {
42 self.max_retries = n;
43 self
44 }
45
46 pub fn wait(mut self, strategy: WaitStrategy) -> Self {
48 self.wait = strategy;
49 self
50 }
51
52 pub fn exponential(mut self, initial: Duration, max: Duration) -> Self {
54 self.wait = WaitStrategy::ExponentialBackoff {
55 initial,
56 max,
57 multiplier: 2.0,
58 };
59 self
60 }
61
62 pub fn exponential_jitter(mut self, initial: Duration, max: Duration, jitter: f64) -> Self {
64 self.wait = WaitStrategy::ExponentialJitter {
65 initial,
66 max,
67 multiplier: 2.0,
68 jitter,
69 };
70 self
71 }
72
73 pub fn fixed(mut self, delay: Duration) -> Self {
75 self.wait = WaitStrategy::Fixed(delay);
76 self
77 }
78
79 pub fn linear(mut self, initial: Duration, increment: Duration, max: Duration) -> Self {
81 self.wait = WaitStrategy::Linear {
82 initial,
83 increment,
84 max,
85 };
86 self
87 }
88
89 pub fn retry_on(mut self, condition: RetryCondition) -> Self {
91 self.retry_on = condition;
92 self
93 }
94
95 pub fn reraise(mut self, reraise: bool) -> Self {
97 self.reraise = reraise;
98 self
99 }
100
101 pub fn for_api() -> Self {
103 Self::new()
104 .max_retries(3)
105 .exponential_jitter(Duration::from_millis(500), Duration::from_secs(60), 0.1)
106 .retry_on(RetryCondition::new().on_rate_limit().on_server_errors())
107 }
108
109 pub fn no_retry() -> Self {
111 Self::new().max_retries(0)
112 }
113}
114
115#[derive(Debug, Clone)]
117pub enum WaitStrategy {
118 None,
120 Fixed(Duration),
122 ExponentialBackoff {
124 initial: Duration,
126 max: Duration,
128 multiplier: f64,
130 },
131 ExponentialJitter {
133 initial: Duration,
135 max: Duration,
137 multiplier: f64,
139 jitter: f64,
141 },
142 Linear {
144 initial: Duration,
146 increment: Duration,
148 max: Duration,
150 },
151 RetryAfter {
153 fallback: Box<WaitStrategy>,
155 max_wait: Duration,
157 },
158}
159
160impl WaitStrategy {
161 pub fn calculate(&self, attempt: u32, retry_after: Option<Duration>) -> Duration {
163 match self {
164 WaitStrategy::None => Duration::ZERO,
165 WaitStrategy::Fixed(d) => *d,
166 WaitStrategy::ExponentialBackoff {
167 initial,
168 max,
169 multiplier,
170 } => {
171 let delay = initial.as_secs_f64() * multiplier.powi(attempt as i32 - 1);
172 Duration::from_secs_f64(delay.min(max.as_secs_f64()))
173 }
174 WaitStrategy::ExponentialJitter {
175 initial,
176 max,
177 multiplier,
178 jitter,
179 } => {
180 let base = initial.as_secs_f64() * multiplier.powi(attempt as i32 - 1);
181 let jitter_amount = base * jitter * random_jitter();
182 let delay = (base + jitter_amount).min(max.as_secs_f64());
183 Duration::from_secs_f64(delay)
184 }
185 WaitStrategy::Linear {
186 initial,
187 increment,
188 max,
189 } => {
190 let delay = *initial + *increment * (attempt - 1);
191 delay.min(*max)
192 }
193 WaitStrategy::RetryAfter { fallback, max_wait } => retry_after
194 .map(|d| d.min(*max_wait))
195 .unwrap_or_else(|| fallback.calculate(attempt, None)),
196 }
197 }
198}
199
200#[derive(Debug, Clone, Default)]
202pub struct RetryCondition {
203 pub on_status_codes: Vec<u16>,
205 pub custom: Option<fn(&RetryableError) -> bool>,
207}
208
209impl RetryCondition {
210 pub fn new() -> Self {
212 Self::default()
213 }
214
215 pub fn on_status(mut self, codes: impl IntoIterator<Item = u16>) -> Self {
217 self.on_status_codes.extend(codes);
218 self
219 }
220
221 pub fn on_server_errors(mut self) -> Self {
223 self.on_status_codes.extend(500..=599);
224 self
225 }
226
227 pub fn on_rate_limit(mut self) -> Self {
229 self.on_status_codes.push(429);
230 self
231 }
232
233 pub fn on_timeout(self) -> Self {
235 self
237 }
238
239 pub fn with_custom(mut self, predicate: fn(&RetryableError) -> bool) -> Self {
241 self.custom = Some(predicate);
242 self
243 }
244
245 pub fn should_retry(&self, error: &RetryableError) -> bool {
247 if let Some(predicate) = self.custom {
249 return predicate(error);
250 }
251
252 if let Some(status) = error.status() {
254 if self.on_status_codes.contains(&status) {
255 return true;
256 }
257 }
258
259 error.is_retryable()
261 }
262}
263
264fn random_jitter() -> f64 {
266 use rand::Rng;
267 let mut rng = rand::thread_rng();
268 rng.gen_range(-1.0..1.0)
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_default_config() {
277 let config = RetryConfig::default();
278 assert_eq!(config.max_retries, 3);
279 assert!(config.reraise);
280 }
281
282 #[test]
283 fn test_config_builder() {
284 let config = RetryConfig::new()
285 .max_retries(5)
286 .fixed(Duration::from_secs(1))
287 .reraise(false);
288
289 assert_eq!(config.max_retries, 5);
290 assert!(!config.reraise);
291 }
292
293 #[test]
294 fn test_wait_strategy_fixed() {
295 let strategy = WaitStrategy::Fixed(Duration::from_secs(1));
296 assert_eq!(strategy.calculate(1, None), Duration::from_secs(1));
297 assert_eq!(strategy.calculate(3, None), Duration::from_secs(1));
298 }
299
300 #[test]
301 fn test_wait_strategy_exponential() {
302 let strategy = WaitStrategy::ExponentialBackoff {
303 initial: Duration::from_millis(100),
304 max: Duration::from_secs(10),
305 multiplier: 2.0,
306 };
307
308 assert_eq!(strategy.calculate(1, None), Duration::from_millis(100));
309 assert_eq!(strategy.calculate(2, None), Duration::from_millis(200));
310 assert_eq!(strategy.calculate(3, None), Duration::from_millis(400));
311 }
312
313 #[test]
314 fn test_wait_strategy_linear() {
315 let strategy = WaitStrategy::Linear {
316 initial: Duration::from_millis(100),
317 increment: Duration::from_millis(100),
318 max: Duration::from_secs(10),
319 };
320
321 assert_eq!(strategy.calculate(1, None), Duration::from_millis(100));
322 assert_eq!(strategy.calculate(2, None), Duration::from_millis(200));
323 assert_eq!(strategy.calculate(3, None), Duration::from_millis(300));
324 }
325
326 #[test]
327 fn test_wait_strategy_retry_after() {
328 let strategy = WaitStrategy::RetryAfter {
329 fallback: Box::new(WaitStrategy::Fixed(Duration::from_secs(1))),
330 max_wait: Duration::from_secs(60),
331 };
332
333 assert_eq!(
335 strategy.calculate(1, Some(Duration::from_secs(5))),
336 Duration::from_secs(5)
337 );
338
339 assert_eq!(strategy.calculate(1, None), Duration::from_secs(1));
341 }
342
343 #[test]
344 fn test_retry_condition() {
345 let condition = RetryCondition::new().on_rate_limit().on_server_errors();
346
347 assert!(condition.should_retry(&RetryableError::http(429, "")));
348 assert!(condition.should_retry(&RetryableError::http(500, "")));
349 assert!(!condition.should_retry(&RetryableError::http(400, "")));
350 }
351
352 #[test]
353 fn test_api_config() {
354 let config = RetryConfig::for_api();
355 assert_eq!(config.max_retries, 3);
356 assert!(config.retry_on.on_status_codes.contains(&429));
357 assert!(config.retry_on.on_status_codes.contains(&500));
358 }
359}