Skip to main content

synwire_core/runnables/
retry.rs

1//! Retry configuration types and retry-wrapped runnables.
2
3use crate::error::SynwireErrorKind;
4use std::time::Duration;
5
6/// Configuration for retry behaviour.
7#[derive(Debug, Clone)]
8pub struct RetryConfig {
9    /// Error kinds to retry on.
10    pub retry_on: Vec<SynwireErrorKind>,
11    /// Maximum number of attempts.
12    pub max_attempts: u32,
13    /// Whether to use exponential backoff with jitter.
14    pub wait_exponential_jitter: bool,
15    /// Initial interval between retries.
16    pub initial_interval: Duration,
17    /// Maximum interval between retries.
18    pub max_interval: Duration,
19}
20
21impl Default for RetryConfig {
22    fn default() -> Self {
23        Self {
24            retry_on: Vec::new(),
25            max_attempts: 3,
26            wait_exponential_jitter: true,
27            initial_interval: Duration::from_secs(1),
28            max_interval: Duration::from_secs(60),
29        }
30    }
31}
32
33/// State tracked during retry attempts.
34#[derive(Debug)]
35pub struct RetryState {
36    /// Current attempt number.
37    pub attempt: u32,
38    /// The error that triggered the retry.
39    pub error: crate::error::SynwireError,
40    /// Total elapsed time since first attempt.
41    pub elapsed: Duration,
42}
43
44// --- RunnableRetry implementation ---
45
46use crate::BoxFuture;
47use crate::error::SynwireError;
48use crate::runnables::config::RunnableConfig;
49use crate::runnables::core::RunnableCore;
50use serde_json::Value;
51
52/// A runnable that retries an inner runnable on matching errors.
53///
54/// Uses exponential backoff with optional jitter. Only errors whose
55/// [`SynwireErrorKind`] appears in the `retry_on` list are retried;
56/// all other errors propagate immediately.
57pub struct RunnableRetry {
58    inner: Box<dyn RunnableCore>,
59    config: RetryConfig,
60}
61
62impl RunnableRetry {
63    /// Wrap a runnable with retry behaviour.
64    pub fn new(inner: Box<dyn RunnableCore>, config: RetryConfig) -> Self {
65        Self { inner, config }
66    }
67
68    /// Determine whether an error should be retried.
69    fn should_retry(&self, err: &SynwireError) -> bool {
70        if self.config.retry_on.is_empty() {
71            return true;
72        }
73        self.config.retry_on.contains(&err.kind())
74    }
75
76    /// Compute backoff duration for a given attempt (0-indexed).
77    fn backoff_duration(&self, attempt: u32) -> Duration {
78        let base = self
79            .config
80            .initial_interval
81            .saturating_mul(1u32.checked_shl(attempt).unwrap_or(u32::MAX));
82        base.min(self.config.max_interval)
83    }
84}
85
86impl RunnableCore for RunnableRetry {
87    fn invoke<'a>(
88        &'a self,
89        input: Value,
90        config: Option<&'a RunnableConfig>,
91    ) -> BoxFuture<'a, Result<Value, SynwireError>> {
92        Box::pin(async move {
93            let mut last_error: Option<SynwireError> = None;
94
95            for attempt in 0..self.config.max_attempts {
96                match self.inner.invoke(input.clone(), config).await {
97                    Ok(v) => return Ok(v),
98                    Err(e) => {
99                        if !self.should_retry(&e) || attempt + 1 >= self.config.max_attempts {
100                            return Err(e);
101                        }
102                        let delay = self.backoff_duration(attempt);
103                        tokio::time::sleep(delay).await;
104                        last_error = Some(e);
105                    }
106                }
107            }
108
109            // This branch is only reachable if max_attempts == 0.
110            Err(last_error
111                .unwrap_or_else(|| SynwireError::Other("retry exhausted with no attempts".into())))
112        })
113    }
114
115    #[allow(clippy::unnecessary_literal_bound)]
116    fn name(&self) -> &str {
117        "RunnableRetry"
118    }
119}
120
121#[cfg(test)]
122#[allow(clippy::unwrap_used)]
123mod tests {
124    use super::*;
125    use crate::runnables::lambda::RunnableLambda;
126    use std::sync::Arc;
127    use std::sync::atomic::{AtomicU32, Ordering};
128
129    #[tokio::test]
130    async fn test_retry_on_error() {
131        let call_count = Arc::new(AtomicU32::new(0));
132        let count = Arc::clone(&call_count);
133
134        let flaky = RunnableLambda::new(move |v: Value| {
135            let count = Arc::clone(&count);
136            Box::pin(async move {
137                let n = count.fetch_add(1, Ordering::SeqCst);
138                if n < 2 {
139                    Err(SynwireError::Other("transient".into()))
140                } else {
141                    Ok(v)
142                }
143            })
144        });
145
146        let retry_config = RetryConfig {
147            max_attempts: 5,
148            initial_interval: Duration::from_millis(1),
149            max_interval: Duration::from_millis(10),
150            ..RetryConfig::default()
151        };
152
153        let retried = RunnableRetry::new(Box::new(flaky), retry_config);
154        let result = retried.invoke(Value::from("ok"), None).await.unwrap();
155        assert_eq!(result, Value::from("ok"));
156        assert_eq!(call_count.load(Ordering::SeqCst), 3);
157    }
158
159    #[tokio::test]
160    async fn test_retry_respects_max_attempts() {
161        let always_fail = RunnableLambda::new(|_: Value| {
162            Box::pin(async { Err(SynwireError::Other("always fails".into())) })
163        });
164
165        let retry_config = RetryConfig {
166            max_attempts: 2,
167            initial_interval: Duration::from_millis(1),
168            max_interval: Duration::from_millis(1),
169            ..RetryConfig::default()
170        };
171
172        let retried = RunnableRetry::new(Box::new(always_fail), retry_config);
173        let result = retried.invoke(Value::from("input"), None).await;
174        assert!(result.is_err());
175    }
176
177    #[tokio::test]
178    async fn test_retry_skips_non_matching_errors() {
179        let tool_err = RunnableLambda::new(|_: Value| {
180            Box::pin(async {
181                Err(SynwireError::Prompt {
182                    message: "bad prompt".into(),
183                })
184            })
185        });
186
187        let retry_config = RetryConfig {
188            retry_on: vec![SynwireErrorKind::Model], // only retry model errors
189            max_attempts: 3,
190            initial_interval: Duration::from_millis(1),
191            max_interval: Duration::from_millis(1),
192            ..RetryConfig::default()
193        };
194
195        let retried = RunnableRetry::new(Box::new(tool_err), retry_config);
196        let result = retried.invoke(Value::from("input"), None).await;
197        assert!(result.is_err());
198    }
199}