1use std::fmt::Display;
6use std::future::Future;
7use std::time::Duration;
8
9use tokio::time::sleep;
10use tracing::{debug, warn};
11
12const DEFAULT_MAX_ATTEMPTS: u32 = 3;
14
15const DEFAULT_INITIAL_DELAY_MS: u64 = 100;
17
18const DEFAULT_MAX_DELAY_MS: u64 = 2000;
20
21#[derive(Debug, Clone)]
23pub struct RetryConfig {
24 pub max_attempts: u32,
26 pub initial_delay_ms: u64,
28 pub max_delay_ms: u64,
30}
31
32impl Default for RetryConfig {
33 fn default() -> Self {
34 Self {
35 max_attempts: DEFAULT_MAX_ATTEMPTS,
36 initial_delay_ms: DEFAULT_INITIAL_DELAY_MS,
37 max_delay_ms: DEFAULT_MAX_DELAY_MS,
38 }
39 }
40}
41
42impl RetryConfig {
43 pub fn new(max_attempts: u32, initial_delay_ms: u64, max_delay_ms: u64) -> Self {
45 Self {
46 max_attempts,
47 initial_delay_ms,
48 max_delay_ms,
49 }
50 }
51
52 fn delay_for_attempt(&self, attempt: u32) -> Duration {
54 let delay_ms = self.initial_delay_ms * 2u64.pow(attempt);
55 let capped_delay_ms = delay_ms.min(self.max_delay_ms);
56 Duration::from_millis(capped_delay_ms)
57 }
58}
59
60pub async fn with_retry<T, E, F, Fut>(config: &RetryConfig, operation: F) -> Result<T, E>
86where
87 F: Fn() -> Fut,
88 Fut: Future<Output = Result<T, E>>,
89 E: Display,
90{
91 let mut last_error: Option<E> = None;
92
93 for attempt in 0..config.max_attempts {
94 match operation().await {
95 Ok(result) => {
96 if attempt > 0 {
97 debug!("Operation succeeded on attempt {}", attempt + 1);
98 }
99 return Ok(result);
100 }
101 Err(e) => {
102 let is_last_attempt = attempt + 1 >= config.max_attempts;
103
104 if is_last_attempt {
105 warn!(
106 "Operation failed after {} attempts: {}",
107 config.max_attempts, e
108 );
109 last_error = Some(e);
110 } else {
111 let delay = config.delay_for_attempt(attempt);
112 warn!(
113 "Operation failed (attempt {}/{}): {}. Retrying in {:?}...",
114 attempt + 1,
115 config.max_attempts,
116 e,
117 delay
118 );
119 sleep(delay).await;
120 last_error = Some(e);
121 }
122 }
123 }
124 }
125
126 Err(last_error.expect("at least one attempt should have been made"))
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use std::sync::Arc;
134 use std::sync::atomic::{AtomicU32, Ordering};
135
136 #[tokio::test]
137 async fn test_retry_success_first_attempt() {
138 let config = RetryConfig::default();
139 let result: Result<&str, &str> = with_retry(&config, || async { Ok("success") }).await;
140 assert_eq!(result, Ok("success"));
141 }
142
143 #[tokio::test]
144 async fn test_retry_success_after_failures() {
145 let config = RetryConfig::new(3, 10, 100); let attempt_count = Arc::new(AtomicU32::new(0));
147 let attempt_count_clone = attempt_count.clone();
148
149 let result: Result<&str, &str> = with_retry(&config, || {
150 let count = attempt_count_clone.clone();
151 async move {
152 let current = count.fetch_add(1, Ordering::SeqCst);
153 if current < 2 {
154 Err("transient error")
155 } else {
156 Ok("success")
157 }
158 }
159 })
160 .await;
161
162 assert_eq!(result, Ok("success"));
163 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
164 }
165
166 #[tokio::test]
167 async fn test_retry_all_failures() {
168 let config = RetryConfig::new(3, 10, 100); let attempt_count = Arc::new(AtomicU32::new(0));
170 let attempt_count_clone = attempt_count.clone();
171
172 let result: Result<&str, &str> = with_retry(&config, || {
173 let count = attempt_count_clone.clone();
174 async move {
175 count.fetch_add(1, Ordering::SeqCst);
176 Err("persistent error")
177 }
178 })
179 .await;
180
181 assert_eq!(result, Err("persistent error"));
182 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
183 }
184
185 #[test]
186 fn test_delay_calculation() {
187 let config = RetryConfig::new(5, 100, 1000);
188
189 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
190 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
191 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
192 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
193 assert_eq!(config.delay_for_attempt(4), Duration::from_millis(1000)); }
195}