1use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12#[derive(Debug, Clone)]
14pub struct RetryConfig {
15 pub max_retries: usize,
17 pub base_delay_ms: u64,
19 pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24 fn default() -> Self {
25 Self {
26 max_retries: 3,
27 base_delay_ms: 500,
28 max_delay_ms: 30_000,
29 }
30 }
31}
32
33fn is_retryable(err: &SgrError) -> bool {
35 match err {
36 SgrError::RateLimit { .. } => true,
37 SgrError::EmptyResponse => true,
38 SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40 SgrError::Api { status, .. } => *status >= 500 || *status == 408 || *status == 429,
41 SgrError::Schema(msg) => msg.contains("Empty response"),
43 _ => false,
44 }
45}
46
47fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
49 if let Some(info) = err.rate_limit_info()
51 && let Some(secs) = info.retry_after_secs
52 {
53 return Duration::from_secs(secs + 1); }
55
56 let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
58 let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
60 Duration::from_millis(delay_ms.saturating_add(jitter))
61}
62
63pub struct RetryClient<C: LlmClient> {
65 inner: C,
66 config: RetryConfig,
67}
68
69impl<C: LlmClient> RetryClient<C> {
70 pub fn new(inner: C) -> Self {
71 Self {
72 inner,
73 config: RetryConfig::default(),
74 }
75 }
76
77 pub fn with_config(mut self, config: RetryConfig) -> Self {
78 self.config = config;
79 self
80 }
81}
82
83#[async_trait::async_trait]
84impl<C: LlmClient> LlmClient for RetryClient<C> {
85 async fn structured_call(
86 &self,
87 messages: &[Message],
88 schema: &Value,
89 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
90 let mut last_err = None;
91 for attempt in 0..=self.config.max_retries {
92 match self.inner.structured_call(messages, schema).await {
93 Ok(result) => return Ok(result),
94 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
95 let delay = delay_for_attempt(attempt, &self.config, &e);
96 tracing::warn!(
97 attempt = attempt + 1,
98 max = self.config.max_retries,
99 delay_ms = delay.as_millis() as u64,
100 "Retrying structured_call: {}",
101 e
102 );
103 tokio::time::sleep(delay).await;
104 last_err = Some(e);
105 }
106 Err(e) => return Err(e),
107 }
108 }
109 Err(last_err.unwrap())
110 }
111
112 async fn tools_call(
113 &self,
114 messages: &[Message],
115 tools: &[ToolDef],
116 ) -> Result<Vec<ToolCall>, SgrError> {
117 let mut last_err = None;
118 for attempt in 0..=self.config.max_retries {
119 match self.inner.tools_call(messages, tools).await {
120 Ok(result) => return Ok(result),
121 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
122 let delay = delay_for_attempt(attempt, &self.config, &e);
123 tracing::warn!(
124 attempt = attempt + 1,
125 max = self.config.max_retries,
126 delay_ms = delay.as_millis() as u64,
127 "Retrying tools_call: {}",
128 e
129 );
130 tokio::time::sleep(delay).await;
131 last_err = Some(e);
132 }
133 Err(e) => return Err(e),
134 }
135 }
136 Err(last_err.unwrap())
137 }
138
139 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
140 let mut last_err = None;
141 for attempt in 0..=self.config.max_retries {
142 match self.inner.complete(messages).await {
143 Ok(result) => return Ok(result),
144 Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
145 let delay = delay_for_attempt(attempt, &self.config, &e);
146 tracing::warn!(
147 attempt = attempt + 1,
148 max = self.config.max_retries,
149 delay_ms = delay.as_millis() as u64,
150 "Retrying complete: {}",
151 e
152 );
153 tokio::time::sleep(delay).await;
154 last_err = Some(e);
155 }
156 Err(e) => return Err(e),
157 }
158 }
159 Err(last_err.unwrap())
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use std::sync::Arc;
167 use std::sync::atomic::{AtomicUsize, Ordering};
168
169 struct FailingClient {
170 fail_count: usize,
171 call_count: Arc<AtomicUsize>,
172 }
173
174 #[async_trait::async_trait]
175 impl LlmClient for FailingClient {
176 async fn structured_call(
177 &self,
178 _: &[Message],
179 _: &Value,
180 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
181 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
182 if n < self.fail_count {
183 Err(SgrError::EmptyResponse)
184 } else {
185 Ok((None, vec![], "ok".into()))
186 }
187 }
188 async fn tools_call(
189 &self,
190 _: &[Message],
191 _: &[ToolDef],
192 ) -> Result<Vec<ToolCall>, SgrError> {
193 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
194 if n < self.fail_count {
195 Err(SgrError::Api {
196 status: 500,
197 body: "internal error".into(),
198 })
199 } else {
200 Ok(vec![])
201 }
202 }
203 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
204 Ok("ok".into())
205 }
206 }
207
208 #[tokio::test]
209 async fn retries_on_empty_response() {
210 let count = Arc::new(AtomicUsize::new(0));
211 let client = RetryClient::new(FailingClient {
212 fail_count: 2,
213 call_count: count.clone(),
214 })
215 .with_config(RetryConfig {
216 max_retries: 3,
217 base_delay_ms: 1,
218 max_delay_ms: 10,
219 });
220
221 let result = client
222 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
223 .await;
224 assert!(result.is_ok());
225 assert_eq!(count.load(Ordering::SeqCst), 3); }
227
228 #[tokio::test]
229 async fn retries_on_server_error() {
230 let count = Arc::new(AtomicUsize::new(0));
231 let client = RetryClient::new(FailingClient {
232 fail_count: 1,
233 call_count: count.clone(),
234 })
235 .with_config(RetryConfig {
236 max_retries: 2,
237 base_delay_ms: 1,
238 max_delay_ms: 10,
239 });
240
241 let result = client.tools_call(&[Message::user("hi")], &[]).await;
242 assert!(result.is_ok());
243 assert_eq!(count.load(Ordering::SeqCst), 2);
244 }
245
246 #[tokio::test]
247 async fn fails_after_max_retries() {
248 let count = Arc::new(AtomicUsize::new(0));
249 let client = RetryClient::new(FailingClient {
250 fail_count: 10,
251 call_count: count.clone(),
252 })
253 .with_config(RetryConfig {
254 max_retries: 2,
255 base_delay_ms: 1,
256 max_delay_ms: 10,
257 });
258
259 let result = client
260 .structured_call(&[Message::user("hi")], &serde_json::json!({}))
261 .await;
262 assert!(result.is_err());
263 assert_eq!(count.load(Ordering::SeqCst), 3); }
265
266 #[test]
267 fn non_retryable_errors() {
268 assert!(!is_retryable(&SgrError::Api {
269 status: 400,
270 body: "bad request".into()
271 }));
272 assert!(!is_retryable(&SgrError::Schema("parse".into())));
273 assert!(is_retryable(&SgrError::Schema(
274 "Empty response from model (parts: text)".into()
275 )));
276 assert!(is_retryable(&SgrError::EmptyResponse));
277 assert!(is_retryable(&SgrError::Api {
278 status: 503,
279 body: "server error".into()
280 }));
281 assert!(is_retryable(&SgrError::Api {
282 status: 429,
283 body: "rate limit".into()
284 }));
285 }
286
287 #[test]
288 fn delay_exponential_backoff() {
289 let config = RetryConfig {
290 max_retries: 5,
291 base_delay_ms: 100,
292 max_delay_ms: 5000,
293 };
294 let err = SgrError::EmptyResponse;
295
296 let d0 = delay_for_attempt(0, &config, &err);
297 let d1 = delay_for_attempt(1, &config, &err);
298 let d2 = delay_for_attempt(2, &config, &err);
299
300 assert!(d0.as_millis() <= 150);
302 assert!(d1.as_millis() <= 250);
303 assert!(d2.as_millis() <= 500);
304 }
305
306 #[test]
307 fn delay_capped_at_max() {
308 let config = RetryConfig {
309 max_retries: 10,
310 base_delay_ms: 1000,
311 max_delay_ms: 5000,
312 };
313 let err = SgrError::EmptyResponse;
314
315 let d10 = delay_for_attempt(10, &config, &err);
316 assert!(d10.as_millis() <= 5500); }
318}