1use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub struct RetryConfig {
15 pub max_retries: usize,
17 pub base_delay_ms: u64,
19 pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 5,
27 base_delay_ms: 500,
28 max_delay_ms: 30_000,
29 }
30 }
31}
32
33pub fn is_retryable(err: &SgrError) -> bool {
35 match err {
36 SgrError::RateLimit { .. } => true,
37 SgrError::EmptyResponse => true,
38 SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40 SgrError::Api { status, .. } => {
41 *status == 0 || *status >= 500 || *status == 408 || *status == 429
42 }
43 SgrError::Schema(msg) => msg.contains("Empty response"),
45 SgrError::MaxOutputTokens { .. } | SgrError::PromptTooLong(_) => false,
48 _ => false,
49 }
50}
51
52const FIBO_DELAYS_MS: &[u64] = &[1000, 2000, 3000, 5000, 8000, 13000, 21000, 30000];
55
56pub fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
58 if let Some(info) = err.rate_limit_info()
60 && let Some(secs) = info.retry_after_secs
61 {
62 return Duration::from_secs(secs + 1); }
64
65 let is_rate_limit = matches!(err, SgrError::RateLimit { .. })
67 || matches!(err, SgrError::Api { status: 429, .. })
68 || matches!(err, SgrError::Api { status: 0, body } if body.contains("429") || body.contains("rate limit"));
69 if is_rate_limit {
70 let delay_ms = FIBO_DELAYS_MS
71 .get(attempt)
72 .copied()
73 .unwrap_or(config.max_delay_ms);
74 let jitter = (delay_ms as f64 * 0.15 * fastrand()) as u64;
75 return Duration::from_millis(delay_ms + jitter);
76 }
77
78 let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
80 let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
81 Duration::from_millis(delay_ms.saturating_add(jitter))
82}
83
84fn fastrand() -> f64 {
86 use std::time::SystemTime;
87 let t = SystemTime::now()
88 .duration_since(SystemTime::UNIX_EPOCH)
89 .unwrap_or_default();
90 ((t.subsec_nanos() as f64) / 1_000_000_000.0).fract()
91}
92
93pub struct RetryClient<C: LlmClient> {
95 inner: C,
96 config: RetryConfig,
97}
98
99impl<C: LlmClient> RetryClient<C> {
100 pub fn new(inner: C) -> Self {
101 Self {
102 inner,
103 config: RetryConfig::default(),
104 }
105 }
106
107 pub fn with_config(mut self, config: RetryConfig) -> Self {
108 self.config = config;
109 self
110 }
111
112 pub fn inner(&self) -> &C {
114 &self.inner
115 }
116}
117
118#[async_trait::async_trait]
119impl<C: LlmClient> LlmClient for RetryClient<C> {
120 async fn structured_call(
121 &self,
122 messages: &[Message],
123 schema: &Value,
124 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
125 let mut last_err = None;
126 for attempt in 0..=self.config.max_retries {
127 match self.inner.structured_call(messages, schema).await {
128 Ok(result) => return Ok(result),
129 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
130 let delay = delay_for_attempt(attempt, &self.config, &e);
131 tracing::warn!(
132 attempt = attempt + 1,
133 max = self.config.max_retries,
134 delay_ms = delay.as_millis() as u64,
135 "Retrying structured_call: {}",
136 e
137 );
138 tokio::time::sleep(delay).await;
139 last_err = Some(e);
140 }
141 Err(e) => return Err(e),
142 }
143 }
144 Err(last_err.unwrap())
145 }
146
147 async fn tools_call(
148 &self,
149 messages: &[Message],
150 tools: &[ToolDef],
151 ) -> Result<Vec<ToolCall>, SgrError> {
152 let mut last_err = None;
153 for attempt in 0..=self.config.max_retries {
154 match self.inner.tools_call(messages, tools).await {
155 Ok(result) => return Ok(result),
156 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
157 let delay = delay_for_attempt(attempt, &self.config, &e);
158 tracing::warn!(
159 attempt = attempt + 1,
160 max = self.config.max_retries,
161 delay_ms = delay.as_millis() as u64,
162 "Retrying tools_call: {}",
163 e
164 );
165 tokio::time::sleep(delay).await;
166 last_err = Some(e);
167 }
168 Err(e) => return Err(e),
169 }
170 }
171 Err(last_err.unwrap())
172 }
173
174 async fn tools_call_with_text(
175 &self,
176 messages: &[Message],
177 tools: &[ToolDef],
178 ) -> Result<(Vec<ToolCall>, String), SgrError> {
179 let mut last_err = None;
180 for attempt in 0..=self.config.max_retries {
181 match self.inner.tools_call_with_text(messages, tools).await {
182 Ok(result) => return Ok(result),
183 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
184 let delay = delay_for_attempt(attempt, &self.config, &e);
185 tracing::warn!(
186 attempt = attempt + 1,
187 max = self.config.max_retries,
188 delay_ms = delay.as_millis() as u64,
189 "Retrying tools_call_with_text: {}",
190 e
191 );
192 tokio::time::sleep(delay).await;
193 last_err = Some(e);
194 }
195 Err(e) => return Err(e),
196 }
197 }
198 Err(last_err.unwrap())
199 }
200
201 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
202 let mut last_err = None;
203 for attempt in 0..=self.config.max_retries {
204 match self.inner.complete(messages).await {
205 Ok(result) => return Ok(result),
206 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
207 let delay = delay_for_attempt(attempt, &self.config, &e);
208 tracing::warn!(
209 attempt = attempt + 1,
210 max = self.config.max_retries,
211 delay_ms = delay.as_millis() as u64,
212 "Retrying complete: {}",
213 e
214 );
215 tokio::time::sleep(delay).await;
216 last_err = Some(e);
217 }
218 Err(e) => return Err(e),
219 }
220 }
221 Err(last_err.unwrap())
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use std::sync::Arc;
229 use std::sync::atomic::{AtomicUsize, Ordering};
230
231 struct FailingClient {
232 fail_count: usize,
233 call_count: Arc<AtomicUsize>,
234 }
235
236 #[async_trait::async_trait]
237 impl LlmClient for FailingClient {
238 async fn structured_call(
239 &self,
240 _: &[Message],
241 _: &Value,
242 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
243 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
244 if n < self.fail_count {
245 Err(SgrError::EmptyResponse)
246 } else {
247 Ok((None, vec![], "ok".into()))
248 }
249 }
250 async fn tools_call(
251 &self,
252 _: &[Message],
253 _: &[ToolDef],
254 ) -> Result<Vec<ToolCall>, SgrError> {
255 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
256 if n < self.fail_count {
257 Err(SgrError::Api {
258 status: 500,
259 body: "internal error".into(),
260 })
261 } else {
262 Ok(vec![])
263 }
264 }
265 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
266 Ok("ok".into())
267 }
268 }
269
270 #[tokio::test]
271 async fn retries_on_empty_response() {
272 let count = Arc::new(AtomicUsize::new(0));
273 let client = RetryClient::new(FailingClient {
274 fail_count: 2,
275 call_count: count.clone(),
276 })
277 .with_config(RetryConfig {
278 max_retries: 3,
279 base_delay_ms: 1,
280 max_delay_ms: 10,
281 });
282
283 let result = client
284 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
285 .await;
286 assert!(result.is_ok());
287 assert_eq!(count.load(Ordering::SeqCst), 3); }
289
290 #[tokio::test]
291 async fn retries_on_server_error() {
292 let count = Arc::new(AtomicUsize::new(0));
293 let client = RetryClient::new(FailingClient {
294 fail_count: 1,
295 call_count: count.clone(),
296 })
297 .with_config(RetryConfig {
298 max_retries: 2,
299 base_delay_ms: 1,
300 max_delay_ms: 10,
301 });
302
303 let result = client.tools_call(&[Message::user("hi")], &[]).await;
304 assert!(result.is_ok());
305 assert_eq!(count.load(Ordering::SeqCst), 2);
306 }
307
308 #[tokio::test]
309 async fn fails_after_max_retries() {
310 let count = Arc::new(AtomicUsize::new(0));
311 let client = RetryClient::new(FailingClient {
312 fail_count: 10,
313 call_count: count.clone(),
314 })
315 .with_config(RetryConfig {
316 max_retries: 2,
317 base_delay_ms: 1,
318 max_delay_ms: 10,
319 });
320
321 let result = client
322 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
323 .await;
324 assert!(result.is_err());
325 assert_eq!(count.load(Ordering::SeqCst), 3); }
327
328 #[test]
329 fn non_retryable_errors() {
330 assert!(!is_retryable(&SgrError::Api {
331 status: 400,
332 body: "bad request".into()
333 }));
334 assert!(!is_retryable(&SgrError::Schema("parse".into())));
335 assert!(is_retryable(&SgrError::Schema(
336 "Empty response from model (parts: text)".into()
337 )));
338 assert!(is_retryable(&SgrError::EmptyResponse));
339 assert!(is_retryable(&SgrError::Api {
340 status: 503,
341 body: "server error".into()
342 }));
343 assert!(is_retryable(&SgrError::Api {
344 status: 429,
345 body: "rate limit".into()
346 }));
347 }
348
349 #[test]
350 fn delay_exponential_backoff() {
351 let config = RetryConfig {
352 max_retries: 5,
353 base_delay_ms: 100,
354 max_delay_ms: 5000,
355 };
356 let err = SgrError::EmptyResponse;
357
358 let d0 = delay_for_attempt(0, &config, &err);
359 let d1 = delay_for_attempt(1, &config, &err);
360 let d2 = delay_for_attempt(2, &config, &err);
361
362 assert!(d0.as_millis() <= 150);
364 assert!(d1.as_millis() <= 250);
365 assert!(d2.as_millis() <= 500);
366 }
367
368 #[test]
369 fn delay_capped_at_max() {
370 let config = RetryConfig {
371 max_retries: 10,
372 base_delay_ms: 1000,
373 max_delay_ms: 5000,
374 };
375 let err = SgrError::EmptyResponse;
376
377 let d10 = delay_for_attempt(10, &config, &err);
378 assert!(d10.as_millis() <= 5500); }
380}