Skip to main content

subx_cli/services/ai/
retry.rs

1use crate::Result;
2use crate::error::SubXError;
3use tokio::time::{Duration, sleep};
4
5/// Retry configuration for AI service operations.
6///
7/// Configures the retry behavior for AI API calls, including
8/// backoff strategies and maximum attempt limits.
9pub struct RetryConfig {
10    /// Maximum number of retry attempts
11    pub max_attempts: usize,
12    /// Initial delay between retries
13    pub base_delay: Duration,
14    /// Maximum delay between retries
15    pub max_delay: Duration,
16    /// Multiplier for exponential backoff
17    pub backoff_multiplier: f64,
18}
19
20impl Default for RetryConfig {
21    fn default() -> Self {
22        Self {
23            max_attempts: 3,
24            base_delay: Duration::from_millis(1000),
25            max_delay: Duration::from_secs(30),
26            backoff_multiplier: 2.0,
27        }
28    }
29}
30
31/// Retries an operation with an exponential backoff mechanism.
32pub async fn retry_with_backoff<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
33where
34    F: Fn() -> Fut,
35    Fut: std::future::Future<Output = Result<T>>,
36{
37    if config.max_attempts == 0 {
38        return Err(SubXError::AiService(
39            "Retry configuration invalid: max_attempts must be at least 1".to_string(),
40        ));
41    }
42
43    let mut last_error = None;
44
45    for attempt in 0..config.max_attempts {
46        match operation().await {
47            Ok(result) => return Ok(result),
48            Err(e) => {
49                last_error = Some(e);
50
51                if attempt < config.max_attempts - 1 {
52                    let delay = std::cmp::min(
53                        Duration::from_millis(
54                            (config.base_delay.as_millis() as f64
55                                * config.backoff_multiplier.powi(attempt as i32))
56                                as u64,
57                        ),
58                        config.max_delay,
59                    );
60                    sleep(delay).await;
61                }
62            }
63        }
64    }
65
66    // `last_error` is guaranteed to be `Some` here because `max_attempts >= 1`
67    // was verified above and the loop always populates it on failure.
68    Err(last_error
69        .unwrap_or_else(|| SubXError::AiService("Retry loop produced no error state".to_string())))
70}
71
72/// HTTP request retry trait for AI clients.
73#[allow(async_fn_in_trait)]
74pub trait HttpRetryClient {
75    /// Number of retry attempts.
76    fn retry_attempts(&self) -> u32;
77    /// Delay between retries in milliseconds.
78    fn retry_delay_ms(&self) -> u64;
79
80    /// Make an HTTP request with retry logic.
81    async fn make_request_with_retry(
82        &self,
83        request: reqwest::RequestBuilder,
84    ) -> Result<reqwest::Response> {
85        make_http_request_with_retry_impl(request, self.retry_attempts(), self.retry_delay_ms())
86            .await
87    }
88}
89
90/// Internal implementation of HTTP request retry with backoff.
91async fn make_http_request_with_retry_impl(
92    request: reqwest::RequestBuilder,
93    retry_attempts: u32,
94    retry_delay_ms: u64,
95) -> Result<reqwest::Response> {
96    let mut attempts = 0;
97    loop {
98        let cloned = request.try_clone().ok_or_else(|| {
99            SubXError::AiService("Request body cannot be cloned for retry".to_string())
100        })?;
101        match cloned.send().await {
102            Ok(resp) => match resp.error_for_status() {
103                Ok(success) => return Ok(success),
104                Err(err) if attempts + 1 >= retry_attempts => return Err(err.into()),
105                Err(_) => {}
106            },
107            Err(err) if attempts + 1 >= retry_attempts => return Err(err.into()),
108            Err(_) => {}
109        }
110        attempts += 1;
111        sleep(Duration::from_millis(retry_delay_ms)).await;
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::error::SubXError;
119    use std::sync::{Arc, Mutex};
120    use std::time::Instant;
121
122    /// Test basic retry mechanism
123    #[tokio::test]
124    async fn test_retry_success_on_second_attempt() {
125        let config = RetryConfig {
126            max_attempts: 3,
127            base_delay: Duration::from_millis(10),
128            max_delay: Duration::from_secs(1),
129            backoff_multiplier: 2.0,
130        };
131
132        let attempt_count = Arc::new(Mutex::new(0));
133        let attempt_count_clone = attempt_count.clone();
134
135        let operation = || async {
136            let mut count = attempt_count_clone.lock().unwrap();
137            *count += 1;
138            if *count == 1 {
139                Err(SubXError::AiService("First attempt fails".to_string()))
140            } else {
141                Ok("Success on second attempt".to_string())
142            }
143        };
144
145        let result = retry_with_backoff(operation, &config).await;
146        assert!(result.is_ok());
147        assert_eq!(result.unwrap(), "Success on second attempt");
148        assert_eq!(*attempt_count.lock().unwrap(), 2);
149    }
150
151    /// Test maximum retry attempts limit
152    #[tokio::test]
153    async fn test_retry_exhaust_max_attempts() {
154        let config = RetryConfig {
155            max_attempts: 2,
156            base_delay: Duration::from_millis(10),
157            max_delay: Duration::from_secs(1),
158            backoff_multiplier: 2.0,
159        };
160
161        let attempt_count = Arc::new(Mutex::new(0));
162        let attempt_count_clone = attempt_count.clone();
163
164        let operation = || async {
165            let mut count = attempt_count_clone.lock().unwrap();
166            *count += 1;
167            Err(SubXError::AiService("Always fails".to_string()))
168        };
169
170        let result: Result<String> = retry_with_backoff(operation, &config).await;
171        assert!(result.is_err());
172        assert_eq!(*attempt_count.lock().unwrap(), 2);
173    }
174
175    /// Test exponential backoff delay
176    #[tokio::test]
177    async fn test_exponential_backoff_timing() {
178        let config = RetryConfig {
179            max_attempts: 3,
180            base_delay: Duration::from_millis(50),
181            max_delay: Duration::from_millis(200),
182            backoff_multiplier: 2.0,
183        };
184
185        let attempt_times = Arc::new(Mutex::new(Vec::new()));
186        let attempt_times_clone = attempt_times.clone();
187
188        let operation = || async {
189            let start_time = Instant::now();
190            attempt_times_clone.lock().unwrap().push(start_time);
191            Err(SubXError::AiService(
192                "Always fails for timing test".to_string(),
193            ))
194        };
195
196        let _overall_start = Instant::now();
197        let _result: Result<String> = retry_with_backoff(operation, &config).await;
198
199        let times = attempt_times.lock().unwrap();
200        assert_eq!(times.len(), 3);
201
202        // Verify delay times increase (considering execution time tolerance)
203        if times.len() >= 2 {
204            let delay1 = times[1].duration_since(times[0]);
205            // First delay should be approximately 50ms (±20ms tolerance)
206            assert!(delay1 >= Duration::from_millis(30));
207            assert!(delay1 <= Duration::from_millis(100));
208        }
209    }
210
211    /// Test maximum delay cap limit
212    #[tokio::test]
213    async fn test_max_delay_cap() {
214        let config = RetryConfig {
215            max_attempts: 5,
216            base_delay: Duration::from_millis(100),
217            max_delay: Duration::from_millis(200), // Low cap
218            backoff_multiplier: 3.0,               // High multiplier
219        };
220
221        let attempt_times = Arc::new(Mutex::new(Vec::new()));
222        let attempt_times_clone = attempt_times.clone();
223
224        let operation = || async {
225            attempt_times_clone.lock().unwrap().push(Instant::now());
226            Err(SubXError::AiService("Always fails".to_string()))
227        };
228
229        let _result: Result<String> = retry_with_backoff(operation, &config).await;
230
231        let times = attempt_times.lock().unwrap();
232
233        // Verify subsequent delays don't exceed max_delay
234        if times.len() >= 3 {
235            let delay2 = times[2].duration_since(times[1]);
236            // Second delay should be capped at max_delay (±50ms tolerance)
237            assert!(delay2 <= Duration::from_millis(250));
238        }
239    }
240
241    /// When `max_attempts == 0` the retry loop must fail fast instead of
242    /// attempting to unwrap a `None` error.
243    #[tokio::test]
244    async fn test_retry_rejects_zero_max_attempts() {
245        let config = RetryConfig {
246            max_attempts: 0,
247            base_delay: Duration::from_millis(1),
248            max_delay: Duration::from_millis(1),
249            backoff_multiplier: 2.0,
250        };
251
252        let called = Arc::new(Mutex::new(false));
253        let called_clone = called.clone();
254        let operation = || {
255            let called = called_clone.clone();
256            async move {
257                *called.lock().unwrap() = true;
258                Ok::<_, SubXError>("should not run".to_string())
259            }
260        };
261
262        let result: Result<String> = retry_with_backoff(operation, &config).await;
263        assert!(result.is_err());
264        assert!(!*called.lock().unwrap(), "operation must not be invoked");
265        match result {
266            Err(SubXError::AiService(msg)) => assert!(msg.contains("max_attempts")),
267            other => panic!("unexpected result: {:?}", other),
268        }
269    }
270
271    /// Test configuration validity validation
272    #[test]
273    fn test_retry_config_validation() {
274        // Valid configuration
275        let valid_config = RetryConfig {
276            max_attempts: 3,
277            base_delay: Duration::from_millis(100),
278            max_delay: Duration::from_secs(1),
279            backoff_multiplier: 2.0,
280        };
281        assert!(valid_config.base_delay <= valid_config.max_delay);
282        assert!(valid_config.max_attempts > 0);
283        assert!(valid_config.backoff_multiplier > 1.0);
284    }
285
286    /// Test AI service integration simulation scenario
287    #[tokio::test]
288    async fn test_ai_service_integration_simulation() {
289        let config = RetryConfig {
290            max_attempts: 3,
291            base_delay: Duration::from_millis(10),
292            max_delay: Duration::from_secs(1),
293            backoff_multiplier: 2.0,
294        };
295
296        // Simulate AI service calls
297        let request_count = Arc::new(Mutex::new(0));
298        let request_count_clone = request_count.clone();
299
300        let mock_ai_request = || async {
301            let mut count = request_count_clone.lock().unwrap();
302            *count += 1;
303
304            match *count {
305                1 => Err(SubXError::AiService("Network timeout".to_string())),
306                2 => Err(SubXError::AiService("Rate limit exceeded".to_string())),
307                3 => Ok("AI analysis complete".to_string()),
308                _ => unreachable!(),
309            }
310        };
311
312        let result = retry_with_backoff(mock_ai_request, &config).await;
313        assert!(result.is_ok());
314        assert_eq!(result.unwrap(), "AI analysis complete");
315        assert_eq!(*request_count.lock().unwrap(), 3);
316    }
317}