1use crate::config::RetryConfig;
35use crate::error::Error;
36use rand::Rng;
37use std::future::Future;
38use std::time::Duration;
39
40pub trait IsRetryable {
45 fn is_retryable(&self) -> bool;
47}
48
49impl IsRetryable for Error {
51 fn is_retryable(&self) -> bool {
52 match self {
53 Error::Network(e) => {
55 e.is_timeout() || e.is_connect()
57 }
58 Error::Io(e) => matches!(
60 e.kind(),
61 std::io::ErrorKind::TimedOut
62 | std::io::ErrorKind::ConnectionRefused
63 | std::io::ErrorKind::ConnectionReset
64 | std::io::ErrorKind::ConnectionAborted
65 | std::io::ErrorKind::NotConnected
66 | std::io::ErrorKind::BrokenPipe
67 | std::io::ErrorKind::Interrupted
68 ),
69 Error::Nntp(msg) => {
72 msg.contains("timeout")
74 || msg.contains("busy")
75 || msg.contains("connection")
76 || msg.contains("temporary")
77 || msg.contains("503") || msg.contains("400") }
80 Error::Download(_) => false,
82 Error::PostProcess(_) => false,
84 Error::Database(_) | Error::Sqlx(_) => false,
86 Error::Config { .. } => false,
88 Error::InvalidNzb(_) => false,
90 Error::NotFound(_) => false,
92 Error::ShuttingDown => false,
94 Error::Serialization(_) => false,
96 Error::ApiServerError(_) => false,
98 Error::FolderWatch(_) => false,
100 Error::Duplicate(_) => false,
102 Error::InsufficientSpace { .. } => false,
104 Error::DiskSpaceCheckFailed(_) => false,
106 Error::ExternalTool(msg) => {
108 msg.contains("timeout") || msg.contains("busy") || msg.contains("temporary")
110 }
111 Error::NotSupported(_) => false,
113 Error::Other(_) => false,
115 }
116 }
117}
118
119pub async fn download_with_retry<F, Fut, T, E>(
147 config: &RetryConfig,
148 mut operation: F,
149) -> Result<T, E>
150where
151 F: FnMut() -> Fut,
152 Fut: Future<Output = Result<T, E>>,
153 E: IsRetryable + std::fmt::Display,
154{
155 let mut attempt = 0;
156 let mut delay = config.initial_delay;
157
158 loop {
159 match operation().await {
160 Ok(result) => {
161 if attempt > 0 {
162 tracing::info!(attempts = attempt + 1, "Operation succeeded after retry");
163 }
164 return Ok(result);
165 }
166 Err(e) if e.is_retryable() && attempt < config.max_attempts => {
167 attempt += 1;
168
169 tracing::warn!(
170 error = %e,
171 attempt = attempt,
172 max_attempts = config.max_attempts,
173 delay_ms = delay.as_millis(),
174 "Operation failed, retrying"
175 );
176
177 let jittered_delay = if config.jitter {
179 add_jitter(delay)
180 } else {
181 delay
182 };
183
184 tokio::time::sleep(jittered_delay).await;
186
187 let next_delay =
189 Duration::from_secs_f64(delay.as_secs_f64() * config.backoff_multiplier);
190 delay = next_delay.min(config.max_delay);
191 }
192 Err(e) => {
193 if e.is_retryable() {
194 tracing::error!(
195 error = %e,
196 attempts = attempt + 1,
197 "Operation failed after all retry attempts exhausted"
198 );
199 } else {
200 tracing::error!(
201 error = %e,
202 "Operation failed with non-retryable error"
203 );
204 }
205 return Err(e);
206 }
207 }
208 }
209}
210
211fn add_jitter(delay: Duration) -> Duration {
224 let mut rng = rand::thread_rng();
225 let jitter_factor: f64 = rng.gen_range(0.0..=1.0);
226 let jittered_secs = delay.as_secs_f64() * (1.0 + jitter_factor);
227 Duration::from_secs_f64(jittered_secs)
228}
229
230#[allow(clippy::unwrap_used, clippy::expect_used)]
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use std::sync::Arc;
236 use std::sync::atomic::{AtomicU32, Ordering};
237
238 #[derive(Debug)]
239 enum TestError {
240 Transient,
241 Permanent,
242 }
243
244 impl std::fmt::Display for TestError {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 match self {
247 TestError::Transient => write!(f, "transient error"),
248 TestError::Permanent => write!(f, "permanent error"),
249 }
250 }
251 }
252
253 impl IsRetryable for TestError {
254 fn is_retryable(&self) -> bool {
255 matches!(self, TestError::Transient)
256 }
257 }
258
259 #[tokio::test]
260 async fn test_success_no_retry() {
261 let config = RetryConfig::default();
262 let counter = Arc::new(AtomicU32::new(0));
263 let counter_clone = counter.clone();
264
265 let result = download_with_retry(&config, || {
266 let counter = counter_clone.clone();
267 async move {
268 counter.fetch_add(1, Ordering::SeqCst);
269 Ok::<_, TestError>(42)
270 }
271 })
272 .await;
273
274 assert_eq!(result.unwrap(), 42);
275 assert_eq!(counter.load(Ordering::SeqCst), 1, "should only call once");
276 }
277
278 #[tokio::test]
279 async fn test_retry_transient_then_succeed() {
280 let config = RetryConfig {
281 max_attempts: 3,
282 initial_delay: Duration::from_millis(10),
283 max_delay: Duration::from_secs(1),
284 backoff_multiplier: 2.0,
285 jitter: false,
286 };
287
288 let counter = Arc::new(AtomicU32::new(0));
289 let counter_clone = counter.clone();
290
291 let result = download_with_retry(&config, || {
292 let counter = counter_clone.clone();
293 async move {
294 let count = counter.fetch_add(1, Ordering::SeqCst);
295 if count < 2 {
296 Err(TestError::Transient)
297 } else {
298 Ok(42)
299 }
300 }
301 })
302 .await;
303
304 assert_eq!(result.unwrap(), 42);
305 assert_eq!(
306 counter.load(Ordering::SeqCst),
307 3,
308 "should retry twice before success"
309 );
310 }
311
312 #[tokio::test]
313 async fn test_retry_exhausted() {
314 let config = RetryConfig {
315 max_attempts: 2,
316 initial_delay: Duration::from_millis(10),
317 max_delay: Duration::from_secs(1),
318 backoff_multiplier: 2.0,
319 jitter: false,
320 };
321
322 let counter = Arc::new(AtomicU32::new(0));
323 let counter_clone = counter.clone();
324
325 let result = download_with_retry(&config, || {
326 let counter = counter_clone.clone();
327 async move {
328 counter.fetch_add(1, Ordering::SeqCst);
329 Err::<i32, _>(TestError::Transient)
330 }
331 })
332 .await;
333
334 assert!(result.is_err());
335 assert_eq!(
336 counter.load(Ordering::SeqCst),
337 3,
338 "should try initial + 2 retries"
339 );
340 }
341
342 #[tokio::test]
343 async fn test_permanent_error_no_retry() {
344 let config = RetryConfig::default();
345 let counter = Arc::new(AtomicU32::new(0));
346 let counter_clone = counter.clone();
347
348 let result = download_with_retry(&config, || {
349 let counter = counter_clone.clone();
350 async move {
351 counter.fetch_add(1, Ordering::SeqCst);
352 Err::<i32, _>(TestError::Permanent)
353 }
354 })
355 .await;
356
357 assert!(result.is_err());
358 assert_eq!(
359 counter.load(Ordering::SeqCst),
360 1,
361 "should not retry permanent error"
362 );
363 }
364
365 #[tokio::test]
366 async fn test_exponential_backoff() {
367 let config = RetryConfig {
368 max_attempts: 3,
369 initial_delay: Duration::from_millis(10),
370 max_delay: Duration::from_secs(1),
371 backoff_multiplier: 2.0,
372 jitter: false,
373 };
374
375 let start = std::time::Instant::now();
376 let counter = Arc::new(AtomicU32::new(0));
377 let counter_clone = counter.clone();
378
379 let _result = download_with_retry(&config, || {
380 let counter = counter_clone.clone();
381 async move {
382 counter.fetch_add(1, Ordering::SeqCst);
383 Err::<i32, _>(TestError::Transient)
384 }
385 })
386 .await;
387
388 let elapsed = start.elapsed();
389
390 assert!(
393 elapsed >= Duration::from_millis(70),
394 "should wait at least 70ms, waited {:?}",
395 elapsed
396 );
397 assert!(
398 elapsed < Duration::from_secs(2),
399 "should not wait too long, waited {:?}",
400 elapsed
401 );
402 }
403
404 #[tokio::test]
405 async fn test_jitter_adds_randomness() {
406 let delay = Duration::from_millis(100);
407
408 let jittered1 = add_jitter(delay);
410 let jittered2 = add_jitter(delay);
411
412 assert!(jittered1 >= delay);
414 assert!(jittered1 <= delay * 2);
415 assert!(jittered2 >= delay);
416 assert!(jittered2 <= delay * 2);
417
418 }
422
423 #[tokio::test]
424 async fn test_max_delay_cap() {
425 let config = RetryConfig {
426 max_attempts: 5,
427 initial_delay: Duration::from_secs(1),
428 max_delay: Duration::from_secs(3),
429 backoff_multiplier: 10.0, jitter: false,
431 };
432
433 let counter = Arc::new(AtomicU32::new(0));
434 let counter_clone = counter.clone();
435
436 let start = std::time::Instant::now();
437
438 let _result = download_with_retry(&config, || {
439 let counter = counter_clone.clone();
440 async move {
441 counter.fetch_add(1, Ordering::SeqCst);
442 Err::<i32, _>(TestError::Transient)
443 }
444 })
445 .await;
446
447 let elapsed = start.elapsed();
448
449 assert!(
457 elapsed >= Duration::from_secs(13),
458 "should wait at least 13s with max_delay cap, waited {:?}",
459 elapsed
460 );
461 assert!(
462 elapsed < Duration::from_secs(15),
463 "should not exceed expected time significantly, waited {:?}",
464 elapsed
465 );
466 }
467
468 #[tokio::test]
472 async fn test_individual_retry_delays_never_exceed_max_delay() {
473 let config = RetryConfig {
476 max_attempts: 4,
477 initial_delay: Duration::from_millis(50),
478 max_delay: Duration::from_millis(200),
479 backoff_multiplier: 10.0,
480 jitter: false,
481 };
482
483 let timestamps = Arc::new(tokio::sync::Mutex::new(Vec::new()));
484 let ts_clone = timestamps.clone();
485
486 let _result = download_with_retry(&config, || {
487 let ts = ts_clone.clone();
488 async move {
489 ts.lock().await.push(std::time::Instant::now());
490 Err::<i32, _>(TestError::Transient)
491 }
492 })
493 .await;
494
495 let ts = timestamps.lock().await;
496 assert_eq!(ts.len(), 5, "should have initial + 4 retries = 5 calls");
498
499 let max_allowed = Duration::from_millis(350); for i in 1..ts.len() {
502 let gap = ts[i].duration_since(ts[i - 1]);
503 assert!(
504 gap <= max_allowed,
505 "delay between attempt {} and {} was {:?}, which exceeds max_delay (200ms) + tolerance ({:?})",
506 i,
507 i + 1,
508 gap,
509 max_allowed
510 );
511 }
512
513 let gap_3_to_4 = ts[3].duration_since(ts[2]);
516 let gap_4_to_5 = ts[4].duration_since(ts[3]);
517
518 assert!(
519 gap_3_to_4 >= Duration::from_millis(150),
520 "third delay should be ~200ms (capped), was {:?}",
521 gap_3_to_4
522 );
523 assert!(
524 gap_4_to_5 >= Duration::from_millis(150),
525 "fourth delay should be ~200ms (capped), was {:?}",
526 gap_4_to_5
527 );
528 }
529
530 #[test]
531 fn test_error_is_retryable_io() {
532 let timeout_err = Error::Io(std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout"));
533 assert!(timeout_err.is_retryable());
534
535 let connection_refused = Error::Io(std::io::Error::new(
536 std::io::ErrorKind::ConnectionRefused,
537 "refused",
538 ));
539 assert!(connection_refused.is_retryable());
540
541 let not_found = Error::Io(std::io::Error::new(
542 std::io::ErrorKind::NotFound,
543 "not found",
544 ));
545 assert!(!not_found.is_retryable());
546 }
547
548 #[test]
549 fn test_error_is_retryable_nntp() {
550 let timeout = Error::Nntp("connection timeout".to_string());
551 assert!(timeout.is_retryable());
552
553 let busy = Error::Nntp("server busy (400)".to_string());
554 assert!(busy.is_retryable());
555
556 let auth_failed = Error::Nntp("authentication failed".to_string());
557 assert!(!auth_failed.is_retryable());
558 }
559
560 #[test]
561 fn test_error_is_retryable_permanent() {
562 use crate::error::{DatabaseError, DownloadError};
563
564 assert!(
565 !Error::Config {
566 message: "bad config".to_string(),
567 key: None,
568 }
569 .is_retryable()
570 );
571 assert!(
572 !Error::Database(DatabaseError::QueryFailed("db error".to_string())).is_retryable()
573 );
574 assert!(!Error::InvalidNzb("bad nzb".to_string()).is_retryable());
575 assert!(!Error::NotFound("not found".to_string()).is_retryable());
576 assert!(!Error::Download(DownloadError::NotFound { id: 123 }).is_retryable());
577 }
578
579 #[test]
584 fn add_jitter_stays_within_bounds_over_many_iterations() {
585 let delay = Duration::from_millis(50);
586 for i in 0..200 {
588 let jittered = add_jitter(delay);
589 assert!(
590 jittered >= delay,
591 "iteration {i}: jittered {jittered:?} < base delay {delay:?}"
592 );
593 assert!(
594 jittered <= delay * 2,
595 "iteration {i}: jittered {jittered:?} > 2x base delay {:?}",
596 delay * 2
597 );
598 }
599 }
600
601 #[test]
602 fn add_jitter_on_zero_delay_returns_zero() {
603 let jittered = add_jitter(Duration::ZERO);
604 assert_eq!(
605 jittered,
606 Duration::ZERO,
607 "jitter on zero delay should remain zero"
608 );
609 }
610
611 #[tokio::test]
616 async fn zero_max_attempts_fails_on_first_transient_error() {
617 let config = RetryConfig {
618 max_attempts: 0,
619 initial_delay: Duration::from_millis(1),
620 max_delay: Duration::from_secs(1),
621 backoff_multiplier: 2.0,
622 jitter: false,
623 };
624
625 let counter = Arc::new(AtomicU32::new(0));
626 let counter_clone = counter.clone();
627
628 let result = download_with_retry(&config, || {
629 let counter = counter_clone.clone();
630 async move {
631 counter.fetch_add(1, Ordering::SeqCst);
632 Err::<i32, _>(TestError::Transient)
633 }
634 })
635 .await;
636
637 assert!(
638 matches!(result, Err(TestError::Transient)),
639 "should return the transient error without retrying"
640 );
641 assert_eq!(
642 counter.load(Ordering::SeqCst),
643 1,
644 "should call the operation exactly once (no retries when max_attempts=0)"
645 );
646 }
647
648 #[tokio::test]
653 async fn backoff_delays_increase_exponentially() {
654 let config = RetryConfig {
655 max_attempts: 3,
656 initial_delay: Duration::from_millis(50),
657 max_delay: Duration::from_secs(10),
658 backoff_multiplier: 2.0,
659 jitter: false,
660 };
661
662 let timestamps = Arc::new(tokio::sync::Mutex::new(Vec::new()));
663 let ts_clone = timestamps.clone();
664
665 let _result = download_with_retry(&config, || {
666 let ts = ts_clone.clone();
667 async move {
668 ts.lock().await.push(std::time::Instant::now());
669 Err::<i32, _>(TestError::Transient)
670 }
671 })
672 .await;
673
674 let ts = timestamps.lock().await;
675 assert_eq!(ts.len(), 4, "initial + 3 retries = 4 calls");
676
677 let gap1 = ts[1].duration_since(ts[0]);
679 let gap2 = ts[2].duration_since(ts[1]);
681 let gap3 = ts[3].duration_since(ts[2]);
683
684 assert!(
685 gap1 >= Duration::from_millis(40),
686 "first delay should be ~50ms, was {:?}",
687 gap1
688 );
689 assert!(
690 gap2 >= Duration::from_millis(80),
691 "second delay should be ~100ms, was {:?}",
692 gap2
693 );
694 assert!(
695 gap3 >= Duration::from_millis(160),
696 "third delay should be ~200ms, was {:?}",
697 gap3
698 );
699
700 let ratio = gap2.as_secs_f64() / gap1.as_secs_f64();
702 assert!(
703 (1.5..=2.5).contains(&ratio),
704 "gap2/gap1 ratio should be ~2.0, was {ratio:.2}"
705 );
706 }
707
708 #[tokio::test]
713 async fn jitter_enabled_produces_delay_within_expected_range() {
714 let config = RetryConfig {
715 max_attempts: 1,
716 initial_delay: Duration::from_millis(50),
717 max_delay: Duration::from_secs(10),
718 backoff_multiplier: 2.0,
719 jitter: true,
720 };
721
722 let start = std::time::Instant::now();
723
724 let _result =
725 download_with_retry(&config, || async { Err::<i32, _>(TestError::Transient) }).await;
726
727 let elapsed = start.elapsed();
728
729 assert!(
733 elapsed >= Duration::from_millis(40),
734 "should wait at least the base delay, waited {:?}",
735 elapsed
736 );
737 assert!(
738 elapsed < Duration::from_secs(2),
739 "should not wait longer than expected, waited {:?}",
740 elapsed
741 );
742 }
743
744 #[test]
749 fn io_connection_reset_is_retryable() {
750 let err = Error::Io(std::io::Error::new(
751 std::io::ErrorKind::ConnectionReset,
752 "reset by peer",
753 ));
754 assert!(
755 err.is_retryable(),
756 "ConnectionReset should be retryable for transient network glitches"
757 );
758 }
759
760 #[test]
761 fn io_connection_aborted_is_retryable() {
762 let err = Error::Io(std::io::Error::new(
763 std::io::ErrorKind::ConnectionAborted,
764 "aborted",
765 ));
766 assert!(err.is_retryable());
767 }
768
769 #[test]
770 fn io_not_connected_is_retryable() {
771 let err = Error::Io(std::io::Error::new(
772 std::io::ErrorKind::NotConnected,
773 "not connected",
774 ));
775 assert!(err.is_retryable());
776 }
777
778 #[test]
779 fn io_broken_pipe_is_retryable() {
780 let err = Error::Io(std::io::Error::new(
781 std::io::ErrorKind::BrokenPipe,
782 "broken pipe",
783 ));
784 assert!(err.is_retryable());
785 }
786
787 #[test]
788 fn io_interrupted_is_retryable() {
789 let err = Error::Io(std::io::Error::new(
790 std::io::ErrorKind::Interrupted,
791 "interrupted",
792 ));
793 assert!(err.is_retryable());
794 }
795
796 #[test]
797 fn io_permission_denied_is_not_retryable() {
798 let err = Error::Io(std::io::Error::new(
799 std::io::ErrorKind::PermissionDenied,
800 "denied",
801 ));
802 assert!(
803 !err.is_retryable(),
804 "PermissionDenied is permanent, not transient"
805 );
806 }
807
808 #[test]
809 fn nntp_503_service_unavailable_is_retryable() {
810 let err = Error::Nntp("503 service temporarily unavailable".to_string());
811 assert!(err.is_retryable());
812 }
813
814 #[test]
815 fn nntp_400_server_busy_is_retryable() {
816 let err = Error::Nntp("400 server too busy".to_string());
817 assert!(err.is_retryable());
818 }
819
820 #[test]
821 fn nntp_temporary_failure_is_retryable() {
822 let err = Error::Nntp("temporary failure, please retry".to_string());
823 assert!(err.is_retryable());
824 }
825
826 #[test]
827 fn nntp_unknown_error_without_keywords_is_not_retryable() {
828 let err = Error::Nntp("430 no such article".to_string());
829 assert!(
830 !err.is_retryable(),
831 "NNTP error without transient keywords should not be retried"
832 );
833 }
834
835 #[test]
836 fn external_tool_timeout_is_retryable() {
837 let err = Error::ExternalTool("timeout waiting for par2".to_string());
838 assert!(err.is_retryable());
839 }
840
841 #[test]
842 fn external_tool_busy_is_retryable() {
843 let err = Error::ExternalTool("process busy, try again".to_string());
844 assert!(err.is_retryable());
845 }
846
847 #[test]
848 fn external_tool_temporary_is_retryable() {
849 let err = Error::ExternalTool("temporary failure in unrar".to_string());
850 assert!(err.is_retryable());
851 }
852
853 #[test]
854 fn external_tool_not_found_is_not_retryable() {
855 let err = Error::ExternalTool("par2 not found in PATH".to_string());
856 assert!(
857 !err.is_retryable(),
858 "missing binary is permanent, not transient"
859 );
860 }
861
862 #[test]
863 fn post_process_error_is_never_retryable() {
864 use crate::error::PostProcessError;
865 let err = Error::PostProcess(PostProcessError::ExtractionFailed {
866 archive: std::path::PathBuf::from("test.rar"),
867 reason: "CRC error".to_string(),
868 });
869 assert!(!err.is_retryable(), "post-processing errors are permanent");
870 }
871
872 #[test]
873 fn shutting_down_is_not_retryable() {
874 assert!(
875 !Error::ShuttingDown.is_retryable(),
876 "shutdown should not trigger retries"
877 );
878 }
879
880 #[test]
881 fn serialization_error_is_not_retryable() {
882 let err = Error::Serialization(serde_json::from_str::<String>("bad json").unwrap_err());
883 assert!(!err.is_retryable());
884 }
885
886 #[test]
887 fn api_server_error_is_not_retryable() {
888 let err = Error::ApiServerError("bind failed".to_string());
889 assert!(!err.is_retryable());
890 }
891
892 #[test]
893 fn folder_watch_error_is_not_retryable() {
894 let err = Error::FolderWatch("inotify error".to_string());
895 assert!(!err.is_retryable());
896 }
897
898 #[test]
899 fn duplicate_error_is_not_retryable() {
900 let err = Error::Duplicate("already exists".to_string());
901 assert!(!err.is_retryable());
902 }
903
904 #[test]
905 fn insufficient_space_is_not_retryable() {
906 let err = Error::InsufficientSpace {
907 required: 1_000_000,
908 available: 500,
909 };
910 assert!(
911 !err.is_retryable(),
912 "disk space issues require user action, not retries"
913 );
914 }
915
916 #[test]
917 fn disk_space_check_failed_is_not_retryable() {
918 let err = Error::DiskSpaceCheckFailed("statvfs failed".to_string());
919 assert!(!err.is_retryable());
920 }
921
922 #[test]
923 fn not_supported_is_not_retryable() {
924 let err = Error::NotSupported("feature unavailable".to_string());
925 assert!(!err.is_retryable());
926 }
927
928 #[test]
929 fn other_error_is_not_retryable() {
930 let err = Error::Other("unknown problem".to_string());
931 assert!(!err.is_retryable());
932 }
933}