Skip to main content

seer_core/
retry.rs

1//! Retry logic with exponential backoff for transient failures.
2//!
3//! This module provides configurable retry policies and executors for handling
4//! transient network failures in WHOIS and RDAP lookups.
5
6use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::{debug, warn};
11
12use crate::error::{Result, SeerError};
13
14/// Configuration for retry behavior with exponential backoff.
15#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17    /// Maximum number of attempts (including the initial attempt).
18    pub max_attempts: usize,
19    /// Initial delay before the first retry.
20    pub initial_delay: Duration,
21    /// Maximum delay between retries (caps exponential growth).
22    pub max_delay: Duration,
23    /// Multiplier for exponential backoff (delay *= multiplier after each retry).
24    pub multiplier: f64,
25    /// Whether to add random jitter to delays to avoid thundering herd.
26    pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30    fn default() -> Self {
31        Self {
32            max_attempts: 3,
33            initial_delay: Duration::from_millis(100),
34            max_delay: Duration::from_secs(5),
35            multiplier: 2.0,
36            jitter: true,
37        }
38    }
39}
40
41impl RetryPolicy {
42    /// Creates a new retry policy with default settings.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Sets the maximum number of attempts.
48    pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49        self.max_attempts = attempts.max(1);
50        self
51    }
52
53    /// Sets the initial delay before the first retry.
54    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55        self.initial_delay = delay;
56        self
57    }
58
59    /// Sets the maximum delay between retries.
60    pub fn with_max_delay(mut self, delay: Duration) -> Self {
61        self.max_delay = delay;
62        self
63    }
64
65    /// Sets the multiplier for exponential backoff.
66    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67        self.multiplier = multiplier.max(1.0);
68        self
69    }
70
71    /// Enables or disables jitter.
72    pub fn with_jitter(mut self, jitter: bool) -> Self {
73        self.jitter = jitter;
74        self
75    }
76
77    /// Creates a policy that disables retries (single attempt only).
78    pub fn no_retry() -> Self {
79        Self {
80            max_attempts: 1,
81            ..Self::default()
82        }
83    }
84
85    /// Calculates the delay for a given attempt number (0-indexed).
86    ///
87    /// The attempt number is internally capped to prevent integer overflow
88    /// in the exponential calculation.
89    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90        if attempt == 0 {
91            return self.initial_delay;
92        }
93
94        // Cap attempt to prevent overflow in powi() - 20 attempts with multiplier 2.0
95        // gives 2^20 = ~1 million, which is safe for f64 and reasonable for delays
96        let safe_attempt = attempt.min(20) as i32;
97
98        let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99        let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101        let final_delay = if self.jitter {
102            // Add jitter: random value between 50% and 100% of the delay
103            let mut rng = rand::thread_rng();
104            let jitter_factor = rng.gen_range(0.5..1.0);
105            capped_delay * jitter_factor
106        } else {
107            capped_delay
108        };
109
110        Duration::from_millis(final_delay as u64)
111    }
112}
113
114/// Trait for classifying whether an error is retryable.
115pub trait RetryClassifier: Send + Sync {
116    /// Returns true if the error is transient and the operation should be retried.
117    fn is_retryable(&self, error: &SeerError) -> bool;
118}
119
120/// Default classifier for network operations (WHOIS/RDAP).
121///
122/// Classifies the following as retryable:
123/// - Timeouts
124/// - Connection failures (IO errors)
125/// - Rate limiting (429)
126/// - Server errors (5xx)
127///
128/// Non-retryable errors:
129/// - Invalid input (domain, IP, record type)
130/// - Server not found
131/// - Parse errors (JSON, WHOIS format)
132#[derive(Debug, Clone, Default)]
133pub struct NetworkRetryClassifier;
134
135impl NetworkRetryClassifier {
136    pub fn new() -> Self {
137        Self
138    }
139}
140
141impl RetryClassifier for NetworkRetryClassifier {
142    fn is_retryable(&self, error: &SeerError) -> bool {
143        match error {
144            // Transient errors - worth retrying
145            SeerError::Timeout(_) => true,
146            SeerError::WhoisConnectionFailed(_) => true,
147            SeerError::RateLimited(_) => true,
148
149            // Reqwest errors need deeper inspection
150            SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
151
152            // WHOIS errors might be transient if they're connection-related
153            SeerError::WhoisError(msg) => {
154                let lower = msg.to_lowercase();
155                lower.contains("connection")
156                    || lower.contains("timeout")
157                    || lower.contains("refused")
158                    || lower.contains("reset")
159            }
160
161            // RDAP errors might be transient server errors
162            SeerError::RdapError(msg) => {
163                let lower = msg.to_lowercase();
164                lower.contains("status 5")
165                    || lower.contains("status 429")
166                    || lower.contains("timeout")
167            }
168
169            // Bootstrap errors could be transient if IANA is temporarily unavailable
170            SeerError::RdapBootstrapError(msg) => {
171                let lower = msg.to_lowercase();
172                lower.contains("timeout") || lower.contains("connection")
173            }
174
175            // DNS errors can be transient
176            SeerError::DnsError(msg) => {
177                let lower = msg.to_lowercase();
178                lower.contains("timeout") || lower.contains("temporary")
179            }
180
181            // HTTP errors might be transient (server errors 5xx or 429 Too Many Requests)
182            SeerError::HttpError(msg) => {
183                let lower = msg.to_lowercase();
184                lower.contains("timeout")
185                    || lower.contains("connection")
186                    || lower.contains("status 5")
187                    || lower.contains("status 429")
188            }
189
190            // Not retryable - permanent failures
191            SeerError::InvalidDomain(_) => false,
192            SeerError::DomainNotAllowed { .. } => false,
193            SeerError::InvalidIpAddress(_) => false,
194            SeerError::InvalidRecordType(_) => false,
195            SeerError::WhoisServerNotFound(_) => false,
196            SeerError::JsonError(_) => false,
197            SeerError::CertificateError(_) => false,
198            SeerError::SslError(_) => false,
199            SeerError::DnsResolverError(_) => false,
200            SeerError::BulkOperationError { .. } => false,
201            SeerError::LookupFailed { .. } => false,
202            SeerError::ConfigError(_) => false,
203            SeerError::RetryExhausted { .. } => false,
204            SeerError::Other(_) => false,
205        }
206    }
207}
208
209/// Checks if a reqwest error is transient and worth retrying.
210fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
211    // Connection errors are transient
212    if error.is_connect() {
213        return true;
214    }
215
216    // Timeout errors are transient
217    if error.is_timeout() {
218        return true;
219    }
220
221    // Check HTTP status codes
222    if let Some(status) = error.status() {
223        // 429 Too Many Requests - rate limited, retry with backoff
224        if status.as_u16() == 429 {
225            return true;
226        }
227        // 5xx Server errors are transient
228        if status.is_server_error() {
229            return true;
230        }
231        // 4xx Client errors (except 429) are not retryable
232        return false;
233    }
234
235    // Request/body errors are generally not retryable
236    if error.is_request() || error.is_body() {
237        return false;
238    }
239
240    // Default: assume transient for unknown errors
241    true
242}
243
244/// Executes operations with retry logic using exponential backoff.
245#[derive(Debug, Clone)]
246pub struct RetryExecutor<C: RetryClassifier> {
247    policy: RetryPolicy,
248    classifier: C,
249}
250
251impl RetryExecutor<NetworkRetryClassifier> {
252    /// Creates a new executor with the default network retry classifier.
253    pub fn new(policy: RetryPolicy) -> Self {
254        Self {
255            policy,
256            classifier: NetworkRetryClassifier::new(),
257        }
258    }
259}
260
261impl<C: RetryClassifier> RetryExecutor<C> {
262    /// Creates a new executor with a custom classifier.
263    pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
264        Self { policy, classifier }
265    }
266
267    /// Executes an async operation with retry logic.
268    ///
269    /// The operation will be retried up to `max_attempts` times if it fails
270    /// with a retryable error. Delays between retries follow exponential
271    /// backoff with optional jitter.
272    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
273    where
274        F: FnMut() -> Fut,
275        Fut: Future<Output = Result<T>>,
276    {
277        let mut last_error: Option<SeerError> = None;
278        let mut attempt = 0;
279
280        while attempt < self.policy.max_attempts {
281            match operation().await {
282                Ok(result) => return Ok(result),
283                Err(e) => {
284                    let is_retryable = self.classifier.is_retryable(&e);
285                    let attempts_remaining = self.policy.max_attempts - attempt - 1;
286
287                    if !is_retryable || attempts_remaining == 0 {
288                        if attempt > 0 {
289                            warn!(
290                                attempt = attempt + 1,
291                                max_attempts = self.policy.max_attempts,
292                                error = %e,
293                                "Operation failed after retries"
294                            );
295                        }
296                        return Err(if attempt > 0 {
297                            SeerError::RetryExhausted {
298                                attempts: attempt + 1,
299                                last_error: e.to_string(),
300                            }
301                        } else {
302                            e
303                        });
304                    }
305
306                    let delay = self.policy.delay_for_attempt(attempt);
307                    debug!(
308                        attempt = attempt + 1,
309                        max_attempts = self.policy.max_attempts,
310                        delay_ms = delay.as_millis(),
311                        error = %e,
312                        "Retrying after transient error"
313                    );
314
315                    last_error = Some(e);
316                    tokio::time::sleep(delay).await;
317                    attempt += 1;
318                }
319            }
320        }
321
322        // Should not reach here, but handle it gracefully
323        Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
324    }
325
326    /// Executes an async operation once without retries.
327    /// Useful for operations that should not be retried.
328    pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
329    where
330        F: FnOnce() -> Fut,
331        Fut: Future<Output = Result<T>>,
332    {
333        operation().await
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use std::sync::atomic::{AtomicUsize, Ordering};
341    use std::sync::Arc;
342
343    #[test]
344    fn test_retry_policy_defaults() {
345        let policy = RetryPolicy::default();
346        assert_eq!(policy.max_attempts, 3);
347        assert_eq!(policy.initial_delay, Duration::from_millis(100));
348        assert_eq!(policy.max_delay, Duration::from_secs(5));
349        assert_eq!(policy.multiplier, 2.0);
350        assert!(policy.jitter);
351    }
352
353    #[test]
354    fn test_retry_policy_builder() {
355        let policy = RetryPolicy::new()
356            .with_max_attempts(5)
357            .with_initial_delay(Duration::from_millis(200))
358            .with_max_delay(Duration::from_secs(10))
359            .with_multiplier(3.0)
360            .with_jitter(false);
361
362        assert_eq!(policy.max_attempts, 5);
363        assert_eq!(policy.initial_delay, Duration::from_millis(200));
364        assert_eq!(policy.max_delay, Duration::from_secs(10));
365        assert_eq!(policy.multiplier, 3.0);
366        assert!(!policy.jitter);
367    }
368
369    #[test]
370    fn test_delay_calculation_no_jitter() {
371        let policy = RetryPolicy::new()
372            .with_initial_delay(Duration::from_millis(100))
373            .with_multiplier(2.0)
374            .with_max_delay(Duration::from_secs(10))
375            .with_jitter(false);
376
377        assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
378        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
379        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
380        assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
381    }
382
383    #[test]
384    fn test_delay_capped_at_max() {
385        let policy = RetryPolicy::new()
386            .with_initial_delay(Duration::from_secs(1))
387            .with_multiplier(10.0)
388            .with_max_delay(Duration::from_secs(5))
389            .with_jitter(false);
390
391        // 1s * 10^2 = 100s, but capped at 5s
392        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
393    }
394
395    #[test]
396    fn test_classifier_timeout_is_retryable() {
397        let classifier = NetworkRetryClassifier::new();
398        assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
399    }
400
401    #[test]
402    fn test_classifier_invalid_domain_not_retryable() {
403        let classifier = NetworkRetryClassifier::new();
404        assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
405    }
406
407    #[test]
408    fn test_classifier_server_not_found_not_retryable() {
409        let classifier = NetworkRetryClassifier::new();
410        assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
411    }
412
413    #[test]
414    fn test_classifier_rate_limited_is_retryable() {
415        let classifier = NetworkRetryClassifier::new();
416        assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
417    }
418
419    #[tokio::test]
420    async fn test_executor_success_on_first_try() {
421        let policy = RetryPolicy::new().with_max_attempts(3);
422        let executor = RetryExecutor::new(policy);
423        let attempts = Arc::new(AtomicUsize::new(0));
424
425        let attempts_clone = attempts.clone();
426        let result: Result<&str> = executor
427            .execute(|| {
428                let a = attempts_clone.clone();
429                async move {
430                    a.fetch_add(1, Ordering::SeqCst);
431                    Ok("success")
432                }
433            })
434            .await;
435
436        assert!(result.is_ok());
437        assert_eq!(result.unwrap(), "success");
438        assert_eq!(attempts.load(Ordering::SeqCst), 1);
439    }
440
441    #[tokio::test]
442    async fn test_executor_retries_on_transient_error() {
443        let policy = RetryPolicy::new()
444            .with_max_attempts(3)
445            .with_initial_delay(Duration::from_millis(1))
446            .with_jitter(false);
447        let executor = RetryExecutor::new(policy);
448        let attempts = Arc::new(AtomicUsize::new(0));
449
450        let attempts_clone = attempts.clone();
451        let result: Result<&str> = executor
452            .execute(|| {
453                let a = attempts_clone.clone();
454                async move {
455                    let count = a.fetch_add(1, Ordering::SeqCst);
456                    if count < 2 {
457                        Err(SeerError::Timeout("test timeout".to_string()))
458                    } else {
459                        Ok("success after retries")
460                    }
461                }
462            })
463            .await;
464
465        assert!(result.is_ok());
466        assert_eq!(result.unwrap(), "success after retries");
467        assert_eq!(attempts.load(Ordering::SeqCst), 3);
468    }
469
470    #[tokio::test]
471    async fn test_executor_no_retry_on_non_retryable_error() {
472        let policy = RetryPolicy::new()
473            .with_max_attempts(3)
474            .with_initial_delay(Duration::from_millis(1));
475        let executor = RetryExecutor::new(policy);
476        let attempts = Arc::new(AtomicUsize::new(0));
477
478        let attempts_clone = attempts.clone();
479        let result: Result<&str> = executor
480            .execute(|| {
481                let a = attempts_clone.clone();
482                async move {
483                    a.fetch_add(1, Ordering::SeqCst);
484                    Err(SeerError::InvalidDomain("bad.".to_string()))
485                }
486            })
487            .await;
488
489        assert!(result.is_err());
490        // Should only attempt once since InvalidDomain is not retryable
491        assert_eq!(attempts.load(Ordering::SeqCst), 1);
492    }
493
494    #[tokio::test]
495    async fn test_executor_exhausts_retries() {
496        let policy = RetryPolicy::new()
497            .with_max_attempts(3)
498            .with_initial_delay(Duration::from_millis(1))
499            .with_jitter(false);
500        let executor = RetryExecutor::new(policy);
501        let attempts = Arc::new(AtomicUsize::new(0));
502
503        let attempts_clone = attempts.clone();
504        let result: Result<&str> = executor
505            .execute(|| {
506                let a = attempts_clone.clone();
507                async move {
508                    a.fetch_add(1, Ordering::SeqCst);
509                    Err(SeerError::Timeout("always fails".to_string()))
510                }
511            })
512            .await;
513
514        assert!(result.is_err());
515        assert_eq!(attempts.load(Ordering::SeqCst), 3);
516
517        // Check that we get RetryExhausted error
518        match result.unwrap_err() {
519            SeerError::RetryExhausted { attempts, .. } => {
520                assert_eq!(attempts, 3);
521            }
522            other => panic!("Expected RetryExhausted, got {:?}", other),
523        }
524    }
525
526    #[test]
527    fn test_no_retry_policy() {
528        let policy = RetryPolicy::no_retry();
529        assert_eq!(policy.max_attempts, 1);
530    }
531
532    #[test]
533    fn test_delay_overflow_protection() {
534        let policy = RetryPolicy::new()
535            .with_initial_delay(Duration::from_millis(100))
536            .with_multiplier(2.0)
537            .with_max_delay(Duration::from_secs(5))
538            .with_jitter(false);
539
540        // Test with very large attempt numbers - should not panic or produce invalid durations
541        let delay_50 = policy.delay_for_attempt(50);
542        let delay_100 = policy.delay_for_attempt(100);
543        let delay_1000 = policy.delay_for_attempt(1000);
544
545        // All should be capped at max_delay due to our overflow protection
546        assert!(delay_50 <= Duration::from_secs(5));
547        assert!(delay_100 <= Duration::from_secs(5));
548        assert!(delay_1000 <= Duration::from_secs(5));
549    }
550}