subx_cli/services/ai/
retry.rs1use crate::Result;
2use tokio::time::{Duration, sleep};
3
4pub struct RetryConfig {
9 pub max_attempts: usize,
11 pub base_delay: Duration,
13 pub max_delay: Duration,
15 pub backoff_multiplier: f64,
17}
18
19impl Default for RetryConfig {
20 fn default() -> Self {
21 Self {
22 max_attempts: 3,
23 base_delay: Duration::from_millis(1000),
24 max_delay: Duration::from_secs(30),
25 backoff_multiplier: 2.0,
26 }
27 }
28}
29
30pub async fn retry_with_backoff<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
32where
33 F: Fn() -> Fut,
34 Fut: std::future::Future<Output = Result<T>>,
35{
36 let mut last_error = None;
37
38 for attempt in 0..config.max_attempts {
39 match operation().await {
40 Ok(result) => return Ok(result),
41 Err(e) => {
42 last_error = Some(e);
43
44 if attempt < config.max_attempts - 1 {
45 let delay = std::cmp::min(
46 Duration::from_millis(
47 (config.base_delay.as_millis() as f64
48 * config.backoff_multiplier.powi(attempt as i32))
49 as u64,
50 ),
51 config.max_delay,
52 );
53 sleep(delay).await;
54 }
55 }
56 }
57 }
58
59 Err(last_error.unwrap())
60}
61
62#[allow(async_fn_in_trait)]
64pub trait HttpRetryClient {
65 fn retry_attempts(&self) -> u32;
67 fn retry_delay_ms(&self) -> u64;
69
70 async fn make_request_with_retry(
72 &self,
73 request: reqwest::RequestBuilder,
74 ) -> reqwest::Result<reqwest::Response> {
75 make_http_request_with_retry_impl(request, self.retry_attempts(), self.retry_delay_ms())
76 .await
77 }
78}
79
80async fn make_http_request_with_retry_impl(
82 request: reqwest::RequestBuilder,
83 retry_attempts: u32,
84 retry_delay_ms: u64,
85) -> reqwest::Result<reqwest::Response> {
86 let mut attempts = 0;
87 loop {
88 let cloned = request.try_clone().unwrap();
89 match cloned.send().await {
90 Ok(resp) => match resp.error_for_status() {
91 Ok(success) => return Ok(success),
92 Err(err) if attempts + 1 >= retry_attempts => return Err(err),
93 Err(_) => {}
94 },
95 Err(err) if attempts + 1 >= retry_attempts => return Err(err),
96 Err(_) => {}
97 }
98 attempts += 1;
99 sleep(Duration::from_millis(retry_delay_ms)).await;
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::error::SubXError;
107 use std::sync::{Arc, Mutex};
108 use std::time::Instant;
109
110 #[tokio::test]
112 async fn test_retry_success_on_second_attempt() {
113 let config = RetryConfig {
114 max_attempts: 3,
115 base_delay: Duration::from_millis(10),
116 max_delay: Duration::from_secs(1),
117 backoff_multiplier: 2.0,
118 };
119
120 let attempt_count = Arc::new(Mutex::new(0));
121 let attempt_count_clone = attempt_count.clone();
122
123 let operation = || async {
124 let mut count = attempt_count_clone.lock().unwrap();
125 *count += 1;
126 if *count == 1 {
127 Err(SubXError::AiService("First attempt fails".to_string()))
128 } else {
129 Ok("Success on second attempt".to_string())
130 }
131 };
132
133 let result = retry_with_backoff(operation, &config).await;
134 assert!(result.is_ok());
135 assert_eq!(result.unwrap(), "Success on second attempt");
136 assert_eq!(*attempt_count.lock().unwrap(), 2);
137 }
138
139 #[tokio::test]
141 async fn test_retry_exhaust_max_attempts() {
142 let config = RetryConfig {
143 max_attempts: 2,
144 base_delay: Duration::from_millis(10),
145 max_delay: Duration::from_secs(1),
146 backoff_multiplier: 2.0,
147 };
148
149 let attempt_count = Arc::new(Mutex::new(0));
150 let attempt_count_clone = attempt_count.clone();
151
152 let operation = || async {
153 let mut count = attempt_count_clone.lock().unwrap();
154 *count += 1;
155 Err(SubXError::AiService("Always fails".to_string()))
156 };
157
158 let result: Result<String> = retry_with_backoff(operation, &config).await;
159 assert!(result.is_err());
160 assert_eq!(*attempt_count.lock().unwrap(), 2);
161 }
162
163 #[tokio::test]
165 async fn test_exponential_backoff_timing() {
166 let config = RetryConfig {
167 max_attempts: 3,
168 base_delay: Duration::from_millis(50),
169 max_delay: Duration::from_millis(200),
170 backoff_multiplier: 2.0,
171 };
172
173 let attempt_times = Arc::new(Mutex::new(Vec::new()));
174 let attempt_times_clone = attempt_times.clone();
175
176 let operation = || async {
177 let start_time = Instant::now();
178 attempt_times_clone.lock().unwrap().push(start_time);
179 Err(SubXError::AiService(
180 "Always fails for timing test".to_string(),
181 ))
182 };
183
184 let _overall_start = Instant::now();
185 let _result: Result<String> = retry_with_backoff(operation, &config).await;
186
187 let times = attempt_times.lock().unwrap();
188 assert_eq!(times.len(), 3);
189
190 if times.len() >= 2 {
192 let delay1 = times[1].duration_since(times[0]);
193 assert!(delay1 >= Duration::from_millis(30));
195 assert!(delay1 <= Duration::from_millis(100));
196 }
197 }
198
199 #[tokio::test]
201 async fn test_max_delay_cap() {
202 let config = RetryConfig {
203 max_attempts: 5,
204 base_delay: Duration::from_millis(100),
205 max_delay: Duration::from_millis(200), backoff_multiplier: 3.0, };
208
209 let attempt_times = Arc::new(Mutex::new(Vec::new()));
210 let attempt_times_clone = attempt_times.clone();
211
212 let operation = || async {
213 attempt_times_clone.lock().unwrap().push(Instant::now());
214 Err(SubXError::AiService("Always fails".to_string()))
215 };
216
217 let _result: Result<String> = retry_with_backoff(operation, &config).await;
218
219 let times = attempt_times.lock().unwrap();
220
221 if times.len() >= 3 {
223 let delay2 = times[2].duration_since(times[1]);
224 assert!(delay2 <= Duration::from_millis(250));
226 }
227 }
228
229 #[test]
231 fn test_retry_config_validation() {
232 let valid_config = RetryConfig {
234 max_attempts: 3,
235 base_delay: Duration::from_millis(100),
236 max_delay: Duration::from_secs(1),
237 backoff_multiplier: 2.0,
238 };
239 assert!(valid_config.base_delay <= valid_config.max_delay);
240 assert!(valid_config.max_attempts > 0);
241 assert!(valid_config.backoff_multiplier > 1.0);
242 }
243
244 #[tokio::test]
246 async fn test_ai_service_integration_simulation() {
247 let config = RetryConfig {
248 max_attempts: 3,
249 base_delay: Duration::from_millis(10),
250 max_delay: Duration::from_secs(1),
251 backoff_multiplier: 2.0,
252 };
253
254 let request_count = Arc::new(Mutex::new(0));
256 let request_count_clone = request_count.clone();
257
258 let mock_ai_request = || async {
259 let mut count = request_count_clone.lock().unwrap();
260 *count += 1;
261
262 match *count {
263 1 => Err(SubXError::AiService("Network timeout".to_string())),
264 2 => Err(SubXError::AiService("Rate limit exceeded".to_string())),
265 3 => Ok("AI analysis complete".to_string()),
266 _ => unreachable!(),
267 }
268 };
269
270 let result = retry_with_backoff(mock_ai_request, &config).await;
271 assert!(result.is_ok());
272 assert_eq!(result.unwrap(), "AI analysis complete");
273 assert_eq!(*request_count.lock().unwrap(), 3);
274 }
275}