1use std::thread;
10use std::time::Duration;
11use tracing::{debug, warn};
12
13const DEFAULT_MAX_ATTEMPTS: u32 = 3;
24
25const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
29
30const DEFAULT_MAX_DELAY_SECS: u64 = 5;
34
35const BACKOFF_MULTIPLIER: f64 = 2.0;
38
39const NETWORK_MAX_ATTEMPTS: u32 = 4;
43
44const NETWORK_INITIAL_DELAY_MS: u64 = 500;
48
49const NETWORK_MAX_DELAY_SECS: u64 = 5;
53
54const CONNECTION_INITIAL_DELAY_MS: u64 = 100;
58
59const CONNECTION_MAX_DELAY_SECS: u64 = 2;
61
62#[derive(Debug, Clone)]
64pub struct RetryConfig {
65 pub max_attempts: u32,
67 pub initial_delay: Duration,
69 pub max_delay: Duration,
71 pub backoff_multiplier: f64,
73}
74
75impl Default for RetryConfig {
76 fn default() -> Self {
77 Self {
78 max_attempts: DEFAULT_MAX_ATTEMPTS,
79 initial_delay: Duration::from_millis(DEFAULT_INITIAL_DELAY_MS),
80 max_delay: Duration::from_secs(DEFAULT_MAX_DELAY_SECS),
81 backoff_multiplier: BACKOFF_MULTIPLIER,
82 }
83 }
84}
85
86impl RetryConfig {
87 pub fn for_network() -> Self {
92 Self {
93 max_attempts: NETWORK_MAX_ATTEMPTS,
94 initial_delay: Duration::from_millis(NETWORK_INITIAL_DELAY_MS),
95 max_delay: Duration::from_secs(NETWORK_MAX_DELAY_SECS),
96 backoff_multiplier: BACKOFF_MULTIPLIER,
97 }
98 }
99
100 pub fn for_connection() -> Self {
105 Self {
106 max_attempts: 6,
107 initial_delay: Duration::from_millis(CONNECTION_INITIAL_DELAY_MS),
108 max_delay: Duration::from_secs(CONNECTION_MAX_DELAY_SECS),
109 backoff_multiplier: BACKOFF_MULTIPLIER,
110 }
111 }
112}
113
114pub fn retry_with_backoff<T, E, F, R>(
130 config: RetryConfig,
131 operation_name: &str,
132 mut operation: F,
133 should_retry: R,
134) -> Result<T, E>
135where
136 F: FnMut() -> Result<T, E>,
137 R: Fn(&E) -> bool,
138 E: std::fmt::Display,
139{
140 let mut attempt = 0;
141 let mut delay = config.initial_delay;
142
143 loop {
144 attempt += 1;
145
146 match operation() {
147 Ok(result) => {
148 if attempt > 1 {
149 debug!(
150 operation = %operation_name,
151 attempts = attempt,
152 "operation succeeded after retry"
153 );
154 }
155 return Ok(result);
156 }
157 Err(e) => {
158 if attempt >= config.max_attempts {
159 warn!(
160 operation = %operation_name,
161 attempts = attempt,
162 error = %e,
163 "operation failed after max attempts"
164 );
165 return Err(e);
166 }
167
168 if !should_retry(&e) {
169 debug!(
170 operation = %operation_name,
171 attempt = attempt,
172 error = %e,
173 "operation failed with non-retryable error"
174 );
175 return Err(e);
176 }
177
178 warn!(
179 operation = %operation_name,
180 attempt = attempt,
181 max_attempts = config.max_attempts,
182 delay_ms = delay.as_millis(),
183 error = %e,
184 "operation failed, will retry"
185 );
186
187 thread::sleep(delay);
188
189 delay = Duration::from_secs_f64(
191 (delay.as_secs_f64() * config.backoff_multiplier)
192 .min(config.max_delay.as_secs_f64()),
193 );
194 }
195 }
196 }
197}
198
199pub fn is_transient_network_error(error_msg: &str) -> bool {
204 let error_lower = error_msg.to_lowercase();
205
206 if error_lower.contains("connection refused")
208 || error_lower.contains("connection reset")
209 || error_lower.contains("connection timed out")
210 || error_lower.contains("network is unreachable")
211 || error_lower.contains("no route to host")
212 || error_lower.contains("temporary failure")
213 || error_lower.contains("try again")
214 || error_lower.contains("resource temporarily unavailable")
215 {
216 return true;
217 }
218
219 if error_lower.contains("name resolution")
221 || error_lower.contains("dns")
222 || error_lower.contains("could not resolve")
223 || error_lower.contains("no such host")
224 {
225 return true;
226 }
227
228 if error_lower.contains("502 bad gateway")
230 || error_lower.contains("503 service unavailable")
231 || error_lower.contains("504 gateway timeout")
232 || error_lower.contains("429 too many requests")
233 {
234 return true;
235 }
236
237 if error_lower.contains("toomanyrequests")
239 || error_lower.contains("rate limit")
240 || error_lower.contains("quota exceeded")
241 {
242 return true;
243 }
244
245 if error_lower.contains("broken pipe")
247 || error_lower.contains("interrupted")
248 || error_lower.contains("eagain")
249 || error_lower.contains("ewouldblock")
250 {
251 return true;
252 }
253
254 false
255}
256
257pub fn is_permanent_error(error_msg: &str) -> bool {
259 let error_lower = error_msg.to_lowercase();
260
261 if error_lower.contains("401 unauthorized")
263 || error_lower.contains("403 forbidden")
264 || error_lower.contains("authentication required")
265 || error_lower.contains("access denied")
266 {
267 return true;
268 }
269
270 if error_lower.contains("404 not found")
272 || error_lower.contains("manifest unknown")
273 || error_lower.contains("name unknown")
274 || error_lower.contains("repository does not exist")
275 {
276 return true;
277 }
278
279 if error_lower.contains("invalid reference")
281 || error_lower.contains("invalid image")
282 || error_lower.contains("malformed")
283 {
284 return true;
285 }
286
287 false
288}
289
290pub fn is_transient_io_error(error: &std::io::Error) -> bool {
292 use std::io::ErrorKind;
293
294 matches!(
295 error.kind(),
296 ErrorKind::ConnectionRefused
297 | ErrorKind::ConnectionReset
298 | ErrorKind::ConnectionAborted
299 | ErrorKind::NotConnected
300 | ErrorKind::BrokenPipe
301 | ErrorKind::TimedOut
302 | ErrorKind::Interrupted
303 | ErrorKind::WouldBlock
304 )
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310 use std::cell::RefCell;
311
312 #[test]
313 fn test_retry_success_first_attempt() {
314 let result: Result<i32, &str> =
315 retry_with_backoff(RetryConfig::default(), "test", || Ok(42), |_| true);
316 assert_eq!(result.unwrap(), 42);
317 }
318
319 #[test]
320 fn test_retry_success_after_failures() {
321 let attempts = RefCell::new(0);
322 let result: Result<i32, &str> = retry_with_backoff(
323 RetryConfig {
324 max_attempts: 3,
325 initial_delay: Duration::from_millis(1),
326 max_delay: Duration::from_millis(10),
327 backoff_multiplier: 2.0,
328 },
329 "test",
330 || {
331 *attempts.borrow_mut() += 1;
332 if *attempts.borrow() < 3 {
333 Err("transient error")
334 } else {
335 Ok(42)
336 }
337 },
338 |_| true,
339 );
340 assert_eq!(result.unwrap(), 42);
341 assert_eq!(*attempts.borrow(), 3);
342 }
343
344 #[test]
345 fn test_retry_exhausted() {
346 let attempts = RefCell::new(0);
347 let result: Result<i32, &str> = retry_with_backoff(
348 RetryConfig {
349 max_attempts: 3,
350 initial_delay: Duration::from_millis(1),
351 max_delay: Duration::from_millis(10),
352 backoff_multiplier: 2.0,
353 },
354 "test",
355 || {
356 *attempts.borrow_mut() += 1;
357 Err("always fails")
358 },
359 |_| true,
360 );
361 assert!(result.is_err());
362 assert_eq!(*attempts.borrow(), 3);
363 }
364
365 #[test]
366 fn test_retry_non_retryable_error() {
367 let attempts = RefCell::new(0);
368 let result: Result<i32, &str> = retry_with_backoff(
369 RetryConfig::default(),
370 "test",
371 || {
372 *attempts.borrow_mut() += 1;
373 Err("permanent error")
374 },
375 |_| false, );
377 assert!(result.is_err());
378 assert_eq!(*attempts.borrow(), 1);
379 }
380
381 #[test]
382 fn test_transient_network_errors() {
383 assert!(is_transient_network_error("connection refused"));
384 assert!(is_transient_network_error("Connection timed out"));
385 assert!(is_transient_network_error("503 Service Unavailable"));
386 assert!(is_transient_network_error("rate limit exceeded"));
387 assert!(!is_transient_network_error("404 not found"));
388 assert!(!is_transient_network_error("some random error"));
389 }
390
391 #[test]
392 fn test_permanent_errors() {
393 assert!(is_permanent_error("401 Unauthorized"));
394 assert!(is_permanent_error("404 Not Found"));
395 assert!(is_permanent_error("manifest unknown"));
396 assert!(!is_permanent_error("connection refused"));
397 assert!(!is_permanent_error("503 Service Unavailable"));
398 }
399
400 #[test]
401 fn test_config_presets() {
402 let network = RetryConfig::for_network();
403 assert_eq!(network.max_attempts, 4);
404 assert_eq!(network.initial_delay, Duration::from_millis(500));
405
406 let connection = RetryConfig::for_connection();
407 assert_eq!(connection.max_attempts, 6);
408 assert_eq!(connection.initial_delay, Duration::from_millis(100));
409 }
410}