1use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub struct RetryConfig {
15 pub max_retries: usize,
17 pub base_delay_ms: u64,
19 pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 base_delay_ms: 500,
28 max_delay_ms: 30_000,
29 }
30 }
31}
32
33pub fn is_retryable(err: &SgrError) -> bool {
35 match err {
36 SgrError::RateLimit { .. } => true,
37 SgrError::EmptyResponse => true,
38 SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40 SgrError::Api { status, .. } => {
41 *status == 0 || *status >= 500 || *status == 408 || *status == 429
42 }
43 SgrError::Schema(msg) => msg.contains("Empty response"),
45 SgrError::MaxOutputTokens { .. } | SgrError::PromptTooLong(_) => false,
48 _ => false,
49 }
50}
51
52pub fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
54 if let Some(info) = err.rate_limit_info()
56 && let Some(secs) = info.retry_after_secs
57 {
58 return Duration::from_secs(secs + 1); }
60
61 let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
63 let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
65 Duration::from_millis(delay_ms.saturating_add(jitter))
66}
67
68pub struct RetryClient<C: LlmClient> {
70 inner: C,
71 config: RetryConfig,
72}
73
74impl<C: LlmClient> RetryClient<C> {
75 pub fn new(inner: C) -> Self {
76 Self {
77 inner,
78 config: RetryConfig::default(),
79 }
80 }
81
82 pub fn with_config(mut self, config: RetryConfig) -> Self {
83 self.config = config;
84 self
85 }
86}
87
88#[async_trait::async_trait]
89impl<C: LlmClient> LlmClient for RetryClient<C> {
90 async fn structured_call(
91 &self,
92 messages: &[Message],
93 schema: &Value,
94 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
95 let mut last_err = None;
96 for attempt in 0..=self.config.max_retries {
97 match self.inner.structured_call(messages, schema).await {
98 Ok(result) => return Ok(result),
99 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
100 let delay = delay_for_attempt(attempt, &self.config, &e);
101 tracing::warn!(
102 attempt = attempt + 1,
103 max = self.config.max_retries,
104 delay_ms = delay.as_millis() as u64,
105 "Retrying structured_call: {}",
106 e
107 );
108 tokio::time::sleep(delay).await;
109 last_err = Some(e);
110 }
111 Err(e) => return Err(e),
112 }
113 }
114 Err(last_err.unwrap())
115 }
116
117 async fn tools_call(
118 &self,
119 messages: &[Message],
120 tools: &[ToolDef],
121 ) -> Result<Vec<ToolCall>, SgrError> {
122 let mut last_err = None;
123 for attempt in 0..=self.config.max_retries {
124 match self.inner.tools_call(messages, tools).await {
125 Ok(result) => return Ok(result),
126 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
127 let delay = delay_for_attempt(attempt, &self.config, &e);
128 tracing::warn!(
129 attempt = attempt + 1,
130 max = self.config.max_retries,
131 delay_ms = delay.as_millis() as u64,
132 "Retrying tools_call: {}",
133 e
134 );
135 tokio::time::sleep(delay).await;
136 last_err = Some(e);
137 }
138 Err(e) => return Err(e),
139 }
140 }
141 Err(last_err.unwrap())
142 }
143
144 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
145 let mut last_err = None;
146 for attempt in 0..=self.config.max_retries {
147 match self.inner.complete(messages).await {
148 Ok(result) => return Ok(result),
149 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
150 let delay = delay_for_attempt(attempt, &self.config, &e);
151 tracing::warn!(
152 attempt = attempt + 1,
153 max = self.config.max_retries,
154 delay_ms = delay.as_millis() as u64,
155 "Retrying complete: {}",
156 e
157 );
158 tokio::time::sleep(delay).await;
159 last_err = Some(e);
160 }
161 Err(e) => return Err(e),
162 }
163 }
164 Err(last_err.unwrap())
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use std::sync::Arc;
172 use std::sync::atomic::{AtomicUsize, Ordering};
173
174 struct FailingClient {
175 fail_count: usize,
176 call_count: Arc<AtomicUsize>,
177 }
178
179 #[async_trait::async_trait]
180 impl LlmClient for FailingClient {
181 async fn structured_call(
182 &self,
183 _: &[Message],
184 _: &Value,
185 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
186 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
187 if n < self.fail_count {
188 Err(SgrError::EmptyResponse)
189 } else {
190 Ok((None, vec![], "ok".into()))
191 }
192 }
193 async fn tools_call(
194 &self,
195 _: &[Message],
196 _: &[ToolDef],
197 ) -> Result<Vec<ToolCall>, SgrError> {
198 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
199 if n < self.fail_count {
200 Err(SgrError::Api {
201 status: 500,
202 body: "internal error".into(),
203 })
204 } else {
205 Ok(vec![])
206 }
207 }
208 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
209 Ok("ok".into())
210 }
211 }
212
213 #[tokio::test]
214 async fn retries_on_empty_response() {
215 let count = Arc::new(AtomicUsize::new(0));
216 let client = RetryClient::new(FailingClient {
217 fail_count: 2,
218 call_count: count.clone(),
219 })
220 .with_config(RetryConfig {
221 max_retries: 3,
222 base_delay_ms: 1,
223 max_delay_ms: 10,
224 });
225
226 let result = client
227 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
228 .await;
229 assert!(result.is_ok());
230 assert_eq!(count.load(Ordering::SeqCst), 3); }
232
233 #[tokio::test]
234 async fn retries_on_server_error() {
235 let count = Arc::new(AtomicUsize::new(0));
236 let client = RetryClient::new(FailingClient {
237 fail_count: 1,
238 call_count: count.clone(),
239 })
240 .with_config(RetryConfig {
241 max_retries: 2,
242 base_delay_ms: 1,
243 max_delay_ms: 10,
244 });
245
246 let result = client.tools_call(&[Message::user("hi")], &[]).await;
247 assert!(result.is_ok());
248 assert_eq!(count.load(Ordering::SeqCst), 2);
249 }
250
251 #[tokio::test]
252 async fn fails_after_max_retries() {
253 let count = Arc::new(AtomicUsize::new(0));
254 let client = RetryClient::new(FailingClient {
255 fail_count: 10,
256 call_count: count.clone(),
257 })
258 .with_config(RetryConfig {
259 max_retries: 2,
260 base_delay_ms: 1,
261 max_delay_ms: 10,
262 });
263
264 let result = client
265 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
266 .await;
267 assert!(result.is_err());
268 assert_eq!(count.load(Ordering::SeqCst), 3); }
270
271 #[test]
272 fn non_retryable_errors() {
273 assert!(!is_retryable(&SgrError::Api {
274 status: 400,
275 body: "bad request".into()
276 }));
277 assert!(!is_retryable(&SgrError::Schema("parse".into())));
278 assert!(is_retryable(&SgrError::Schema(
279 "Empty response from model (parts: text)".into()
280 )));
281 assert!(is_retryable(&SgrError::EmptyResponse));
282 assert!(is_retryable(&SgrError::Api {
283 status: 503,
284 body: "server error".into()
285 }));
286 assert!(is_retryable(&SgrError::Api {
287 status: 429,
288 body: "rate limit".into()
289 }));
290 }
291
292 #[test]
293 fn delay_exponential_backoff() {
294 let config = RetryConfig {
295 max_retries: 5,
296 base_delay_ms: 100,
297 max_delay_ms: 5000,
298 };
299 let err = SgrError::EmptyResponse;
300
301 let d0 = delay_for_attempt(0, &config, &err);
302 let d1 = delay_for_attempt(1, &config, &err);
303 let d2 = delay_for_attempt(2, &config, &err);
304
305 assert!(d0.as_millis() <= 150);
307 assert!(d1.as_millis() <= 250);
308 assert!(d2.as_millis() <= 500);
309 }
310
311 #[test]
312 fn delay_capped_at_max() {
313 let config = RetryConfig {
314 max_retries: 10,
315 base_delay_ms: 1000,
316 max_delay_ms: 5000,
317 };
318 let err = SgrError::EmptyResponse;
319
320 let d10 = delay_for_attempt(10, &config, &err);
321 assert!(d10.as_millis() <= 5500); }
323}