Skip to main content

seer_core/
retry.rs

1//! Retry logic with exponential backoff for transient failures.
2//!
3//! This module provides configurable retry policies and executors for handling
4//! transient network failures in WHOIS and RDAP lookups.
5
6use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::debug;
11
12use crate::error::{Result, SeerError};
13
14/// Configuration for retry behavior with exponential backoff.
15#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17    /// Maximum number of attempts (including the initial attempt).
18    pub max_attempts: usize,
19    /// Initial delay before the first retry.
20    pub initial_delay: Duration,
21    /// Maximum delay between retries (caps exponential growth).
22    pub max_delay: Duration,
23    /// Multiplier for exponential backoff (delay *= multiplier after each retry).
24    pub multiplier: f64,
25    /// Whether to add random jitter to delays to avoid thundering herd.
26    pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30    fn default() -> Self {
31        Self {
32            max_attempts: 3,
33            initial_delay: Duration::from_millis(100),
34            max_delay: Duration::from_secs(5),
35            multiplier: 2.0,
36            jitter: true,
37        }
38    }
39}
40
41impl RetryPolicy {
42    /// Creates a new retry policy with default settings.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Sets the maximum number of attempts.
48    pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49        self.max_attempts = attempts.max(1);
50        self
51    }
52
53    /// Sets the initial delay before the first retry.
54    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55        self.initial_delay = delay;
56        self
57    }
58
59    /// Sets the maximum delay between retries.
60    pub fn with_max_delay(mut self, delay: Duration) -> Self {
61        self.max_delay = delay;
62        self
63    }
64
65    /// Sets the multiplier for exponential backoff.
66    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67        self.multiplier = multiplier.max(1.0);
68        self
69    }
70
71    /// Enables or disables jitter.
72    pub fn with_jitter(mut self, jitter: bool) -> Self {
73        self.jitter = jitter;
74        self
75    }
76
77    /// Creates a policy that disables retries (single attempt only).
78    pub fn no_retry() -> Self {
79        Self {
80            max_attempts: 1,
81            ..Self::default()
82        }
83    }
84
85    /// Calculates the delay for a given attempt number (0-indexed).
86    ///
87    /// The attempt number is internally capped to prevent integer overflow
88    /// in the exponential calculation.
89    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90        if attempt == 0 {
91            return self.initial_delay;
92        }
93
94        // Cap attempt to prevent overflow in powi() - 20 attempts with multiplier 2.0
95        // gives 2^20 = ~1 million, which is safe for f64 and reasonable for delays
96        let safe_attempt = attempt.min(20) as i32;
97
98        let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99        let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101        let final_delay = if self.jitter {
102            // Full jitter (AWS "Exponential Backoff and Jitter"):
103            // sleep ∈ [0, capped_delay]. Half-jitter ([0.5..1.0]) clusters
104            // retries in the upper half of the window, which leaves
105            // measurable thundering-herd behaviour after a brief outage.
106            // Full jitter spreads retries uniformly across the window.
107            let mut rng = rand::thread_rng();
108            let jitter_factor = rng.gen_range(0.0..1.0);
109            capped_delay * jitter_factor
110        } else {
111            capped_delay
112        };
113
114        Duration::from_millis(final_delay as u64)
115    }
116}
117
118/// Trait for classifying whether an error is retryable.
119pub trait RetryClassifier: Send + Sync {
120    /// Returns true if the error is transient and the operation should be retried.
121    fn is_retryable(&self, error: &SeerError) -> bool;
122}
123
124/// Default classifier for network operations (WHOIS/RDAP).
125///
126/// Classifies the following as retryable:
127/// - Timeouts
128/// - Connection failures (IO errors)
129/// - Rate limiting (429)
130/// - Server errors (5xx)
131///
132/// Non-retryable errors:
133/// - Invalid input (domain, IP, record type)
134/// - Server not found
135/// - Parse errors (JSON, WHOIS format)
136#[derive(Debug, Clone, Default)]
137pub struct NetworkRetryClassifier;
138
139impl NetworkRetryClassifier {
140    pub fn new() -> Self {
141        Self
142    }
143}
144
145impl RetryClassifier for NetworkRetryClassifier {
146    fn is_retryable(&self, error: &SeerError) -> bool {
147        match error {
148            // Transient errors - worth retrying
149            SeerError::Timeout(_) => true,
150            SeerError::WhoisConnectionFailed(_) => true,
151            SeerError::RateLimited(_) => true,
152
153            // Reqwest errors need deeper inspection
154            SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
155
156            // WHOIS errors might be transient if they're connection-related
157            SeerError::WhoisError(msg) => {
158                let lower = msg.to_lowercase();
159                lower.contains("connection")
160                    || lower.contains("timeout")
161                    || lower.contains("refused")
162                    || lower.contains("reset")
163            }
164
165            // RDAP errors might be transient server errors
166            SeerError::RdapError(msg) => {
167                let lower = msg.to_lowercase();
168                lower.contains("status 5")
169                    || lower.contains("status 429")
170                    || lower.contains("timeout")
171            }
172
173            // Bootstrap errors could be transient if IANA is temporarily unavailable
174            SeerError::RdapBootstrapError(msg) => {
175                let lower = msg.to_lowercase();
176                lower.contains("timeout") || lower.contains("connection")
177            }
178
179            // DNS errors can be transient
180            SeerError::DnsError(msg) => {
181                let lower = msg.to_lowercase();
182                lower.contains("timeout") || lower.contains("temporary")
183            }
184
185            // HTTP errors might be transient (server errors 5xx or 429 Too Many Requests)
186            SeerError::HttpError(msg) => {
187                let lower = msg.to_lowercase();
188                lower.contains("timeout")
189                    || lower.contains("connection")
190                    || lower.contains("status 5")
191                    || lower.contains("status 429")
192            }
193
194            // Not retryable - permanent failures
195            SeerError::InvalidDomain(_) => false,
196            SeerError::DomainNotAllowed { .. } => false,
197            SeerError::InvalidIpAddress(_) => false,
198            SeerError::InvalidRecordType(_) => false,
199            SeerError::WhoisServerNotFound(_) => false,
200            SeerError::JsonError(_) => false,
201            SeerError::CertificateError(_) => false,
202            SeerError::SslError(_) => false,
203            SeerError::DnsResolverError(_) => false,
204            SeerError::BulkOperationError { .. } => false,
205            SeerError::LookupFailed { .. } => false,
206            SeerError::ConfigError(_) => false,
207            SeerError::InvalidInput(_) => false,
208            // Transparent pass-through: if a prior attempt was wrapped into
209            // RetryExhausted upstream, defer the retryable decision to the
210            // underlying cause so a caller layering retries still sees the
211            // true fault classification instead of a non-retryable wrapper.
212            SeerError::RetryExhausted { last_error, .. } => self.is_retryable(last_error),
213            SeerError::Other(_) => false,
214        }
215    }
216}
217
218/// Checks if a reqwest error is transient and worth retrying.
219fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
220    // Connection errors are transient
221    if error.is_connect() {
222        return true;
223    }
224
225    // Timeout errors are transient
226    if error.is_timeout() {
227        return true;
228    }
229
230    // Check HTTP status codes
231    if let Some(status) = error.status() {
232        // 429 Too Many Requests - rate limited, retry with backoff
233        if status.as_u16() == 429 {
234            return true;
235        }
236        // 5xx Server errors are transient
237        if status.is_server_error() {
238            return true;
239        }
240        // 4xx Client errors (except 429) are not retryable
241        return false;
242    }
243
244    // Request/body errors are generally not retryable
245    if error.is_request() || error.is_body() {
246        return false;
247    }
248
249    // Default: assume transient for unknown errors
250    true
251}
252
253/// Executes operations with retry logic using exponential backoff.
254#[derive(Debug, Clone)]
255pub struct RetryExecutor<C: RetryClassifier> {
256    policy: RetryPolicy,
257    classifier: C,
258}
259
260impl RetryExecutor<NetworkRetryClassifier> {
261    /// Creates a new executor with the default network retry classifier.
262    pub fn new(policy: RetryPolicy) -> Self {
263        Self {
264            policy,
265            classifier: NetworkRetryClassifier::new(),
266        }
267    }
268}
269
270impl<C: RetryClassifier> RetryExecutor<C> {
271    /// Creates a new executor with a custom classifier.
272    pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
273        Self { policy, classifier }
274    }
275
276    /// Executes an async operation with retry logic.
277    ///
278    /// The operation will be retried up to `max_attempts` times if it fails
279    /// with a retryable error. Delays between retries follow exponential
280    /// backoff with optional jitter.
281    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
282    where
283        F: FnMut() -> Fut,
284        Fut: Future<Output = Result<T>>,
285    {
286        let mut last_error: Option<SeerError> = None;
287        let mut attempt = 0;
288
289        while attempt < self.policy.max_attempts {
290            match operation().await {
291                Ok(result) => return Ok(result),
292                Err(e) => {
293                    let is_retryable = self.classifier.is_retryable(&e);
294                    let attempts_remaining = self.policy.max_attempts - attempt - 1;
295
296                    if !is_retryable || attempts_remaining == 0 {
297                        if attempt > 0 {
298                            debug!(
299                                attempt = attempt + 1,
300                                max_attempts = self.policy.max_attempts,
301                                error = %e,
302                                "Operation failed after retries"
303                            );
304                        }
305                        return Err(if attempt > 0 {
306                            SeerError::RetryExhausted {
307                                attempts: attempt + 1,
308                                last_error: Box::new(e),
309                            }
310                        } else {
311                            e
312                        });
313                    }
314
315                    let delay = self.policy.delay_for_attempt(attempt);
316                    debug!(
317                        attempt = attempt + 1,
318                        max_attempts = self.policy.max_attempts,
319                        delay_ms = delay.as_millis(),
320                        error = %e,
321                        "Retrying after transient error"
322                    );
323
324                    last_error = Some(e);
325                    tokio::time::sleep(delay).await;
326                    attempt += 1;
327                }
328            }
329        }
330
331        // Should not reach here, but handle it gracefully
332        Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
333    }
334
335    /// Executes an async operation once without retries.
336    /// Useful for operations that should not be retried.
337    pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
338    where
339        F: FnOnce() -> Fut,
340        Fut: Future<Output = Result<T>>,
341    {
342        operation().await
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use std::sync::atomic::{AtomicUsize, Ordering};
350    use std::sync::Arc;
351
352    #[test]
353    fn test_retry_policy_defaults() {
354        let policy = RetryPolicy::default();
355        assert_eq!(policy.max_attempts, 3);
356        assert_eq!(policy.initial_delay, Duration::from_millis(100));
357        assert_eq!(policy.max_delay, Duration::from_secs(5));
358        assert_eq!(policy.multiplier, 2.0);
359        assert!(policy.jitter);
360    }
361
362    #[test]
363    fn test_retry_policy_builder() {
364        let policy = RetryPolicy::new()
365            .with_max_attempts(5)
366            .with_initial_delay(Duration::from_millis(200))
367            .with_max_delay(Duration::from_secs(10))
368            .with_multiplier(3.0)
369            .with_jitter(false);
370
371        assert_eq!(policy.max_attempts, 5);
372        assert_eq!(policy.initial_delay, Duration::from_millis(200));
373        assert_eq!(policy.max_delay, Duration::from_secs(10));
374        assert_eq!(policy.multiplier, 3.0);
375        assert!(!policy.jitter);
376    }
377
378    #[test]
379    fn test_delay_calculation_no_jitter() {
380        let policy = RetryPolicy::new()
381            .with_initial_delay(Duration::from_millis(100))
382            .with_multiplier(2.0)
383            .with_max_delay(Duration::from_secs(10))
384            .with_jitter(false);
385
386        assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
387        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
388        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
389        assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
390    }
391
392    #[test]
393    fn test_delay_capped_at_max() {
394        let policy = RetryPolicy::new()
395            .with_initial_delay(Duration::from_secs(1))
396            .with_multiplier(10.0)
397            .with_max_delay(Duration::from_secs(5))
398            .with_jitter(false);
399
400        // 1s * 10^2 = 100s, but capped at 5s
401        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
402    }
403
404    #[test]
405    fn test_classifier_timeout_is_retryable() {
406        let classifier = NetworkRetryClassifier::new();
407        assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
408    }
409
410    #[test]
411    fn test_classifier_invalid_domain_not_retryable() {
412        let classifier = NetworkRetryClassifier::new();
413        assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
414    }
415
416    #[test]
417    fn test_classifier_server_not_found_not_retryable() {
418        let classifier = NetworkRetryClassifier::new();
419        assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
420    }
421
422    #[test]
423    fn test_classifier_rate_limited_is_retryable() {
424        let classifier = NetworkRetryClassifier::new();
425        assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
426    }
427
428    #[tokio::test]
429    async fn test_executor_success_on_first_try() {
430        let policy = RetryPolicy::new().with_max_attempts(3);
431        let executor = RetryExecutor::new(policy);
432        let attempts = Arc::new(AtomicUsize::new(0));
433
434        let attempts_clone = attempts.clone();
435        let result: Result<&str> = executor
436            .execute(|| {
437                let a = attempts_clone.clone();
438                async move {
439                    a.fetch_add(1, Ordering::SeqCst);
440                    Ok("success")
441                }
442            })
443            .await;
444
445        assert!(result.is_ok());
446        assert_eq!(result.unwrap(), "success");
447        assert_eq!(attempts.load(Ordering::SeqCst), 1);
448    }
449
450    #[tokio::test]
451    async fn test_executor_retries_on_transient_error() {
452        let policy = RetryPolicy::new()
453            .with_max_attempts(3)
454            .with_initial_delay(Duration::from_millis(1))
455            .with_jitter(false);
456        let executor = RetryExecutor::new(policy);
457        let attempts = Arc::new(AtomicUsize::new(0));
458
459        let attempts_clone = attempts.clone();
460        let result: Result<&str> = executor
461            .execute(|| {
462                let a = attempts_clone.clone();
463                async move {
464                    let count = a.fetch_add(1, Ordering::SeqCst);
465                    if count < 2 {
466                        Err(SeerError::Timeout("test timeout".to_string()))
467                    } else {
468                        Ok("success after retries")
469                    }
470                }
471            })
472            .await;
473
474        assert!(result.is_ok());
475        assert_eq!(result.unwrap(), "success after retries");
476        assert_eq!(attempts.load(Ordering::SeqCst), 3);
477    }
478
479    #[tokio::test]
480    async fn test_executor_no_retry_on_non_retryable_error() {
481        let policy = RetryPolicy::new()
482            .with_max_attempts(3)
483            .with_initial_delay(Duration::from_millis(1));
484        let executor = RetryExecutor::new(policy);
485        let attempts = Arc::new(AtomicUsize::new(0));
486
487        let attempts_clone = attempts.clone();
488        let result: Result<&str> = executor
489            .execute(|| {
490                let a = attempts_clone.clone();
491                async move {
492                    a.fetch_add(1, Ordering::SeqCst);
493                    Err(SeerError::InvalidDomain("bad.".to_string()))
494                }
495            })
496            .await;
497
498        assert!(result.is_err());
499        // Should only attempt once since InvalidDomain is not retryable
500        assert_eq!(attempts.load(Ordering::SeqCst), 1);
501    }
502
503    #[tokio::test]
504    async fn test_executor_exhausts_retries() {
505        let policy = RetryPolicy::new()
506            .with_max_attempts(3)
507            .with_initial_delay(Duration::from_millis(1))
508            .with_jitter(false);
509        let executor = RetryExecutor::new(policy);
510        let attempts = Arc::new(AtomicUsize::new(0));
511
512        let attempts_clone = attempts.clone();
513        let result: Result<&str> = executor
514            .execute(|| {
515                let a = attempts_clone.clone();
516                async move {
517                    a.fetch_add(1, Ordering::SeqCst);
518                    Err(SeerError::Timeout("always fails".to_string()))
519                }
520            })
521            .await;
522
523        assert!(result.is_err());
524        assert_eq!(attempts.load(Ordering::SeqCst), 3);
525
526        // Check that we get RetryExhausted error
527        match result.unwrap_err() {
528            SeerError::RetryExhausted { attempts, .. } => {
529                assert_eq!(attempts, 3);
530            }
531            other => panic!("Expected RetryExhausted, got {:?}", other),
532        }
533    }
534
535    #[test]
536    fn test_no_retry_policy() {
537        let policy = RetryPolicy::no_retry();
538        assert_eq!(policy.max_attempts, 1);
539    }
540
541    #[test]
542    fn test_delay_overflow_protection() {
543        let policy = RetryPolicy::new()
544            .with_initial_delay(Duration::from_millis(100))
545            .with_multiplier(2.0)
546            .with_max_delay(Duration::from_secs(5))
547            .with_jitter(false);
548
549        // Test with very large attempt numbers - should not panic or produce invalid durations
550        let delay_50 = policy.delay_for_attempt(50);
551        let delay_100 = policy.delay_for_attempt(100);
552        let delay_1000 = policy.delay_for_attempt(1000);
553
554        // All should be capped at max_delay due to our overflow protection
555        assert!(delay_50 <= Duration::from_secs(5));
556        assert!(delay_100 <= Duration::from_secs(5));
557        assert!(delay_1000 <= Duration::from_secs(5));
558    }
559
560    #[test]
561    fn retry_exhausted_is_retryable_if_inner_is() {
562        // Regression for H5: the retry classifier must look through a
563        // RetryExhausted wrapper so a caller layering retries can still see
564        // the true underlying fault classification instead of collapsing to
565        // a non-retryable wrapper.
566        let classifier = NetworkRetryClassifier::new();
567
568        let retryable_inner = SeerError::Timeout("inner timed out".to_string());
569        let wrapped_retryable = SeerError::RetryExhausted {
570            attempts: 3,
571            last_error: Box::new(retryable_inner),
572        };
573        assert!(
574            classifier.is_retryable(&wrapped_retryable),
575            "RetryExhausted wrapping a retryable Timeout should be retryable",
576        );
577
578        let non_retryable_inner = SeerError::InvalidDomain("bad.".to_string());
579        let wrapped_non_retryable = SeerError::RetryExhausted {
580            attempts: 3,
581            last_error: Box::new(non_retryable_inner),
582        };
583        assert!(
584            !classifier.is_retryable(&wrapped_non_retryable),
585            "RetryExhausted wrapping a non-retryable InvalidDomain must not be retryable",
586        );
587    }
588}