1use std::time::Duration;
29use tokio::time::sleep;
30use tracing::{warn, info};
31use std::future::Future;
32
33#[derive(Debug, Clone)]
40pub struct RetryConfig {
41 pub initial_delay: Duration,
42 pub max_delay: Duration,
43 pub factor: f64,
44 pub max_retries: Option<usize>,
45}
46
47impl Default for RetryConfig {
48 fn default() -> Self {
49 Self::daemon()
50 }
51}
52
53impl RetryConfig {
54 #[must_use]
58 pub fn startup() -> Self {
59 Self {
60 max_retries: Some(5),
61 initial_delay: Duration::from_millis(200),
62 max_delay: Duration::from_secs(2),
63 factor: 2.0,
64 }
65 }
66
67 #[must_use]
71 pub fn daemon() -> Self {
72 Self {
73 max_retries: None, initial_delay: Duration::from_secs(1),
75 max_delay: Duration::from_secs(300), factor: 2.0,
77 }
78 }
79
80 #[must_use]
83 pub fn query() -> Self {
84 Self {
85 max_retries: Some(3),
86 initial_delay: Duration::from_millis(100),
87 max_delay: Duration::from_secs(2),
88 factor: 2.0,
89 }
90 }
91
92 #[must_use]
98 pub fn batch_write() -> Self {
99 Self {
100 max_retries: Some(5),
101 initial_delay: Duration::from_millis(50 + (rand::random::<u64>() % 100)), max_delay: Duration::from_secs(5),
103 factor: 2.0,
104 }
105 }
106
107 #[cfg(test)]
109 pub fn test() -> Self {
110 Self {
111 max_retries: Some(3),
112 initial_delay: Duration::from_millis(1),
113 max_delay: Duration::from_millis(10),
114 factor: 2.0,
115 }
116 }
117}
118
119pub async fn retry<F, Fut, T, E>(
120 operation_name: &str,
121 config: &RetryConfig,
122 mut operation: F,
123) -> Result<T, E>
124where
125 F: FnMut() -> Fut,
126 Fut: Future<Output = Result<T, E>>,
127 E: std::fmt::Display,
128{
129 let mut delay = config.initial_delay;
130 let mut attempts = 0;
131
132 loop {
133 match operation().await {
134 Ok(val) => {
135 if attempts > 0 {
136 info!("Operation '{}' succeeded after {} retries", operation_name, attempts);
137 }
138 return Ok(val);
139 }
140 Err(err) => {
141 attempts += 1;
142
143 if let Some(max) = config.max_retries {
144 if attempts >= max {
145 return Err(err);
146 }
147 }
148
149 if config.max_retries.is_none() {
150 warn!(
152 "Operation '{}' failed (attempt {}, will retry forever): {}. Next retry in {:?}...",
153 operation_name, attempts, err, delay
154 );
155 } else {
156 warn!(
157 "Operation '{}' failed (attempt {}/{}): {}. Retrying in {:?}...",
158 operation_name, attempts, config.max_retries.unwrap(), err, delay
159 );
160 }
161
162 sleep(delay).await;
163 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
164 }
165 }
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use std::sync::atomic::{AtomicUsize, Ordering};
173 use std::sync::Arc;
174
175 #[derive(Debug)]
176 struct TestError(String);
177
178 impl std::fmt::Display for TestError {
179 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
180 write!(f, "{}", self.0)
181 }
182 }
183
184 #[tokio::test]
185 async fn test_retry_succeeds_first_try() {
186 let result: Result<i32, TestError> = retry(
187 "test_op",
188 &RetryConfig::test(),
189 || async { Ok(42) },
190 ).await;
191
192 assert_eq!(result.unwrap(), 42);
193 }
194
195 #[tokio::test]
196 async fn test_retry_succeeds_after_failures() {
197 let attempts = Arc::new(AtomicUsize::new(0));
198 let attempts_clone = attempts.clone();
199
200 let result: Result<i32, TestError> = retry(
201 "test_op",
202 &RetryConfig::test(),
203 || {
204 let a = attempts_clone.clone();
205 async move {
206 let count = a.fetch_add(1, Ordering::SeqCst) + 1;
207 if count < 3 {
208 Err(TestError(format!("fail {}", count)))
209 } else {
210 Ok(42)
211 }
212 }
213 },
214 ).await;
215
216 assert_eq!(result.unwrap(), 42);
217 assert_eq!(attempts.load(Ordering::SeqCst), 3);
218 }
219
220 #[tokio::test]
221 async fn test_retry_exhausts_retries() {
222 let attempts = Arc::new(AtomicUsize::new(0));
223 let attempts_clone = attempts.clone();
224
225 let config = RetryConfig {
226 max_retries: Some(3),
227 initial_delay: Duration::from_millis(1),
228 max_delay: Duration::from_millis(10),
229 factor: 2.0,
230 };
231
232 let result: Result<i32, TestError> = retry(
233 "test_op",
234 &config,
235 || {
236 let a = attempts_clone.clone();
237 async move {
238 a.fetch_add(1, Ordering::SeqCst);
239 Err(TestError("always fail".to_string()))
240 }
241 },
242 ).await;
243
244 assert!(result.is_err());
245 assert!(result.unwrap_err().0.contains("always fail"));
246 assert_eq!(attempts.load(Ordering::SeqCst), 3);
247 }
248
249 #[test]
250 fn test_retry_config_presets() {
251 let startup = RetryConfig::startup();
253 assert!(startup.max_retries.is_some());
254 assert_eq!(startup.max_retries.unwrap(), 5);
255
256 let daemon = RetryConfig::daemon();
258 assert!(daemon.max_retries.is_none());
259
260 let query = RetryConfig::query();
262 assert!(query.max_retries.is_some());
263 assert_eq!(query.max_retries.unwrap(), 3);
264 }
265
266 #[test]
267 fn test_delay_exponential_backoff() {
268 let config = RetryConfig {
269 initial_delay: Duration::from_millis(100),
270 max_delay: Duration::from_secs(10),
271 factor: 2.0,
272 max_retries: Some(5),
273 };
274
275 let mut delay = config.initial_delay;
276
277 assert_eq!(delay, Duration::from_millis(100));
279
280 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
282 assert_eq!(delay, Duration::from_millis(200));
283
284 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
286 assert_eq!(delay, Duration::from_millis(400));
287 }
288
289 #[test]
290 fn test_delay_caps_at_max() {
291 let config = RetryConfig {
292 initial_delay: Duration::from_secs(1),
293 max_delay: Duration::from_secs(5),
294 factor: 10.0, max_retries: Some(5),
296 };
297
298 let mut delay = config.initial_delay;
299 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
300
301 assert_eq!(delay, Duration::from_secs(5));
303 }
304}