1use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub struct RetryConfig {
15 pub max_retries: usize,
17 pub base_delay_ms: u64,
19 pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 base_delay_ms: 500,
28 max_delay_ms: 30_000,
29 }
30 }
31}
32
33pub fn is_retryable(err: &SgrError) -> bool {
35 match err {
36 SgrError::RateLimit { .. } => true,
37 SgrError::EmptyResponse => true,
38 SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40 SgrError::Api { status, .. } => {
41 *status == 0 || *status >= 500 || *status == 408 || *status == 429
42 }
43 SgrError::Schema(msg) => msg.contains("Empty response"),
45 SgrError::MaxOutputTokens { .. } | SgrError::PromptTooLong(_) => false,
48 _ => false,
49 }
50}
51
52pub fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
54 if let Some(info) = err.rate_limit_info()
56 && let Some(secs) = info.retry_after_secs
57 {
58 return Duration::from_secs(secs + 1); }
60
61 let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
63 let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
65 Duration::from_millis(delay_ms.saturating_add(jitter))
66}
67
68pub struct RetryClient<C: LlmClient> {
70 inner: C,
71 config: RetryConfig,
72}
73
74impl<C: LlmClient> RetryClient<C> {
75 pub fn new(inner: C) -> Self {
76 Self {
77 inner,
78 config: RetryConfig::default(),
79 }
80 }
81
82 pub fn with_config(mut self, config: RetryConfig) -> Self {
83 self.config = config;
84 self
85 }
86
87 pub fn inner(&self) -> &C {
89 &self.inner
90 }
91}
92
93#[async_trait::async_trait]
94impl<C: LlmClient> LlmClient for RetryClient<C> {
95 async fn structured_call(
96 &self,
97 messages: &[Message],
98 schema: &Value,
99 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
100 let mut last_err = None;
101 for attempt in 0..=self.config.max_retries {
102 match self.inner.structured_call(messages, schema).await {
103 Ok(result) => return Ok(result),
104 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
105 let delay = delay_for_attempt(attempt, &self.config, &e);
106 tracing::warn!(
107 attempt = attempt + 1,
108 max = self.config.max_retries,
109 delay_ms = delay.as_millis() as u64,
110 "Retrying structured_call: {}",
111 e
112 );
113 tokio::time::sleep(delay).await;
114 last_err = Some(e);
115 }
116 Err(e) => return Err(e),
117 }
118 }
119 Err(last_err.unwrap())
120 }
121
122 async fn tools_call(
123 &self,
124 messages: &[Message],
125 tools: &[ToolDef],
126 ) -> Result<Vec<ToolCall>, SgrError> {
127 let mut last_err = None;
128 for attempt in 0..=self.config.max_retries {
129 match self.inner.tools_call(messages, tools).await {
130 Ok(result) => return Ok(result),
131 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
132 let delay = delay_for_attempt(attempt, &self.config, &e);
133 tracing::warn!(
134 attempt = attempt + 1,
135 max = self.config.max_retries,
136 delay_ms = delay.as_millis() as u64,
137 "Retrying tools_call: {}",
138 e
139 );
140 tokio::time::sleep(delay).await;
141 last_err = Some(e);
142 }
143 Err(e) => return Err(e),
144 }
145 }
146 Err(last_err.unwrap())
147 }
148
149 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
150 let mut last_err = None;
151 for attempt in 0..=self.config.max_retries {
152 match self.inner.complete(messages).await {
153 Ok(result) => return Ok(result),
154 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
155 let delay = delay_for_attempt(attempt, &self.config, &e);
156 tracing::warn!(
157 attempt = attempt + 1,
158 max = self.config.max_retries,
159 delay_ms = delay.as_millis() as u64,
160 "Retrying complete: {}",
161 e
162 );
163 tokio::time::sleep(delay).await;
164 last_err = Some(e);
165 }
166 Err(e) => return Err(e),
167 }
168 }
169 Err(last_err.unwrap())
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use std::sync::Arc;
177 use std::sync::atomic::{AtomicUsize, Ordering};
178
179 struct FailingClient {
180 fail_count: usize,
181 call_count: Arc<AtomicUsize>,
182 }
183
184 #[async_trait::async_trait]
185 impl LlmClient for FailingClient {
186 async fn structured_call(
187 &self,
188 _: &[Message],
189 _: &Value,
190 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
191 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
192 if n < self.fail_count {
193 Err(SgrError::EmptyResponse)
194 } else {
195 Ok((None, vec![], "ok".into()))
196 }
197 }
198 async fn tools_call(
199 &self,
200 _: &[Message],
201 _: &[ToolDef],
202 ) -> Result<Vec<ToolCall>, SgrError> {
203 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
204 if n < self.fail_count {
205 Err(SgrError::Api {
206 status: 500,
207 body: "internal error".into(),
208 })
209 } else {
210 Ok(vec![])
211 }
212 }
213 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
214 Ok("ok".into())
215 }
216 }
217
218 #[tokio::test]
219 async fn retries_on_empty_response() {
220 let count = Arc::new(AtomicUsize::new(0));
221 let client = RetryClient::new(FailingClient {
222 fail_count: 2,
223 call_count: count.clone(),
224 })
225 .with_config(RetryConfig {
226 max_retries: 3,
227 base_delay_ms: 1,
228 max_delay_ms: 10,
229 });
230
231 let result = client
232 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
233 .await;
234 assert!(result.is_ok());
235 assert_eq!(count.load(Ordering::SeqCst), 3); }
237
238 #[tokio::test]
239 async fn retries_on_server_error() {
240 let count = Arc::new(AtomicUsize::new(0));
241 let client = RetryClient::new(FailingClient {
242 fail_count: 1,
243 call_count: count.clone(),
244 })
245 .with_config(RetryConfig {
246 max_retries: 2,
247 base_delay_ms: 1,
248 max_delay_ms: 10,
249 });
250
251 let result = client.tools_call(&[Message::user("hi")], &[]).await;
252 assert!(result.is_ok());
253 assert_eq!(count.load(Ordering::SeqCst), 2);
254 }
255
256 #[tokio::test]
257 async fn fails_after_max_retries() {
258 let count = Arc::new(AtomicUsize::new(0));
259 let client = RetryClient::new(FailingClient {
260 fail_count: 10,
261 call_count: count.clone(),
262 })
263 .with_config(RetryConfig {
264 max_retries: 2,
265 base_delay_ms: 1,
266 max_delay_ms: 10,
267 });
268
269 let result = client
270 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
271 .await;
272 assert!(result.is_err());
273 assert_eq!(count.load(Ordering::SeqCst), 3); }
275
276 #[test]
277 fn non_retryable_errors() {
278 assert!(!is_retryable(&SgrError::Api {
279 status: 400,
280 body: "bad request".into()
281 }));
282 assert!(!is_retryable(&SgrError::Schema("parse".into())));
283 assert!(is_retryable(&SgrError::Schema(
284 "Empty response from model (parts: text)".into()
285 )));
286 assert!(is_retryable(&SgrError::EmptyResponse));
287 assert!(is_retryable(&SgrError::Api {
288 status: 503,
289 body: "server error".into()
290 }));
291 assert!(is_retryable(&SgrError::Api {
292 status: 429,
293 body: "rate limit".into()
294 }));
295 }
296
297 #[test]
298 fn delay_exponential_backoff() {
299 let config = RetryConfig {
300 max_retries: 5,
301 base_delay_ms: 100,
302 max_delay_ms: 5000,
303 };
304 let err = SgrError::EmptyResponse;
305
306 let d0 = delay_for_attempt(0, &config, &err);
307 let d1 = delay_for_attempt(1, &config, &err);
308 let d2 = delay_for_attempt(2, &config, &err);
309
310 assert!(d0.as_millis() <= 150);
312 assert!(d1.as_millis() <= 250);
313 assert!(d2.as_millis() <= 500);
314 }
315
316 #[test]
317 fn delay_capped_at_max() {
318 let config = RetryConfig {
319 max_retries: 10,
320 base_delay_ms: 1000,
321 max_delay_ms: 5000,
322 };
323 let err = SgrError::EmptyResponse;
324
325 let d10 = delay_for_attempt(10, &config, &err);
326 assert!(d10.as_millis() <= 5500); }
328}