Skip to main content

serdes_ai_retries/
config.rs

1//! Retry configuration.
2
3use crate::error::RetryableError;
4use std::time::Duration;
5
6/// Configuration for retry behavior.
7#[derive(Debug, Clone)]
8pub struct RetryConfig {
9    /// Maximum number of retries.
10    pub max_retries: u32,
11    /// Wait strategy.
12    pub wait: WaitStrategy,
13    /// Retry condition.
14    pub retry_on: RetryCondition,
15    /// Whether to reraise the last error if all retries fail.
16    pub reraise: bool,
17}
18
19impl Default for RetryConfig {
20    fn default() -> Self {
21        Self {
22            max_retries: 3,
23            wait: WaitStrategy::ExponentialBackoff {
24                initial: Duration::from_millis(500),
25                max: Duration::from_secs(60),
26                multiplier: 2.0,
27            },
28            retry_on: RetryCondition::default(),
29            reraise: true,
30        }
31    }
32}
33
34impl RetryConfig {
35    /// Create a new default config.
36    pub fn new() -> Self {
37        Self::default()
38    }
39
40    /// Set max retries.
41    pub fn max_retries(mut self, n: u32) -> Self {
42        self.max_retries = n;
43        self
44    }
45
46    /// Set the wait strategy.
47    pub fn wait(mut self, strategy: WaitStrategy) -> Self {
48        self.wait = strategy;
49        self
50    }
51
52    /// Use exponential backoff.
53    pub fn exponential(mut self, initial: Duration, max: Duration) -> Self {
54        self.wait = WaitStrategy::ExponentialBackoff {
55            initial,
56            max,
57            multiplier: 2.0,
58        };
59        self
60    }
61
62    /// Use exponential backoff with jitter.
63    pub fn exponential_jitter(mut self, initial: Duration, max: Duration, jitter: f64) -> Self {
64        self.wait = WaitStrategy::ExponentialJitter {
65            initial,
66            max,
67            multiplier: 2.0,
68            jitter,
69        };
70        self
71    }
72
73    /// Use fixed delay.
74    pub fn fixed(mut self, delay: Duration) -> Self {
75        self.wait = WaitStrategy::Fixed(delay);
76        self
77    }
78
79    /// Use linear backoff.
80    pub fn linear(mut self, initial: Duration, increment: Duration, max: Duration) -> Self {
81        self.wait = WaitStrategy::Linear {
82            initial,
83            increment,
84            max,
85        };
86        self
87    }
88
89    /// Set retry condition.
90    pub fn retry_on(mut self, condition: RetryCondition) -> Self {
91        self.retry_on = condition;
92        self
93    }
94
95    /// Set whether to reraise the last error.
96    pub fn reraise(mut self, reraise: bool) -> Self {
97        self.reraise = reraise;
98        self
99    }
100
101    /// Create config for API calls with sensible defaults.
102    pub fn for_api() -> Self {
103        Self::new()
104            .max_retries(3)
105            .exponential_jitter(Duration::from_millis(500), Duration::from_secs(60), 0.1)
106            .retry_on(RetryCondition::new().on_rate_limit().on_server_errors())
107    }
108
109    /// Create config that never retries.
110    pub fn no_retry() -> Self {
111        Self::new().max_retries(0)
112    }
113}
114
115/// Strategy for waiting between retries.
116#[derive(Debug, Clone)]
117pub enum WaitStrategy {
118    /// No waiting.
119    None,
120    /// Fixed delay.
121    Fixed(Duration),
122    /// Exponential backoff.
123    ExponentialBackoff {
124        /// Initial delay.
125        initial: Duration,
126        /// Maximum delay.
127        max: Duration,
128        /// Multiplier for each attempt.
129        multiplier: f64,
130    },
131    /// Exponential backoff with jitter.
132    ExponentialJitter {
133        /// Initial delay.
134        initial: Duration,
135        /// Maximum delay.
136        max: Duration,
137        /// Multiplier for each attempt.
138        multiplier: f64,
139        /// Jitter factor (0.0 to 1.0).
140        jitter: f64,
141    },
142    /// Linear backoff.
143    Linear {
144        /// Initial delay.
145        initial: Duration,
146        /// Increment per attempt.
147        increment: Duration,
148        /// Maximum delay.
149        max: Duration,
150    },
151    /// Respect Retry-After header.
152    RetryAfter {
153        /// Fallback if no header.
154        fallback: Box<WaitStrategy>,
155        /// Maximum wait time.
156        max_wait: Duration,
157    },
158}
159
160impl WaitStrategy {
161    /// Calculate the wait duration for a given attempt.
162    pub fn calculate(&self, attempt: u32, retry_after: Option<Duration>) -> Duration {
163        match self {
164            WaitStrategy::None => Duration::ZERO,
165            WaitStrategy::Fixed(d) => *d,
166            WaitStrategy::ExponentialBackoff {
167                initial,
168                max,
169                multiplier,
170            } => {
171                let delay = initial.as_secs_f64() * multiplier.powi(attempt as i32 - 1);
172                Duration::from_secs_f64(delay.min(max.as_secs_f64()))
173            }
174            WaitStrategy::ExponentialJitter {
175                initial,
176                max,
177                multiplier,
178                jitter,
179            } => {
180                let base = initial.as_secs_f64() * multiplier.powi(attempt as i32 - 1);
181                let jitter_amount = base * jitter * random_jitter();
182                let delay = (base + jitter_amount).min(max.as_secs_f64());
183                Duration::from_secs_f64(delay)
184            }
185            WaitStrategy::Linear {
186                initial,
187                increment,
188                max,
189            } => {
190                let delay = *initial + *increment * (attempt - 1);
191                delay.min(*max)
192            }
193            WaitStrategy::RetryAfter { fallback, max_wait } => retry_after
194                .map(|d| d.min(*max_wait))
195                .unwrap_or_else(|| fallback.calculate(attempt, None)),
196        }
197    }
198}
199
200/// Condition for retrying.
201#[derive(Debug, Clone, Default)]
202pub struct RetryCondition {
203    /// HTTP status codes to retry on.
204    pub on_status_codes: Vec<u16>,
205    /// Custom predicate function.
206    pub custom: Option<fn(&RetryableError) -> bool>,
207}
208
209impl RetryCondition {
210    /// Create a new empty condition.
211    pub fn new() -> Self {
212        Self::default()
213    }
214
215    /// Add status codes to retry on.
216    pub fn on_status(mut self, codes: impl IntoIterator<Item = u16>) -> Self {
217        self.on_status_codes.extend(codes);
218        self
219    }
220
221    /// Retry on server errors (5xx).
222    pub fn on_server_errors(mut self) -> Self {
223        self.on_status_codes.extend(500..=599);
224        self
225    }
226
227    /// Retry on rate limit (429).
228    pub fn on_rate_limit(mut self) -> Self {
229        self.on_status_codes.push(429);
230        self
231    }
232
233    /// Retry on timeout errors.
234    pub fn on_timeout(self) -> Self {
235        // Timeout is always retryable by default
236        self
237    }
238
239    /// Set a custom predicate.
240    pub fn with_custom(mut self, predicate: fn(&RetryableError) -> bool) -> Self {
241        self.custom = Some(predicate);
242        self
243    }
244
245    /// Check if an error should be retried.
246    pub fn should_retry(&self, error: &RetryableError) -> bool {
247        // Check custom predicate first
248        if let Some(predicate) = self.custom {
249            return predicate(error);
250        }
251
252        // Check status codes
253        if let Some(status) = error.status() {
254            if self.on_status_codes.contains(&status) {
255                return true;
256            }
257        }
258
259        // Default to error's own retryable check
260        error.is_retryable()
261    }
262}
263
264/// Generate a random jitter factor between -1.0 and 1.0.
265fn random_jitter() -> f64 {
266    use rand::Rng;
267    let mut rng = rand::thread_rng();
268    rng.gen_range(-1.0..1.0)
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_default_config() {
277        let config = RetryConfig::default();
278        assert_eq!(config.max_retries, 3);
279        assert!(config.reraise);
280    }
281
282    #[test]
283    fn test_config_builder() {
284        let config = RetryConfig::new()
285            .max_retries(5)
286            .fixed(Duration::from_secs(1))
287            .reraise(false);
288
289        assert_eq!(config.max_retries, 5);
290        assert!(!config.reraise);
291    }
292
293    #[test]
294    fn test_wait_strategy_fixed() {
295        let strategy = WaitStrategy::Fixed(Duration::from_secs(1));
296        assert_eq!(strategy.calculate(1, None), Duration::from_secs(1));
297        assert_eq!(strategy.calculate(3, None), Duration::from_secs(1));
298    }
299
300    #[test]
301    fn test_wait_strategy_exponential() {
302        let strategy = WaitStrategy::ExponentialBackoff {
303            initial: Duration::from_millis(100),
304            max: Duration::from_secs(10),
305            multiplier: 2.0,
306        };
307
308        assert_eq!(strategy.calculate(1, None), Duration::from_millis(100));
309        assert_eq!(strategy.calculate(2, None), Duration::from_millis(200));
310        assert_eq!(strategy.calculate(3, None), Duration::from_millis(400));
311    }
312
313    #[test]
314    fn test_wait_strategy_linear() {
315        let strategy = WaitStrategy::Linear {
316            initial: Duration::from_millis(100),
317            increment: Duration::from_millis(100),
318            max: Duration::from_secs(10),
319        };
320
321        assert_eq!(strategy.calculate(1, None), Duration::from_millis(100));
322        assert_eq!(strategy.calculate(2, None), Duration::from_millis(200));
323        assert_eq!(strategy.calculate(3, None), Duration::from_millis(300));
324    }
325
326    #[test]
327    fn test_wait_strategy_retry_after() {
328        let strategy = WaitStrategy::RetryAfter {
329            fallback: Box::new(WaitStrategy::Fixed(Duration::from_secs(1))),
330            max_wait: Duration::from_secs(60),
331        };
332
333        // With retry-after header
334        assert_eq!(
335            strategy.calculate(1, Some(Duration::from_secs(5))),
336            Duration::from_secs(5)
337        );
338
339        // Without retry-after header (uses fallback)
340        assert_eq!(strategy.calculate(1, None), Duration::from_secs(1));
341    }
342
343    #[test]
344    fn test_retry_condition() {
345        let condition = RetryCondition::new().on_rate_limit().on_server_errors();
346
347        assert!(condition.should_retry(&RetryableError::http(429, "")));
348        assert!(condition.should_retry(&RetryableError::http(500, "")));
349        assert!(!condition.should_retry(&RetryableError::http(400, "")));
350    }
351
352    #[test]
353    fn test_api_config() {
354        let config = RetryConfig::for_api();
355        assert_eq!(config.max_retries, 3);
356        assert!(config.retry_on.on_status_codes.contains(&429));
357        assert!(config.retry_on.on_status_codes.contains(&500));
358    }
359}