1use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::debug;
11
12use crate::error::{Result, SeerError};
13
14#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17 pub max_attempts: usize,
19 pub initial_delay: Duration,
21 pub max_delay: Duration,
23 pub multiplier: f64,
25 pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30 fn default() -> Self {
31 Self {
32 max_attempts: 3,
33 initial_delay: Duration::from_millis(100),
34 max_delay: Duration::from_secs(5),
35 multiplier: 2.0,
36 jitter: true,
37 }
38 }
39}
40
41impl RetryPolicy {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49 self.max_attempts = attempts.max(1);
50 self
51 }
52
53 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55 self.initial_delay = delay;
56 self
57 }
58
59 pub fn with_max_delay(mut self, delay: Duration) -> Self {
61 self.max_delay = delay;
62 self
63 }
64
65 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67 self.multiplier = multiplier.max(1.0);
68 self
69 }
70
71 pub fn with_jitter(mut self, jitter: bool) -> Self {
73 self.jitter = jitter;
74 self
75 }
76
77 pub fn no_retry() -> Self {
79 Self {
80 max_attempts: 1,
81 ..Self::default()
82 }
83 }
84
85 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90 if attempt == 0 {
91 return self.initial_delay;
92 }
93
94 let safe_attempt = attempt.min(20) as i32;
97
98 let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99 let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101 let final_delay = if self.jitter {
102 let mut rng = rand::thread_rng();
104 let jitter_factor = rng.gen_range(0.5..1.0);
105 capped_delay * jitter_factor
106 } else {
107 capped_delay
108 };
109
110 Duration::from_millis(final_delay as u64)
111 }
112}
113
114pub trait RetryClassifier: Send + Sync {
116 fn is_retryable(&self, error: &SeerError) -> bool;
118}
119
120#[derive(Debug, Clone, Default)]
133pub struct NetworkRetryClassifier;
134
135impl NetworkRetryClassifier {
136 pub fn new() -> Self {
137 Self
138 }
139}
140
141impl RetryClassifier for NetworkRetryClassifier {
142 fn is_retryable(&self, error: &SeerError) -> bool {
143 match error {
144 SeerError::Timeout(_) => true,
146 SeerError::WhoisConnectionFailed(_) => true,
147 SeerError::RateLimited(_) => true,
148
149 SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
151
152 SeerError::WhoisError(msg) => {
154 let lower = msg.to_lowercase();
155 lower.contains("connection")
156 || lower.contains("timeout")
157 || lower.contains("refused")
158 || lower.contains("reset")
159 }
160
161 SeerError::RdapError(msg) => {
163 let lower = msg.to_lowercase();
164 lower.contains("status 5")
165 || lower.contains("status 429")
166 || lower.contains("timeout")
167 }
168
169 SeerError::RdapBootstrapError(msg) => {
171 let lower = msg.to_lowercase();
172 lower.contains("timeout") || lower.contains("connection")
173 }
174
175 SeerError::DnsError(msg) => {
177 let lower = msg.to_lowercase();
178 lower.contains("timeout") || lower.contains("temporary")
179 }
180
181 SeerError::HttpError(msg) => {
183 let lower = msg.to_lowercase();
184 lower.contains("timeout")
185 || lower.contains("connection")
186 || lower.contains("status 5")
187 || lower.contains("status 429")
188 }
189
190 SeerError::InvalidDomain(_) => false,
192 SeerError::DomainNotAllowed { .. } => false,
193 SeerError::InvalidIpAddress(_) => false,
194 SeerError::InvalidRecordType(_) => false,
195 SeerError::WhoisServerNotFound(_) => false,
196 SeerError::JsonError(_) => false,
197 SeerError::CertificateError(_) => false,
198 SeerError::SslError(_) => false,
199 SeerError::DnsResolverError(_) => false,
200 SeerError::BulkOperationError { .. } => false,
201 SeerError::LookupFailed { .. } => false,
202 SeerError::ConfigError(_) => false,
203 SeerError::InvalidInput(_) => false,
204 SeerError::RetryExhausted { last_error, .. } => self.is_retryable(last_error),
209 SeerError::Other(_) => false,
210 }
211 }
212}
213
214fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
216 if error.is_connect() {
218 return true;
219 }
220
221 if error.is_timeout() {
223 return true;
224 }
225
226 if let Some(status) = error.status() {
228 if status.as_u16() == 429 {
230 return true;
231 }
232 if status.is_server_error() {
234 return true;
235 }
236 return false;
238 }
239
240 if error.is_request() || error.is_body() {
242 return false;
243 }
244
245 true
247}
248
249#[derive(Debug, Clone)]
251pub struct RetryExecutor<C: RetryClassifier> {
252 policy: RetryPolicy,
253 classifier: C,
254}
255
256impl RetryExecutor<NetworkRetryClassifier> {
257 pub fn new(policy: RetryPolicy) -> Self {
259 Self {
260 policy,
261 classifier: NetworkRetryClassifier::new(),
262 }
263 }
264}
265
266impl<C: RetryClassifier> RetryExecutor<C> {
267 pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
269 Self { policy, classifier }
270 }
271
272 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
278 where
279 F: FnMut() -> Fut,
280 Fut: Future<Output = Result<T>>,
281 {
282 let mut last_error: Option<SeerError> = None;
283 let mut attempt = 0;
284
285 while attempt < self.policy.max_attempts {
286 match operation().await {
287 Ok(result) => return Ok(result),
288 Err(e) => {
289 let is_retryable = self.classifier.is_retryable(&e);
290 let attempts_remaining = self.policy.max_attempts - attempt - 1;
291
292 if !is_retryable || attempts_remaining == 0 {
293 if attempt > 0 {
294 debug!(
295 attempt = attempt + 1,
296 max_attempts = self.policy.max_attempts,
297 error = %e,
298 "Operation failed after retries"
299 );
300 }
301 return Err(if attempt > 0 {
302 SeerError::RetryExhausted {
303 attempts: attempt + 1,
304 last_error: Box::new(e),
305 }
306 } else {
307 e
308 });
309 }
310
311 let delay = self.policy.delay_for_attempt(attempt);
312 debug!(
313 attempt = attempt + 1,
314 max_attempts = self.policy.max_attempts,
315 delay_ms = delay.as_millis(),
316 error = %e,
317 "Retrying after transient error"
318 );
319
320 last_error = Some(e);
321 tokio::time::sleep(delay).await;
322 attempt += 1;
323 }
324 }
325 }
326
327 Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
329 }
330
331 pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
334 where
335 F: FnOnce() -> Fut,
336 Fut: Future<Output = Result<T>>,
337 {
338 operation().await
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use std::sync::atomic::{AtomicUsize, Ordering};
346 use std::sync::Arc;
347
348 #[test]
349 fn test_retry_policy_defaults() {
350 let policy = RetryPolicy::default();
351 assert_eq!(policy.max_attempts, 3);
352 assert_eq!(policy.initial_delay, Duration::from_millis(100));
353 assert_eq!(policy.max_delay, Duration::from_secs(5));
354 assert_eq!(policy.multiplier, 2.0);
355 assert!(policy.jitter);
356 }
357
358 #[test]
359 fn test_retry_policy_builder() {
360 let policy = RetryPolicy::new()
361 .with_max_attempts(5)
362 .with_initial_delay(Duration::from_millis(200))
363 .with_max_delay(Duration::from_secs(10))
364 .with_multiplier(3.0)
365 .with_jitter(false);
366
367 assert_eq!(policy.max_attempts, 5);
368 assert_eq!(policy.initial_delay, Duration::from_millis(200));
369 assert_eq!(policy.max_delay, Duration::from_secs(10));
370 assert_eq!(policy.multiplier, 3.0);
371 assert!(!policy.jitter);
372 }
373
374 #[test]
375 fn test_delay_calculation_no_jitter() {
376 let policy = RetryPolicy::new()
377 .with_initial_delay(Duration::from_millis(100))
378 .with_multiplier(2.0)
379 .with_max_delay(Duration::from_secs(10))
380 .with_jitter(false);
381
382 assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
383 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
384 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
385 assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
386 }
387
388 #[test]
389 fn test_delay_capped_at_max() {
390 let policy = RetryPolicy::new()
391 .with_initial_delay(Duration::from_secs(1))
392 .with_multiplier(10.0)
393 .with_max_delay(Duration::from_secs(5))
394 .with_jitter(false);
395
396 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
398 }
399
400 #[test]
401 fn test_classifier_timeout_is_retryable() {
402 let classifier = NetworkRetryClassifier::new();
403 assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
404 }
405
406 #[test]
407 fn test_classifier_invalid_domain_not_retryable() {
408 let classifier = NetworkRetryClassifier::new();
409 assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
410 }
411
412 #[test]
413 fn test_classifier_server_not_found_not_retryable() {
414 let classifier = NetworkRetryClassifier::new();
415 assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
416 }
417
418 #[test]
419 fn test_classifier_rate_limited_is_retryable() {
420 let classifier = NetworkRetryClassifier::new();
421 assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
422 }
423
424 #[tokio::test]
425 async fn test_executor_success_on_first_try() {
426 let policy = RetryPolicy::new().with_max_attempts(3);
427 let executor = RetryExecutor::new(policy);
428 let attempts = Arc::new(AtomicUsize::new(0));
429
430 let attempts_clone = attempts.clone();
431 let result: Result<&str> = executor
432 .execute(|| {
433 let a = attempts_clone.clone();
434 async move {
435 a.fetch_add(1, Ordering::SeqCst);
436 Ok("success")
437 }
438 })
439 .await;
440
441 assert!(result.is_ok());
442 assert_eq!(result.unwrap(), "success");
443 assert_eq!(attempts.load(Ordering::SeqCst), 1);
444 }
445
446 #[tokio::test]
447 async fn test_executor_retries_on_transient_error() {
448 let policy = RetryPolicy::new()
449 .with_max_attempts(3)
450 .with_initial_delay(Duration::from_millis(1))
451 .with_jitter(false);
452 let executor = RetryExecutor::new(policy);
453 let attempts = Arc::new(AtomicUsize::new(0));
454
455 let attempts_clone = attempts.clone();
456 let result: Result<&str> = executor
457 .execute(|| {
458 let a = attempts_clone.clone();
459 async move {
460 let count = a.fetch_add(1, Ordering::SeqCst);
461 if count < 2 {
462 Err(SeerError::Timeout("test timeout".to_string()))
463 } else {
464 Ok("success after retries")
465 }
466 }
467 })
468 .await;
469
470 assert!(result.is_ok());
471 assert_eq!(result.unwrap(), "success after retries");
472 assert_eq!(attempts.load(Ordering::SeqCst), 3);
473 }
474
475 #[tokio::test]
476 async fn test_executor_no_retry_on_non_retryable_error() {
477 let policy = RetryPolicy::new()
478 .with_max_attempts(3)
479 .with_initial_delay(Duration::from_millis(1));
480 let executor = RetryExecutor::new(policy);
481 let attempts = Arc::new(AtomicUsize::new(0));
482
483 let attempts_clone = attempts.clone();
484 let result: Result<&str> = executor
485 .execute(|| {
486 let a = attempts_clone.clone();
487 async move {
488 a.fetch_add(1, Ordering::SeqCst);
489 Err(SeerError::InvalidDomain("bad.".to_string()))
490 }
491 })
492 .await;
493
494 assert!(result.is_err());
495 assert_eq!(attempts.load(Ordering::SeqCst), 1);
497 }
498
499 #[tokio::test]
500 async fn test_executor_exhausts_retries() {
501 let policy = RetryPolicy::new()
502 .with_max_attempts(3)
503 .with_initial_delay(Duration::from_millis(1))
504 .with_jitter(false);
505 let executor = RetryExecutor::new(policy);
506 let attempts = Arc::new(AtomicUsize::new(0));
507
508 let attempts_clone = attempts.clone();
509 let result: Result<&str> = executor
510 .execute(|| {
511 let a = attempts_clone.clone();
512 async move {
513 a.fetch_add(1, Ordering::SeqCst);
514 Err(SeerError::Timeout("always fails".to_string()))
515 }
516 })
517 .await;
518
519 assert!(result.is_err());
520 assert_eq!(attempts.load(Ordering::SeqCst), 3);
521
522 match result.unwrap_err() {
524 SeerError::RetryExhausted { attempts, .. } => {
525 assert_eq!(attempts, 3);
526 }
527 other => panic!("Expected RetryExhausted, got {:?}", other),
528 }
529 }
530
531 #[test]
532 fn test_no_retry_policy() {
533 let policy = RetryPolicy::no_retry();
534 assert_eq!(policy.max_attempts, 1);
535 }
536
537 #[test]
538 fn test_delay_overflow_protection() {
539 let policy = RetryPolicy::new()
540 .with_initial_delay(Duration::from_millis(100))
541 .with_multiplier(2.0)
542 .with_max_delay(Duration::from_secs(5))
543 .with_jitter(false);
544
545 let delay_50 = policy.delay_for_attempt(50);
547 let delay_100 = policy.delay_for_attempt(100);
548 let delay_1000 = policy.delay_for_attempt(1000);
549
550 assert!(delay_50 <= Duration::from_secs(5));
552 assert!(delay_100 <= Duration::from_secs(5));
553 assert!(delay_1000 <= Duration::from_secs(5));
554 }
555
556 #[test]
557 fn retry_exhausted_is_retryable_if_inner_is() {
558 let classifier = NetworkRetryClassifier::new();
563
564 let retryable_inner = SeerError::Timeout("inner timed out".to_string());
565 let wrapped_retryable = SeerError::RetryExhausted {
566 attempts: 3,
567 last_error: Box::new(retryable_inner),
568 };
569 assert!(
570 classifier.is_retryable(&wrapped_retryable),
571 "RetryExhausted wrapping a retryable Timeout should be retryable",
572 );
573
574 let non_retryable_inner = SeerError::InvalidDomain("bad.".to_string());
575 let wrapped_non_retryable = SeerError::RetryExhausted {
576 attempts: 3,
577 last_error: Box::new(non_retryable_inner),
578 };
579 assert!(
580 !classifier.is_retryable(&wrapped_non_retryable),
581 "RetryExhausted wrapping a non-retryable InvalidDomain must not be retryable",
582 );
583 }
584}