Skip to main content

seer_core/
retry.rs

1//! Retry logic with exponential backoff for transient failures.
2//!
3//! This module provides configurable retry policies and executors for handling
4//! transient network failures in WHOIS and RDAP lookups.
5
6use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::debug;
11
12use crate::error::{Result, SeerError};
13
14/// Configuration for retry behavior with exponential backoff.
15#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17    /// Maximum number of attempts (including the initial attempt).
18    pub max_attempts: usize,
19    /// Initial delay before the first retry.
20    pub initial_delay: Duration,
21    /// Maximum delay between retries (caps exponential growth).
22    pub max_delay: Duration,
23    /// Multiplier for exponential backoff (delay *= multiplier after each retry).
24    pub multiplier: f64,
25    /// Whether to add random jitter to delays to avoid thundering herd.
26    pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30    fn default() -> Self {
31        Self {
32            max_attempts: 3,
33            initial_delay: Duration::from_millis(100),
34            max_delay: Duration::from_secs(5),
35            multiplier: 2.0,
36            jitter: true,
37        }
38    }
39}
40
41impl RetryPolicy {
42    /// Creates a new retry policy with default settings.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Sets the maximum number of attempts.
48    pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49        self.max_attempts = attempts.max(1);
50        self
51    }
52
53    /// Sets the initial delay before the first retry.
54    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55        self.initial_delay = delay;
56        self
57    }
58
59    /// Sets the maximum delay between retries.
60    pub fn with_max_delay(mut self, delay: Duration) -> Self {
61        self.max_delay = delay;
62        self
63    }
64
65    /// Sets the multiplier for exponential backoff.
66    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67        self.multiplier = multiplier.max(1.0);
68        self
69    }
70
71    /// Enables or disables jitter.
72    pub fn with_jitter(mut self, jitter: bool) -> Self {
73        self.jitter = jitter;
74        self
75    }
76
77    /// Creates a policy that disables retries (single attempt only).
78    pub fn no_retry() -> Self {
79        Self {
80            max_attempts: 1,
81            ..Self::default()
82        }
83    }
84
85    /// Calculates the delay for a given attempt number (0-indexed).
86    ///
87    /// The attempt number is internally capped to prevent integer overflow
88    /// in the exponential calculation.
89    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90        if attempt == 0 {
91            return self.initial_delay;
92        }
93
94        // Cap attempt to prevent overflow in powi() - 20 attempts with multiplier 2.0
95        // gives 2^20 = ~1 million, which is safe for f64 and reasonable for delays
96        let safe_attempt = attempt.min(20) as i32;
97
98        let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99        let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101        let final_delay = if self.jitter {
102            // Add jitter: random value between 50% and 100% of the delay
103            let mut rng = rand::thread_rng();
104            let jitter_factor = rng.gen_range(0.5..1.0);
105            capped_delay * jitter_factor
106        } else {
107            capped_delay
108        };
109
110        Duration::from_millis(final_delay as u64)
111    }
112}
113
114/// Trait for classifying whether an error is retryable.
115pub trait RetryClassifier: Send + Sync {
116    /// Returns true if the error is transient and the operation should be retried.
117    fn is_retryable(&self, error: &SeerError) -> bool;
118}
119
120/// Default classifier for network operations (WHOIS/RDAP).
121///
122/// Classifies the following as retryable:
123/// - Timeouts
124/// - Connection failures (IO errors)
125/// - Rate limiting (429)
126/// - Server errors (5xx)
127///
128/// Non-retryable errors:
129/// - Invalid input (domain, IP, record type)
130/// - Server not found
131/// - Parse errors (JSON, WHOIS format)
132#[derive(Debug, Clone, Default)]
133pub struct NetworkRetryClassifier;
134
135impl NetworkRetryClassifier {
136    pub fn new() -> Self {
137        Self
138    }
139}
140
141impl RetryClassifier for NetworkRetryClassifier {
142    fn is_retryable(&self, error: &SeerError) -> bool {
143        match error {
144            // Transient errors - worth retrying
145            SeerError::Timeout(_) => true,
146            SeerError::WhoisConnectionFailed(_) => true,
147            SeerError::RateLimited(_) => true,
148
149            // Reqwest errors need deeper inspection
150            SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
151
152            // WHOIS errors might be transient if they're connection-related
153            SeerError::WhoisError(msg) => {
154                let lower = msg.to_lowercase();
155                lower.contains("connection")
156                    || lower.contains("timeout")
157                    || lower.contains("refused")
158                    || lower.contains("reset")
159            }
160
161            // RDAP errors might be transient server errors
162            SeerError::RdapError(msg) => {
163                let lower = msg.to_lowercase();
164                lower.contains("status 5")
165                    || lower.contains("status 429")
166                    || lower.contains("timeout")
167            }
168
169            // Bootstrap errors could be transient if IANA is temporarily unavailable
170            SeerError::RdapBootstrapError(msg) => {
171                let lower = msg.to_lowercase();
172                lower.contains("timeout") || lower.contains("connection")
173            }
174
175            // DNS errors can be transient
176            SeerError::DnsError(msg) => {
177                let lower = msg.to_lowercase();
178                lower.contains("timeout") || lower.contains("temporary")
179            }
180
181            // HTTP errors might be transient (server errors 5xx or 429 Too Many Requests)
182            SeerError::HttpError(msg) => {
183                let lower = msg.to_lowercase();
184                lower.contains("timeout")
185                    || lower.contains("connection")
186                    || lower.contains("status 5")
187                    || lower.contains("status 429")
188            }
189
190            // Not retryable - permanent failures
191            SeerError::InvalidDomain(_) => false,
192            SeerError::DomainNotAllowed { .. } => false,
193            SeerError::InvalidIpAddress(_) => false,
194            SeerError::InvalidRecordType(_) => false,
195            SeerError::WhoisServerNotFound(_) => false,
196            SeerError::JsonError(_) => false,
197            SeerError::CertificateError(_) => false,
198            SeerError::SslError(_) => false,
199            SeerError::DnsResolverError(_) => false,
200            SeerError::BulkOperationError { .. } => false,
201            SeerError::LookupFailed { .. } => false,
202            SeerError::ConfigError(_) => false,
203            SeerError::InvalidInput(_) => false,
204            // Transparent pass-through: if a prior attempt was wrapped into
205            // RetryExhausted upstream, defer the retryable decision to the
206            // underlying cause so a caller layering retries still sees the
207            // true fault classification instead of a non-retryable wrapper.
208            SeerError::RetryExhausted { last_error, .. } => self.is_retryable(last_error),
209            SeerError::Other(_) => false,
210        }
211    }
212}
213
214/// Checks if a reqwest error is transient and worth retrying.
215fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
216    // Connection errors are transient
217    if error.is_connect() {
218        return true;
219    }
220
221    // Timeout errors are transient
222    if error.is_timeout() {
223        return true;
224    }
225
226    // Check HTTP status codes
227    if let Some(status) = error.status() {
228        // 429 Too Many Requests - rate limited, retry with backoff
229        if status.as_u16() == 429 {
230            return true;
231        }
232        // 5xx Server errors are transient
233        if status.is_server_error() {
234            return true;
235        }
236        // 4xx Client errors (except 429) are not retryable
237        return false;
238    }
239
240    // Request/body errors are generally not retryable
241    if error.is_request() || error.is_body() {
242        return false;
243    }
244
245    // Default: assume transient for unknown errors
246    true
247}
248
249/// Executes operations with retry logic using exponential backoff.
250#[derive(Debug, Clone)]
251pub struct RetryExecutor<C: RetryClassifier> {
252    policy: RetryPolicy,
253    classifier: C,
254}
255
256impl RetryExecutor<NetworkRetryClassifier> {
257    /// Creates a new executor with the default network retry classifier.
258    pub fn new(policy: RetryPolicy) -> Self {
259        Self {
260            policy,
261            classifier: NetworkRetryClassifier::new(),
262        }
263    }
264}
265
266impl<C: RetryClassifier> RetryExecutor<C> {
267    /// Creates a new executor with a custom classifier.
268    pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
269        Self { policy, classifier }
270    }
271
272    /// Executes an async operation with retry logic.
273    ///
274    /// The operation will be retried up to `max_attempts` times if it fails
275    /// with a retryable error. Delays between retries follow exponential
276    /// backoff with optional jitter.
277    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
278    where
279        F: FnMut() -> Fut,
280        Fut: Future<Output = Result<T>>,
281    {
282        let mut last_error: Option<SeerError> = None;
283        let mut attempt = 0;
284
285        while attempt < self.policy.max_attempts {
286            match operation().await {
287                Ok(result) => return Ok(result),
288                Err(e) => {
289                    let is_retryable = self.classifier.is_retryable(&e);
290                    let attempts_remaining = self.policy.max_attempts - attempt - 1;
291
292                    if !is_retryable || attempts_remaining == 0 {
293                        if attempt > 0 {
294                            debug!(
295                                attempt = attempt + 1,
296                                max_attempts = self.policy.max_attempts,
297                                error = %e,
298                                "Operation failed after retries"
299                            );
300                        }
301                        return Err(if attempt > 0 {
302                            SeerError::RetryExhausted {
303                                attempts: attempt + 1,
304                                last_error: Box::new(e),
305                            }
306                        } else {
307                            e
308                        });
309                    }
310
311                    let delay = self.policy.delay_for_attempt(attempt);
312                    debug!(
313                        attempt = attempt + 1,
314                        max_attempts = self.policy.max_attempts,
315                        delay_ms = delay.as_millis(),
316                        error = %e,
317                        "Retrying after transient error"
318                    );
319
320                    last_error = Some(e);
321                    tokio::time::sleep(delay).await;
322                    attempt += 1;
323                }
324            }
325        }
326
327        // Should not reach here, but handle it gracefully
328        Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
329    }
330
331    /// Executes an async operation once without retries.
332    /// Useful for operations that should not be retried.
333    pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
334    where
335        F: FnOnce() -> Fut,
336        Fut: Future<Output = Result<T>>,
337    {
338        operation().await
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use std::sync::atomic::{AtomicUsize, Ordering};
346    use std::sync::Arc;
347
348    #[test]
349    fn test_retry_policy_defaults() {
350        let policy = RetryPolicy::default();
351        assert_eq!(policy.max_attempts, 3);
352        assert_eq!(policy.initial_delay, Duration::from_millis(100));
353        assert_eq!(policy.max_delay, Duration::from_secs(5));
354        assert_eq!(policy.multiplier, 2.0);
355        assert!(policy.jitter);
356    }
357
358    #[test]
359    fn test_retry_policy_builder() {
360        let policy = RetryPolicy::new()
361            .with_max_attempts(5)
362            .with_initial_delay(Duration::from_millis(200))
363            .with_max_delay(Duration::from_secs(10))
364            .with_multiplier(3.0)
365            .with_jitter(false);
366
367        assert_eq!(policy.max_attempts, 5);
368        assert_eq!(policy.initial_delay, Duration::from_millis(200));
369        assert_eq!(policy.max_delay, Duration::from_secs(10));
370        assert_eq!(policy.multiplier, 3.0);
371        assert!(!policy.jitter);
372    }
373
374    #[test]
375    fn test_delay_calculation_no_jitter() {
376        let policy = RetryPolicy::new()
377            .with_initial_delay(Duration::from_millis(100))
378            .with_multiplier(2.0)
379            .with_max_delay(Duration::from_secs(10))
380            .with_jitter(false);
381
382        assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
383        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
384        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
385        assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
386    }
387
388    #[test]
389    fn test_delay_capped_at_max() {
390        let policy = RetryPolicy::new()
391            .with_initial_delay(Duration::from_secs(1))
392            .with_multiplier(10.0)
393            .with_max_delay(Duration::from_secs(5))
394            .with_jitter(false);
395
396        // 1s * 10^2 = 100s, but capped at 5s
397        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
398    }
399
400    #[test]
401    fn test_classifier_timeout_is_retryable() {
402        let classifier = NetworkRetryClassifier::new();
403        assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
404    }
405
406    #[test]
407    fn test_classifier_invalid_domain_not_retryable() {
408        let classifier = NetworkRetryClassifier::new();
409        assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
410    }
411
412    #[test]
413    fn test_classifier_server_not_found_not_retryable() {
414        let classifier = NetworkRetryClassifier::new();
415        assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
416    }
417
418    #[test]
419    fn test_classifier_rate_limited_is_retryable() {
420        let classifier = NetworkRetryClassifier::new();
421        assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
422    }
423
424    #[tokio::test]
425    async fn test_executor_success_on_first_try() {
426        let policy = RetryPolicy::new().with_max_attempts(3);
427        let executor = RetryExecutor::new(policy);
428        let attempts = Arc::new(AtomicUsize::new(0));
429
430        let attempts_clone = attempts.clone();
431        let result: Result<&str> = executor
432            .execute(|| {
433                let a = attempts_clone.clone();
434                async move {
435                    a.fetch_add(1, Ordering::SeqCst);
436                    Ok("success")
437                }
438            })
439            .await;
440
441        assert!(result.is_ok());
442        assert_eq!(result.unwrap(), "success");
443        assert_eq!(attempts.load(Ordering::SeqCst), 1);
444    }
445
446    #[tokio::test]
447    async fn test_executor_retries_on_transient_error() {
448        let policy = RetryPolicy::new()
449            .with_max_attempts(3)
450            .with_initial_delay(Duration::from_millis(1))
451            .with_jitter(false);
452        let executor = RetryExecutor::new(policy);
453        let attempts = Arc::new(AtomicUsize::new(0));
454
455        let attempts_clone = attempts.clone();
456        let result: Result<&str> = executor
457            .execute(|| {
458                let a = attempts_clone.clone();
459                async move {
460                    let count = a.fetch_add(1, Ordering::SeqCst);
461                    if count < 2 {
462                        Err(SeerError::Timeout("test timeout".to_string()))
463                    } else {
464                        Ok("success after retries")
465                    }
466                }
467            })
468            .await;
469
470        assert!(result.is_ok());
471        assert_eq!(result.unwrap(), "success after retries");
472        assert_eq!(attempts.load(Ordering::SeqCst), 3);
473    }
474
475    #[tokio::test]
476    async fn test_executor_no_retry_on_non_retryable_error() {
477        let policy = RetryPolicy::new()
478            .with_max_attempts(3)
479            .with_initial_delay(Duration::from_millis(1));
480        let executor = RetryExecutor::new(policy);
481        let attempts = Arc::new(AtomicUsize::new(0));
482
483        let attempts_clone = attempts.clone();
484        let result: Result<&str> = executor
485            .execute(|| {
486                let a = attempts_clone.clone();
487                async move {
488                    a.fetch_add(1, Ordering::SeqCst);
489                    Err(SeerError::InvalidDomain("bad.".to_string()))
490                }
491            })
492            .await;
493
494        assert!(result.is_err());
495        // Should only attempt once since InvalidDomain is not retryable
496        assert_eq!(attempts.load(Ordering::SeqCst), 1);
497    }
498
499    #[tokio::test]
500    async fn test_executor_exhausts_retries() {
501        let policy = RetryPolicy::new()
502            .with_max_attempts(3)
503            .with_initial_delay(Duration::from_millis(1))
504            .with_jitter(false);
505        let executor = RetryExecutor::new(policy);
506        let attempts = Arc::new(AtomicUsize::new(0));
507
508        let attempts_clone = attempts.clone();
509        let result: Result<&str> = executor
510            .execute(|| {
511                let a = attempts_clone.clone();
512                async move {
513                    a.fetch_add(1, Ordering::SeqCst);
514                    Err(SeerError::Timeout("always fails".to_string()))
515                }
516            })
517            .await;
518
519        assert!(result.is_err());
520        assert_eq!(attempts.load(Ordering::SeqCst), 3);
521
522        // Check that we get RetryExhausted error
523        match result.unwrap_err() {
524            SeerError::RetryExhausted { attempts, .. } => {
525                assert_eq!(attempts, 3);
526            }
527            other => panic!("Expected RetryExhausted, got {:?}", other),
528        }
529    }
530
531    #[test]
532    fn test_no_retry_policy() {
533        let policy = RetryPolicy::no_retry();
534        assert_eq!(policy.max_attempts, 1);
535    }
536
537    #[test]
538    fn test_delay_overflow_protection() {
539        let policy = RetryPolicy::new()
540            .with_initial_delay(Duration::from_millis(100))
541            .with_multiplier(2.0)
542            .with_max_delay(Duration::from_secs(5))
543            .with_jitter(false);
544
545        // Test with very large attempt numbers - should not panic or produce invalid durations
546        let delay_50 = policy.delay_for_attempt(50);
547        let delay_100 = policy.delay_for_attempt(100);
548        let delay_1000 = policy.delay_for_attempt(1000);
549
550        // All should be capped at max_delay due to our overflow protection
551        assert!(delay_50 <= Duration::from_secs(5));
552        assert!(delay_100 <= Duration::from_secs(5));
553        assert!(delay_1000 <= Duration::from_secs(5));
554    }
555
556    #[test]
557    fn retry_exhausted_is_retryable_if_inner_is() {
558        // Regression for H5: the retry classifier must look through a
559        // RetryExhausted wrapper so a caller layering retries can still see
560        // the true underlying fault classification instead of collapsing to
561        // a non-retryable wrapper.
562        let classifier = NetworkRetryClassifier::new();
563
564        let retryable_inner = SeerError::Timeout("inner timed out".to_string());
565        let wrapped_retryable = SeerError::RetryExhausted {
566            attempts: 3,
567            last_error: Box::new(retryable_inner),
568        };
569        assert!(
570            classifier.is_retryable(&wrapped_retryable),
571            "RetryExhausted wrapping a retryable Timeout should be retryable",
572        );
573
574        let non_retryable_inner = SeerError::InvalidDomain("bad.".to_string());
575        let wrapped_non_retryable = SeerError::RetryExhausted {
576            attempts: 3,
577            last_error: Box::new(non_retryable_inner),
578        };
579        assert!(
580            !classifier.is_retryable(&wrapped_non_retryable),
581            "RetryExhausted wrapping a non-retryable InvalidDomain must not be retryable",
582        );
583    }
584}