1use std::fmt::Display;
6use std::future::Future;
7use std::time::Duration;
8
9use tokio::time::sleep;
10use tracing::{debug, warn};
11
12const DEFAULT_MAX_ATTEMPTS: u32 = 3;
14
15const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
17
18const DEFAULT_MAX_DELAY_MS: u64 = 2000;
20
21#[derive(Debug, Clone)]
23pub struct RetryConfig {
24 pub max_attempts: u32,
26 pub initial_delay_ms: u64,
28 pub max_delay_ms: u64,
30}
31
32impl Default for RetryConfig {
33 fn default() -> Self {
34 Self {
35 max_attempts: DEFAULT_MAX_ATTEMPTS,
36 initial_delay_ms: DEFAULT_INITIAL_DELAY_MS,
37 max_delay_ms: DEFAULT_MAX_DELAY_MS,
38 }
39 }
40}
41
42impl RetryConfig {
43 pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
45 Self {
46 max_attempts,
47 initial_delay_ms,
48 max_delay_ms,
49 }
50 }
51
52 fn delay_for_attempt(&self, attempt: u32) -> Duration {
54 let delay_ms = self
55 .initial_delay_ms
56 .saturating_mul(1u64.checked_shl(attempt).unwrap_or(u64::MAX));
57 let capped_delay_ms = delay_ms.min(self.max_delay_ms);
58 Duration::from_millis(capped_delay_ms)
59 }
60}
61
62pub async fn with_retry<T, E, F, Fut>(config: &RetryConfig, operation: F) -> Result<T, E>
88where
89 F: Fn() -> Fut,
90 Fut: Future<Output = Result<T, E>>,
91 E: Display,
92{
93 let mut last_error: Option<E> = None;
94
95 for attempt in 0..config.max_attempts {
96 match operation().await {
97 Ok(result) => {
98 if attempt > 0 {
99 debug!("Operation succeeded on attempt {}", attempt + 1);
100 }
101 return Ok(result);
102 }
103 Err(e) => {
104 let is_last_attempt = attempt + 1 >= config.max_attempts;
105
106 if is_last_attempt {
107 warn!(
108 "Operation failed after {} attempts: {}",
109 config.max_attempts, e
110 );
111 last_error = Some(e);
112 } else {
113 let delay = config.delay_for_attempt(attempt);
114 warn!(
115 "Operation failed (attempt {}/{}): {}. Retrying in {:?}...",
116 attempt + 1,
117 config.max_attempts,
118 e,
119 delay
120 );
121 sleep(delay).await;
122 last_error = Some(e);
123 }
124 }
125 }
126 }
127
128 Err(last_error.expect("at least one attempt should have been made"))
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use std::sync::Arc;
136 use std::sync::atomic::{AtomicU32, Ordering};
137
138 #[tokio::test]
139 async fn test_retry_success_first_attempt() {
140 let config = RetryConfig::default();
141 let result: Result<&str, &str> = with_retry(&config, || async { Ok("success") }).await;
142 assert_eq!(result, Ok("success"));
143 }
144
145 #[tokio::test]
146 async fn test_retry_success_after_failures() {
147 let config = RetryConfig::new(3, 10, 100); let attempt_count = Arc::new(AtomicU32::new(0));
149 let attempt_count_clone = attempt_count.clone();
150
151 let result: Result<&str, &str> = with_retry(&config, || {
152 let count = attempt_count_clone.clone();
153 async move {
154 let current = count.fetch_add(1, Ordering::SeqCst);
155 if current < 2 {
156 Err("transient error")
157 } else {
158 Ok("success")
159 }
160 }
161 })
162 .await;
163
164 assert_eq!(result, Ok("success"));
165 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
166 }
167
168 #[tokio::test]
169 async fn test_retry_all_failures() {
170 let config = RetryConfig::new(3, 10, 100); let attempt_count = Arc::new(AtomicU32::new(0));
172 let attempt_count_clone = attempt_count.clone();
173
174 let result: Result<&str, &str> = with_retry(&config, || {
175 let count = attempt_count_clone.clone();
176 async move {
177 count.fetch_add(1, Ordering::SeqCst);
178 Err("persistent error")
179 }
180 })
181 .await;
182
183 assert_eq!(result, Err("persistent error"));
184 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
185 }
186
187 #[test]
188 fn test_delay_for_attempt_no_overflow() {
189 let config = RetryConfig::new(100, 100, 2000);
191 let d64 = config.delay_for_attempt(64);
193 let d100 = config.delay_for_attempt(99);
194 assert_eq!(d64, Duration::from_millis(2000));
196 assert_eq!(d100, Duration::from_millis(2000));
197 }
198
199 #[test]
200 fn test_delay_calculation() {
201 let config = RetryConfig::new(5, 100, 1000);
202
203 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
204 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
205 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
206 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
207 assert_eq!(config.delay_for_attempt(4), Duration::from_millis(1000)); }
209}