Skip to main content

seer_core/
retry.rs

1//! Retry logic with exponential backoff for transient failures.
2//!
3//! This module provides configurable retry policies and executors for handling
4//! transient network failures in WHOIS and RDAP lookups.
5
6use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::{debug, warn};
11
12use crate::error::{Result, SeerError};
13
14/// Configuration for retry behavior with exponential backoff.
15#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17    /// Maximum number of attempts (including the initial attempt).
18    pub max_attempts: usize,
19    /// Initial delay before the first retry.
20    pub initial_delay: Duration,
21    /// Maximum delay between retries (caps exponential growth).
22    pub max_delay: Duration,
23    /// Multiplier for exponential backoff (delay *= multiplier after each retry).
24    pub multiplier: f64,
25    /// Whether to add random jitter to delays to avoid thundering herd.
26    pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30    fn default() -> Self {
31        Self {
32            max_attempts: 3,
33            initial_delay: Duration::from_millis(100),
34            max_delay: Duration::from_secs(5),
35            multiplier: 2.0,
36            jitter: true,
37        }
38    }
39}
40
41impl RetryPolicy {
42    /// Creates a new retry policy with default settings.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Sets the maximum number of attempts.
48    pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49        self.max_attempts = attempts.max(1);
50        self
51    }
52
53    /// Sets the initial delay before the first retry.
54    pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55        self.initial_delay = delay;
56        self
57    }
58
59    /// Sets the maximum delay between retries.
60    pub fn with_max_delay(mut self, delay: Duration) -> Self {
61        self.max_delay = delay;
62        self
63    }
64
65    /// Sets the multiplier for exponential backoff.
66    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67        self.multiplier = multiplier.max(1.0);
68        self
69    }
70
71    /// Enables or disables jitter.
72    pub fn with_jitter(mut self, jitter: bool) -> Self {
73        self.jitter = jitter;
74        self
75    }
76
77    /// Creates a policy that disables retries (single attempt only).
78    pub fn no_retry() -> Self {
79        Self {
80            max_attempts: 1,
81            ..Self::default()
82        }
83    }
84
85    /// Calculates the delay for a given attempt number (0-indexed).
86    ///
87    /// The attempt number is internally capped to prevent integer overflow
88    /// in the exponential calculation.
89    pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90        if attempt == 0 {
91            return self.initial_delay;
92        }
93
94        // Cap attempt to prevent overflow in powi() - 20 attempts with multiplier 2.0
95        // gives 2^20 = ~1 million, which is safe for f64 and reasonable for delays
96        let safe_attempt = attempt.min(20) as i32;
97
98        let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99        let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101        let final_delay = if self.jitter {
102            // Add jitter: random value between 50% and 100% of the delay
103            let mut rng = rand::thread_rng();
104            let jitter_factor = rng.gen_range(0.5..1.0);
105            capped_delay * jitter_factor
106        } else {
107            capped_delay
108        };
109
110        Duration::from_millis(final_delay as u64)
111    }
112}
113
114/// Trait for classifying whether an error is retryable.
115pub trait RetryClassifier: Send + Sync {
116    /// Returns true if the error is transient and the operation should be retried.
117    fn is_retryable(&self, error: &SeerError) -> bool;
118}
119
120/// Default classifier for network operations (WHOIS/RDAP).
121///
122/// Classifies the following as retryable:
123/// - Timeouts
124/// - Connection failures (IO errors)
125/// - Rate limiting (429)
126/// - Server errors (5xx)
127///
128/// Non-retryable errors:
129/// - Invalid input (domain, IP, record type)
130/// - Server not found
131/// - Parse errors (JSON, WHOIS format)
132#[derive(Debug, Clone, Default)]
133pub struct NetworkRetryClassifier;
134
135impl NetworkRetryClassifier {
136    pub fn new() -> Self {
137        Self
138    }
139}
140
141impl RetryClassifier for NetworkRetryClassifier {
142    fn is_retryable(&self, error: &SeerError) -> bool {
143        match error {
144            // Transient errors - worth retrying
145            SeerError::Timeout(_) => true,
146            SeerError::WhoisConnectionFailed(_) => true,
147            SeerError::RateLimited(_) => true,
148
149            // Reqwest errors need deeper inspection
150            SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
151
152            // WHOIS errors might be transient if they're connection-related
153            SeerError::WhoisError(msg) => {
154                let lower = msg.to_lowercase();
155                lower.contains("connection")
156                    || lower.contains("timeout")
157                    || lower.contains("refused")
158                    || lower.contains("reset")
159            }
160
161            // RDAP errors might be transient server errors
162            SeerError::RdapError(msg) => {
163                let lower = msg.to_lowercase();
164                lower.contains("status 5")
165                    || lower.contains("status 429")
166                    || lower.contains("timeout")
167            }
168
169            // Bootstrap errors could be transient if IANA is temporarily unavailable
170            SeerError::RdapBootstrapError(msg) => {
171                let lower = msg.to_lowercase();
172                lower.contains("timeout") || lower.contains("connection")
173            }
174
175            // DNS errors can be transient
176            SeerError::DnsError(msg) => {
177                let lower = msg.to_lowercase();
178                lower.contains("timeout") || lower.contains("temporary")
179            }
180
181            // HTTP errors might be transient
182            SeerError::HttpError(msg) => {
183                let lower = msg.to_lowercase();
184                lower.contains("timeout") || lower.contains("connection") || lower.contains("5")
185            }
186
187            // Not retryable - permanent failures
188            SeerError::InvalidDomain(_) => false,
189            SeerError::DomainNotAllowed { .. } => false,
190            SeerError::InvalidIpAddress(_) => false,
191            SeerError::InvalidRecordType(_) => false,
192            SeerError::WhoisServerNotFound(_) => false,
193            SeerError::JsonError(_) => false,
194            SeerError::CertificateError(_) => false,
195            SeerError::DnsResolverError(_) => false,
196            SeerError::BulkOperationError { .. } => false,
197            SeerError::LookupFailed { .. } => false,
198            SeerError::RetryExhausted { .. } => false,
199            SeerError::Other(_) => false,
200        }
201    }
202}
203
204/// Checks if a reqwest error is transient and worth retrying.
205fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
206    // Connection errors are transient
207    if error.is_connect() {
208        return true;
209    }
210
211    // Timeout errors are transient
212    if error.is_timeout() {
213        return true;
214    }
215
216    // Check HTTP status codes
217    if let Some(status) = error.status() {
218        // 429 Too Many Requests - rate limited, retry with backoff
219        if status.as_u16() == 429 {
220            return true;
221        }
222        // 5xx Server errors are transient
223        if status.is_server_error() {
224            return true;
225        }
226        // 4xx Client errors (except 429) are not retryable
227        return false;
228    }
229
230    // Request/body errors are generally not retryable
231    if error.is_request() || error.is_body() {
232        return false;
233    }
234
235    // Default: assume transient for unknown errors
236    true
237}
238
239/// Executes operations with retry logic using exponential backoff.
240#[derive(Debug, Clone)]
241pub struct RetryExecutor<C: RetryClassifier> {
242    policy: RetryPolicy,
243    classifier: C,
244}
245
246impl RetryExecutor<NetworkRetryClassifier> {
247    /// Creates a new executor with the default network retry classifier.
248    pub fn new(policy: RetryPolicy) -> Self {
249        Self {
250            policy,
251            classifier: NetworkRetryClassifier::new(),
252        }
253    }
254}
255
256impl<C: RetryClassifier> RetryExecutor<C> {
257    /// Creates a new executor with a custom classifier.
258    pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
259        Self { policy, classifier }
260    }
261
262    /// Executes an async operation with retry logic.
263    ///
264    /// The operation will be retried up to `max_attempts` times if it fails
265    /// with a retryable error. Delays between retries follow exponential
266    /// backoff with optional jitter.
267    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
268    where
269        F: FnMut() -> Fut,
270        Fut: Future<Output = Result<T>>,
271    {
272        let mut last_error: Option<SeerError> = None;
273        let mut attempt = 0;
274
275        while attempt < self.policy.max_attempts {
276            match operation().await {
277                Ok(result) => return Ok(result),
278                Err(e) => {
279                    let is_retryable = self.classifier.is_retryable(&e);
280                    let attempts_remaining = self.policy.max_attempts - attempt - 1;
281
282                    if !is_retryable || attempts_remaining == 0 {
283                        if attempt > 0 {
284                            warn!(
285                                attempt = attempt + 1,
286                                max_attempts = self.policy.max_attempts,
287                                error = %e,
288                                "Operation failed after retries"
289                            );
290                        }
291                        return Err(if attempt > 0 {
292                            SeerError::RetryExhausted {
293                                attempts: attempt + 1,
294                                last_error: e.to_string(),
295                            }
296                        } else {
297                            e
298                        });
299                    }
300
301                    let delay = self.policy.delay_for_attempt(attempt);
302                    debug!(
303                        attempt = attempt + 1,
304                        max_attempts = self.policy.max_attempts,
305                        delay_ms = delay.as_millis(),
306                        error = %e,
307                        "Retrying after transient error"
308                    );
309
310                    last_error = Some(e);
311                    tokio::time::sleep(delay).await;
312                    attempt += 1;
313                }
314            }
315        }
316
317        // Should not reach here, but handle it gracefully
318        Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
319    }
320
321    /// Executes an async operation once without retries.
322    /// Useful for operations that should not be retried.
323    pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
324    where
325        F: FnOnce() -> Fut,
326        Fut: Future<Output = Result<T>>,
327    {
328        operation().await
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use std::sync::atomic::{AtomicUsize, Ordering};
336    use std::sync::Arc;
337
338    #[test]
339    fn test_retry_policy_defaults() {
340        let policy = RetryPolicy::default();
341        assert_eq!(policy.max_attempts, 3);
342        assert_eq!(policy.initial_delay, Duration::from_millis(100));
343        assert_eq!(policy.max_delay, Duration::from_secs(5));
344        assert_eq!(policy.multiplier, 2.0);
345        assert!(policy.jitter);
346    }
347
348    #[test]
349    fn test_retry_policy_builder() {
350        let policy = RetryPolicy::new()
351            .with_max_attempts(5)
352            .with_initial_delay(Duration::from_millis(200))
353            .with_max_delay(Duration::from_secs(10))
354            .with_multiplier(3.0)
355            .with_jitter(false);
356
357        assert_eq!(policy.max_attempts, 5);
358        assert_eq!(policy.initial_delay, Duration::from_millis(200));
359        assert_eq!(policy.max_delay, Duration::from_secs(10));
360        assert_eq!(policy.multiplier, 3.0);
361        assert!(!policy.jitter);
362    }
363
364    #[test]
365    fn test_delay_calculation_no_jitter() {
366        let policy = RetryPolicy::new()
367            .with_initial_delay(Duration::from_millis(100))
368            .with_multiplier(2.0)
369            .with_max_delay(Duration::from_secs(10))
370            .with_jitter(false);
371
372        assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
373        assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
374        assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
375        assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
376    }
377
378    #[test]
379    fn test_delay_capped_at_max() {
380        let policy = RetryPolicy::new()
381            .with_initial_delay(Duration::from_secs(1))
382            .with_multiplier(10.0)
383            .with_max_delay(Duration::from_secs(5))
384            .with_jitter(false);
385
386        // 1s * 10^2 = 100s, but capped at 5s
387        assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
388    }
389
390    #[test]
391    fn test_classifier_timeout_is_retryable() {
392        let classifier = NetworkRetryClassifier::new();
393        assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
394    }
395
396    #[test]
397    fn test_classifier_invalid_domain_not_retryable() {
398        let classifier = NetworkRetryClassifier::new();
399        assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
400    }
401
402    #[test]
403    fn test_classifier_server_not_found_not_retryable() {
404        let classifier = NetworkRetryClassifier::new();
405        assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
406    }
407
408    #[test]
409    fn test_classifier_rate_limited_is_retryable() {
410        let classifier = NetworkRetryClassifier::new();
411        assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
412    }
413
414    #[tokio::test]
415    async fn test_executor_success_on_first_try() {
416        let policy = RetryPolicy::new().with_max_attempts(3);
417        let executor = RetryExecutor::new(policy);
418        let attempts = Arc::new(AtomicUsize::new(0));
419
420        let attempts_clone = attempts.clone();
421        let result: Result<&str> = executor
422            .execute(|| {
423                let a = attempts_clone.clone();
424                async move {
425                    a.fetch_add(1, Ordering::SeqCst);
426                    Ok("success")
427                }
428            })
429            .await;
430
431        assert!(result.is_ok());
432        assert_eq!(result.unwrap(), "success");
433        assert_eq!(attempts.load(Ordering::SeqCst), 1);
434    }
435
436    #[tokio::test]
437    async fn test_executor_retries_on_transient_error() {
438        let policy = RetryPolicy::new()
439            .with_max_attempts(3)
440            .with_initial_delay(Duration::from_millis(1))
441            .with_jitter(false);
442        let executor = RetryExecutor::new(policy);
443        let attempts = Arc::new(AtomicUsize::new(0));
444
445        let attempts_clone = attempts.clone();
446        let result: Result<&str> = executor
447            .execute(|| {
448                let a = attempts_clone.clone();
449                async move {
450                    let count = a.fetch_add(1, Ordering::SeqCst);
451                    if count < 2 {
452                        Err(SeerError::Timeout("test timeout".to_string()))
453                    } else {
454                        Ok("success after retries")
455                    }
456                }
457            })
458            .await;
459
460        assert!(result.is_ok());
461        assert_eq!(result.unwrap(), "success after retries");
462        assert_eq!(attempts.load(Ordering::SeqCst), 3);
463    }
464
465    #[tokio::test]
466    async fn test_executor_no_retry_on_non_retryable_error() {
467        let policy = RetryPolicy::new()
468            .with_max_attempts(3)
469            .with_initial_delay(Duration::from_millis(1));
470        let executor = RetryExecutor::new(policy);
471        let attempts = Arc::new(AtomicUsize::new(0));
472
473        let attempts_clone = attempts.clone();
474        let result: Result<&str> = executor
475            .execute(|| {
476                let a = attempts_clone.clone();
477                async move {
478                    a.fetch_add(1, Ordering::SeqCst);
479                    Err(SeerError::InvalidDomain("bad.".to_string()))
480                }
481            })
482            .await;
483
484        assert!(result.is_err());
485        // Should only attempt once since InvalidDomain is not retryable
486        assert_eq!(attempts.load(Ordering::SeqCst), 1);
487    }
488
489    #[tokio::test]
490    async fn test_executor_exhausts_retries() {
491        let policy = RetryPolicy::new()
492            .with_max_attempts(3)
493            .with_initial_delay(Duration::from_millis(1))
494            .with_jitter(false);
495        let executor = RetryExecutor::new(policy);
496        let attempts = Arc::new(AtomicUsize::new(0));
497
498        let attempts_clone = attempts.clone();
499        let result: Result<&str> = executor
500            .execute(|| {
501                let a = attempts_clone.clone();
502                async move {
503                    a.fetch_add(1, Ordering::SeqCst);
504                    Err(SeerError::Timeout("always fails".to_string()))
505                }
506            })
507            .await;
508
509        assert!(result.is_err());
510        assert_eq!(attempts.load(Ordering::SeqCst), 3);
511
512        // Check that we get RetryExhausted error
513        match result.unwrap_err() {
514            SeerError::RetryExhausted { attempts, .. } => {
515                assert_eq!(attempts, 3);
516            }
517            other => panic!("Expected RetryExhausted, got {:?}", other),
518        }
519    }
520
521    #[test]
522    fn test_no_retry_policy() {
523        let policy = RetryPolicy::no_retry();
524        assert_eq!(policy.max_attempts, 1);
525    }
526
527    #[test]
528    fn test_delay_overflow_protection() {
529        let policy = RetryPolicy::new()
530            .with_initial_delay(Duration::from_millis(100))
531            .with_multiplier(2.0)
532            .with_max_delay(Duration::from_secs(5))
533            .with_jitter(false);
534
535        // Test with very large attempt numbers - should not panic or produce invalid durations
536        let delay_50 = policy.delay_for_attempt(50);
537        let delay_100 = policy.delay_for_attempt(100);
538        let delay_1000 = policy.delay_for_attempt(1000);
539
540        // All should be capped at max_delay due to our overflow protection
541        assert!(delay_50 <= Duration::from_secs(5));
542        assert!(delay_100 <= Duration::from_secs(5));
543        assert!(delay_1000 <= Duration::from_secs(5));
544    }
545}