1use std::future::Future;
7use std::time::Duration;
8
9use rand::Rng;
10use tracing::{debug, warn};
11
12use crate::error::{Result, SeerError};
13
14#[derive(Debug, Clone)]
16pub struct RetryPolicy {
17 pub max_attempts: usize,
19 pub initial_delay: Duration,
21 pub max_delay: Duration,
23 pub multiplier: f64,
25 pub jitter: bool,
27}
28
29impl Default for RetryPolicy {
30 fn default() -> Self {
31 Self {
32 max_attempts: 3,
33 initial_delay: Duration::from_millis(100),
34 max_delay: Duration::from_secs(5),
35 multiplier: 2.0,
36 jitter: true,
37 }
38 }
39}
40
41impl RetryPolicy {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn with_max_attempts(mut self, attempts: usize) -> Self {
49 self.max_attempts = attempts.max(1);
50 self
51 }
52
53 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
55 self.initial_delay = delay;
56 self
57 }
58
59 pub fn with_max_delay(mut self, delay: Duration) -> Self {
61 self.max_delay = delay;
62 self
63 }
64
65 pub fn with_multiplier(mut self, multiplier: f64) -> Self {
67 self.multiplier = multiplier.max(1.0);
68 self
69 }
70
71 pub fn with_jitter(mut self, jitter: bool) -> Self {
73 self.jitter = jitter;
74 self
75 }
76
77 pub fn no_retry() -> Self {
79 Self {
80 max_attempts: 1,
81 ..Self::default()
82 }
83 }
84
85 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
90 if attempt == 0 {
91 return self.initial_delay;
92 }
93
94 let safe_attempt = attempt.min(20) as i32;
97
98 let base_delay = self.initial_delay.as_millis() as f64 * self.multiplier.powi(safe_attempt);
99 let capped_delay = base_delay.min(self.max_delay.as_millis() as f64);
100
101 let final_delay = if self.jitter {
102 let mut rng = rand::thread_rng();
104 let jitter_factor = rng.gen_range(0.5..1.0);
105 capped_delay * jitter_factor
106 } else {
107 capped_delay
108 };
109
110 Duration::from_millis(final_delay as u64)
111 }
112}
113
114pub trait RetryClassifier: Send + Sync {
116 fn is_retryable(&self, error: &SeerError) -> bool;
118}
119
120#[derive(Debug, Clone, Default)]
133pub struct NetworkRetryClassifier;
134
135impl NetworkRetryClassifier {
136 pub fn new() -> Self {
137 Self
138 }
139}
140
141impl RetryClassifier for NetworkRetryClassifier {
142 fn is_retryable(&self, error: &SeerError) -> bool {
143 match error {
144 SeerError::Timeout(_) => true,
146 SeerError::WhoisConnectionFailed(_) => true,
147 SeerError::RateLimited(_) => true,
148
149 SeerError::ReqwestError(e) => is_transient_reqwest_error(e),
151
152 SeerError::WhoisError(msg) => {
154 let lower = msg.to_lowercase();
155 lower.contains("connection")
156 || lower.contains("timeout")
157 || lower.contains("refused")
158 || lower.contains("reset")
159 }
160
161 SeerError::RdapError(msg) => {
163 let lower = msg.to_lowercase();
164 lower.contains("status 5")
165 || lower.contains("status 429")
166 || lower.contains("timeout")
167 }
168
169 SeerError::RdapBootstrapError(msg) => {
171 let lower = msg.to_lowercase();
172 lower.contains("timeout") || lower.contains("connection")
173 }
174
175 SeerError::DnsError(msg) => {
177 let lower = msg.to_lowercase();
178 lower.contains("timeout") || lower.contains("temporary")
179 }
180
181 SeerError::HttpError(msg) => {
183 let lower = msg.to_lowercase();
184 lower.contains("timeout") || lower.contains("connection") || lower.contains("5")
185 }
186
187 SeerError::InvalidDomain(_) => false,
189 SeerError::DomainNotAllowed { .. } => false,
190 SeerError::InvalidIpAddress(_) => false,
191 SeerError::InvalidRecordType(_) => false,
192 SeerError::WhoisServerNotFound(_) => false,
193 SeerError::JsonError(_) => false,
194 SeerError::CertificateError(_) => false,
195 SeerError::DnsResolverError(_) => false,
196 SeerError::BulkOperationError { .. } => false,
197 SeerError::LookupFailed { .. } => false,
198 SeerError::RetryExhausted { .. } => false,
199 SeerError::Other(_) => false,
200 }
201 }
202}
203
204fn is_transient_reqwest_error(error: &reqwest::Error) -> bool {
206 if error.is_connect() {
208 return true;
209 }
210
211 if error.is_timeout() {
213 return true;
214 }
215
216 if let Some(status) = error.status() {
218 if status.as_u16() == 429 {
220 return true;
221 }
222 if status.is_server_error() {
224 return true;
225 }
226 return false;
228 }
229
230 if error.is_request() || error.is_body() {
232 return false;
233 }
234
235 true
237}
238
239#[derive(Debug, Clone)]
241pub struct RetryExecutor<C: RetryClassifier> {
242 policy: RetryPolicy,
243 classifier: C,
244}
245
246impl RetryExecutor<NetworkRetryClassifier> {
247 pub fn new(policy: RetryPolicy) -> Self {
249 Self {
250 policy,
251 classifier: NetworkRetryClassifier::new(),
252 }
253 }
254}
255
256impl<C: RetryClassifier> RetryExecutor<C> {
257 pub fn with_classifier(policy: RetryPolicy, classifier: C) -> Self {
259 Self { policy, classifier }
260 }
261
262 pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T>
268 where
269 F: FnMut() -> Fut,
270 Fut: Future<Output = Result<T>>,
271 {
272 let mut last_error: Option<SeerError> = None;
273 let mut attempt = 0;
274
275 while attempt < self.policy.max_attempts {
276 match operation().await {
277 Ok(result) => return Ok(result),
278 Err(e) => {
279 let is_retryable = self.classifier.is_retryable(&e);
280 let attempts_remaining = self.policy.max_attempts - attempt - 1;
281
282 if !is_retryable || attempts_remaining == 0 {
283 if attempt > 0 {
284 warn!(
285 attempt = attempt + 1,
286 max_attempts = self.policy.max_attempts,
287 error = %e,
288 "Operation failed after retries"
289 );
290 }
291 return Err(if attempt > 0 {
292 SeerError::RetryExhausted {
293 attempts: attempt + 1,
294 last_error: e.to_string(),
295 }
296 } else {
297 e
298 });
299 }
300
301 let delay = self.policy.delay_for_attempt(attempt);
302 debug!(
303 attempt = attempt + 1,
304 max_attempts = self.policy.max_attempts,
305 delay_ms = delay.as_millis(),
306 error = %e,
307 "Retrying after transient error"
308 );
309
310 last_error = Some(e);
311 tokio::time::sleep(delay).await;
312 attempt += 1;
313 }
314 }
315 }
316
317 Err(last_error.unwrap_or_else(|| SeerError::Other("retry loop exited unexpectedly".into())))
319 }
320
321 pub async fn execute_once<F, Fut, T>(&self, operation: F) -> Result<T>
324 where
325 F: FnOnce() -> Fut,
326 Fut: Future<Output = Result<T>>,
327 {
328 operation().await
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use std::sync::atomic::{AtomicUsize, Ordering};
336 use std::sync::Arc;
337
338 #[test]
339 fn test_retry_policy_defaults() {
340 let policy = RetryPolicy::default();
341 assert_eq!(policy.max_attempts, 3);
342 assert_eq!(policy.initial_delay, Duration::from_millis(100));
343 assert_eq!(policy.max_delay, Duration::from_secs(5));
344 assert_eq!(policy.multiplier, 2.0);
345 assert!(policy.jitter);
346 }
347
348 #[test]
349 fn test_retry_policy_builder() {
350 let policy = RetryPolicy::new()
351 .with_max_attempts(5)
352 .with_initial_delay(Duration::from_millis(200))
353 .with_max_delay(Duration::from_secs(10))
354 .with_multiplier(3.0)
355 .with_jitter(false);
356
357 assert_eq!(policy.max_attempts, 5);
358 assert_eq!(policy.initial_delay, Duration::from_millis(200));
359 assert_eq!(policy.max_delay, Duration::from_secs(10));
360 assert_eq!(policy.multiplier, 3.0);
361 assert!(!policy.jitter);
362 }
363
364 #[test]
365 fn test_delay_calculation_no_jitter() {
366 let policy = RetryPolicy::new()
367 .with_initial_delay(Duration::from_millis(100))
368 .with_multiplier(2.0)
369 .with_max_delay(Duration::from_secs(10))
370 .with_jitter(false);
371
372 assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
373 assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
374 assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
375 assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
376 }
377
378 #[test]
379 fn test_delay_capped_at_max() {
380 let policy = RetryPolicy::new()
381 .with_initial_delay(Duration::from_secs(1))
382 .with_multiplier(10.0)
383 .with_max_delay(Duration::from_secs(5))
384 .with_jitter(false);
385
386 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
388 }
389
390 #[test]
391 fn test_classifier_timeout_is_retryable() {
392 let classifier = NetworkRetryClassifier::new();
393 assert!(classifier.is_retryable(&SeerError::Timeout("test".to_string())));
394 }
395
396 #[test]
397 fn test_classifier_invalid_domain_not_retryable() {
398 let classifier = NetworkRetryClassifier::new();
399 assert!(!classifier.is_retryable(&SeerError::InvalidDomain("test".to_string())));
400 }
401
402 #[test]
403 fn test_classifier_server_not_found_not_retryable() {
404 let classifier = NetworkRetryClassifier::new();
405 assert!(!classifier.is_retryable(&SeerError::WhoisServerNotFound("test".to_string())));
406 }
407
408 #[test]
409 fn test_classifier_rate_limited_is_retryable() {
410 let classifier = NetworkRetryClassifier::new();
411 assert!(classifier.is_retryable(&SeerError::RateLimited("test".to_string())));
412 }
413
414 #[tokio::test]
415 async fn test_executor_success_on_first_try() {
416 let policy = RetryPolicy::new().with_max_attempts(3);
417 let executor = RetryExecutor::new(policy);
418 let attempts = Arc::new(AtomicUsize::new(0));
419
420 let attempts_clone = attempts.clone();
421 let result: Result<&str> = executor
422 .execute(|| {
423 let a = attempts_clone.clone();
424 async move {
425 a.fetch_add(1, Ordering::SeqCst);
426 Ok("success")
427 }
428 })
429 .await;
430
431 assert!(result.is_ok());
432 assert_eq!(result.unwrap(), "success");
433 assert_eq!(attempts.load(Ordering::SeqCst), 1);
434 }
435
436 #[tokio::test]
437 async fn test_executor_retries_on_transient_error() {
438 let policy = RetryPolicy::new()
439 .with_max_attempts(3)
440 .with_initial_delay(Duration::from_millis(1))
441 .with_jitter(false);
442 let executor = RetryExecutor::new(policy);
443 let attempts = Arc::new(AtomicUsize::new(0));
444
445 let attempts_clone = attempts.clone();
446 let result: Result<&str> = executor
447 .execute(|| {
448 let a = attempts_clone.clone();
449 async move {
450 let count = a.fetch_add(1, Ordering::SeqCst);
451 if count < 2 {
452 Err(SeerError::Timeout("test timeout".to_string()))
453 } else {
454 Ok("success after retries")
455 }
456 }
457 })
458 .await;
459
460 assert!(result.is_ok());
461 assert_eq!(result.unwrap(), "success after retries");
462 assert_eq!(attempts.load(Ordering::SeqCst), 3);
463 }
464
465 #[tokio::test]
466 async fn test_executor_no_retry_on_non_retryable_error() {
467 let policy = RetryPolicy::new()
468 .with_max_attempts(3)
469 .with_initial_delay(Duration::from_millis(1));
470 let executor = RetryExecutor::new(policy);
471 let attempts = Arc::new(AtomicUsize::new(0));
472
473 let attempts_clone = attempts.clone();
474 let result: Result<&str> = executor
475 .execute(|| {
476 let a = attempts_clone.clone();
477 async move {
478 a.fetch_add(1, Ordering::SeqCst);
479 Err(SeerError::InvalidDomain("bad.".to_string()))
480 }
481 })
482 .await;
483
484 assert!(result.is_err());
485 assert_eq!(attempts.load(Ordering::SeqCst), 1);
487 }
488
489 #[tokio::test]
490 async fn test_executor_exhausts_retries() {
491 let policy = RetryPolicy::new()
492 .with_max_attempts(3)
493 .with_initial_delay(Duration::from_millis(1))
494 .with_jitter(false);
495 let executor = RetryExecutor::new(policy);
496 let attempts = Arc::new(AtomicUsize::new(0));
497
498 let attempts_clone = attempts.clone();
499 let result: Result<&str> = executor
500 .execute(|| {
501 let a = attempts_clone.clone();
502 async move {
503 a.fetch_add(1, Ordering::SeqCst);
504 Err(SeerError::Timeout("always fails".to_string()))
505 }
506 })
507 .await;
508
509 assert!(result.is_err());
510 assert_eq!(attempts.load(Ordering::SeqCst), 3);
511
512 match result.unwrap_err() {
514 SeerError::RetryExhausted { attempts, .. } => {
515 assert_eq!(attempts, 3);
516 }
517 other => panic!("Expected RetryExhausted, got {:?}", other),
518 }
519 }
520
521 #[test]
522 fn test_no_retry_policy() {
523 let policy = RetryPolicy::no_retry();
524 assert_eq!(policy.max_attempts, 1);
525 }
526
527 #[test]
528 fn test_delay_overflow_protection() {
529 let policy = RetryPolicy::new()
530 .with_initial_delay(Duration::from_millis(100))
531 .with_multiplier(2.0)
532 .with_max_delay(Duration::from_secs(5))
533 .with_jitter(false);
534
535 let delay_50 = policy.delay_for_attempt(50);
537 let delay_100 = policy.delay_for_attempt(100);
538 let delay_1000 = policy.delay_for_attempt(1000);
539
540 assert!(delay_50 <= Duration::from_secs(5));
542 assert!(delay_100 <= Duration::from_secs(5));
543 assert!(delay_1000 <= Duration::from_secs(5));
544 }
545}