Skip to main content

seer_core/
retry.rs

1//! Retry logic with exponential backoff for transient failures.
2//!
3//! This module provides configurable retry policies and executors for handling
4//! transient network failures in WHOIS and RDAP lookups.
5
6use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::{debug, warn};
11
12use crate::error::{Result, SeerError};
13
14/// Configuration for retry behavior with exponential backoff.
15#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17    /// Maximum number of attempts (including the initial attempt).
18    pub max_attempts: usize,
19    /// Initial delay before the first retry.
20    pub initial_delay: Duration,
21    /// Maximum delay between retries (caps exponential growth).
22    pub max_delay: Duration,
23    /// Multiplier for exponential backoff (delay *= multiplier after each retry).
24    pub multiplier: f64,
25    /// Whether to add random jitter to delays to avoid thundering herd.
26    pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30    fn default() -> Self {
31        Self {
32            max_attempts: 3,
33            initial_delay: Duration::from_millis(100),
34            max_delay: Duration::from_secs(5),
35            multiplier: 2.0,
36            jitter: true,
37        }
38    }
39}
40
41impl RetryPolicy {
42    /// Creates a new retry policy with default settings.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Sets the maximum number of attempts.
48    pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49        self.max_attempts = attempts.max(1);
50        self
51    }
52
53    /// Sets the initial delay before the first retry.
54    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55        self.initial_delay = delay;
56        self
57    }
58
59    /// Sets the maximum delay between retries.
60    pub fn with_max_delay(mut self, delay: Duration) -> Self {
61        self.max_delay = delay;
62        self
63    }
64
65    /// Sets the multiplier for exponential backoff.
66    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67        self.multiplier = multiplier.max(1.0);
68        self
69    }
70
71    /// Enables or disables jitter.
72    pub fn with_jitter(mut self, jitter: bool) -> Self {
73        self.jitter = jitter;
74        self
75    }
76
77    /// Creates a policy that disables retries (single attempt only).
78    pub fn no_retry() -> Self {
79        Self {
80            max_attempts: 1,
81            ..Self::default()
82        }
83    }
84
85    /// Calculates the delay for a given attempt number (0-indexed).
86    ///
87    /// The attempt number is internally capped to prevent integer overflow
88    /// in the exponential calculation.
89    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90        if attempt == 0 {
91            return self.initial_delay;
92        }
93
94        // Cap attempt to prevent overflow in powi() - 20 attempts with multiplier 2.0
95        // gives 2^20 = ~1 million, which is safe for f64 and reasonable for delays
96        let safe_attempt = attempt.min(20) as i32;
97
98        let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99        let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101        let final_delay = if self.jitter {
102            // Add jitter: random value between 50% and 100% of the delay
103            let mut rng = rand::thread_rng();
104            let jitter_factor = rng.gen_range(0.5..1.0);
105            capped_delay * jitter_factor
106        } else {
107            capped_delay
108        };
109
110        Duration::from_millis(final_delay as u64)
111    }
112}
113
114/// Trait for classifying whether an error is retryable.
115pub trait RetryClassifier: Send + Sync {
116    /// Returns true if the error is transient and the operation should be retried.
117    fn is_retryable(&self, error: &SeerError) -> bool;
118}
119
120/// Default classifier for network operations (WHOIS/RDAP).
121///
122/// Classifies the following as retryable:
123/// - Timeouts
124/// - Connection failures (IO errors)
125/// - Rate limiting (429)
126/// - Server errors (5xx)
127///
128/// Non-retryable errors:
129/// - Invalid input (domain, IP, record type)
130/// - Server not found
131/// - Parse errors (JSON, WHOIS format)
132#[derive(Debug, Clone, Default)]
133pub struct NetworkRetryClassifier;
134
135impl NetworkRetryClassifier {
136    pub fn new() -> Self {
137        Self
138    }
139}
140
141impl RetryClassifier for NetworkRetryClassifier {
142    fn is_retryable(&self, error: &SeerError) -> bool {
143        match error {
144            // Transient errors - worth retrying
145            SeerError::Timeout(_) => true,
146            SeerError::WhoisConnectionFailed(_) => true,
147            SeerError::RateLimited(_) => true,
148
149            // Reqwest errors need deeper inspection
150            SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
151
152            // WHOIS errors might be transient if they're connection-related
153            SeerError::WhoisError(msg) => {
154                let lower = msg.to_lowercase();
155                lower.contains("connection")
156                    || lower.contains("timeout")
157                    || lower.contains("refused")
158                    || lower.contains("reset")
159            }
160
161            // RDAP errors might be transient server errors
162            SeerError::RdapError(msg) => {
163                let lower = msg.to_lowercase();
164                lower.contains("status 5")
165                    || lower.contains("status 429")
166                    || lower.contains("timeout")
167            }
168
169            // Bootstrap errors could be transient if IANA is temporarily unavailable
170            SeerError::RdapBootstrapError(msg) => {
171                let lower = msg.to_lowercase();
172                lower.contains("timeout") || lower.contains("connection")
173            }
174
175            // DNS errors can be transient
176            SeerError::DnsError(msg) => {
177                let lower = msg.to_lowercase();
178                lower.contains("timeout") || lower.contains("temporary")
179            }
180
181            // HTTP errors might be transient (server errors 5xx or 429 Too Many Requests)
182            SeerError::HttpError(msg) => {
183                let lower = msg.to_lowercase();
184                lower.contains("timeout")
185                    || lower.contains("connection")
186                    || lower.contains("status 5")
187                    || lower.contains("status 429")
188            }
189
190            // Not retryable - permanent failures
191            SeerError::InvalidDomain(_) => false,
192            SeerError::DomainNotAllowed { .. } => false,
193            SeerError::InvalidIpAddress(_) => false,
194            SeerError::InvalidRecordType(_) => false,
195            SeerError::WhoisServerNotFound(_) => false,
196            SeerError::JsonError(_) => false,
197            SeerError::CertificateError(_) => false,
198            SeerError::SslError(_) => false,
199            SeerError::DnsResolverError(_) => false,
200            SeerError::BulkOperationError { .. } => false,
201            SeerError::LookupFailed { .. } => false,
202            SeerError::ConfigError(_) => false,
203            SeerError::InvalidInput(_) => false,
204            SeerError::RetryExhausted { .. } => false,
205            SeerError::Other(_) => false,
206        }
207    }
208}
209
210/// Checks if a reqwest error is transient and worth retrying.
211fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
212    // Connection errors are transient
213    if error.is_connect() {
214        return true;
215    }
216
217    // Timeout errors are transient
218    if error.is_timeout() {
219        return true;
220    }
221
222    // Check HTTP status codes
223    if let Some(status) = error.status() {
224        // 429 Too Many Requests - rate limited, retry with backoff
225        if status.as_u16() == 429 {
226            return true;
227        }
228        // 5xx Server errors are transient
229        if status.is_server_error() {
230            return true;
231        }
232        // 4xx Client errors (except 429) are not retryable
233        return false;
234    }
235
236    // Request/body errors are generally not retryable
237    if error.is_request() || error.is_body() {
238        return false;
239    }
240
241    // Default: assume transient for unknown errors
242    true
243}
244
245/// Executes operations with retry logic using exponential backoff.
246#[derive(Debug, Clone)]
247pub struct RetryExecutor<C: RetryClassifier> {
248    policy: RetryPolicy,
249    classifier: C,
250}
251
252impl RetryExecutor<NetworkRetryClassifier> {
253    /// Creates a new executor with the default network retry classifier.
254    pub fn new(policy: RetryPolicy) -> Self {
255        Self {
256            policy,
257            classifier: NetworkRetryClassifier::new(),
258        }
259    }
260}
261
262impl<C: RetryClassifier> RetryExecutor<C> {
263    /// Creates a new executor with a custom classifier.
264    pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
265        Self { policy, classifier }
266    }
267
268    /// Executes an async operation with retry logic.
269    ///
270    /// The operation will be retried up to `max_attempts` times if it fails
271    /// with a retryable error. Delays between retries follow exponential
272    /// backoff with optional jitter.
273    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
274    where
275        F: FnMut() -> Fut,
276        Fut: Future<Output = Result<T>>,
277    {
278        let mut last_error: Option<SeerError> = None;
279        let mut attempt = 0;
280
281        while attempt < self.policy.max_attempts {
282            match operation().await {
283                Ok(result) => return Ok(result),
284                Err(e) => {
285                    let is_retryable = self.classifier.is_retryable(&e);
286                    let attempts_remaining = self.policy.max_attempts - attempt - 1;
287
288                    if !is_retryable || attempts_remaining == 0 {
289                        if attempt > 0 {
290                            warn!(
291                                attempt = attempt + 1,
292                                max_attempts = self.policy.max_attempts,
293                                error = %e,
294                                "Operation failed after retries"
295                            );
296                        }
297                        return Err(if attempt > 0 {
298                            SeerError::RetryExhausted {
299                                attempts: attempt + 1,
300                                last_error: e.to_string(),
301                            }
302                        } else {
303                            e
304                        });
305                    }
306
307                    let delay = self.policy.delay_for_attempt(attempt);
308                    debug!(
309                        attempt = attempt + 1,
310                        max_attempts = self.policy.max_attempts,
311                        delay_ms = delay.as_millis(),
312                        error = %e,
313                        "Retrying after transient error"
314                    );
315
316                    last_error = Some(e);
317                    tokio::time::sleep(delay).await;
318                    attempt += 1;
319                }
320            }
321        }
322
323        // Should not reach here, but handle it gracefully
324        Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
325    }
326
327    /// Executes an async operation once without retries.
328    /// Useful for operations that should not be retried.
329    pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
330    where
331        F: FnOnce() -> Fut,
332        Fut: Future<Output = Result<T>>,
333    {
334        operation().await
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use std::sync::atomic::{AtomicUsize, Ordering};
342    use std::sync::Arc;
343
344    #[test]
345    fn test_retry_policy_defaults() {
346        let policy = RetryPolicy::default();
347        assert_eq!(policy.max_attempts, 3);
348        assert_eq!(policy.initial_delay, Duration::from_millis(100));
349        assert_eq!(policy.max_delay, Duration::from_secs(5));
350        assert_eq!(policy.multiplier, 2.0);
351        assert!(policy.jitter);
352    }
353
354    #[test]
355    fn test_retry_policy_builder() {
356        let policy = RetryPolicy::new()
357            .with_max_attempts(5)
358            .with_initial_delay(Duration::from_millis(200))
359            .with_max_delay(Duration::from_secs(10))
360            .with_multiplier(3.0)
361            .with_jitter(false);
362
363        assert_eq!(policy.max_attempts, 5);
364        assert_eq!(policy.initial_delay, Duration::from_millis(200));
365        assert_eq!(policy.max_delay, Duration::from_secs(10));
366        assert_eq!(policy.multiplier, 3.0);
367        assert!(!policy.jitter);
368    }
369
370    #[test]
371    fn test_delay_calculation_no_jitter() {
372        let policy = RetryPolicy::new()
373            .with_initial_delay(Duration::from_millis(100))
374            .with_multiplier(2.0)
375            .with_max_delay(Duration::from_secs(10))
376            .with_jitter(false);
377
378        assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
379        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
380        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
381        assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
382    }
383
384    #[test]
385    fn test_delay_capped_at_max() {
386        let policy = RetryPolicy::new()
387            .with_initial_delay(Duration::from_secs(1))
388            .with_multiplier(10.0)
389            .with_max_delay(Duration::from_secs(5))
390            .with_jitter(false);
391
392        // 1s * 10^2 = 100s, but capped at 5s
393        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
394    }
395
396    #[test]
397    fn test_classifier_timeout_is_retryable() {
398        let classifier = NetworkRetryClassifier::new();
399        assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
400    }
401
402    #[test]
403    fn test_classifier_invalid_domain_not_retryable() {
404        let classifier = NetworkRetryClassifier::new();
405        assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
406    }
407
408    #[test]
409    fn test_classifier_server_not_found_not_retryable() {
410        let classifier = NetworkRetryClassifier::new();
411        assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
412    }
413
414    #[test]
415    fn test_classifier_rate_limited_is_retryable() {
416        let classifier = NetworkRetryClassifier::new();
417        assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
418    }
419
420    #[tokio::test]
421    async fn test_executor_success_on_first_try() {
422        let policy = RetryPolicy::new().with_max_attempts(3);
423        let executor = RetryExecutor::new(policy);
424        let attempts = Arc::new(AtomicUsize::new(0));
425
426        let attempts_clone = attempts.clone();
427        let result: Result<&str> = executor
428            .execute(|| {
429                let a = attempts_clone.clone();
430                async move {
431                    a.fetch_add(1, Ordering::SeqCst);
432                    Ok("success")
433                }
434            })
435            .await;
436
437        assert!(result.is_ok());
438        assert_eq!(result.unwrap(), "success");
439        assert_eq!(attempts.load(Ordering::SeqCst), 1);
440    }
441
442    #[tokio::test]
443    async fn test_executor_retries_on_transient_error() {
444        let policy = RetryPolicy::new()
445            .with_max_attempts(3)
446            .with_initial_delay(Duration::from_millis(1))
447            .with_jitter(false);
448        let executor = RetryExecutor::new(policy);
449        let attempts = Arc::new(AtomicUsize::new(0));
450
451        let attempts_clone = attempts.clone();
452        let result: Result<&str> = executor
453            .execute(|| {
454                let a = attempts_clone.clone();
455                async move {
456                    let count = a.fetch_add(1, Ordering::SeqCst);
457                    if count < 2 {
458                        Err(SeerError::Timeout("test timeout".to_string()))
459                    } else {
460                        Ok("success after retries")
461                    }
462                }
463            })
464            .await;
465
466        assert!(result.is_ok());
467        assert_eq!(result.unwrap(), "success after retries");
468        assert_eq!(attempts.load(Ordering::SeqCst), 3);
469    }
470
471    #[tokio::test]
472    async fn test_executor_no_retry_on_non_retryable_error() {
473        let policy = RetryPolicy::new()
474            .with_max_attempts(3)
475            .with_initial_delay(Duration::from_millis(1));
476        let executor = RetryExecutor::new(policy);
477        let attempts = Arc::new(AtomicUsize::new(0));
478
479        let attempts_clone = attempts.clone();
480        let result: Result<&str> = executor
481            .execute(|| {
482                let a = attempts_clone.clone();
483                async move {
484                    a.fetch_add(1, Ordering::SeqCst);
485                    Err(SeerError::InvalidDomain("bad.".to_string()))
486                }
487            })
488            .await;
489
490        assert!(result.is_err());
491        // Should only attempt once since InvalidDomain is not retryable
492        assert_eq!(attempts.load(Ordering::SeqCst), 1);
493    }
494
495    #[tokio::test]
496    async fn test_executor_exhausts_retries() {
497        let policy = RetryPolicy::new()
498            .with_max_attempts(3)
499            .with_initial_delay(Duration::from_millis(1))
500            .with_jitter(false);
501        let executor = RetryExecutor::new(policy);
502        let attempts = Arc::new(AtomicUsize::new(0));
503
504        let attempts_clone = attempts.clone();
505        let result: Result<&str> = executor
506            .execute(|| {
507                let a = attempts_clone.clone();
508                async move {
509                    a.fetch_add(1, Ordering::SeqCst);
510                    Err(SeerError::Timeout("always fails".to_string()))
511                }
512            })
513            .await;
514
515        assert!(result.is_err());
516        assert_eq!(attempts.load(Ordering::SeqCst), 3);
517
518        // Check that we get RetryExhausted error
519        match result.unwrap_err() {
520            SeerError::RetryExhausted { attempts, .. } => {
521                assert_eq!(attempts, 3);
522            }
523            other => panic!("Expected RetryExhausted, got {:?}", other),
524        }
525    }
526
527    #[test]
528    fn test_no_retry_policy() {
529        let policy = RetryPolicy::no_retry();
530        assert_eq!(policy.max_attempts, 1);
531    }
532
533    #[test]
534    fn test_delay_overflow_protection() {
535        let policy = RetryPolicy::new()
536            .with_initial_delay(Duration::from_millis(100))
537            .with_multiplier(2.0)
538            .with_max_delay(Duration::from_secs(5))
539            .with_jitter(false);
540
541        // Test with very large attempt numbers - should not panic or produce invalid durations
542        let delay_50 = policy.delay_for_attempt(50);
543        let delay_100 = policy.delay_for_attempt(100);
544        let delay_1000 = policy.delay_for_attempt(1000);
545
546        // All should be capped at max_delay due to our overflow protection
547        assert!(delay_50 <= Duration::from_secs(5));
548        assert!(delay_100 <= Duration::from_secs(5));
549        assert!(delay_1000 <= Duration::from_secs(5));
550    }
551}