Skip to main content

wesichain_core/
retry.rs

1use futures::stream::BoxStream;
2use rand::Rng;
3
4use crate::{Runnable, StreamEvent, WesichainError};
5
6pub struct Retrying<R> {
7    runnable: R,
8    max_attempts: usize,
9}
10
11impl<R> Retrying<R> {
12    pub fn new(runnable: R, max_attempts: usize) -> Self {
13        Self {
14            runnable,
15            max_attempts,
16        }
17    }
18}
19
20pub fn is_retryable(error: &WesichainError) -> bool {
21    matches!(
22        error,
23        WesichainError::LlmProvider(_)
24            | WesichainError::ToolCallFailed { .. }
25            | WesichainError::Timeout(_)
26            | WesichainError::RateLimitExceeded { .. }
27    )
28}
29
30#[async_trait::async_trait]
31impl<Input, Output, R> Runnable<Input, Output> for Retrying<R>
32where
33    Input: Send + Clone + 'static,
34    Output: Send + 'static,
35    R: Runnable<Input, Output> + Send + Sync,
36{
37    async fn invoke(&self, input: Input) -> Result<Output, WesichainError> {
38        if self.max_attempts == 0 {
39            return Err(WesichainError::MaxRetriesExceeded { max: 0 });
40        }
41
42        let mut attempt = 0;
43        loop {
44            attempt += 1;
45            match self.runnable.invoke(input.clone()).await {
46                Ok(output) => return Ok(output),
47                Err(error) => {
48                    if !is_retryable(&error) || attempt >= self.max_attempts {
49                        if attempt >= self.max_attempts {
50                            return Err(WesichainError::MaxRetriesExceeded {
51                                max: self.max_attempts,
52                            });
53                        }
54                        return Err(error);
55                    }
56
57                    // Exponential backoff: base 100ms * 2^(attempt-1)
58                    // Cap at ~10s (attempt 7+) to avoid excessive delays in interactive apps
59                    let base_delay_ms = 100u64 * (1u64 << (attempt - 1).min(7));
60                    let jitter_ms = rand::thread_rng().gen_range(0..100);
61                    let delay = std::time::Duration::from_millis(base_delay_ms + jitter_ms);
62
63                    tokio::time::sleep(delay).await;
64                }
65            }
66        }
67    }
68
69    /// Retry-on-stream-start: if the stream errors before its first item is emitted,
70    /// apply exponential backoff and re-attempt (up to `max_attempts`).
71    /// Once streaming is in progress (first item emitted), errors pass through as-is.
72    fn stream<'a>(&'a self, input: Input) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
73        use futures::StreamExt as _;
74        let runnable = &self.runnable;
75        let max_attempts = self.max_attempts;
76
77        async_stream::stream! {
78            if max_attempts == 0 {
79                yield Err(WesichainError::MaxRetriesExceeded { max: 0 });
80                return;
81            }
82
83            let mut attempt = 0usize;
84            loop {
85                attempt += 1;
86                let mut inner = runnable.stream(input.clone());
87
88                match inner.next().await {
89                    None => break,
90                    Some(first) => {
91                        if matches!(&first, Err(e) if is_retryable(e) && attempt < max_attempts) {
92                            let base_delay_ms = 100u64 * (1u64 << (attempt - 1).min(7));
93                            let jitter_ms = rand::thread_rng().gen_range(0..100u64);
94                            let delay = std::time::Duration::from_millis(base_delay_ms + jitter_ms);
95                            tokio::time::sleep(delay).await;
96                            continue;
97                        }
98
99                        // Exhausted retries on a retryable error → emit MaxRetriesExceeded
100                        let item = match first {
101                            Err(ref e) if is_retryable(e) => {
102                                Err(WesichainError::MaxRetriesExceeded { max: max_attempts })
103                            }
104                            item => item,
105                        };
106                        yield item;
107                        while let Some(event) = inner.next().await {
108                            yield event;
109                        }
110                        break;
111                    }
112                }
113            }
114        }
115        .boxed()
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use futures::stream;
123    use futures::StreamExt as _;
124    use std::sync::atomic::{AtomicUsize, Ordering};
125    use std::sync::Arc;
126
127    struct FailRunnable {
128        failures: usize,
129        count: Arc<AtomicUsize>,
130    }
131
132    #[async_trait::async_trait]
133    impl Runnable<(), ()> for FailRunnable {
134        async fn invoke(&self, _: ()) -> Result<(), WesichainError> {
135            let current = self.count.fetch_add(1, Ordering::SeqCst);
136            if current < self.failures {
137                Err(WesichainError::Timeout(std::time::Duration::from_millis(1)))
138            } else {
139                Ok(())
140            }
141        }
142
143        fn stream<'a>(&'a self, _: ()) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
144            let current = self.count.fetch_add(1, Ordering::SeqCst);
145            if current < self.failures {
146                stream::iter(vec![Err(WesichainError::Timeout(
147                    std::time::Duration::from_millis(1),
148                ))])
149                .boxed()
150            } else {
151                stream::iter(vec![Ok(StreamEvent::ContentChunk("ok".to_string()))]).boxed()
152            }
153        }
154    }
155
156    #[tokio::test]
157    async fn test_retry_success() {
158        let count = Arc::new(AtomicUsize::new(0));
159        let runnable = FailRunnable {
160            failures: 2,
161            count: count.clone(),
162        };
163        let retrying = Retrying::new(runnable, 3);
164
165        let start = std::time::Instant::now();
166        retrying.invoke(()).await.unwrap();
167        let elapsed = start.elapsed();
168
169        assert_eq!(count.load(Ordering::SeqCst), 3); // 2 fails + 1 success
170                                                     // Base delays: 100ms (attempt 1) + 200ms (attempt 2) = 300ms minimum
171        assert!(elapsed.as_millis() >= 300);
172    }
173
174    #[tokio::test]
175    async fn test_max_retries_exceeded() {
176        let count = Arc::new(AtomicUsize::new(0));
177        let runnable = FailRunnable {
178            failures: 5,
179            count: count.clone(),
180        };
181        let retrying = Retrying::new(runnable, 3);
182
183        let result = retrying.invoke(()).await;
184        assert!(matches!(
185            result,
186            Err(WesichainError::MaxRetriesExceeded { max: 3 })
187        ));
188        assert_eq!(count.load(Ordering::SeqCst), 3);
189    }
190
191    #[tokio::test]
192    async fn test_stream_retry_on_first_item_error() {
193        // Stream fails on first 2 attempts, succeeds on 3rd
194        let count = Arc::new(AtomicUsize::new(0));
195        let runnable = FailRunnable {
196            failures: 2,
197            count: count.clone(),
198        };
199        let retrying = Retrying::new(runnable, 3);
200
201        let events: Vec<_> = retrying.stream(()).collect().await;
202        // Should succeed on 3rd attempt with one ContentChunk
203        assert_eq!(events.len(), 1);
204        assert!(matches!(events[0], Ok(StreamEvent::ContentChunk(_))));
205        assert_eq!(count.load(Ordering::SeqCst), 3);
206    }
207
208    #[tokio::test]
209    async fn test_stream_max_retries_exceeded_yields_error() {
210        // Stream always fails
211        let count = Arc::new(AtomicUsize::new(0));
212        let runnable = FailRunnable {
213            failures: 10,
214            count: count.clone(),
215        };
216        let retrying = Retrying::new(runnable, 3);
217
218        let events: Vec<_> = retrying.stream(()).collect().await;
219        assert_eq!(events.len(), 1);
220        assert!(matches!(
221            events[0],
222            Err(WesichainError::MaxRetriesExceeded { max: 3 })
223        ));
224    }
225
226    #[tokio::test]
227    async fn test_stream_zero_max_attempts_yields_error() {
228        let count = Arc::new(AtomicUsize::new(0));
229        let runnable = FailRunnable {
230            failures: 0,
231            count: count.clone(),
232        };
233        let retrying = Retrying::new(runnable, 0);
234
235        let events: Vec<_> = retrying.stream(()).collect().await;
236        assert_eq!(events.len(), 1);
237        assert!(matches!(
238            events[0],
239            Err(WesichainError::MaxRetriesExceeded { max: 0 })
240        ));
241    }
242}