Skip to main content

sgr_agent/
retry.rs

1//! RetryClient — wraps LlmClient with exponential backoff for transient errors.
2//!
3//! Retries on: rate limits (429), server errors (5xx), empty responses, network errors.
4//! Honors `retry_after_secs` from rate limit headers when available.
5
6use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12/// Retry configuration.
13#[derive(Debug, Clone)]
14pub struct RetryConfig {
15    /// Max retry attempts (0 = no retries).
16    pub max_retries: usize,
17    /// Base delay in milliseconds.
18    pub base_delay_ms: u64,
19    /// Max delay cap in milliseconds.
20    pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 3,
27            base_delay_ms: 500,
28            max_delay_ms: 30_000,
29        }
30    }
31}
32
33/// Determine if an error is retryable (transient: rate limit, timeout, server errors).
34pub fn is_retryable(err: &SgrError) -> bool {
35    match err {
36        SgrError::RateLimit { .. } => true,
37        SgrError::EmptyResponse => true,
38        // reqwest::Error — retryable if timeout or connect error
39        SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40        SgrError::Api { status, .. } => {
41            *status == 0 || *status >= 500 || *status == 408 || *status == 429
42        }
43        // Empty response wrapped as Schema error — transient model behavior
44        SgrError::Schema(msg) => msg.contains("Empty response"),
45        // MaxOutputTokens and PromptTooLong are NOT retryable at this level —
46        // they are handled by the agent loop with special recovery logic
47        SgrError::MaxOutputTokens { .. } | SgrError::PromptTooLong(_) => false,
48        _ => false,
49    }
50}
51
52/// Calculate delay for attempt N, honoring rate limit headers.
53pub fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
54    // Honor retry-after header from rate limit
55    if let Some(info) = err.rate_limit_info()
56        && let Some(secs) = info.retry_after_secs
57    {
58        return Duration::from_secs(secs + 1); // +1s safety margin
59    }
60
61    // Exponential backoff: base * 2^attempt, capped at max
62    let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
63    // Add jitter ±10%
64    let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
65    Duration::from_millis(delay_ms.saturating_add(jitter))
66}
67
68/// LLM client wrapper with automatic retry on transient errors.
69pub struct RetryClient<C: LlmClient> {
70    inner: C,
71    config: RetryConfig,
72}
73
74impl<C: LlmClient> RetryClient<C> {
75    pub fn new(inner: C) -> Self {
76        Self {
77            inner,
78            config: RetryConfig::default(),
79        }
80    }
81
82    pub fn with_config(mut self, config: RetryConfig) -> Self {
83        self.config = config;
84        self
85    }
86
87    /// Access inner client (e.g. for connect_ws on OxideClient).
88    pub fn inner(&self) -> &C {
89        &self.inner
90    }
91}
92
93#[async_trait::async_trait]
94impl<C: LlmClient> LlmClient for RetryClient<C> {
95    async fn structured_call(
96        &self,
97        messages: &[Message],
98        schema: &Value,
99    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
100        let mut last_err = None;
101        for attempt in 0..=self.config.max_retries {
102            match self.inner.structured_call(messages, schema).await {
103                Ok(result) => return Ok(result),
104                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
105                    let delay = delay_for_attempt(attempt, &self.config, &e);
106                    tracing::warn!(
107                        attempt = attempt + 1,
108                        max = self.config.max_retries,
109                        delay_ms = delay.as_millis() as u64,
110                        "Retrying structured_call: {}",
111                        e
112                    );
113                    tokio::time::sleep(delay).await;
114                    last_err = Some(e);
115                }
116                Err(e) => return Err(e),
117            }
118        }
119        Err(last_err.unwrap())
120    }
121
122    async fn tools_call(
123        &self,
124        messages: &[Message],
125        tools: &[ToolDef],
126    ) -> Result<Vec<ToolCall>, SgrError> {
127        let mut last_err = None;
128        for attempt in 0..=self.config.max_retries {
129            match self.inner.tools_call(messages, tools).await {
130                Ok(result) => return Ok(result),
131                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
132                    let delay = delay_for_attempt(attempt, &self.config, &e);
133                    tracing::warn!(
134                        attempt = attempt + 1,
135                        max = self.config.max_retries,
136                        delay_ms = delay.as_millis() as u64,
137                        "Retrying tools_call: {}",
138                        e
139                    );
140                    tokio::time::sleep(delay).await;
141                    last_err = Some(e);
142                }
143                Err(e) => return Err(e),
144            }
145        }
146        Err(last_err.unwrap())
147    }
148
149    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
150        let mut last_err = None;
151        for attempt in 0..=self.config.max_retries {
152            match self.inner.complete(messages).await {
153                Ok(result) => return Ok(result),
154                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
155                    let delay = delay_for_attempt(attempt, &self.config, &e);
156                    tracing::warn!(
157                        attempt = attempt + 1,
158                        max = self.config.max_retries,
159                        delay_ms = delay.as_millis() as u64,
160                        "Retrying complete: {}",
161                        e
162                    );
163                    tokio::time::sleep(delay).await;
164                    last_err = Some(e);
165                }
166                Err(e) => return Err(e),
167            }
168        }
169        Err(last_err.unwrap())
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use std::sync::Arc;
177    use std::sync::atomic::{AtomicUsize, Ordering};
178
179    struct FailingClient {
180        fail_count: usize,
181        call_count: Arc<AtomicUsize>,
182    }
183
184    #[async_trait::async_trait]
185    impl LlmClient for FailingClient {
186        async fn structured_call(
187            &self,
188            _: &[Message],
189            _: &Value,
190        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
191            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
192            if n < self.fail_count {
193                Err(SgrError::EmptyResponse)
194            } else {
195                Ok((None, vec![], "ok".into()))
196            }
197        }
198        async fn tools_call(
199            &self,
200            _: &[Message],
201            _: &[ToolDef],
202        ) -> Result<Vec<ToolCall>, SgrError> {
203            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
204            if n < self.fail_count {
205                Err(SgrError::Api {
206                    status: 500,
207                    body: "internal error".into(),
208                })
209            } else {
210                Ok(vec![])
211            }
212        }
213        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
214            Ok("ok".into())
215        }
216    }
217
218    #[tokio::test]
219    async fn retries_on_empty_response() {
220        let count = Arc::new(AtomicUsize::new(0));
221        let client = RetryClient::new(FailingClient {
222            fail_count: 2,
223            call_count: count.clone(),
224        })
225        .with_config(RetryConfig {
226            max_retries: 3,
227            base_delay_ms: 1,
228            max_delay_ms: 10,
229        });
230
231        let result = client
232            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
233            .await;
234        assert!(result.is_ok());
235        assert_eq!(count.load(Ordering::SeqCst), 3); // 2 fails + 1 success
236    }
237
238    #[tokio::test]
239    async fn retries_on_server_error() {
240        let count = Arc::new(AtomicUsize::new(0));
241        let client = RetryClient::new(FailingClient {
242            fail_count: 1,
243            call_count: count.clone(),
244        })
245        .with_config(RetryConfig {
246            max_retries: 2,
247            base_delay_ms: 1,
248            max_delay_ms: 10,
249        });
250
251        let result = client.tools_call(&[Message::user("hi")], &[]).await;
252        assert!(result.is_ok());
253        assert_eq!(count.load(Ordering::SeqCst), 2);
254    }
255
256    #[tokio::test]
257    async fn fails_after_max_retries() {
258        let count = Arc::new(AtomicUsize::new(0));
259        let client = RetryClient::new(FailingClient {
260            fail_count: 10,
261            call_count: count.clone(),
262        })
263        .with_config(RetryConfig {
264            max_retries: 2,
265            base_delay_ms: 1,
266            max_delay_ms: 10,
267        });
268
269        let result = client
270            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
271            .await;
272        assert!(result.is_err());
273        assert_eq!(count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
274    }
275
276    #[test]
277    fn non_retryable_errors() {
278        assert!(!is_retryable(&SgrError::Api {
279            status: 400,
280            body: "bad request".into()
281        }));
282        assert!(!is_retryable(&SgrError::Schema("parse".into())));
283        assert!(is_retryable(&SgrError::Schema(
284            "Empty response from model (parts: text)".into()
285        )));
286        assert!(is_retryable(&SgrError::EmptyResponse));
287        assert!(is_retryable(&SgrError::Api {
288            status: 503,
289            body: "server error".into()
290        }));
291        assert!(is_retryable(&SgrError::Api {
292            status: 429,
293            body: "rate limit".into()
294        }));
295    }
296
297    #[test]
298    fn delay_exponential_backoff() {
299        let config = RetryConfig {
300            max_retries: 5,
301            base_delay_ms: 100,
302            max_delay_ms: 5000,
303        };
304        let err = SgrError::EmptyResponse;
305
306        let d0 = delay_for_attempt(0, &config, &err);
307        let d1 = delay_for_attempt(1, &config, &err);
308        let d2 = delay_for_attempt(2, &config, &err);
309
310        // Roughly 100ms, 200ms, 400ms (with jitter)
311        assert!(d0.as_millis() <= 150);
312        assert!(d1.as_millis() <= 250);
313        assert!(d2.as_millis() <= 500);
314    }
315
316    #[test]
317    fn delay_capped_at_max() {
318        let config = RetryConfig {
319            max_retries: 10,
320            base_delay_ms: 1000,
321            max_delay_ms: 5000,
322        };
323        let err = SgrError::EmptyResponse;
324
325        let d10 = delay_for_attempt(10, &config, &err);
326        assert!(d10.as_millis() <= 5500); // max + jitter
327    }
328}