synwire_core/runnables/
retry.rs1use crate::error::SynwireErrorKind;
4use std::time::Duration;
5
6#[derive(Debug, Clone)]
8pub struct RetryConfig {
9 pub retry_on: Vec<SynwireErrorKind>,
11 pub max_attempts: u32,
13 pub wait_exponential_jitter: bool,
15 pub initial_interval: Duration,
17 pub max_interval: Duration,
19}
20
21impl Default for RetryConfig {
22 fn default() -> Self {
23 Self {
24 retry_on: Vec::new(),
25 max_attempts: 3,
26 wait_exponential_jitter: true,
27 initial_interval: Duration::from_secs(1),
28 max_interval: Duration::from_secs(60),
29 }
30 }
31}
32
33#[derive(Debug)]
35pub struct RetryState {
36 pub attempt: u32,
38 pub error: crate::error::SynwireError,
40 pub elapsed: Duration,
42}
43
44use crate::BoxFuture;
47use crate::error::SynwireError;
48use crate::runnables::config::RunnableConfig;
49use crate::runnables::core::RunnableCore;
50use serde_json::Value;
51
52pub struct RunnableRetry {
58 inner: Box<dyn RunnableCore>,
59 config: RetryConfig,
60}
61
62impl RunnableRetry {
63 pub fn new(inner: Box<dyn RunnableCore>, config: RetryConfig) -> Self {
65 Self { inner, config }
66 }
67
68 fn should_retry(&self, err: &SynwireError) -> bool {
70 if self.config.retry_on.is_empty() {
71 return true;
72 }
73 self.config.retry_on.contains(&err.kind())
74 }
75
76 fn backoff_duration(&self, attempt: u32) -> Duration {
78 let base = self
79 .config
80 .initial_interval
81 .saturating_mul(1u32.checked_shl(attempt).unwrap_or(u32::MAX));
82 base.min(self.config.max_interval)
83 }
84}
85
86impl RunnableCore for RunnableRetry {
87 fn invoke<'a>(
88 &'a self,
89 input: Value,
90 config: Option<&'a RunnableConfig>,
91 ) -> BoxFuture<'a, Result<Value, SynwireError>> {
92 Box::pin(async move {
93 let mut last_error: Option<SynwireError> = None;
94
95 for attempt in 0..self.config.max_attempts {
96 match self.inner.invoke(input.clone(), config).await {
97 Ok(v) => return Ok(v),
98 Err(e) => {
99 if !self.should_retry(&e) || attempt + 1 >= self.config.max_attempts {
100 return Err(e);
101 }
102 let delay = self.backoff_duration(attempt);
103 tokio::time::sleep(delay).await;
104 last_error = Some(e);
105 }
106 }
107 }
108
109 Err(last_error
111 .unwrap_or_else(|| SynwireError::Other("retry exhausted with no attempts".into())))
112 })
113 }
114
115 #[allow(clippy::unnecessary_literal_bound)]
116 fn name(&self) -> &str {
117 "RunnableRetry"
118 }
119}
120
121#[cfg(test)]
122#[allow(clippy::unwrap_used)]
123mod tests {
124 use super::*;
125 use crate::runnables::lambda::RunnableLambda;
126 use std::sync::Arc;
127 use std::sync::atomic::{AtomicU32, Ordering};
128
129 #[tokio::test]
130 async fn test_retry_on_error() {
131 let call_count = Arc::new(AtomicU32::new(0));
132 let count = Arc::clone(&call_count);
133
134 let flaky = RunnableLambda::new(move |v: Value| {
135 let count = Arc::clone(&count);
136 Box::pin(async move {
137 let n = count.fetch_add(1, Ordering::SeqCst);
138 if n < 2 {
139 Err(SynwireError::Other("transient".into()))
140 } else {
141 Ok(v)
142 }
143 })
144 });
145
146 let retry_config = RetryConfig {
147 max_attempts: 5,
148 initial_interval: Duration::from_millis(1),
149 max_interval: Duration::from_millis(10),
150 ..RetryConfig::default()
151 };
152
153 let retried = RunnableRetry::new(Box::new(flaky), retry_config);
154 let result = retried.invoke(Value::from("ok"), None).await.unwrap();
155 assert_eq!(result, Value::from("ok"));
156 assert_eq!(call_count.load(Ordering::SeqCst), 3);
157 }
158
159 #[tokio::test]
160 async fn test_retry_respects_max_attempts() {
161 let always_fail = RunnableLambda::new(|_: Value| {
162 Box::pin(async { Err(SynwireError::Other("always fails".into())) })
163 });
164
165 let retry_config = RetryConfig {
166 max_attempts: 2,
167 initial_interval: Duration::from_millis(1),
168 max_interval: Duration::from_millis(1),
169 ..RetryConfig::default()
170 };
171
172 let retried = RunnableRetry::new(Box::new(always_fail), retry_config);
173 let result = retried.invoke(Value::from("input"), None).await;
174 assert!(result.is_err());
175 }
176
177 #[tokio::test]
178 async fn test_retry_skips_non_matching_errors() {
179 let tool_err = RunnableLambda::new(|_: Value| {
180 Box::pin(async {
181 Err(SynwireError::Prompt {
182 message: "bad prompt".into(),
183 })
184 })
185 });
186
187 let retry_config = RetryConfig {
188 retry_on: vec![SynwireErrorKind::Model], max_attempts: 3,
190 initial_interval: Duration::from_millis(1),
191 max_interval: Duration::from_millis(1),
192 ..RetryConfig::default()
193 };
194
195 let retried = RunnableRetry::new(Box::new(tool_err), retry_config);
196 let result = retried.invoke(Value::from("input"), None).await;
197 assert!(result.is_err());
198 }
199}