1use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub struct RetryConfig {
15 pub max_retries: usize,
17 pub base_delay_ms: u64,
19 pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 base_delay_ms: 500,
28 max_delay_ms: 30_000,
29 }
30 }
31}
32
33pub fn is_retryable(err: &SgrError) -> bool {
35 match err {
36 SgrError::RateLimit { .. } => true,
37 SgrError::EmptyResponse => true,
38 SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40 SgrError::Api { status, body } => {
41 *status >= 500
42 || *status == 408
43 || *status == 429
44 || (*status == 400 && body.contains("AiError"))
46 }
47 SgrError::Schema(msg) => msg.contains("Empty response"),
49 _ => false,
50 }
51}
52
53pub fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
55 if let Some(info) = err.rate_limit_info()
57 && let Some(secs) = info.retry_after_secs
58 {
59 return Duration::from_secs(secs + 1); }
61
62 let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
64 let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
66 Duration::from_millis(delay_ms.saturating_add(jitter))
67}
68
69pub struct RetryClient<C: LlmClient> {
71 inner: C,
72 config: RetryConfig,
73}
74
75impl<C: LlmClient> RetryClient<C> {
76 pub fn new(inner: C) -> Self {
77 Self {
78 inner,
79 config: RetryConfig::default(),
80 }
81 }
82
83 pub fn with_config(mut self, config: RetryConfig) -> Self {
84 self.config = config;
85 self
86 }
87}
88
89#[async_trait::async_trait]
90impl<C: LlmClient> LlmClient for RetryClient<C> {
91 async fn structured_call(
92 &self,
93 messages: &[Message],
94 schema: &Value,
95 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
96 let mut last_err = None;
97 for attempt in 0..=self.config.max_retries {
98 match self.inner.structured_call(messages, schema).await {
99 Ok(result) => return Ok(result),
100 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
101 let delay = delay_for_attempt(attempt, &self.config, &e);
102 tracing::warn!(
103 attempt = attempt + 1,
104 max = self.config.max_retries,
105 delay_ms = delay.as_millis() as u64,
106 "Retrying structured_call: {}",
107 e
108 );
109 tokio::time::sleep(delay).await;
110 last_err = Some(e);
111 }
112 Err(e) => return Err(e),
113 }
114 }
115 Err(last_err.unwrap())
116 }
117
118 async fn tools_call(
119 &self,
120 messages: &[Message],
121 tools: &[ToolDef],
122 ) -> Result<Vec<ToolCall>, SgrError> {
123 let mut last_err = None;
124 for attempt in 0..=self.config.max_retries {
125 match self.inner.tools_call(messages, tools).await {
126 Ok(result) => return Ok(result),
127 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
128 let delay = delay_for_attempt(attempt, &self.config, &e);
129 tracing::warn!(
130 attempt = attempt + 1,
131 max = self.config.max_retries,
132 delay_ms = delay.as_millis() as u64,
133 "Retrying tools_call: {}",
134 e
135 );
136 tokio::time::sleep(delay).await;
137 last_err = Some(e);
138 }
139 Err(e) => return Err(e),
140 }
141 }
142 Err(last_err.unwrap())
143 }
144
145 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
146 let mut last_err = None;
147 for attempt in 0..=self.config.max_retries {
148 match self.inner.complete(messages).await {
149 Ok(result) => return Ok(result),
150 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
151 let delay = delay_for_attempt(attempt, &self.config, &e);
152 tracing::warn!(
153 attempt = attempt + 1,
154 max = self.config.max_retries,
155 delay_ms = delay.as_millis() as u64,
156 "Retrying complete: {}",
157 e
158 );
159 tokio::time::sleep(delay).await;
160 last_err = Some(e);
161 }
162 Err(e) => return Err(e),
163 }
164 }
165 Err(last_err.unwrap())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use std::sync::Arc;
173 use std::sync::atomic::{AtomicUsize, Ordering};
174
175 struct FailingClient {
176 fail_count: usize,
177 call_count: Arc<AtomicUsize>,
178 }
179
180 #[async_trait::async_trait]
181 impl LlmClient for FailingClient {
182 async fn structured_call(
183 &self,
184 _: &[Message],
185 _: &Value,
186 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
187 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
188 if n < self.fail_count {
189 Err(SgrError::EmptyResponse)
190 } else {
191 Ok((None, vec![], "ok".into()))
192 }
193 }
194 async fn tools_call(
195 &self,
196 _: &[Message],
197 _: &[ToolDef],
198 ) -> Result<Vec<ToolCall>, SgrError> {
199 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
200 if n < self.fail_count {
201 Err(SgrError::Api {
202 status: 500,
203 body: "internal error".into(),
204 })
205 } else {
206 Ok(vec![])
207 }
208 }
209 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
210 Ok("ok".into())
211 }
212 }
213
214 #[tokio::test]
215 async fn retries_on_empty_response() {
216 let count = Arc::new(AtomicUsize::new(0));
217 let client = RetryClient::new(FailingClient {
218 fail_count: 2,
219 call_count: count.clone(),
220 })
221 .with_config(RetryConfig {
222 max_retries: 3,
223 base_delay_ms: 1,
224 max_delay_ms: 10,
225 });
226
227 let result = client
228 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
229 .await;
230 assert!(result.is_ok());
231 assert_eq!(count.load(Ordering::SeqCst), 3); }
233
234 #[tokio::test]
235 async fn retries_on_server_error() {
236 let count = Arc::new(AtomicUsize::new(0));
237 let client = RetryClient::new(FailingClient {
238 fail_count: 1,
239 call_count: count.clone(),
240 })
241 .with_config(RetryConfig {
242 max_retries: 2,
243 base_delay_ms: 1,
244 max_delay_ms: 10,
245 });
246
247 let result = client.tools_call(&[Message::user("hi")], &[]).await;
248 assert!(result.is_ok());
249 assert_eq!(count.load(Ordering::SeqCst), 2);
250 }
251
252 #[tokio::test]
253 async fn fails_after_max_retries() {
254 let count = Arc::new(AtomicUsize::new(0));
255 let client = RetryClient::new(FailingClient {
256 fail_count: 10,
257 call_count: count.clone(),
258 })
259 .with_config(RetryConfig {
260 max_retries: 2,
261 base_delay_ms: 1,
262 max_delay_ms: 10,
263 });
264
265 let result = client
266 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
267 .await;
268 assert!(result.is_err());
269 assert_eq!(count.load(Ordering::SeqCst), 3); }
271
272 #[test]
273 fn non_retryable_errors() {
274 assert!(!is_retryable(&SgrError::Api {
275 status: 400,
276 body: "bad request".into()
277 }));
278 assert!(!is_retryable(&SgrError::Schema("parse".into())));
279 assert!(is_retryable(&SgrError::Schema(
280 "Empty response from model (parts: text)".into()
281 )));
282 assert!(is_retryable(&SgrError::EmptyResponse));
283 assert!(is_retryable(&SgrError::Api {
284 status: 503,
285 body: "server error".into()
286 }));
287 assert!(is_retryable(&SgrError::Api {
288 status: 429,
289 body: "rate limit".into()
290 }));
291 }
292
293 #[test]
294 fn delay_exponential_backoff() {
295 let config = RetryConfig {
296 max_retries: 5,
297 base_delay_ms: 100,
298 max_delay_ms: 5000,
299 };
300 let err = SgrError::EmptyResponse;
301
302 let d0 = delay_for_attempt(0, &config, &err);
303 let d1 = delay_for_attempt(1, &config, &err);
304 let d2 = delay_for_attempt(2, &config, &err);
305
306 assert!(d0.as_millis() <= 150);
308 assert!(d1.as_millis() <= 250);
309 assert!(d2.as_millis() <= 500);
310 }
311
312 #[test]
313 fn delay_capped_at_max() {
314 let config = RetryConfig {
315 max_retries: 10,
316 base_delay_ms: 1000,
317 max_delay_ms: 5000,
318 };
319 let err = SgrError::EmptyResponse;
320
321 let d10 = delay_for_attempt(10, &config, &err);
322 assert!(d10.as_millis() <= 5500); }
324}