rusty_commit/utils/
retry.rs1use anyhow::Result;
2use backoff::{future::retry, ExponentialBackoff, ExponentialBackoffBuilder};
3use std::time::Duration;
4
5const MAX_RETRY_TIMEOUT_SECS: u64 = 120;
7
8pub fn is_retryable_error(error: &anyhow::Error) -> bool {
10 let error_msg = error.to_string().to_lowercase();
11
12 error_msg.contains("429") || error_msg.contains("rate_limit") ||
15 error_msg.contains("rate limit") ||
16 error_msg.contains("500") || error_msg.contains("502") || error_msg.contains("503") || error_msg.contains("504") || error_msg.contains("timeout") ||
21 error_msg.contains("connection") ||
22 error_msg.contains("network") ||
23 error_msg.contains("dns") ||
24 error_msg.contains("overloaded")
25}
26
27pub fn is_permanent_error(error: &anyhow::Error) -> bool {
29 let error_msg = error.to_string().to_lowercase();
30
31 error_msg.contains("401") || error_msg.contains("403") || error_msg.contains("invalid api key") ||
35 error_msg.contains("insufficient quota") ||
36 error_msg.contains("quota exceeded") ||
37 error_msg.contains("invalid request") ||
38 error_msg.contains("model not found") ||
39 error_msg.contains("400") }
41
42pub fn create_backoff() -> ExponentialBackoff {
44 ExponentialBackoffBuilder::new()
45 .with_initial_interval(Duration::from_millis(500))
46 .with_max_interval(Duration::from_secs(30))
47 .with_multiplier(2.0)
48 .with_max_elapsed_time(Some(Duration::from_secs(MAX_RETRY_TIMEOUT_SECS)))
49 .build()
50}
51
52pub async fn retry_async<F, Fut, T>(operation: F) -> Result<T>
54where
55 F: Fn() -> Fut + Send + Sync,
56 Fut: std::future::Future<Output = Result<T>> + Send,
57{
58 let backoff = create_backoff();
59
60 retry(backoff, || async {
61 match operation().await {
62 Ok(result) => Ok(result),
63 Err(error) => {
64 if is_permanent_error(&error) {
65 Err(backoff::Error::permanent(error))
67 } else if is_retryable_error(&error) {
68 Err(backoff::Error::transient(error))
70 } else {
71 Err(backoff::Error::permanent(error))
73 }
74 }
75 }
76 })
77 .await
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use anyhow::anyhow;
84
85 #[test]
86 fn test_is_retryable_error() {
87 assert!(is_retryable_error(&anyhow!("429 Rate limit exceeded")));
88 assert!(is_retryable_error(&anyhow!("500 Internal server error")));
89 assert!(is_retryable_error(&anyhow!("Connection timeout")));
90 assert!(is_retryable_error(&anyhow!("Network error")));
91 assert!(is_retryable_error(&anyhow!("Model overloaded")));
92
93 assert!(!is_retryable_error(&anyhow!("401 Unauthorized")));
94 assert!(!is_retryable_error(&anyhow!("Invalid API key")));
95 }
96
97 #[test]
98 fn test_is_permanent_error() {
99 assert!(is_permanent_error(&anyhow!("401 Unauthorized")));
100 assert!(is_permanent_error(&anyhow!("Invalid API key")));
101 assert!(is_permanent_error(&anyhow!("Insufficient quota")));
102 assert!(is_permanent_error(&anyhow!("400 Bad request")));
103
104 assert!(!is_permanent_error(&anyhow!("429 Rate limit")));
105 assert!(!is_permanent_error(&anyhow!("500 Server error")));
106 }
107
108 #[tokio::test]
109 async fn test_retry_success() {
110 use std::sync::{Arc, Mutex};
111
112 let attempts = Arc::new(Mutex::new(0));
113
114 let result = retry_async(|| {
115 let attempts = attempts.clone();
116 async move {
117 let mut attempts_lock = attempts.lock().unwrap();
118 *attempts_lock += 1;
119 if *attempts_lock < 3 {
120 Err(anyhow!("429 Rate limit"))
121 } else {
122 Ok("success".to_string())
123 }
124 }
125 })
126 .await;
127
128 assert!(result.is_ok());
129 assert_eq!(result.unwrap(), "success");
130 assert_eq!(*attempts.lock().unwrap(), 3);
131 }
132
133 #[tokio::test]
134 async fn test_retry_permanent_error() {
135 use std::sync::{Arc, Mutex};
136
137 let attempts = Arc::new(Mutex::new(0));
138
139 let result: Result<String, _> = retry_async(|| {
140 let attempts = attempts.clone();
141 async move {
142 let mut attempts_lock = attempts.lock().unwrap();
143 *attempts_lock += 1;
144 Err(anyhow!("401 Unauthorized"))
145 }
146 })
147 .await;
148
149 assert!(result.is_err());
150 assert_eq!(*attempts.lock().unwrap(), 1); }
152}