Skip to main content

serdes_ai_retries/
backoff.rs

1//! Backoff strategies.
2
3use crate::error::RetryableError;
4use crate::strategy::RetryStrategy;
5use async_trait::async_trait;
6use std::time::Duration;
7
8/// Exponential backoff with optional jitter.
9#[derive(Debug, Clone)]
10pub struct ExponentialBackoff {
11    /// Maximum number of retries.
12    pub max_retries: u32,
13    /// Initial delay.
14    pub initial_delay: Duration,
15    /// Maximum delay.
16    pub max_delay: Duration,
17    /// Jitter factor (0.0 to 1.0).
18    pub jitter: f64,
19    /// Multiplier for each retry.
20    pub multiplier: f64,
21}
22
23impl Default for ExponentialBackoff {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            initial_delay: Duration::from_millis(100),
28            max_delay: Duration::from_secs(30),
29            jitter: 0.1,
30            multiplier: 2.0,
31        }
32    }
33}
34
35impl ExponentialBackoff {
36    /// Create a new exponential backoff.
37    #[must_use]
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Create a builder.
43    #[must_use]
44    pub fn builder() -> ExponentialBackoffBuilder {
45        ExponentialBackoffBuilder::default()
46    }
47
48    /// Calculate delay for an attempt.
49    pub fn calculate_delay(&self, attempt: u32) -> Duration {
50        let base_delay = self.initial_delay.as_secs_f64() * self.multiplier.powi(attempt as i32);
51        let jitter = base_delay * self.jitter * rand_jitter();
52        let delay = (base_delay + jitter).min(self.max_delay.as_secs_f64());
53        Duration::from_secs_f64(delay.max(0.0))
54    }
55}
56
57#[async_trait]
58impl RetryStrategy for ExponentialBackoff {
59    fn should_retry(&self, error: &RetryableError, attempt: u32) -> Option<Duration> {
60        if attempt > self.max_retries || !error.is_retryable() {
61            return None;
62        }
63        Some(self.calculate_delay(attempt))
64    }
65
66    fn max_retries(&self) -> u32 {
67        self.max_retries
68    }
69}
70
71/// Builder for ExponentialBackoff.
72#[derive(Debug, Default)]
73pub struct ExponentialBackoffBuilder {
74    max_retries: Option<u32>,
75    initial_delay: Option<Duration>,
76    max_delay: Option<Duration>,
77    jitter: Option<f64>,
78    multiplier: Option<f64>,
79}
80
81impl ExponentialBackoffBuilder {
82    /// Set max retries.
83    #[must_use]
84    pub fn max_retries(mut self, n: u32) -> Self {
85        self.max_retries = Some(n);
86        self
87    }
88
89    /// Set initial delay.
90    #[must_use]
91    pub fn initial_delay(mut self, d: Duration) -> Self {
92        self.initial_delay = Some(d);
93        self
94    }
95
96    /// Set max delay.
97    #[must_use]
98    pub fn max_delay(mut self, d: Duration) -> Self {
99        self.max_delay = Some(d);
100        self
101    }
102
103    /// Set jitter factor.
104    #[must_use]
105    pub fn jitter(mut self, j: f64) -> Self {
106        self.jitter = Some(j);
107        self
108    }
109
110    /// Set multiplier.
111    #[must_use]
112    pub fn multiplier(mut self, m: f64) -> Self {
113        self.multiplier = Some(m);
114        self
115    }
116
117    /// Build the backoff strategy.
118    #[must_use]
119    pub fn build(self) -> ExponentialBackoff {
120        let mut backoff = ExponentialBackoff::default();
121        if let Some(v) = self.max_retries {
122            backoff.max_retries = v;
123        }
124        if let Some(v) = self.initial_delay {
125            backoff.initial_delay = v;
126        }
127        if let Some(v) = self.max_delay {
128            backoff.max_delay = v;
129        }
130        if let Some(v) = self.jitter {
131            backoff.jitter = v;
132        }
133        if let Some(v) = self.multiplier {
134            backoff.multiplier = v;
135        }
136        backoff
137    }
138}
139
140/// Fixed delay between retries.
141#[derive(Debug, Clone)]
142pub struct FixedDelay {
143    /// Delay between retries.
144    pub delay: Duration,
145    /// Maximum retries.
146    pub max_retries: u32,
147}
148
149impl FixedDelay {
150    /// Create a new fixed delay strategy.
151    #[must_use]
152    pub fn new(delay: Duration, max_retries: u32) -> Self {
153        Self { delay, max_retries }
154    }
155}
156
157#[async_trait]
158impl RetryStrategy for FixedDelay {
159    fn should_retry(&self, error: &RetryableError, attempt: u32) -> Option<Duration> {
160        if attempt > self.max_retries || !error.is_retryable() {
161            None
162        } else {
163            Some(self.delay)
164        }
165    }
166
167    fn max_retries(&self) -> u32 {
168        self.max_retries
169    }
170}
171
172/// Linear backoff.
173#[derive(Debug, Clone)]
174pub struct LinearBackoff {
175    /// Initial delay.
176    pub initial_delay: Duration,
177    /// Increment per retry.
178    pub increment: Duration,
179    /// Maximum delay.
180    pub max_delay: Duration,
181    /// Maximum retries.
182    pub max_retries: u32,
183}
184
185impl LinearBackoff {
186    /// Create a new linear backoff.
187    #[must_use]
188    pub fn new(
189        initial_delay: Duration,
190        increment: Duration,
191        max_delay: Duration,
192        max_retries: u32,
193    ) -> Self {
194        Self {
195            initial_delay,
196            increment,
197            max_delay,
198            max_retries,
199        }
200    }
201
202    /// Calculate delay for an attempt.
203    pub fn calculate_delay(&self, attempt: u32) -> Duration {
204        let delay = self.initial_delay + self.increment * attempt.saturating_sub(1);
205        delay.min(self.max_delay)
206    }
207}
208
209#[async_trait]
210impl RetryStrategy for LinearBackoff {
211    fn should_retry(&self, error: &RetryableError, attempt: u32) -> Option<Duration> {
212        if attempt > self.max_retries || !error.is_retryable() {
213            None
214        } else {
215            Some(self.calculate_delay(attempt))
216        }
217    }
218
219    fn max_retries(&self) -> u32 {
220        self.max_retries
221    }
222}
223
224/// Generate a random jitter factor between -1.0 and 1.0.
225fn rand_jitter() -> f64 {
226    use rand::Rng;
227    let mut rng = rand::thread_rng();
228    rng.gen_range(-1.0..1.0)
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[test]
236    fn test_exponential_backoff_default() {
237        let backoff = ExponentialBackoff::new();
238        assert_eq!(backoff.max_retries, 3);
239        assert_eq!(backoff.initial_delay, Duration::from_millis(100));
240        assert_eq!(backoff.multiplier, 2.0);
241    }
242
243    #[test]
244    fn test_exponential_backoff_builder() {
245        let backoff = ExponentialBackoff::builder()
246            .max_retries(5)
247            .initial_delay(Duration::from_millis(50))
248            .max_delay(Duration::from_secs(10))
249            .jitter(0.2)
250            .build();
251
252        assert_eq!(backoff.max_retries, 5);
253        assert_eq!(backoff.initial_delay, Duration::from_millis(50));
254        assert_eq!(backoff.max_delay, Duration::from_secs(10));
255        assert_eq!(backoff.jitter, 0.2);
256    }
257
258    #[test]
259    fn test_exponential_backoff_delay() {
260        let backoff = ExponentialBackoff::builder()
261            .initial_delay(Duration::from_millis(100))
262            .multiplier(2.0)
263            .jitter(0.0)
264            .build();
265
266        // Without jitter, delays should be predictable
267        let delay1 = backoff.calculate_delay(1);
268        let delay2 = backoff.calculate_delay(2);
269        let delay3 = backoff.calculate_delay(3);
270
271        assert_eq!(delay1, Duration::from_millis(200));
272        assert_eq!(delay2, Duration::from_millis(400));
273        assert_eq!(delay3, Duration::from_millis(800));
274    }
275
276    #[test]
277    fn test_exponential_backoff_max_delay() {
278        let backoff = ExponentialBackoff::builder()
279            .initial_delay(Duration::from_secs(1))
280            .max_delay(Duration::from_secs(5))
281            .multiplier(10.0)
282            .jitter(0.0)
283            .build();
284
285        // Even with large multiplier, should cap at max
286        let delay = backoff.calculate_delay(5);
287        assert!(delay <= Duration::from_secs(5));
288    }
289
290    #[test]
291    fn test_exponential_backoff_should_retry() {
292        let backoff = ExponentialBackoff::builder().max_retries(3).build();
293
294        let error = RetryableError::http(500, "error");
295        assert!(backoff.should_retry(&error, 1).is_some());
296        assert!(backoff.should_retry(&error, 3).is_some());
297        assert!(backoff.should_retry(&error, 4).is_none());
298
299        // Non-retryable error
300        let error = RetryableError::http(400, "bad request");
301        assert!(backoff.should_retry(&error, 1).is_none());
302    }
303
304    #[test]
305    fn test_fixed_delay() {
306        let delay = FixedDelay::new(Duration::from_secs(1), 3);
307
308        let error = RetryableError::http(500, "error");
309        assert_eq!(delay.should_retry(&error, 1), Some(Duration::from_secs(1)));
310        assert_eq!(delay.should_retry(&error, 3), Some(Duration::from_secs(1)));
311        assert_eq!(delay.should_retry(&error, 4), None);
312    }
313
314    #[test]
315    fn test_linear_backoff() {
316        let backoff = LinearBackoff::new(
317            Duration::from_millis(100),
318            Duration::from_millis(100),
319            Duration::from_secs(1),
320            5,
321        );
322
323        assert_eq!(backoff.calculate_delay(1), Duration::from_millis(100));
324        assert_eq!(backoff.calculate_delay(2), Duration::from_millis(200));
325        assert_eq!(backoff.calculate_delay(3), Duration::from_millis(300));
326
327        // Check max delay cap
328        assert_eq!(backoff.calculate_delay(20), Duration::from_secs(1));
329    }
330}