1use futures::stream::BoxStream;
2use rand::Rng;
3
4use crate::{Runnable, StreamEvent, WesichainError};
5
6pub struct Retrying<R> {
7 runnable: R,
8 max_attempts: usize,
9}
10
11impl<R> Retrying<R> {
12 pub fn new(runnable: R, max_attempts: usize) -> Self {
13 Self {
14 runnable,
15 max_attempts,
16 }
17 }
18}
19
20pub fn is_retryable(error: &WesichainError) -> bool {
21 matches!(
22 error,
23 WesichainError::LlmProvider(_)
24 | WesichainError::ToolCallFailed { .. }
25 | WesichainError::Timeout(_)
26 | WesichainError::RateLimitExceeded { .. }
27 )
28}
29
30#[async_trait::async_trait]
31impl<Input, Output, R> Runnable<Input, Output> for Retrying<R>
32where
33 Input: Send + Clone + 'static,
34 Output: Send + 'static,
35 R: Runnable<Input, Output> + Send + Sync,
36{
37 async fn invoke(&self, input: Input) -> Result<Output, WesichainError> {
38 if self.max_attempts == 0 {
39 return Err(WesichainError::MaxRetriesExceeded { max: 0 });
40 }
41
42 let mut attempt = 0;
43 loop {
44 attempt += 1;
45 match self.runnable.invoke(input.clone()).await {
46 Ok(output) => return Ok(output),
47 Err(error) => {
48 if !is_retryable(&error) || attempt >= self.max_attempts {
49 if attempt >= self.max_attempts {
50 return Err(WesichainError::MaxRetriesExceeded {
51 max: self.max_attempts,
52 });
53 }
54 return Err(error);
55 }
56
57 let base_delay_ms = 100u64 * (1u64 << (attempt - 1).min(7));
60 let jitter_ms = rand::thread_rng().gen_range(0..100);
61 let delay = std::time::Duration::from_millis(base_delay_ms + jitter_ms);
62
63 tokio::time::sleep(delay).await;
64 }
65 }
66 }
67 }
68
69 fn stream<'a>(&'a self, input: Input) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
73 use futures::StreamExt as _;
74 let runnable = &self.runnable;
75 let max_attempts = self.max_attempts;
76
77 async_stream::stream! {
78 if max_attempts == 0 {
79 yield Err(WesichainError::MaxRetriesExceeded { max: 0 });
80 return;
81 }
82
83 let mut attempt = 0usize;
84 loop {
85 attempt += 1;
86 let mut inner = runnable.stream(input.clone());
87
88 match inner.next().await {
89 None => break,
90 Some(first) => {
91 if matches!(&first, Err(e) if is_retryable(e) && attempt < max_attempts) {
92 let base_delay_ms = 100u64 * (1u64 << (attempt - 1).min(7));
93 let jitter_ms = rand::thread_rng().gen_range(0..100u64);
94 let delay = std::time::Duration::from_millis(base_delay_ms + jitter_ms);
95 tokio::time::sleep(delay).await;
96 continue;
97 }
98
99 let item = match first {
101 Err(ref e) if is_retryable(e) => {
102 Err(WesichainError::MaxRetriesExceeded { max: max_attempts })
103 }
104 item => item,
105 };
106 yield item;
107 while let Some(event) = inner.next().await {
108 yield event;
109 }
110 break;
111 }
112 }
113 }
114 }
115 .boxed()
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use futures::stream;
123 use futures::StreamExt as _;
124 use std::sync::atomic::{AtomicUsize, Ordering};
125 use std::sync::Arc;
126
127 struct FailRunnable {
128 failures: usize,
129 count: Arc<AtomicUsize>,
130 }
131
132 #[async_trait::async_trait]
133 impl Runnable<(), ()> for FailRunnable {
134 async fn invoke(&self, _: ()) -> Result<(), WesichainError> {
135 let current = self.count.fetch_add(1, Ordering::SeqCst);
136 if current < self.failures {
137 Err(WesichainError::Timeout(std::time::Duration::from_millis(1)))
138 } else {
139 Ok(())
140 }
141 }
142
143 fn stream<'a>(&'a self, _: ()) -> BoxStream<'a, Result<StreamEvent, WesichainError>> {
144 let current = self.count.fetch_add(1, Ordering::SeqCst);
145 if current < self.failures {
146 stream::iter(vec![Err(WesichainError::Timeout(
147 std::time::Duration::from_millis(1),
148 ))])
149 .boxed()
150 } else {
151 stream::iter(vec![Ok(StreamEvent::ContentChunk("ok".to_string()))]).boxed()
152 }
153 }
154 }
155
156 #[tokio::test]
157 async fn test_retry_success() {
158 let count = Arc::new(AtomicUsize::new(0));
159 let runnable = FailRunnable {
160 failures: 2,
161 count: count.clone(),
162 };
163 let retrying = Retrying::new(runnable, 3);
164
165 let start = std::time::Instant::now();
166 retrying.invoke(()).await.unwrap();
167 let elapsed = start.elapsed();
168
169 assert_eq!(count.load(Ordering::SeqCst), 3); assert!(elapsed.as_millis() >= 300);
172 }
173
174 #[tokio::test]
175 async fn test_max_retries_exceeded() {
176 let count = Arc::new(AtomicUsize::new(0));
177 let runnable = FailRunnable {
178 failures: 5,
179 count: count.clone(),
180 };
181 let retrying = Retrying::new(runnable, 3);
182
183 let result = retrying.invoke(()).await;
184 assert!(matches!(
185 result,
186 Err(WesichainError::MaxRetriesExceeded { max: 3 })
187 ));
188 assert_eq!(count.load(Ordering::SeqCst), 3);
189 }
190
191 #[tokio::test]
192 async fn test_stream_retry_on_first_item_error() {
193 let count = Arc::new(AtomicUsize::new(0));
195 let runnable = FailRunnable {
196 failures: 2,
197 count: count.clone(),
198 };
199 let retrying = Retrying::new(runnable, 3);
200
201 let events: Vec<_> = retrying.stream(()).collect().await;
202 assert_eq!(events.len(), 1);
204 assert!(matches!(events[0], Ok(StreamEvent::ContentChunk(_))));
205 assert_eq!(count.load(Ordering::SeqCst), 3);
206 }
207
208 #[tokio::test]
209 async fn test_stream_max_retries_exceeded_yields_error() {
210 let count = Arc::new(AtomicUsize::new(0));
212 let runnable = FailRunnable {
213 failures: 10,
214 count: count.clone(),
215 };
216 let retrying = Retrying::new(runnable, 3);
217
218 let events: Vec<_> = retrying.stream(()).collect().await;
219 assert_eq!(events.len(), 1);
220 assert!(matches!(
221 events[0],
222 Err(WesichainError::MaxRetriesExceeded { max: 3 })
223 ));
224 }
225
226 #[tokio::test]
227 async fn test_stream_zero_max_attempts_yields_error() {
228 let count = Arc::new(AtomicUsize::new(0));
229 let runnable = FailRunnable {
230 failures: 0,
231 count: count.clone(),
232 };
233 let retrying = Retrying::new(runnable, 0);
234
235 let events: Vec<_> = retrying.stream(()).collect().await;
236 assert_eq!(events.len(), 1);
237 assert!(matches!(
238 events[0],
239 Err(WesichainError::MaxRetriesExceeded { max: 0 })
240 ));
241 }
242}