Skip to main content

rusty_commit/utils/
retry.rs

1use anyhow::Result;
2use backoff::{future::retry, ExponentialBackoff, ExponentialBackoffBuilder};
3use std::time::Duration;
4
5/// Maximum total time to spend retrying before giving up
6const MAX_RETRY_TIMEOUT_SECS: u64 = 120;
7
8/// Determines if an error is retryable
9pub fn is_retryable_error(error: &anyhow::Error) -> bool {
10    let error_msg = error.to_string().to_lowercase();
11
12    // Retryable errors: network issues, timeouts, rate limits, server errors
13    error_msg.contains("429") ||  // Rate limit
14    error_msg.contains("rate_limit") ||
15    error_msg.contains("rate limit") ||
16    error_msg.contains("500") ||  // Internal server error
17    error_msg.contains("502") ||  // Bad gateway
18    error_msg.contains("503") ||  // Service unavailable
19    error_msg.contains("504") ||  // Gateway timeout
20    error_msg.contains("timeout") ||
21    error_msg.contains("connection") ||
22    error_msg.contains("network") ||
23    error_msg.contains("dns") ||
24    error_msg.contains("overloaded")
25}
26
27/// Determines if an error is permanent (should not retry)
28pub fn is_permanent_error(error: &anyhow::Error) -> bool {
29    let error_msg = error.to_string().to_lowercase();
30
31    // Permanent errors: auth issues, invalid requests, quota exceeded
32    error_msg.contains("401") ||  // Unauthorized
33    error_msg.contains("403") ||  // Forbidden
34    error_msg.contains("invalid api key") ||
35    error_msg.contains("insufficient quota") ||
36    error_msg.contains("quota exceeded") ||
37    error_msg.contains("invalid request") ||
38    error_msg.contains("model not found") ||
39    error_msg.contains("400") // Bad request
40}
41
42/// Create a backoff policy for API retries
43pub fn create_backoff() -> ExponentialBackoff {
44    ExponentialBackoffBuilder::new()
45        .with_initial_interval(Duration::from_millis(500))
46        .with_max_interval(Duration::from_secs(30))
47        .with_multiplier(2.0)
48        .with_max_elapsed_time(Some(Duration::from_secs(MAX_RETRY_TIMEOUT_SECS)))
49        .build()
50}
51
52/// Retry an async operation with exponential backoff
53pub async fn retry_async<F, Fut, T>(operation: F) -> Result<T>
54where
55    F: Fn() -> Fut + Send + Sync,
56    Fut: std::future::Future<Output = Result<T>> + Send,
57{
58    let backoff = create_backoff();
59
60    retry(backoff, || async {
61        match operation().await {
62            Ok(result) => Ok(result),
63            Err(error) => {
64                if is_permanent_error(&error) {
65                    // Don't retry permanent errors
66                    Err(backoff::Error::permanent(error))
67                } else if is_retryable_error(&error) {
68                    // Retry transient errors
69                    Err(backoff::Error::transient(error))
70                } else {
71                    // Unknown errors - treat as permanent to be safe
72                    Err(backoff::Error::permanent(error))
73                }
74            }
75        }
76    })
77    .await
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use anyhow::anyhow;
84
85    #[test]
86    fn test_is_retryable_error() {
87        assert!(is_retryable_error(&anyhow!("429 Rate limit exceeded")));
88        assert!(is_retryable_error(&anyhow!("500 Internal server error")));
89        assert!(is_retryable_error(&anyhow!("Connection timeout")));
90        assert!(is_retryable_error(&anyhow!("Network error")));
91        assert!(is_retryable_error(&anyhow!("Model overloaded")));
92
93        assert!(!is_retryable_error(&anyhow!("401 Unauthorized")));
94        assert!(!is_retryable_error(&anyhow!("Invalid API key")));
95    }
96
97    #[test]
98    fn test_is_permanent_error() {
99        assert!(is_permanent_error(&anyhow!("401 Unauthorized")));
100        assert!(is_permanent_error(&anyhow!("Invalid API key")));
101        assert!(is_permanent_error(&anyhow!("Insufficient quota")));
102        assert!(is_permanent_error(&anyhow!("400 Bad request")));
103
104        assert!(!is_permanent_error(&anyhow!("429 Rate limit")));
105        assert!(!is_permanent_error(&anyhow!("500 Server error")));
106    }
107
108    #[tokio::test]
109    async fn test_retry_success() {
110        use std::sync::{Arc, Mutex};
111
112        let attempts = Arc::new(Mutex::new(0));
113
114        let result = retry_async(|| {
115            let attempts = attempts.clone();
116            async move {
117                let mut attempts_lock = attempts.lock().unwrap();
118                *attempts_lock += 1;
119                if *attempts_lock < 3 {
120                    Err(anyhow!("429 Rate limit"))
121                } else {
122                    Ok("success".to_string())
123                }
124            }
125        })
126        .await;
127
128        assert!(result.is_ok());
129        assert_eq!(result.unwrap(), "success");
130        assert_eq!(*attempts.lock().unwrap(), 3);
131    }
132
133    #[tokio::test]
134    async fn test_retry_permanent_error() {
135        use std::sync::{Arc, Mutex};
136
137        let attempts = Arc::new(Mutex::new(0));
138
139        let result: Result<String, _> = retry_async(|| {
140            let attempts = attempts.clone();
141            async move {
142                let mut attempts_lock = attempts.lock().unwrap();
143                *attempts_lock += 1;
144                Err(anyhow!("401 Unauthorized"))
145            }
146        })
147        .await;
148
149        assert!(result.is_err());
150        assert_eq!(*attempts.lock().unwrap(), 1); // Should not retry permanent errors
151    }
152}