Skip to main content

sgr_agent/
retry.rs

1//! RetryClient — wraps LlmClient with exponential backoff for transient errors.
2//!
3//! Retries on: rate limits (429), server errors (5xx), empty responses, network errors.
4//! Honors `retry_after_secs` from rate limit headers when available.
5
6use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12/// Retry configuration.
13#[derive(Debug, Clone)]
14pub struct RetryConfig {
15    /// Max retry attempts (0 = no retries).
16    pub max_retries: usize,
17    /// Base delay in milliseconds.
18    pub base_delay_ms: u64,
19    /// Max delay cap in milliseconds.
20    pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            base_delay_ms: 500,
28            max_delay_ms: 30_000,
29        }
30    }
31}
32
33/// Determine if an error is retryable.
34fn is_retryable(err: &SgrError) -> bool {
35    match err {
36        SgrError::RateLimit { .. } => true,
37        SgrError::EmptyResponse => true,
38        // reqwest::Error — retryable if timeout or connect error
39        SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40        SgrError::Api { status, .. } => *status >= 500 || *status == 408 || *status == 429,
41        // Empty response wrapped as Schema error — transient model behavior
42        SgrError::Schema(msg) => msg.contains("Empty response"),
43        _ => false,
44    }
45}
46
47/// Calculate delay for attempt N, honoring rate limit headers.
48fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
49    // Honor retry-after header from rate limit
50    if let Some(info) = err.rate_limit_info()
51        && let Some(secs) = info.retry_after_secs
52    {
53        return Duration::from_secs(secs + 1); // +1s safety margin
54    }
55
56    // Exponential backoff: base * 2^attempt, capped at max
57    let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
58    // Add jitter ±10%
59    let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
60    Duration::from_millis(delay_ms.saturating_add(jitter))
61}
62
63/// LLM client wrapper with automatic retry on transient errors.
64pub struct RetryClient<C: LlmClient> {
65    inner: C,
66    config: RetryConfig,
67}
68
69impl<C: LlmClient> RetryClient<C> {
70    pub fn new(inner: C) -> Self {
71        Self {
72            inner,
73            config: RetryConfig::default(),
74        }
75    }
76
77    pub fn with_config(mut self, config: RetryConfig) -> Self {
78        self.config = config;
79        self
80    }
81}
82
83#[async_trait::async_trait]
84impl<C: LlmClient> LlmClient for RetryClient<C> {
85    async fn structured_call(
86        &self,
87        messages: &[Message],
88        schema: &Value,
89    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
90        let mut last_err = None;
91        for attempt in 0..=self.config.max_retries {
92            match self.inner.structured_call(messages, schema).await {
93                Ok(result) => return Ok(result),
94                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
95                    let delay = delay_for_attempt(attempt, &self.config, &e);
96                    tracing::warn!(
97                        attempt = attempt + 1,
98                        max = self.config.max_retries,
99                        delay_ms = delay.as_millis() as u64,
100                        "Retrying structured_call: {}",
101                        e
102                    );
103                    tokio::time::sleep(delay).await;
104                    last_err = Some(e);
105                }
106                Err(e) => return Err(e),
107            }
108        }
109        Err(last_err.unwrap())
110    }
111
112    async fn tools_call(
113        &self,
114        messages: &[Message],
115        tools: &[ToolDef],
116    ) -> Result<Vec<ToolCall>, SgrError> {
117        let mut last_err = None;
118        for attempt in 0..=self.config.max_retries {
119            match self.inner.tools_call(messages, tools).await {
120                Ok(result) => return Ok(result),
121                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
122                    let delay = delay_for_attempt(attempt, &self.config, &e);
123                    tracing::warn!(
124                        attempt = attempt + 1,
125                        max = self.config.max_retries,
126                        delay_ms = delay.as_millis() as u64,
127                        "Retrying tools_call: {}",
128                        e
129                    );
130                    tokio::time::sleep(delay).await;
131                    last_err = Some(e);
132                }
133                Err(e) => return Err(e),
134            }
135        }
136        Err(last_err.unwrap())
137    }
138
139    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
140        let mut last_err = None;
141        for attempt in 0..=self.config.max_retries {
142            match self.inner.complete(messages).await {
143                Ok(result) => return Ok(result),
144                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
145                    let delay = delay_for_attempt(attempt, &self.config, &e);
146                    tracing::warn!(
147                        attempt = attempt + 1,
148                        max = self.config.max_retries,
149                        delay_ms = delay.as_millis() as u64,
150                        "Retrying complete: {}",
151                        e
152                    );
153                    tokio::time::sleep(delay).await;
154                    last_err = Some(e);
155                }
156                Err(e) => return Err(e),
157            }
158        }
159        Err(last_err.unwrap())
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use std::sync::Arc;
167    use std::sync::atomic::{AtomicUsize, Ordering};
168
169    struct FailingClient {
170        fail_count: usize,
171        call_count: Arc<AtomicUsize>,
172    }
173
174    #[async_trait::async_trait]
175    impl LlmClient for FailingClient {
176        async fn structured_call(
177            &self,
178            _: &[Message],
179            _: &Value,
180        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
181            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
182            if n < self.fail_count {
183                Err(SgrError::EmptyResponse)
184            } else {
185                Ok((None, vec![], "ok".into()))
186            }
187        }
188        async fn tools_call(
189            &self,
190            _: &[Message],
191            _: &[ToolDef],
192        ) -> Result<Vec<ToolCall>, SgrError> {
193            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
194            if n < self.fail_count {
195                Err(SgrError::Api {
196                    status: 500,
197                    body: "internal error".into(),
198                })
199            } else {
200                Ok(vec![])
201            }
202        }
203        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
204            Ok("ok".into())
205        }
206    }
207
208    #[tokio::test]
209    async fn retries_on_empty_response() {
210        let count = Arc::new(AtomicUsize::new(0));
211        let client = RetryClient::new(FailingClient {
212            fail_count: 2,
213            call_count: count.clone(),
214        })
215        .with_config(RetryConfig {
216            max_retries: 3,
217            base_delay_ms: 1,
218            max_delay_ms: 10,
219        });
220
221        let result = client
222            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
223            .await;
224        assert!(result.is_ok());
225        assert_eq!(count.load(Ordering::SeqCst), 3); // 2 fails + 1 success
226    }
227
228    #[tokio::test]
229    async fn retries_on_server_error() {
230        let count = Arc::new(AtomicUsize::new(0));
231        let client = RetryClient::new(FailingClient {
232            fail_count: 1,
233            call_count: count.clone(),
234        })
235        .with_config(RetryConfig {
236            max_retries: 2,
237            base_delay_ms: 1,
238            max_delay_ms: 10,
239        });
240
241        let result = client.tools_call(&[Message::user("hi")], &[]).await;
242        assert!(result.is_ok());
243        assert_eq!(count.load(Ordering::SeqCst), 2);
244    }
245
246    #[tokio::test]
247    async fn fails_after_max_retries() {
248        let count = Arc::new(AtomicUsize::new(0));
249        let client = RetryClient::new(FailingClient {
250            fail_count: 10,
251            call_count: count.clone(),
252        })
253        .with_config(RetryConfig {
254            max_retries: 2,
255            base_delay_ms: 1,
256            max_delay_ms: 10,
257        });
258
259        let result = client
260            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
261            .await;
262        assert!(result.is_err());
263        assert_eq!(count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
264    }
265
266    #[test]
267    fn non_retryable_errors() {
268        assert!(!is_retryable(&SgrError::Api {
269            status: 400,
270            body: "bad request".into()
271        }));
272        assert!(!is_retryable(&SgrError::Schema("parse".into())));
273        assert!(is_retryable(&SgrError::Schema(
274            "Empty response from model (parts: text)".into()
275        )));
276        assert!(is_retryable(&SgrError::EmptyResponse));
277        assert!(is_retryable(&SgrError::Api {
278            status: 503,
279            body: "server error".into()
280        }));
281        assert!(is_retryable(&SgrError::Api {
282            status: 429,
283            body: "rate limit".into()
284        }));
285    }
286
287    #[test]
288    fn delay_exponential_backoff() {
289        let config = RetryConfig {
290            max_retries: 5,
291            base_delay_ms: 100,
292            max_delay_ms: 5000,
293        };
294        let err = SgrError::EmptyResponse;
295
296        let d0 = delay_for_attempt(0, &config, &err);
297        let d1 = delay_for_attempt(1, &config, &err);
298        let d2 = delay_for_attempt(2, &config, &err);
299
300        // Roughly 100ms, 200ms, 400ms (with jitter)
301        assert!(d0.as_millis() <= 150);
302        assert!(d1.as_millis() <= 250);
303        assert!(d2.as_millis() <= 500);
304    }
305
306    #[test]
307    fn delay_capped_at_max() {
308        let config = RetryConfig {
309            max_retries: 10,
310            base_delay_ms: 1000,
311            max_delay_ms: 5000,
312        };
313        let err = SgrError::EmptyResponse;
314
315        let d10 = delay_for_attempt(10, &config, &err);
316        assert!(d10.as_millis() <= 5500); // max + jitter
317    }
318}