subx_cli/services/ai/
retry.rs

1use crate::Result;
2use tokio::time::{Duration, sleep};
3
4/// Retry configuration for AI service operations.
5///
6/// Configures the retry behavior for AI API calls, including
7/// backoff strategies and maximum attempt limits.
8pub struct RetryConfig {
9    /// Maximum number of retry attempts
10    pub max_attempts: usize,
11    /// Initial delay between retries
12    pub base_delay: Duration,
13    /// Maximum delay between retries
14    pub max_delay: Duration,
15    /// Multiplier for exponential backoff
16    pub backoff_multiplier: f64,
17}
18
19impl Default for RetryConfig {
20    fn default() -> Self {
21        Self {
22            max_attempts: 3,
23            base_delay: Duration::from_millis(1000),
24            max_delay: Duration::from_secs(30),
25            backoff_multiplier: 2.0,
26        }
27    }
28}
29
30/// Retries an operation with an exponential backoff mechanism.
31pub async fn retry_with_backoff<F, Fut, T>(operation: F, config: &RetryConfig) -> Result<T>
32where
33    F: Fn() -> Fut,
34    Fut: std::future::Future<Output = Result<T>>,
35{
36    let mut last_error = None;
37
38    for attempt in 0..config.max_attempts {
39        match operation().await {
40            Ok(result) => return Ok(result),
41            Err(e) => {
42                last_error = Some(e);
43
44                if attempt < config.max_attempts - 1 {
45                    let delay = std::cmp::min(
46                        Duration::from_millis(
47                            (config.base_delay.as_millis() as f64
48                                * config.backoff_multiplier.powi(attempt as i32))
49                                as u64,
50                        ),
51                        config.max_delay,
52                    );
53                    sleep(delay).await;
54                }
55            }
56        }
57    }
58
59    Err(last_error.unwrap())
60}
61
62/// HTTP request retry trait for AI clients.
63#[allow(async_fn_in_trait)]
64pub trait HttpRetryClient {
65    /// Number of retry attempts.
66    fn retry_attempts(&self) -> u32;
67    /// Delay between retries in milliseconds.
68    fn retry_delay_ms(&self) -> u64;
69
70    /// Make an HTTP request with retry logic.
71    async fn make_request_with_retry(
72        &self,
73        request: reqwest::RequestBuilder,
74    ) -> reqwest::Result<reqwest::Response> {
75        make_http_request_with_retry_impl(request, self.retry_attempts(), self.retry_delay_ms())
76            .await
77    }
78}
79
80/// Internal implementation of HTTP request retry with backoff.
81async fn make_http_request_with_retry_impl(
82    request: reqwest::RequestBuilder,
83    retry_attempts: u32,
84    retry_delay_ms: u64,
85) -> reqwest::Result<reqwest::Response> {
86    let mut attempts = 0;
87    loop {
88        let cloned = request.try_clone().unwrap();
89        match cloned.send().await {
90            Ok(resp) => match resp.error_for_status() {
91                Ok(success) => return Ok(success),
92                Err(err) if attempts + 1 >= retry_attempts => return Err(err),
93                Err(_) => {}
94            },
95            Err(err) if attempts + 1 >= retry_attempts => return Err(err),
96            Err(_) => {}
97        }
98        attempts += 1;
99        sleep(Duration::from_millis(retry_delay_ms)).await;
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::error::SubXError;
107    use std::sync::{Arc, Mutex};
108    use std::time::Instant;
109
110    /// Test basic retry mechanism
111    #[tokio::test]
112    async fn test_retry_success_on_second_attempt() {
113        let config = RetryConfig {
114            max_attempts: 3,
115            base_delay: Duration::from_millis(10),
116            max_delay: Duration::from_secs(1),
117            backoff_multiplier: 2.0,
118        };
119
120        let attempt_count = Arc::new(Mutex::new(0));
121        let attempt_count_clone = attempt_count.clone();
122
123        let operation = || async {
124            let mut count = attempt_count_clone.lock().unwrap();
125            *count += 1;
126            if *count == 1 {
127                Err(SubXError::AiService("First attempt fails".to_string()))
128            } else {
129                Ok("Success on second attempt".to_string())
130            }
131        };
132
133        let result = retry_with_backoff(operation, &config).await;
134        assert!(result.is_ok());
135        assert_eq!(result.unwrap(), "Success on second attempt");
136        assert_eq!(*attempt_count.lock().unwrap(), 2);
137    }
138
139    /// Test maximum retry attempts limit
140    #[tokio::test]
141    async fn test_retry_exhaust_max_attempts() {
142        let config = RetryConfig {
143            max_attempts: 2,
144            base_delay: Duration::from_millis(10),
145            max_delay: Duration::from_secs(1),
146            backoff_multiplier: 2.0,
147        };
148
149        let attempt_count = Arc::new(Mutex::new(0));
150        let attempt_count_clone = attempt_count.clone();
151
152        let operation = || async {
153            let mut count = attempt_count_clone.lock().unwrap();
154            *count += 1;
155            Err(SubXError::AiService("Always fails".to_string()))
156        };
157
158        let result: Result<String> = retry_with_backoff(operation, &config).await;
159        assert!(result.is_err());
160        assert_eq!(*attempt_count.lock().unwrap(), 2);
161    }
162
163    /// Test exponential backoff delay
164    #[tokio::test]
165    async fn test_exponential_backoff_timing() {
166        let config = RetryConfig {
167            max_attempts: 3,
168            base_delay: Duration::from_millis(50),
169            max_delay: Duration::from_millis(200),
170            backoff_multiplier: 2.0,
171        };
172
173        let attempt_times = Arc::new(Mutex::new(Vec::new()));
174        let attempt_times_clone = attempt_times.clone();
175
176        let operation = || async {
177            let start_time = Instant::now();
178            attempt_times_clone.lock().unwrap().push(start_time);
179            Err(SubXError::AiService(
180                "Always fails for timing test".to_string(),
181            ))
182        };
183
184        let _overall_start = Instant::now();
185        let _result: Result<String> = retry_with_backoff(operation, &config).await;
186
187        let times = attempt_times.lock().unwrap();
188        assert_eq!(times.len(), 3);
189
190        // Verify delay times increase (considering execution time tolerance)
191        if times.len() >= 2 {
192            let delay1 = times[1].duration_since(times[0]);
193            // First delay should be approximately 50ms (±20ms tolerance)
194            assert!(delay1 >= Duration::from_millis(30));
195            assert!(delay1 <= Duration::from_millis(100));
196        }
197    }
198
199    /// Test maximum delay cap limit
200    #[tokio::test]
201    async fn test_max_delay_cap() {
202        let config = RetryConfig {
203            max_attempts: 5,
204            base_delay: Duration::from_millis(100),
205            max_delay: Duration::from_millis(200), // Low cap
206            backoff_multiplier: 3.0,               // High multiplier
207        };
208
209        let attempt_times = Arc::new(Mutex::new(Vec::new()));
210        let attempt_times_clone = attempt_times.clone();
211
212        let operation = || async {
213            attempt_times_clone.lock().unwrap().push(Instant::now());
214            Err(SubXError::AiService("Always fails".to_string()))
215        };
216
217        let _result: Result<String> = retry_with_backoff(operation, &config).await;
218
219        let times = attempt_times.lock().unwrap();
220
221        // Verify subsequent delays don't exceed max_delay
222        if times.len() >= 3 {
223            let delay2 = times[2].duration_since(times[1]);
224            // Second delay should be capped at max_delay (±50ms tolerance)
225            assert!(delay2 <= Duration::from_millis(250));
226        }
227    }
228
229    /// Test configuration validity validation
230    #[test]
231    fn test_retry_config_validation() {
232        // Valid configuration
233        let valid_config = RetryConfig {
234            max_attempts: 3,
235            base_delay: Duration::from_millis(100),
236            max_delay: Duration::from_secs(1),
237            backoff_multiplier: 2.0,
238        };
239        assert!(valid_config.base_delay <= valid_config.max_delay);
240        assert!(valid_config.max_attempts > 0);
241        assert!(valid_config.backoff_multiplier > 1.0);
242    }
243
244    /// Test AI service integration simulation scenario
245    #[tokio::test]
246    async fn test_ai_service_integration_simulation() {
247        let config = RetryConfig {
248            max_attempts: 3,
249            base_delay: Duration::from_millis(10),
250            max_delay: Duration::from_secs(1),
251            backoff_multiplier: 2.0,
252        };
253
254        // Simulate AI service calls
255        let request_count = Arc::new(Mutex::new(0));
256        let request_count_clone = request_count.clone();
257
258        let mock_ai_request = || async {
259            let mut count = request_count_clone.lock().unwrap();
260            *count += 1;
261
262            match *count {
263                1 => Err(SubXError::AiService("Network timeout".to_string())),
264                2 => Err(SubXError::AiService("Rate limit exceeded".to_string())),
265                3 => Ok("AI analysis complete".to_string()),
266                _ => unreachable!(),
267            }
268        };
269
270        let result = retry_with_backoff(mock_ai_request, &config).await;
271        assert!(result.is_ok());
272        assert_eq!(result.unwrap(), "AI analysis complete");
273        assert_eq!(*request_count.lock().unwrap(), 3);
274    }
275}