Skip to main content

traitclaw_core/
retry.rs

1//! Retry wrapper for LLM providers with exponential backoff.
2//!
3//! Wraps any [`Provider`] and automatically retries transient errors.
4
5use std::sync::Arc;
6use std::time::Duration;
7
8use async_trait::async_trait;
9
10use crate::traits::provider::Provider;
11use crate::types::completion::{CompletionRequest, CompletionResponse};
12use crate::types::model_info::ModelInfo;
13use crate::types::stream::CompletionStream;
14
15/// Configuration for retry behavior.
16#[derive(Debug, Clone)]
17pub struct RetryConfig {
18    /// Maximum number of retry attempts (default: 3).
19    pub max_retries: usize,
20    /// Initial delay before the first retry (default: 500ms).
21    pub initial_delay: Duration,
22    /// Maximum delay cap (default: 30s).
23    pub max_delay: Duration,
24}
25
26impl Default for RetryConfig {
27    fn default() -> Self {
28        Self {
29            max_retries: 3,
30            initial_delay: Duration::from_millis(500),
31            max_delay: Duration::from_secs(30),
32        }
33    }
34}
35
36impl RetryConfig {
37    /// Create a retry config with custom parameters.
38    #[must_use]
39    pub fn new(max_retries: usize, initial_delay: Duration, max_delay: Duration) -> Self {
40        Self {
41            max_retries,
42            initial_delay,
43            max_delay,
44        }
45    }
46}
47
48/// A provider decorator that retries transient errors with exponential backoff.
49///
50/// Uses [`Error::is_retryable()`](crate::Error::is_retryable) to classify errors.
51pub struct RetryProvider {
52    inner: Arc<dyn Provider>,
53    config: RetryConfig,
54}
55
56impl RetryProvider {
57    /// Wrap a provider with retry behavior.
58    #[must_use]
59    pub fn new(inner: Arc<dyn Provider>, config: RetryConfig) -> Self {
60        Self { inner, config }
61    }
62
63    /// Calculate delay for a given attempt (0-indexed), capped at `max_delay`.
64    #[allow(clippy::cast_possible_truncation)]
65    fn delay_for_attempt(&self, attempt: usize) -> Duration {
66        let delay = self
67            .config
68            .initial_delay
69            .saturating_mul(1u32.wrapping_shl(attempt as u32));
70        delay.min(self.config.max_delay)
71    }
72}
73
74#[async_trait]
75impl Provider for RetryProvider {
76    async fn complete(&self, req: CompletionRequest) -> crate::Result<CompletionResponse> {
77        let mut last_error = None;
78
79        for attempt in 0..=self.config.max_retries {
80            let result = self.inner.complete(req.clone()).await;
81            match result {
82                Ok(response) => return Ok(response),
83                Err(e) => {
84                    if !e.is_retryable() || attempt == self.config.max_retries {
85                        return Err(e);
86                    }
87                    let delay = self.delay_for_attempt(attempt);
88                    tracing::warn!(
89                        attempt = attempt + 1,
90                        max_retries = self.config.max_retries,
91                        delay_ms = delay.as_millis(),
92                        error = %e,
93                        "Retrying provider call"
94                    );
95                    tokio::time::sleep(delay).await;
96                    last_error = Some(e);
97                }
98            }
99        }
100
101        Err(last_error.unwrap_or_else(|| crate::Error::provider("retry exhausted")))
102    }
103
104    async fn stream(&self, req: CompletionRequest) -> crate::Result<CompletionStream> {
105        let mut last_error = None;
106
107        for attempt in 0..=self.config.max_retries {
108            let result = self.inner.stream(req.clone()).await;
109            match result {
110                Ok(stream) => return Ok(stream),
111                Err(e) => {
112                    if !e.is_retryable() || attempt == self.config.max_retries {
113                        return Err(e);
114                    }
115                    let delay = self.delay_for_attempt(attempt);
116                    tracing::warn!(
117                        attempt = attempt + 1,
118                        max_retries = self.config.max_retries,
119                        delay_ms = delay.as_millis(),
120                        error = %e,
121                        "Retrying provider stream"
122                    );
123                    tokio::time::sleep(delay).await;
124                    last_error = Some(e);
125                }
126            }
127        }
128
129        Err(last_error.unwrap_or_else(|| crate::Error::provider("retry exhausted")))
130    }
131
132    fn model_info(&self) -> &ModelInfo {
133        self.inner.model_info()
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140    use crate::types::completion::{ResponseContent, Usage};
141    use crate::types::model_info::ModelTier;
142    use std::sync::atomic::{AtomicUsize, Ordering};
143
144    struct FailThenSucceedProvider {
145        fail_count: AtomicUsize,
146        info: ModelInfo,
147    }
148
149    impl FailThenSucceedProvider {
150        fn new(fail_n_times: usize) -> Self {
151            Self {
152                fail_count: AtomicUsize::new(fail_n_times),
153                info: ModelInfo::new("test", ModelTier::Small, 4096, false, false, false),
154            }
155        }
156    }
157
158    #[async_trait]
159    impl Provider for FailThenSucceedProvider {
160        async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
161            let remaining = self.fail_count.fetch_sub(1, Ordering::SeqCst);
162            if remaining > 0 {
163                Err(crate::Error::provider_with_status("server error", 500))
164            } else {
165                Ok(CompletionResponse {
166                    content: ResponseContent::Text("ok".into()),
167                    usage: Usage {
168                        prompt_tokens: 1,
169                        completion_tokens: 1,
170                        total_tokens: 2,
171                    },
172                })
173            }
174        }
175
176        async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
177            unimplemented!()
178        }
179
180        fn model_info(&self) -> &ModelInfo {
181            &self.info
182        }
183    }
184
185    fn make_request() -> CompletionRequest {
186        CompletionRequest {
187            model: "test".into(),
188            messages: vec![],
189            tools: vec![],
190            max_tokens: None,
191            temperature: None,
192            response_format: None,
193            stream: false,
194        }
195    }
196
197    #[tokio::test]
198    async fn test_retry_succeeds_on_second_attempt() {
199        let inner = Arc::new(FailThenSucceedProvider::new(1));
200        let config = RetryConfig {
201            max_retries: 3,
202            initial_delay: Duration::from_millis(1),
203            max_delay: Duration::from_millis(10),
204        };
205        let provider = RetryProvider::new(inner, config);
206
207        let result = provider.complete(make_request()).await;
208        assert!(result.is_ok());
209    }
210
211    #[tokio::test]
212    async fn test_max_retries_exhausted() {
213        let inner = Arc::new(FailThenSucceedProvider::new(10));
214        let config = RetryConfig {
215            max_retries: 2,
216            initial_delay: Duration::from_millis(1),
217            max_delay: Duration::from_millis(10),
218        };
219        let provider = RetryProvider::new(inner, config);
220
221        let result = provider.complete(make_request()).await;
222        assert!(result.is_err());
223    }
224
225    #[tokio::test]
226    async fn test_non_retryable_error_propagated_immediately() {
227        struct NonRetryableProvider {
228            info: ModelInfo,
229        }
230
231        #[async_trait]
232        impl Provider for NonRetryableProvider {
233            async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
234                Err(crate::Error::provider_with_status("unauthorized", 401))
235            }
236            async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
237                unimplemented!()
238            }
239            fn model_info(&self) -> &ModelInfo {
240                &self.info
241            }
242        }
243
244        let inner = Arc::new(NonRetryableProvider {
245            info: ModelInfo::new("test", ModelTier::Small, 4096, false, false, false),
246        });
247        let config = RetryConfig {
248            max_retries: 3,
249            initial_delay: Duration::from_millis(1),
250            max_delay: Duration::from_millis(10),
251        };
252        let provider = RetryProvider::new(inner, config);
253
254        let result = provider.complete(make_request()).await;
255        assert!(result.is_err());
256        assert!(result.unwrap_err().to_string().contains("unauthorized"));
257    }
258
259    #[test]
260    fn test_exponential_backoff_timing() {
261        let config = RetryConfig {
262            max_retries: 5,
263            initial_delay: Duration::from_millis(100),
264            max_delay: Duration::from_secs(5),
265        };
266        let provider = RetryProvider::new(Arc::new(FailThenSucceedProvider::new(0)), config);
267
268        assert_eq!(provider.delay_for_attempt(0), Duration::from_millis(100));
269        assert_eq!(provider.delay_for_attempt(1), Duration::from_millis(200));
270        assert_eq!(provider.delay_for_attempt(2), Duration::from_millis(400));
271        assert_eq!(provider.delay_for_attempt(3), Duration::from_millis(800));
272    }
273
274    #[test]
275    fn test_max_delay_cap() {
276        let config = RetryConfig {
277            max_retries: 5,
278            initial_delay: Duration::from_secs(10),
279            max_delay: Duration::from_secs(30),
280        };
281        let provider = RetryProvider::new(Arc::new(FailThenSucceedProvider::new(0)), config);
282
283        // 10 * 2^2 = 40s, but capped at 30s
284        assert_eq!(provider.delay_for_attempt(2), Duration::from_secs(30));
285    }
286}