Skip to main content

sgr_agent/
retry.rs

1//! RetryClient — wraps LlmClient with exponential backoff for transient errors.
2//!
3//! Retries on: rate limits (429), server errors (5xx), empty responses, network errors.
4//! Honors `retry_after_secs` from rate limit headers when available.
5
6use crate::client::LlmClient;
7use crate::tool::ToolDef;
8use crate::types::{Message, SgrError, ToolCall};
9use serde_json::Value;
10use std::time::Duration;
11
12/// Retry configuration.
13#[derive(Debug, Clone)]
14pub struct RetryConfig {
15    /// Max retry attempts (0 = no retries).
16    pub max_retries: usize,
17    /// Base delay in milliseconds.
18    pub base_delay_ms: u64,
19    /// Max delay cap in milliseconds.
20    pub max_delay_ms: u64,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_retries: 5,
27            base_delay_ms: 500,
28            max_delay_ms: 30_000,
29        }
30    }
31}
32
33/// Determine if an error is retryable (transient: rate limit, timeout, server errors).
34pub fn is_retryable(err: &SgrError) -> bool {
35    match err {
36        SgrError::RateLimit { .. } => true,
37        SgrError::EmptyResponse => true,
38        // reqwest::Error — retryable if timeout or connect error
39        SgrError::Http(e) => e.is_timeout() || e.is_connect() || e.is_request(),
40        SgrError::Api { status, .. } => {
41            *status == 0 || *status >= 500 || *status == 408 || *status == 429
42        }
43        // Empty response wrapped as Schema error — transient model behavior
44        SgrError::Schema(msg) => msg.contains("Empty response"),
45        // MaxOutputTokens and PromptTooLong are NOT retryable at this level —
46        // they are handled by the agent loop with special recovery logic
47        SgrError::MaxOutputTokens { .. } | SgrError::PromptTooLong(_) => false,
48        _ => false,
49    }
50}
51
52/// Fibonacci-like delay sequence for rate limits (1, 2, 3, 5, 8, 13... seconds).
53/// More patient than exponential, avoids hammering the API.
54const FIBO_DELAYS_MS: &[u64] = &[1000, 2000, 3000, 5000, 8000, 13000, 21000, 30000];
55
56/// Calculate delay for attempt N, honoring rate limit headers.
57pub fn delay_for_attempt(attempt: usize, config: &RetryConfig, err: &SgrError) -> Duration {
58    // Honor retry-after header from rate limit
59    if let Some(info) = err.rate_limit_info()
60        && let Some(secs) = info.retry_after_secs
61    {
62        return Duration::from_secs(secs + 1); // +1s safety margin
63    }
64
65    // Rate limit (429): use Fibonacci-like delays (more patient)
66    let is_rate_limit = matches!(err, SgrError::RateLimit { .. })
67        || matches!(err, SgrError::Api { status: 429, .. })
68        || matches!(err, SgrError::Api { status: 0, body } if body.contains("429") || body.contains("rate limit"));
69    if is_rate_limit {
70        let delay_ms = FIBO_DELAYS_MS
71            .get(attempt)
72            .copied()
73            .unwrap_or(config.max_delay_ms);
74        let jitter = (delay_ms as f64 * 0.15 * fastrand()) as u64;
75        return Duration::from_millis(delay_ms + jitter);
76    }
77
78    // Other errors: exponential backoff (faster retry for transient failures)
79    let delay_ms = (config.base_delay_ms * (1 << attempt)).min(config.max_delay_ms);
80    let jitter = (delay_ms as f64 * 0.1 * (attempt as f64 % 2.0 - 0.5)) as u64;
81    Duration::from_millis(delay_ms.saturating_add(jitter))
82}
83
84/// Simple pseudo-random 0.0-1.0 (no dep needed).
85fn fastrand() -> f64 {
86    use std::time::SystemTime;
87    let t = SystemTime::now()
88        .duration_since(SystemTime::UNIX_EPOCH)
89        .unwrap_or_default();
90    ((t.subsec_nanos() as f64) / 1_000_000_000.0).fract()
91}
92
93/// LLM client wrapper with automatic retry on transient errors.
94pub struct RetryClient<C: LlmClient> {
95    inner: C,
96    config: RetryConfig,
97}
98
99impl<C: LlmClient> RetryClient<C> {
100    pub fn new(inner: C) -> Self {
101        Self {
102            inner,
103            config: RetryConfig::default(),
104        }
105    }
106
107    pub fn with_config(mut self, config: RetryConfig) -> Self {
108        self.config = config;
109        self
110    }
111
112    /// Access inner client (e.g. for connect_ws on OxideClient).
113    pub fn inner(&self) -> &C {
114        &self.inner
115    }
116}
117
118#[async_trait::async_trait]
119impl<C: LlmClient> LlmClient for RetryClient<C> {
120    async fn structured_call(
121        &self,
122        messages: &[Message],
123        schema: &Value,
124    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
125        let mut last_err = None;
126        for attempt in 0..=self.config.max_retries {
127            match self.inner.structured_call(messages, schema).await {
128                Ok(result) => return Ok(result),
129                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
130                    let delay = delay_for_attempt(attempt, &self.config, &e);
131                    tracing::warn!(
132                        attempt = attempt + 1,
133                        max = self.config.max_retries,
134                        delay_ms = delay.as_millis() as u64,
135                        "Retrying structured_call: {}",
136                        e
137                    );
138                    tokio::time::sleep(delay).await;
139                    last_err = Some(e);
140                }
141                Err(e) => return Err(e),
142            }
143        }
144        Err(last_err.unwrap())
145    }
146
147    async fn tools_call(
148        &self,
149        messages: &[Message],
150        tools: &[ToolDef],
151    ) -> Result<Vec<ToolCall>, SgrError> {
152        let mut last_err = None;
153        for attempt in 0..=self.config.max_retries {
154            match self.inner.tools_call(messages, tools).await {
155                Ok(result) => return Ok(result),
156                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
157                    let delay = delay_for_attempt(attempt, &self.config, &e);
158                    tracing::warn!(
159                        attempt = attempt + 1,
160                        max = self.config.max_retries,
161                        delay_ms = delay.as_millis() as u64,
162                        "Retrying tools_call: {}",
163                        e
164                    );
165                    tokio::time::sleep(delay).await;
166                    last_err = Some(e);
167                }
168                Err(e) => return Err(e),
169            }
170        }
171        Err(last_err.unwrap())
172    }
173
174    async fn tools_call_with_text(
175        &self,
176        messages: &[Message],
177        tools: &[ToolDef],
178    ) -> Result<(Vec<ToolCall>, String), SgrError> {
179        let mut last_err = None;
180        for attempt in 0..=self.config.max_retries {
181            match self.inner.tools_call_with_text(messages, tools).await {
182                Ok(result) => return Ok(result),
183                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
184                    let delay = delay_for_attempt(attempt, &self.config, &e);
185                    tracing::warn!(
186                        attempt = attempt + 1,
187                        max = self.config.max_retries,
188                        delay_ms = delay.as_millis() as u64,
189                        "Retrying tools_call_with_text: {}",
190                        e
191                    );
192                    tokio::time::sleep(delay).await;
193                    last_err = Some(e);
194                }
195                Err(e) => return Err(e),
196            }
197        }
198        Err(last_err.unwrap())
199    }
200
201    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
202        let mut last_err = None;
203        for attempt in 0..=self.config.max_retries {
204            match self.inner.complete(messages).await {
205                Ok(result) => return Ok(result),
206                Err(e) if is_retryable(&e) && attempt < self.config.max_retries => {
207                    let delay = delay_for_attempt(attempt, &self.config, &e);
208                    tracing::warn!(
209                        attempt = attempt + 1,
210                        max = self.config.max_retries,
211                        delay_ms = delay.as_millis() as u64,
212                        "Retrying complete: {}",
213                        e
214                    );
215                    tokio::time::sleep(delay).await;
216                    last_err = Some(e);
217                }
218                Err(e) => return Err(e),
219            }
220        }
221        Err(last_err.unwrap())
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use std::sync::Arc;
229    use std::sync::atomic::{AtomicUsize, Ordering};
230
231    struct FailingClient {
232        fail_count: usize,
233        call_count: Arc<AtomicUsize>,
234    }
235
236    #[async_trait::async_trait]
237    impl LlmClient for FailingClient {
238        async fn structured_call(
239            &self,
240            _: &[Message],
241            _: &Value,
242        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
243            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
244            if n < self.fail_count {
245                Err(SgrError::EmptyResponse)
246            } else {
247                Ok((None, vec![], "ok".into()))
248            }
249        }
250        async fn tools_call(
251            &self,
252            _: &[Message],
253            _: &[ToolDef],
254        ) -> Result<Vec<ToolCall>, SgrError> {
255            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
256            if n < self.fail_count {
257                Err(SgrError::Api {
258                    status: 500,
259                    body: "internal error".into(),
260                })
261            } else {
262                Ok(vec![])
263            }
264        }
265        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
266            Ok("ok".into())
267        }
268    }
269
270    #[tokio::test]
271    async fn retries_on_empty_response() {
272        let count = Arc::new(AtomicUsize::new(0));
273        let client = RetryClient::new(FailingClient {
274            fail_count: 2,
275            call_count: count.clone(),
276        })
277        .with_config(RetryConfig {
278            max_retries: 3,
279            base_delay_ms: 1,
280            max_delay_ms: 10,
281        });
282
283        let result = client
284            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
285            .await;
286        assert!(result.is_ok());
287        assert_eq!(count.load(Ordering::SeqCst), 3); // 2 fails + 1 success
288    }
289
290    #[tokio::test]
291    async fn retries_on_server_error() {
292        let count = Arc::new(AtomicUsize::new(0));
293        let client = RetryClient::new(FailingClient {
294            fail_count: 1,
295            call_count: count.clone(),
296        })
297        .with_config(RetryConfig {
298            max_retries: 2,
299            base_delay_ms: 1,
300            max_delay_ms: 10,
301        });
302
303        let result = client.tools_call(&[Message::user("hi")], &[]).await;
304        assert!(result.is_ok());
305        assert_eq!(count.load(Ordering::SeqCst), 2);
306    }
307
308    #[tokio::test]
309    async fn fails_after_max_retries() {
310        let count = Arc::new(AtomicUsize::new(0));
311        let client = RetryClient::new(FailingClient {
312            fail_count: 10,
313            call_count: count.clone(),
314        })
315        .with_config(RetryConfig {
316            max_retries: 2,
317            base_delay_ms: 1,
318            max_delay_ms: 10,
319        });
320
321        let result = client
322            .structured_call(&[Message::user("hi")], &serde_json::json!({}))
323            .await;
324        assert!(result.is_err());
325        assert_eq!(count.load(Ordering::SeqCst), 3); // 1 initial + 2 retries
326    }
327
328    #[test]
329    fn non_retryable_errors() {
330        assert!(!is_retryable(&SgrError::Api {
331            status: 400,
332            body: "bad request".into()
333        }));
334        assert!(!is_retryable(&SgrError::Schema("parse".into())));
335        assert!(is_retryable(&SgrError::Schema(
336            "Empty response from model (parts: text)".into()
337        )));
338        assert!(is_retryable(&SgrError::EmptyResponse));
339        assert!(is_retryable(&SgrError::Api {
340            status: 503,
341            body: "server error".into()
342        }));
343        assert!(is_retryable(&SgrError::Api {
344            status: 429,
345            body: "rate limit".into()
346        }));
347    }
348
349    #[test]
350    fn delay_exponential_backoff() {
351        let config = RetryConfig {
352            max_retries: 5,
353            base_delay_ms: 100,
354            max_delay_ms: 5000,
355        };
356        let err = SgrError::EmptyResponse;
357
358        let d0 = delay_for_attempt(0, &config, &err);
359        let d1 = delay_for_attempt(1, &config, &err);
360        let d2 = delay_for_attempt(2, &config, &err);
361
362        // Roughly 100ms, 200ms, 400ms (with jitter)
363        assert!(d0.as_millis() <= 150);
364        assert!(d1.as_millis() <= 250);
365        assert!(d2.as_millis() <= 500);
366    }
367
368    #[test]
369    fn delay_capped_at_max() {
370        let config = RetryConfig {
371            max_retries: 10,
372            base_delay_ms: 1000,
373            max_delay_ms: 5000,
374        };
375        let err = SgrError::EmptyResponse;
376
377        let d10 = delay_for_attempt(10, &config, &err);
378        assert!(d10.as_millis() <= 5500); // max + jitter
379    }
380}