1use std::time::Duration;
26use tokio::time::sleep;
27use tracing::{warn, info};
28use std::future::Future;
29
30#[derive(Debug, Clone)]
37pub struct RetryConfig {
38 pub initial_delay: Duration,
39 pub max_delay: Duration,
40 pub factor: f64,
41 pub max_retries: Option<usize>,
42}
43
44impl Default for RetryConfig {
45 fn default() -> Self {
46 Self::daemon()
47 }
48}
49
50impl RetryConfig {
51 #[must_use]
55 pub fn startup() -> Self {
56 Self {
57 max_retries: Some(5),
58 initial_delay: Duration::from_millis(200),
59 max_delay: Duration::from_secs(2),
60 factor: 2.0,
61 }
62 }
63
64 #[must_use]
68 pub fn daemon() -> Self {
69 Self {
70 max_retries: None, initial_delay: Duration::from_secs(1),
72 max_delay: Duration::from_secs(300), factor: 2.0,
74 }
75 }
76
77 #[must_use]
80 pub fn query() -> Self {
81 Self {
82 max_retries: Some(3),
83 initial_delay: Duration::from_millis(100),
84 max_delay: Duration::from_secs(2),
85 factor: 2.0,
86 }
87 }
88
89 #[cfg(test)]
91 pub fn test() -> Self {
92 Self {
93 max_retries: Some(3),
94 initial_delay: Duration::from_millis(1),
95 max_delay: Duration::from_millis(10),
96 factor: 2.0,
97 }
98 }
99}
100
101pub async fn retry<F, Fut, T, E>(
102 operation_name: &str,
103 config: &RetryConfig,
104 mut operation: F,
105) -> Result<T, E>
106where
107 F: FnMut() -> Fut,
108 Fut: Future<Output = Result<T, E>>,
109 E: std::fmt::Display,
110{
111 let mut delay = config.initial_delay;
112 let mut attempts = 0;
113
114 loop {
115 match operation().await {
116 Ok(val) => {
117 if attempts > 0 {
118 info!("Operation '{}' succeeded after {} retries", operation_name, attempts);
119 }
120 return Ok(val);
121 }
122 Err(err) => {
123 attempts += 1;
124
125 if let Some(max) = config.max_retries {
126 if attempts >= max {
127 return Err(err);
128 }
129 }
130
131 if config.max_retries.is_none() {
132 warn!(
134 "Operation '{}' failed (attempt {}, will retry forever): {}. Next retry in {:?}...",
135 operation_name, attempts, err, delay
136 );
137 } else {
138 warn!(
139 "Operation '{}' failed (attempt {}/{}): {}. Retrying in {:?}...",
140 operation_name, attempts, config.max_retries.unwrap(), err, delay
141 );
142 }
143
144 sleep(delay).await;
145 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
146 }
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use std::sync::atomic::{AtomicUsize, Ordering};
155 use std::sync::Arc;
156
157 #[derive(Debug)]
158 struct TestError(String);
159
160 impl std::fmt::Display for TestError {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 write!(f, "{}", self.0)
163 }
164 }
165
166 #[tokio::test]
167 async fn test_retry_succeeds_first_try() {
168 let result: Result<i32, TestError> = retry(
169 "test_op",
170 &RetryConfig::test(),
171 || async { Ok(42) },
172 ).await;
173
174 assert_eq!(result.unwrap(), 42);
175 }
176
177 #[tokio::test]
178 async fn test_retry_succeeds_after_failures() {
179 let attempts = Arc::new(AtomicUsize::new(0));
180 let attempts_clone = attempts.clone();
181
182 let result: Result<i32, TestError> = retry(
183 "test_op",
184 &RetryConfig::test(),
185 || {
186 let a = attempts_clone.clone();
187 async move {
188 let count = a.fetch_add(1, Ordering::SeqCst) + 1;
189 if count < 3 {
190 Err(TestError(format!("fail {}", count)))
191 } else {
192 Ok(42)
193 }
194 }
195 },
196 ).await;
197
198 assert_eq!(result.unwrap(), 42);
199 assert_eq!(attempts.load(Ordering::SeqCst), 3);
200 }
201
202 #[tokio::test]
203 async fn test_retry_exhausts_retries() {
204 let attempts = Arc::new(AtomicUsize::new(0));
205 let attempts_clone = attempts.clone();
206
207 let config = RetryConfig {
208 max_retries: Some(3),
209 initial_delay: Duration::from_millis(1),
210 max_delay: Duration::from_millis(10),
211 factor: 2.0,
212 };
213
214 let result: Result<i32, TestError> = retry(
215 "test_op",
216 &config,
217 || {
218 let a = attempts_clone.clone();
219 async move {
220 a.fetch_add(1, Ordering::SeqCst);
221 Err(TestError("always fail".to_string()))
222 }
223 },
224 ).await;
225
226 assert!(result.is_err());
227 assert!(result.unwrap_err().0.contains("always fail"));
228 assert_eq!(attempts.load(Ordering::SeqCst), 3);
229 }
230
231 #[test]
232 fn test_retry_config_presets() {
233 let startup = RetryConfig::startup();
235 assert!(startup.max_retries.is_some());
236 assert_eq!(startup.max_retries.unwrap(), 5);
237
238 let daemon = RetryConfig::daemon();
240 assert!(daemon.max_retries.is_none());
241
242 let query = RetryConfig::query();
244 assert!(query.max_retries.is_some());
245 assert_eq!(query.max_retries.unwrap(), 3);
246 }
247
248 #[test]
249 fn test_delay_exponential_backoff() {
250 let config = RetryConfig {
251 initial_delay: Duration::from_millis(100),
252 max_delay: Duration::from_secs(10),
253 factor: 2.0,
254 max_retries: Some(5),
255 };
256
257 let mut delay = config.initial_delay;
258
259 assert_eq!(delay, Duration::from_millis(100));
261
262 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
264 assert_eq!(delay, Duration::from_millis(200));
265
266 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
268 assert_eq!(delay, Duration::from_millis(400));
269 }
270
271 #[test]
272 fn test_delay_caps_at_max() {
273 let config = RetryConfig {
274 initial_delay: Duration::from_secs(1),
275 max_delay: Duration::from_secs(5),
276 factor: 10.0, max_retries: Some(5),
278 };
279
280 let mut delay = config.initial_delay;
281 delay = (delay.mul_f64(config.factor)).min(config.max_delay);
282
283 assert_eq!(delay, Duration::from_secs(5));
285 }
286}