subx_cli/services/ai/
retry.rs1use crate::Result;
2use crate::error::SubXError;
3use tokio::time::{Duration, sleep};
4
5pub struct RetryConfig {
10 pub max_attempts: usize,
12 pub base_delay: Duration,
14 pub max_delay: Duration,
16 pub backoff_multiplier: f64,
18}
19
20impl Default for RetryConfig {
21 fn default() -> Self {
22 Self {
23 max_attempts: 3,
24 base_delay: Duration::from_millis(1000),
25 max_delay: Duration::from_secs(30),
26 backoff_multiplier: 2.0,
27 }
28 }
29}
30
31pub async fn retry_with_backoff<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
33where
34 F: Fn() -> Fut,
35 Fut: std::future::Future<Output = Result<T>>,
36{
37 if config.max_attempts == 0 {
38 return Err(SubXError::AiService(
39 "Retry configuration invalid: max_attempts must be at least 1".to_string(),
40 ));
41 }
42
43 let mut last_error = None;
44
45 for attempt in 0..config.max_attempts {
46 match operation().await {
47 Ok(result) => return Ok(result),
48 Err(e) => {
49 last_error = Some(e);
50
51 if attempt < config.max_attempts - 1 {
52 let delay = std::cmp::min(
53 Duration::from_millis(
54 (config.base_delay.as_millis() as f64
55 * config.backoff_multiplier.powi(attempt as i32))
56 as u64,
57 ),
58 config.max_delay,
59 );
60 sleep(delay).await;
61 }
62 }
63 }
64 }
65
66 Err(last_error
69 .unwrap_or_else(|| SubXError::AiService("Retry loop produced no error state".to_string())))
70}
71
72#[allow(async_fn_in_trait)]
74pub trait HttpRetryClient {
75 fn retry_attempts(&self) -> u32;
77 fn retry_delay_ms(&self) -> u64;
79
80 async fn make_request_with_retry(
82 &self,
83 request: reqwest::RequestBuilder,
84 ) -> Result<reqwest::Response> {
85 make_http_request_with_retry_impl(request, self.retry_attempts(), self.retry_delay_ms())
86 .await
87 }
88}
89
90async fn make_http_request_with_retry_impl(
92 request: reqwest::RequestBuilder,
93 retry_attempts: u32,
94 retry_delay_ms: u64,
95) -> Result<reqwest::Response> {
96 let mut attempts = 0;
97 loop {
98 let cloned = request.try_clone().ok_or_else(|| {
99 SubXError::AiService("Request body cannot be cloned for retry".to_string())
100 })?;
101 match cloned.send().await {
102 Ok(resp) => match resp.error_for_status() {
103 Ok(success) => return Ok(success),
104 Err(err) if attempts + 1 >= retry_attempts => return Err(err.into()),
105 Err(_) => {}
106 },
107 Err(err) if attempts + 1 >= retry_attempts => return Err(err.into()),
108 Err(_) => {}
109 }
110 attempts += 1;
111 sleep(Duration::from_millis(retry_delay_ms)).await;
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::error::SubXError;
119 use std::sync::{Arc, Mutex};
120 use std::time::Instant;
121
122 #[tokio::test]
124 async fn test_retry_success_on_second_attempt() {
125 let config = RetryConfig {
126 max_attempts: 3,
127 base_delay: Duration::from_millis(10),
128 max_delay: Duration::from_secs(1),
129 backoff_multiplier: 2.0,
130 };
131
132 let attempt_count = Arc::new(Mutex::new(0));
133 let attempt_count_clone = attempt_count.clone();
134
135 let operation = || async {
136 let mut count = attempt_count_clone.lock().unwrap();
137 *count += 1;
138 if *count == 1 {
139 Err(SubXError::AiService("First attempt fails".to_string()))
140 } else {
141 Ok("Success on second attempt".to_string())
142 }
143 };
144
145 let result = retry_with_backoff(operation, &config).await;
146 assert!(result.is_ok());
147 assert_eq!(result.unwrap(), "Success on second attempt");
148 assert_eq!(*attempt_count.lock().unwrap(), 2);
149 }
150
151 #[tokio::test]
153 async fn test_retry_exhaust_max_attempts() {
154 let config = RetryConfig {
155 max_attempts: 2,
156 base_delay: Duration::from_millis(10),
157 max_delay: Duration::from_secs(1),
158 backoff_multiplier: 2.0,
159 };
160
161 let attempt_count = Arc::new(Mutex::new(0));
162 let attempt_count_clone = attempt_count.clone();
163
164 let operation = || async {
165 let mut count = attempt_count_clone.lock().unwrap();
166 *count += 1;
167 Err(SubXError::AiService("Always fails".to_string()))
168 };
169
170 let result: Result<String> = retry_with_backoff(operation, &config).await;
171 assert!(result.is_err());
172 assert_eq!(*attempt_count.lock().unwrap(), 2);
173 }
174
175 #[tokio::test]
177 async fn test_exponential_backoff_timing() {
178 let config = RetryConfig {
179 max_attempts: 3,
180 base_delay: Duration::from_millis(50),
181 max_delay: Duration::from_millis(200),
182 backoff_multiplier: 2.0,
183 };
184
185 let attempt_times = Arc::new(Mutex::new(Vec::new()));
186 let attempt_times_clone = attempt_times.clone();
187
188 let operation = || async {
189 let start_time = Instant::now();
190 attempt_times_clone.lock().unwrap().push(start_time);
191 Err(SubXError::AiService(
192 "Always fails for timing test".to_string(),
193 ))
194 };
195
196 let _overall_start = Instant::now();
197 let _result: Result<String> = retry_with_backoff(operation, &config).await;
198
199 let times = attempt_times.lock().unwrap();
200 assert_eq!(times.len(), 3);
201
202 if times.len() >= 2 {
204 let delay1 = times[1].duration_since(times[0]);
205 assert!(delay1 >= Duration::from_millis(30));
207 assert!(delay1 <= Duration::from_millis(100));
208 }
209 }
210
211 #[tokio::test]
213 async fn test_max_delay_cap() {
214 let config = RetryConfig {
215 max_attempts: 5,
216 base_delay: Duration::from_millis(100),
217 max_delay: Duration::from_millis(200), backoff_multiplier: 3.0, };
220
221 let attempt_times = Arc::new(Mutex::new(Vec::new()));
222 let attempt_times_clone = attempt_times.clone();
223
224 let operation = || async {
225 attempt_times_clone.lock().unwrap().push(Instant::now());
226 Err(SubXError::AiService("Always fails".to_string()))
227 };
228
229 let _result: Result<String> = retry_with_backoff(operation, &config).await;
230
231 let times = attempt_times.lock().unwrap();
232
233 if times.len() >= 3 {
235 let delay2 = times[2].duration_since(times[1]);
236 assert!(delay2 <= Duration::from_millis(250));
238 }
239 }
240
241 #[tokio::test]
244 async fn test_retry_rejects_zero_max_attempts() {
245 let config = RetryConfig {
246 max_attempts: 0,
247 base_delay: Duration::from_millis(1),
248 max_delay: Duration::from_millis(1),
249 backoff_multiplier: 2.0,
250 };
251
252 let called = Arc::new(Mutex::new(false));
253 let called_clone = called.clone();
254 let operation = || {
255 let called = called_clone.clone();
256 async move {
257 *called.lock().unwrap() = true;
258 Ok::<_, SubXError>("should not run".to_string())
259 }
260 };
261
262 let result: Result<String> = retry_with_backoff(operation, &config).await;
263 assert!(result.is_err());
264 assert!(!*called.lock().unwrap(), "operation must not be invoked");
265 match result {
266 Err(SubXError::AiService(msg)) => assert!(msg.contains("max_attempts")),
267 other => panic!("unexpected result: {:?}", other),
268 }
269 }
270
271 #[test]
273 fn test_retry_config_validation() {
274 let valid_config = RetryConfig {
276 max_attempts: 3,
277 base_delay: Duration::from_millis(100),
278 max_delay: Duration::from_secs(1),
279 backoff_multiplier: 2.0,
280 };
281 assert!(valid_config.base_delay <= valid_config.max_delay);
282 assert!(valid_config.max_attempts > 0);
283 assert!(valid_config.backoff_multiplier > 1.0);
284 }
285
286 #[tokio::test]
288 async fn test_ai_service_integration_simulation() {
289 let config = RetryConfig {
290 max_attempts: 3,
291 base_delay: Duration::from_millis(10),
292 max_delay: Duration::from_secs(1),
293 backoff_multiplier: 2.0,
294 };
295
296 let request_count = Arc::new(Mutex::new(0));
298 let request_count_clone = request_count.clone();
299
300 let mock_ai_request = || async {
301 let mut count = request_count_clone.lock().unwrap();
302 *count += 1;
303
304 match *count {
305 1 => Err(SubXError::AiService("Network timeout".to_string())),
306 2 => Err(SubXError::AiService("Rate limit exceeded".to_string())),
307 3 => Ok("AI analysis complete".to_string()),
308 _ => unreachable!(),
309 }
310 };
311
312 let result = retry_with_backoff(mock_ai_request, &config).await;
313 assert!(result.is_ok());
314 assert_eq!(result.unwrap(), "AI analysis complete");
315 assert_eq!(*request_count.lock().unwrap(), 3);
316 }
317}