1use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::debug;
11
12use crate::error::{Result, SeerError};
13
14#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17 pub max_attempts: usize,
19 pub initial_delay: Duration,
21 pub max_delay: Duration,
23 pub multiplier: f64,
25 pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30 fn default() -> Self {
31 Self {
32 max_attempts: 3,
33 initial_delay: Duration::from_millis(100),
34 max_delay: Duration::from_secs(5),
35 multiplier: 2.0,
36 jitter: true,
37 }
38 }
39}
40
41impl RetryPolicy {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49 self.max_attempts = attempts.max(1);
50 self
51 }
52
53 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55 self.initial_delay = delay;
56 self
57 }
58
59 pub fn with_max_delay(mut self, delay: Duration) -> Self {
61 self.max_delay = delay;
62 self
63 }
64
65 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67 self.multiplier = multiplier.max(1.0);
68 self
69 }
70
71 pub fn with_jitter(mut self, jitter: bool) -> Self {
73 self.jitter = jitter;
74 self
75 }
76
77 pub fn no_retry() -> Self {
79 Self {
80 max_attempts: 1,
81 ..Self::default()
82 }
83 }
84
85 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90 if attempt == 0 {
91 return self.initial_delay;
92 }
93
94 let safe_attempt = attempt.min(20) as i32;
97
98 let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99 let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101 let final_delay = if self.jitter {
102 let mut rng = rand::thread_rng();
108 let jitter_factor = rng.gen_range(0.0..1.0);
109 capped_delay * jitter_factor
110 } else {
111 capped_delay
112 };
113
114 Duration::from_millis(final_delay as u64)
115 }
116}
117
118pub trait RetryClassifier: Send + Sync {
120 fn is_retryable(&self, error: &SeerError) -> bool;
122}
123
124#[derive(Debug, Clone, Default)]
137pub struct NetworkRetryClassifier;
138
139impl NetworkRetryClassifier {
140 pub fn new() -> Self {
141 Self
142 }
143}
144
145impl RetryClassifier for NetworkRetryClassifier {
146 fn is_retryable(&self, error: &SeerError) -> bool {
147 match error {
148 SeerError::Timeout(_) => true,
150 SeerError::WhoisConnectionFailed(_) => true,
151 SeerError::RateLimited(_) => true,
152
153 SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
155
156 SeerError::WhoisError(msg) => {
158 let lower = msg.to_lowercase();
159 lower.contains("connection")
160 || lower.contains("timeout")
161 || lower.contains("refused")
162 || lower.contains("reset")
163 }
164
165 SeerError::RdapError(msg) => {
167 let lower = msg.to_lowercase();
168 lower.contains("status 5")
169 || lower.contains("status 429")
170 || lower.contains("timeout")
171 }
172
173 SeerError::RdapBootstrapError(msg) => {
175 let lower = msg.to_lowercase();
176 lower.contains("timeout") || lower.contains("connection")
177 }
178
179 SeerError::DnsError(msg) => {
181 let lower = msg.to_lowercase();
182 lower.contains("timeout") || lower.contains("temporary")
183 }
184
185 SeerError::HttpError(msg) => {
187 let lower = msg.to_lowercase();
188 lower.contains("timeout")
189 || lower.contains("connection")
190 || lower.contains("status 5")
191 || lower.contains("status 429")
192 }
193
194 SeerError::InvalidDomain(_) => false,
196 SeerError::DomainNotAllowed { .. } => false,
197 SeerError::InvalidIpAddress(_) => false,
198 SeerError::InvalidRecordType(_) => false,
199 SeerError::WhoisServerNotFound(_) => false,
200 SeerError::JsonError(_) => false,
201 SeerError::CertificateError(_) => false,
202 SeerError::SslError(_) => false,
203 SeerError::DnsResolverError(_) => false,
204 SeerError::BulkOperationError { .. } => false,
205 SeerError::LookupFailed { .. } => false,
206 SeerError::ConfigError(_) => false,
207 SeerError::InvalidInput(_) => false,
208 SeerError::RetryExhausted { last_error, .. } => self.is_retryable(last_error),
213 SeerError::Other(_) => false,
214 }
215 }
216}
217
218fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
220 if error.is_connect() {
222 return true;
223 }
224
225 if error.is_timeout() {
227 return true;
228 }
229
230 if let Some(status) = error.status() {
232 if status.as_u16() == 429 {
234 return true;
235 }
236 if status.is_server_error() {
238 return true;
239 }
240 return false;
242 }
243
244 if error.is_request() || error.is_body() {
246 return false;
247 }
248
249 true
251}
252
253#[derive(Debug, Clone)]
255pub struct RetryExecutor<C: RetryClassifier> {
256 policy: RetryPolicy,
257 classifier: C,
258}
259
260impl RetryExecutor<NetworkRetryClassifier> {
261 pub fn new(policy: RetryPolicy) -> Self {
263 Self {
264 policy,
265 classifier: NetworkRetryClassifier::new(),
266 }
267 }
268}
269
270impl<C: RetryClassifier> RetryExecutor<C> {
271 pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
273 Self { policy, classifier }
274 }
275
276 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
282 where
283 F: FnMut() -> Fut,
284 Fut: Future<Output = Result<T>>,
285 {
286 let mut last_error: Option<SeerError> = None;
287 let mut attempt = 0;
288
289 while attempt < self.policy.max_attempts {
290 match operation().await {
291 Ok(result) => return Ok(result),
292 Err(e) => {
293 let is_retryable = self.classifier.is_retryable(&e);
294 let attempts_remaining = self.policy.max_attempts - attempt - 1;
295
296 if !is_retryable || attempts_remaining == 0 {
297 if attempt > 0 {
298 debug!(
299 attempt = attempt + 1,
300 max_attempts = self.policy.max_attempts,
301 error = %e,
302 "Operation failed after retries"
303 );
304 }
305 return Err(if attempt > 0 {
306 SeerError::RetryExhausted {
307 attempts: attempt + 1,
308 last_error: Box::new(e),
309 }
310 } else {
311 e
312 });
313 }
314
315 let delay = self.policy.delay_for_attempt(attempt);
316 debug!(
317 attempt = attempt + 1,
318 max_attempts = self.policy.max_attempts,
319 delay_ms = delay.as_millis(),
320 error = %e,
321 "Retrying after transient error"
322 );
323
324 last_error = Some(e);
325 tokio::time::sleep(delay).await;
326 attempt += 1;
327 }
328 }
329 }
330
331 Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
333 }
334
335 pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
338 where
339 F: FnOnce() -> Fut,
340 Fut: Future<Output = Result<T>>,
341 {
342 operation().await
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use std::sync::atomic::{AtomicUsize, Ordering};
350 use std::sync::Arc;
351
352 #[test]
353 fn test_retry_policy_defaults() {
354 let policy = RetryPolicy::default();
355 assert_eq!(policy.max_attempts, 3);
356 assert_eq!(policy.initial_delay, Duration::from_millis(100));
357 assert_eq!(policy.max_delay, Duration::from_secs(5));
358 assert_eq!(policy.multiplier, 2.0);
359 assert!(policy.jitter);
360 }
361
362 #[test]
363 fn test_retry_policy_builder() {
364 let policy = RetryPolicy::new()
365 .with_max_attempts(5)
366 .with_initial_delay(Duration::from_millis(200))
367 .with_max_delay(Duration::from_secs(10))
368 .with_multiplier(3.0)
369 .with_jitter(false);
370
371 assert_eq!(policy.max_attempts, 5);
372 assert_eq!(policy.initial_delay, Duration::from_millis(200));
373 assert_eq!(policy.max_delay, Duration::from_secs(10));
374 assert_eq!(policy.multiplier, 3.0);
375 assert!(!policy.jitter);
376 }
377
378 #[test]
379 fn test_delay_calculation_no_jitter() {
380 let policy = RetryPolicy::new()
381 .with_initial_delay(Duration::from_millis(100))
382 .with_multiplier(2.0)
383 .with_max_delay(Duration::from_secs(10))
384 .with_jitter(false);
385
386 assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
387 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
388 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
389 assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
390 }
391
392 #[test]
393 fn test_delay_capped_at_max() {
394 let policy = RetryPolicy::new()
395 .with_initial_delay(Duration::from_secs(1))
396 .with_multiplier(10.0)
397 .with_max_delay(Duration::from_secs(5))
398 .with_jitter(false);
399
400 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
402 }
403
404 #[test]
405 fn test_classifier_timeout_is_retryable() {
406 let classifier = NetworkRetryClassifier::new();
407 assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
408 }
409
410 #[test]
411 fn test_classifier_invalid_domain_not_retryable() {
412 let classifier = NetworkRetryClassifier::new();
413 assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
414 }
415
416 #[test]
417 fn test_classifier_server_not_found_not_retryable() {
418 let classifier = NetworkRetryClassifier::new();
419 assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
420 }
421
422 #[test]
423 fn test_classifier_rate_limited_is_retryable() {
424 let classifier = NetworkRetryClassifier::new();
425 assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
426 }
427
428 #[tokio::test]
429 async fn test_executor_success_on_first_try() {
430 let policy = RetryPolicy::new().with_max_attempts(3);
431 let executor = RetryExecutor::new(policy);
432 let attempts = Arc::new(AtomicUsize::new(0));
433
434 let attempts_clone = attempts.clone();
435 let result: Result<&str> = executor
436 .execute(|| {
437 let a = attempts_clone.clone();
438 async move {
439 a.fetch_add(1, Ordering::SeqCst);
440 Ok("success")
441 }
442 })
443 .await;
444
445 assert!(result.is_ok());
446 assert_eq!(result.unwrap(), "success");
447 assert_eq!(attempts.load(Ordering::SeqCst), 1);
448 }
449
450 #[tokio::test]
451 async fn test_executor_retries_on_transient_error() {
452 let policy = RetryPolicy::new()
453 .with_max_attempts(3)
454 .with_initial_delay(Duration::from_millis(1))
455 .with_jitter(false);
456 let executor = RetryExecutor::new(policy);
457 let attempts = Arc::new(AtomicUsize::new(0));
458
459 let attempts_clone = attempts.clone();
460 let result: Result<&str> = executor
461 .execute(|| {
462 let a = attempts_clone.clone();
463 async move {
464 let count = a.fetch_add(1, Ordering::SeqCst);
465 if count < 2 {
466 Err(SeerError::Timeout("test timeout".to_string()))
467 } else {
468 Ok("success after retries")
469 }
470 }
471 })
472 .await;
473
474 assert!(result.is_ok());
475 assert_eq!(result.unwrap(), "success after retries");
476 assert_eq!(attempts.load(Ordering::SeqCst), 3);
477 }
478
479 #[tokio::test]
480 async fn test_executor_no_retry_on_non_retryable_error() {
481 let policy = RetryPolicy::new()
482 .with_max_attempts(3)
483 .with_initial_delay(Duration::from_millis(1));
484 let executor = RetryExecutor::new(policy);
485 let attempts = Arc::new(AtomicUsize::new(0));
486
487 let attempts_clone = attempts.clone();
488 let result: Result<&str> = executor
489 .execute(|| {
490 let a = attempts_clone.clone();
491 async move {
492 a.fetch_add(1, Ordering::SeqCst);
493 Err(SeerError::InvalidDomain("bad.".to_string()))
494 }
495 })
496 .await;
497
498 assert!(result.is_err());
499 assert_eq!(attempts.load(Ordering::SeqCst), 1);
501 }
502
503 #[tokio::test]
504 async fn test_executor_exhausts_retries() {
505 let policy = RetryPolicy::new()
506 .with_max_attempts(3)
507 .with_initial_delay(Duration::from_millis(1))
508 .with_jitter(false);
509 let executor = RetryExecutor::new(policy);
510 let attempts = Arc::new(AtomicUsize::new(0));
511
512 let attempts_clone = attempts.clone();
513 let result: Result<&str> = executor
514 .execute(|| {
515 let a = attempts_clone.clone();
516 async move {
517 a.fetch_add(1, Ordering::SeqCst);
518 Err(SeerError::Timeout("always fails".to_string()))
519 }
520 })
521 .await;
522
523 assert!(result.is_err());
524 assert_eq!(attempts.load(Ordering::SeqCst), 3);
525
526 match result.unwrap_err() {
528 SeerError::RetryExhausted { attempts, .. } => {
529 assert_eq!(attempts, 3);
530 }
531 other => panic!("Expected RetryExhausted, got {:?}", other),
532 }
533 }
534
535 #[test]
536 fn test_no_retry_policy() {
537 let policy = RetryPolicy::no_retry();
538 assert_eq!(policy.max_attempts, 1);
539 }
540
541 #[test]
542 fn test_delay_overflow_protection() {
543 let policy = RetryPolicy::new()
544 .with_initial_delay(Duration::from_millis(100))
545 .with_multiplier(2.0)
546 .with_max_delay(Duration::from_secs(5))
547 .with_jitter(false);
548
549 let delay_50 = policy.delay_for_attempt(50);
551 let delay_100 = policy.delay_for_attempt(100);
552 let delay_1000 = policy.delay_for_attempt(1000);
553
554 assert!(delay_50 <= Duration::from_secs(5));
556 assert!(delay_100 <= Duration::from_secs(5));
557 assert!(delay_1000 <= Duration::from_secs(5));
558 }
559
560 #[test]
561 fn retry_exhausted_is_retryable_if_inner_is() {
562 let classifier = NetworkRetryClassifier::new();
567
568 let retryable_inner = SeerError::Timeout("inner timed out".to_string());
569 let wrapped_retryable = SeerError::RetryExhausted {
570 attempts: 3,
571 last_error: Box::new(retryable_inner),
572 };
573 assert!(
574 classifier.is_retryable(&wrapped_retryable),
575 "RetryExhausted wrapping a retryable Timeout should be retryable",
576 );
577
578 let non_retryable_inner = SeerError::InvalidDomain("bad.".to_string());
579 let wrapped_non_retryable = SeerError::RetryExhausted {
580 attempts: 3,
581 last_error: Box::new(non_retryable_inner),
582 };
583 assert!(
584 !classifier.is_retryable(&wrapped_non_retryable),
585 "RetryExhausted wrapping a non-retryable InvalidDomain must not be retryable",
586 );
587 }
588}