1use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::{debug, warn};
11
12use crate::error::{Result, SeerError};
13
14#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17 pub max_attempts: usize,
19 pub initial_delay: Duration,
21 pub max_delay: Duration,
23 pub multiplier: f64,
25 pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30 fn default() -> Self {
31 Self {
32 max_attempts: 3,
33 initial_delay: Duration::from_millis(100),
34 max_delay: Duration::from_secs(5),
35 multiplier: 2.0,
36 jitter: true,
37 }
38 }
39}
40
41impl RetryPolicy {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49 self.max_attempts = attempts.max(1);
50 self
51 }
52
53 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55 self.initial_delay = delay;
56 self
57 }
58
59 pub fn with_max_delay(mut self, delay: Duration) -> Self {
61 self.max_delay = delay;
62 self
63 }
64
65 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67 self.multiplier = multiplier.max(1.0);
68 self
69 }
70
71 pub fn with_jitter(mut self, jitter: bool) -> Self {
73 self.jitter = jitter;
74 self
75 }
76
77 pub fn no_retry() -> Self {
79 Self {
80 max_attempts: 1,
81 ..Self::default()
82 }
83 }
84
85 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90 if attempt == 0 {
91 return self.initial_delay;
92 }
93
94 let safe_attempt = attempt.min(20) as i32;
97
98 let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99 let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101 let final_delay = if self.jitter {
102 let mut rng = rand::thread_rng();
104 let jitter_factor = rng.gen_range(0.5..1.0);
105 capped_delay * jitter_factor
106 } else {
107 capped_delay
108 };
109
110 Duration::from_millis(final_delay as u64)
111 }
112}
113
114pub trait RetryClassifier: Send + Sync {
116 fn is_retryable(&self, error: &SeerError) -> bool;
118}
119
120#[derive(Debug, Clone, Default)]
133pub struct NetworkRetryClassifier;
134
135impl NetworkRetryClassifier {
136 pub fn new() -> Self {
137 Self
138 }
139}
140
141impl RetryClassifier for NetworkRetryClassifier {
142 fn is_retryable(&self, error: &SeerError) -> bool {
143 match error {
144 SeerError::Timeout(_) => true,
146 SeerError::WhoisConnectionFailed(_) => true,
147 SeerError::RateLimited(_) => true,
148
149 SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
151
152 SeerError::WhoisError(msg) => {
154 let lower = msg.to_lowercase();
155 lower.contains("connection")
156 || lower.contains("timeout")
157 || lower.contains("refused")
158 || lower.contains("reset")
159 }
160
161 SeerError::RdapError(msg) => {
163 let lower = msg.to_lowercase();
164 lower.contains("status 5")
165 || lower.contains("status 429")
166 || lower.contains("timeout")
167 }
168
169 SeerError::RdapBootstrapError(msg) => {
171 let lower = msg.to_lowercase();
172 lower.contains("timeout") || lower.contains("connection")
173 }
174
175 SeerError::DnsError(msg) => {
177 let lower = msg.to_lowercase();
178 lower.contains("timeout") || lower.contains("temporary")
179 }
180
181 SeerError::HttpError(msg) => {
183 let lower = msg.to_lowercase();
184 lower.contains("timeout")
185 || lower.contains("connection")
186 || lower.contains("status 5")
187 || lower.contains("status 429")
188 }
189
190 SeerError::InvalidDomain(_) => false,
192 SeerError::DomainNotAllowed { .. } => false,
193 SeerError::InvalidIpAddress(_) => false,
194 SeerError::InvalidRecordType(_) => false,
195 SeerError::WhoisServerNotFound(_) => false,
196 SeerError::JsonError(_) => false,
197 SeerError::CertificateError(_) => false,
198 SeerError::SslError(_) => false,
199 SeerError::DnsResolverError(_) => false,
200 SeerError::BulkOperationError { .. } => false,
201 SeerError::LookupFailed { .. } => false,
202 SeerError::ConfigError(_) => false,
203 SeerError::RetryExhausted { .. } => false,
204 SeerError::Other(_) => false,
205 }
206 }
207}
208
209fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
211 if error.is_connect() {
213 return true;
214 }
215
216 if error.is_timeout() {
218 return true;
219 }
220
221 if let Some(status) = error.status() {
223 if status.as_u16() == 429 {
225 return true;
226 }
227 if status.is_server_error() {
229 return true;
230 }
231 return false;
233 }
234
235 if error.is_request() || error.is_body() {
237 return false;
238 }
239
240 true
242}
243
244#[derive(Debug, Clone)]
246pub struct RetryExecutor<C: RetryClassifier> {
247 policy: RetryPolicy,
248 classifier: C,
249}
250
251impl RetryExecutor<NetworkRetryClassifier> {
252 pub fn new(policy: RetryPolicy) -> Self {
254 Self {
255 policy,
256 classifier: NetworkRetryClassifier::new(),
257 }
258 }
259}
260
261impl<C: RetryClassifier> RetryExecutor<C> {
262 pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
264 Self { policy, classifier }
265 }
266
267 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
273 where
274 F: FnMut() -> Fut,
275 Fut: Future<Output = Result<T>>,
276 {
277 let mut last_error: Option<SeerError> = None;
278 let mut attempt = 0;
279
280 while attempt < self.policy.max_attempts {
281 match operation().await {
282 Ok(result) => return Ok(result),
283 Err(e) => {
284 let is_retryable = self.classifier.is_retryable(&e);
285 let attempts_remaining = self.policy.max_attempts - attempt - 1;
286
287 if !is_retryable || attempts_remaining == 0 {
288 if attempt > 0 {
289 warn!(
290 attempt = attempt + 1,
291 max_attempts = self.policy.max_attempts,
292 error = %e,
293 "Operation failed after retries"
294 );
295 }
296 return Err(if attempt > 0 {
297 SeerError::RetryExhausted {
298 attempts: attempt + 1,
299 last_error: e.to_string(),
300 }
301 } else {
302 e
303 });
304 }
305
306 let delay = self.policy.delay_for_attempt(attempt);
307 debug!(
308 attempt = attempt + 1,
309 max_attempts = self.policy.max_attempts,
310 delay_ms = delay.as_millis(),
311 error = %e,
312 "Retrying after transient error"
313 );
314
315 last_error = Some(e);
316 tokio::time::sleep(delay).await;
317 attempt += 1;
318 }
319 }
320 }
321
322 Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
324 }
325
326 pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
329 where
330 F: FnOnce() -> Fut,
331 Fut: Future<Output = Result<T>>,
332 {
333 operation().await
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use std::sync::atomic::{AtomicUsize, Ordering};
341 use std::sync::Arc;
342
343 #[test]
344 fn test_retry_policy_defaults() {
345 let policy = RetryPolicy::default();
346 assert_eq!(policy.max_attempts, 3);
347 assert_eq!(policy.initial_delay, Duration::from_millis(100));
348 assert_eq!(policy.max_delay, Duration::from_secs(5));
349 assert_eq!(policy.multiplier, 2.0);
350 assert!(policy.jitter);
351 }
352
353 #[test]
354 fn test_retry_policy_builder() {
355 let policy = RetryPolicy::new()
356 .with_max_attempts(5)
357 .with_initial_delay(Duration::from_millis(200))
358 .with_max_delay(Duration::from_secs(10))
359 .with_multiplier(3.0)
360 .with_jitter(false);
361
362 assert_eq!(policy.max_attempts, 5);
363 assert_eq!(policy.initial_delay, Duration::from_millis(200));
364 assert_eq!(policy.max_delay, Duration::from_secs(10));
365 assert_eq!(policy.multiplier, 3.0);
366 assert!(!policy.jitter);
367 }
368
369 #[test]
370 fn test_delay_calculation_no_jitter() {
371 let policy = RetryPolicy::new()
372 .with_initial_delay(Duration::from_millis(100))
373 .with_multiplier(2.0)
374 .with_max_delay(Duration::from_secs(10))
375 .with_jitter(false);
376
377 assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
378 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
379 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
380 assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
381 }
382
383 #[test]
384 fn test_delay_capped_at_max() {
385 let policy = RetryPolicy::new()
386 .with_initial_delay(Duration::from_secs(1))
387 .with_multiplier(10.0)
388 .with_max_delay(Duration::from_secs(5))
389 .with_jitter(false);
390
391 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
393 }
394
395 #[test]
396 fn test_classifier_timeout_is_retryable() {
397 let classifier = NetworkRetryClassifier::new();
398 assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
399 }
400
401 #[test]
402 fn test_classifier_invalid_domain_not_retryable() {
403 let classifier = NetworkRetryClassifier::new();
404 assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
405 }
406
407 #[test]
408 fn test_classifier_server_not_found_not_retryable() {
409 let classifier = NetworkRetryClassifier::new();
410 assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
411 }
412
413 #[test]
414 fn test_classifier_rate_limited_is_retryable() {
415 let classifier = NetworkRetryClassifier::new();
416 assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
417 }
418
419 #[tokio::test]
420 async fn test_executor_success_on_first_try() {
421 let policy = RetryPolicy::new().with_max_attempts(3);
422 let executor = RetryExecutor::new(policy);
423 let attempts = Arc::new(AtomicUsize::new(0));
424
425 let attempts_clone = attempts.clone();
426 let result: Result<&str> = executor
427 .execute(|| {
428 let a = attempts_clone.clone();
429 async move {
430 a.fetch_add(1, Ordering::SeqCst);
431 Ok("success")
432 }
433 })
434 .await;
435
436 assert!(result.is_ok());
437 assert_eq!(result.unwrap(), "success");
438 assert_eq!(attempts.load(Ordering::SeqCst), 1);
439 }
440
441 #[tokio::test]
442 async fn test_executor_retries_on_transient_error() {
443 let policy = RetryPolicy::new()
444 .with_max_attempts(3)
445 .with_initial_delay(Duration::from_millis(1))
446 .with_jitter(false);
447 let executor = RetryExecutor::new(policy);
448 let attempts = Arc::new(AtomicUsize::new(0));
449
450 let attempts_clone = attempts.clone();
451 let result: Result<&str> = executor
452 .execute(|| {
453 let a = attempts_clone.clone();
454 async move {
455 let count = a.fetch_add(1, Ordering::SeqCst);
456 if count < 2 {
457 Err(SeerError::Timeout("test timeout".to_string()))
458 } else {
459 Ok("success after retries")
460 }
461 }
462 })
463 .await;
464
465 assert!(result.is_ok());
466 assert_eq!(result.unwrap(), "success after retries");
467 assert_eq!(attempts.load(Ordering::SeqCst), 3);
468 }
469
470 #[tokio::test]
471 async fn test_executor_no_retry_on_non_retryable_error() {
472 let policy = RetryPolicy::new()
473 .with_max_attempts(3)
474 .with_initial_delay(Duration::from_millis(1));
475 let executor = RetryExecutor::new(policy);
476 let attempts = Arc::new(AtomicUsize::new(0));
477
478 let attempts_clone = attempts.clone();
479 let result: Result<&str> = executor
480 .execute(|| {
481 let a = attempts_clone.clone();
482 async move {
483 a.fetch_add(1, Ordering::SeqCst);
484 Err(SeerError::InvalidDomain("bad.".to_string()))
485 }
486 })
487 .await;
488
489 assert!(result.is_err());
490 assert_eq!(attempts.load(Ordering::SeqCst), 1);
492 }
493
494 #[tokio::test]
495 async fn test_executor_exhausts_retries() {
496 let policy = RetryPolicy::new()
497 .with_max_attempts(3)
498 .with_initial_delay(Duration::from_millis(1))
499 .with_jitter(false);
500 let executor = RetryExecutor::new(policy);
501 let attempts = Arc::new(AtomicUsize::new(0));
502
503 let attempts_clone = attempts.clone();
504 let result: Result<&str> = executor
505 .execute(|| {
506 let a = attempts_clone.clone();
507 async move {
508 a.fetch_add(1, Ordering::SeqCst);
509 Err(SeerError::Timeout("always fails".to_string()))
510 }
511 })
512 .await;
513
514 assert!(result.is_err());
515 assert_eq!(attempts.load(Ordering::SeqCst), 3);
516
517 match result.unwrap_err() {
519 SeerError::RetryExhausted { attempts, .. } => {
520 assert_eq!(attempts, 3);
521 }
522 other => panic!("Expected RetryExhausted, got {:?}", other),
523 }
524 }
525
526 #[test]
527 fn test_no_retry_policy() {
528 let policy = RetryPolicy::no_retry();
529 assert_eq!(policy.max_attempts, 1);
530 }
531
532 #[test]
533 fn test_delay_overflow_protection() {
534 let policy = RetryPolicy::new()
535 .with_initial_delay(Duration::from_millis(100))
536 .with_multiplier(2.0)
537 .with_max_delay(Duration::from_secs(5))
538 .with_jitter(false);
539
540 let delay_50 = policy.delay_for_attempt(50);
542 let delay_100 = policy.delay_for_attempt(100);
543 let delay_1000 = policy.delay_for_attempt(1000);
544
545 assert!(delay_50 <= Duration::from_secs(5));
547 assert!(delay_100 <= Duration::from_secs(5));
548 assert!(delay_1000 <= Duration::from_secs(5));
549 }
550}