1use std::sync::Arc;
6use std::time::Duration;
7
8use async_trait::async_trait;
9
10use crate::traits::provider::Provider;
11use crate::types::completion::{CompletionRequest, CompletionResponse};
12use crate::types::model_info::ModelInfo;
13use crate::types::stream::CompletionStream;
14
15#[derive(Debug, Clone)]
17pub struct RetryConfig {
18 pub max_retries: usize,
20 pub initial_delay: Duration,
22 pub max_delay: Duration,
24}
25
26impl Default for RetryConfig {
27 fn default() -> Self {
28 Self {
29 max_retries: 3,
30 initial_delay: Duration::from_millis(500),
31 max_delay: Duration::from_secs(30),
32 }
33 }
34}
35
36impl RetryConfig {
37 #[must_use]
39 pub fn new(max_retries: usize, initial_delay: Duration, max_delay: Duration) -> Self {
40 Self {
41 max_retries,
42 initial_delay,
43 max_delay,
44 }
45 }
46}
47
48pub struct RetryProvider {
52 inner: Arc<dyn Provider>,
53 config: RetryConfig,
54}
55
56impl RetryProvider {
57 #[must_use]
59 pub fn new(inner: Arc<dyn Provider>, config: RetryConfig) -> Self {
60 Self { inner, config }
61 }
62
63 #[allow(clippy::cast_possible_truncation)]
65 fn delay_for_attempt(&self, attempt: usize) -> Duration {
66 let delay = self
67 .config
68 .initial_delay
69 .saturating_mul(1u32.wrapping_shl(attempt as u32));
70 delay.min(self.config.max_delay)
71 }
72}
73
74#[async_trait]
75impl Provider for RetryProvider {
76 async fn complete(&self, req: CompletionRequest) -> crate::Result<CompletionResponse> {
77 let mut last_error = None;
78
79 for attempt in 0..=self.config.max_retries {
80 let result = self.inner.complete(req.clone()).await;
81 match result {
82 Ok(response) => return Ok(response),
83 Err(e) => {
84 if !e.is_retryable() || attempt == self.config.max_retries {
85 return Err(e);
86 }
87 let delay = self.delay_for_attempt(attempt);
88 tracing::warn!(
89 attempt = attempt + 1,
90 max_retries = self.config.max_retries,
91 delay_ms = delay.as_millis(),
92 error = %e,
93 "Retrying provider call"
94 );
95 tokio::time::sleep(delay).await;
96 last_error = Some(e);
97 }
98 }
99 }
100
101 Err(last_error.unwrap_or_else(|| crate::Error::provider("retry exhausted")))
102 }
103
104 async fn stream(&self, req: CompletionRequest) -> crate::Result<CompletionStream> {
105 let mut last_error = None;
106
107 for attempt in 0..=self.config.max_retries {
108 let result = self.inner.stream(req.clone()).await;
109 match result {
110 Ok(stream) => return Ok(stream),
111 Err(e) => {
112 if !e.is_retryable() || attempt == self.config.max_retries {
113 return Err(e);
114 }
115 let delay = self.delay_for_attempt(attempt);
116 tracing::warn!(
117 attempt = attempt + 1,
118 max_retries = self.config.max_retries,
119 delay_ms = delay.as_millis(),
120 error = %e,
121 "Retrying provider stream"
122 );
123 tokio::time::sleep(delay).await;
124 last_error = Some(e);
125 }
126 }
127 }
128
129 Err(last_error.unwrap_or_else(|| crate::Error::provider("retry exhausted")))
130 }
131
132 fn model_info(&self) -> &ModelInfo {
133 self.inner.model_info()
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140 use crate::types::completion::{ResponseContent, Usage};
141 use crate::types::model_info::ModelTier;
142 use std::sync::atomic::{AtomicUsize, Ordering};
143
144 struct FailThenSucceedProvider {
145 fail_count: AtomicUsize,
146 info: ModelInfo,
147 }
148
149 impl FailThenSucceedProvider {
150 fn new(fail_n_times: usize) -> Self {
151 Self {
152 fail_count: AtomicUsize::new(fail_n_times),
153 info: ModelInfo::new("test", ModelTier::Small, 4096, false, false, false),
154 }
155 }
156 }
157
158 #[async_trait]
159 impl Provider for FailThenSucceedProvider {
160 async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
161 let remaining = self.fail_count.fetch_sub(1, Ordering::SeqCst);
162 if remaining > 0 {
163 Err(crate::Error::provider_with_status("server error", 500))
164 } else {
165 Ok(CompletionResponse {
166 content: ResponseContent::Text("ok".into()),
167 usage: Usage {
168 prompt_tokens: 1,
169 completion_tokens: 1,
170 total_tokens: 2,
171 },
172 })
173 }
174 }
175
176 async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
177 unimplemented!()
178 }
179
180 fn model_info(&self) -> &ModelInfo {
181 &self.info
182 }
183 }
184
185 fn make_request() -> CompletionRequest {
186 CompletionRequest {
187 model: "test".into(),
188 messages: vec![],
189 tools: vec![],
190 max_tokens: None,
191 temperature: None,
192 response_format: None,
193 stream: false,
194 }
195 }
196
197 #[tokio::test]
198 async fn test_retry_succeeds_on_second_attempt() {
199 let inner = Arc::new(FailThenSucceedProvider::new(1));
200 let config = RetryConfig {
201 max_retries: 3,
202 initial_delay: Duration::from_millis(1),
203 max_delay: Duration::from_millis(10),
204 };
205 let provider = RetryProvider::new(inner, config);
206
207 let result = provider.complete(make_request()).await;
208 assert!(result.is_ok());
209 }
210
211 #[tokio::test]
212 async fn test_max_retries_exhausted() {
213 let inner = Arc::new(FailThenSucceedProvider::new(10));
214 let config = RetryConfig {
215 max_retries: 2,
216 initial_delay: Duration::from_millis(1),
217 max_delay: Duration::from_millis(10),
218 };
219 let provider = RetryProvider::new(inner, config);
220
221 let result = provider.complete(make_request()).await;
222 assert!(result.is_err());
223 }
224
225 #[tokio::test]
226 async fn test_non_retryable_error_propagated_immediately() {
227 struct NonRetryableProvider {
228 info: ModelInfo,
229 }
230
231 #[async_trait]
232 impl Provider for NonRetryableProvider {
233 async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
234 Err(crate::Error::provider_with_status("unauthorized", 401))
235 }
236 async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
237 unimplemented!()
238 }
239 fn model_info(&self) -> &ModelInfo {
240 &self.info
241 }
242 }
243
244 let inner = Arc::new(NonRetryableProvider {
245 info: ModelInfo::new("test", ModelTier::Small, 4096, false, false, false),
246 });
247 let config = RetryConfig {
248 max_retries: 3,
249 initial_delay: Duration::from_millis(1),
250 max_delay: Duration::from_millis(10),
251 };
252 let provider = RetryProvider::new(inner, config);
253
254 let result = provider.complete(make_request()).await;
255 assert!(result.is_err());
256 assert!(result.unwrap_err().to_string().contains("unauthorized"));
257 }
258
259 #[test]
260 fn test_exponential_backoff_timing() {
261 let config = RetryConfig {
262 max_retries: 5,
263 initial_delay: Duration::from_millis(100),
264 max_delay: Duration::from_secs(5),
265 };
266 let provider = RetryProvider::new(Arc::new(FailThenSucceedProvider::new(0)), config);
267
268 assert_eq!(provider.delay_for_attempt(0), Duration::from_millis(100));
269 assert_eq!(provider.delay_for_attempt(1), Duration::from_millis(200));
270 assert_eq!(provider.delay_for_attempt(2), Duration::from_millis(400));
271 assert_eq!(provider.delay_for_attempt(3), Duration::from_millis(800));
272 }
273
274 #[test]
275 fn test_max_delay_cap() {
276 let config = RetryConfig {
277 max_retries: 5,
278 initial_delay: Duration::from_secs(10),
279 max_delay: Duration::from_secs(30),
280 };
281 let provider = RetryProvider::new(Arc::new(FailThenSucceedProvider::new(0)), config);
282
283 assert_eq!(provider.delay_for_attempt(2), Duration::from_secs(30));
285 }
286}