Skip to main content

synaptic_runnables/
retry.rs

1use std::time::Duration;
2
3use async_trait::async_trait;
4use synaptic_core::{RunnableConfig, SynapseError};
5
6use crate::runnable::{BoxRunnable, Runnable};
7
8/// Retry policy configuration for `RunnableRetry`.
9///
10/// Controls how many times to retry, the backoff schedule, and which errors
11/// are eligible for retrying.
12pub struct RetryPolicy {
13    /// Maximum number of attempts (including the initial attempt).
14    pub max_attempts: usize,
15    /// Base delay for exponential backoff. The actual delay for attempt `n` is
16    /// `min(base_delay * 2^n, max_delay)`.
17    pub base_delay: Duration,
18    /// Upper bound on the backoff delay.
19    pub max_delay: Duration,
20    /// Optional predicate to decide if an error is retryable.
21    /// When `None`, all errors are retried.
22    #[allow(clippy::type_complexity)]
23    retry_on: Option<Box<dyn Fn(&SynapseError) -> bool + Send + Sync>>,
24}
25
26impl std::fmt::Debug for RetryPolicy {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("RetryPolicy")
29            .field("max_attempts", &self.max_attempts)
30            .field("base_delay", &self.base_delay)
31            .field("max_delay", &self.max_delay)
32            .field("retry_on", &self.retry_on.as_ref().map(|_| "..."))
33            .finish()
34    }
35}
36
37impl Default for RetryPolicy {
38    fn default() -> Self {
39        Self {
40            max_attempts: 3,
41            base_delay: Duration::from_millis(100),
42            max_delay: Duration::from_secs(10),
43            retry_on: None,
44        }
45    }
46}
47
48impl RetryPolicy {
49    /// Set the maximum number of attempts (including the initial attempt).
50    pub fn with_max_attempts(mut self, max_attempts: usize) -> Self {
51        self.max_attempts = max_attempts;
52        self
53    }
54
55    /// Set the base delay for exponential backoff.
56    pub fn with_base_delay(mut self, base_delay: Duration) -> Self {
57        self.base_delay = base_delay;
58        self
59    }
60
61    /// Set the upper bound on the backoff delay.
62    pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
63        self.max_delay = max_delay;
64        self
65    }
66
67    /// Set a predicate to decide which errors are retryable.
68    /// When not set, all errors are retried.
69    pub fn with_retry_on(
70        mut self,
71        predicate: impl Fn(&SynapseError) -> bool + Send + Sync + 'static,
72    ) -> Self {
73        self.retry_on = Some(Box::new(predicate));
74        self
75    }
76
77    /// Compute the backoff delay for the given attempt (0-indexed).
78    fn delay_for_attempt(&self, attempt: usize) -> Duration {
79        let delay = self.base_delay.saturating_mul(1 << attempt);
80        std::cmp::min(delay, self.max_delay)
81    }
82
83    /// Check whether the given error should be retried.
84    fn should_retry(&self, error: &SynapseError) -> bool {
85        match &self.retry_on {
86            Some(predicate) => predicate(error),
87            None => true,
88        }
89    }
90}
91
92/// Wraps a runnable with configurable retry logic and exponential backoff.
93///
94/// The input type must be `Clone` because the input is re-used for each retry attempt.
95///
96/// ```ignore
97/// let policy = RetryPolicy::default()
98///     .with_max_attempts(5)
99///     .with_base_delay(Duration::from_millis(200));
100/// let retrying = RunnableRetry::new(flaky_step.boxed(), policy);
101/// let result = retrying.invoke(input, &config).await?;
102/// ```
103pub struct RunnableRetry<I: Send + Clone + 'static, O: Send + 'static> {
104    inner: BoxRunnable<I, O>,
105    policy: RetryPolicy,
106}
107
108impl<I: Send + Clone + 'static, O: Send + 'static> RunnableRetry<I, O> {
109    pub fn new(inner: BoxRunnable<I, O>, policy: RetryPolicy) -> Self {
110        Self { inner, policy }
111    }
112}
113
114#[async_trait]
115impl<I: Send + Clone + 'static, O: Send + 'static> Runnable<I, O> for RunnableRetry<I, O> {
116    async fn invoke(&self, input: I, config: &RunnableConfig) -> Result<O, SynapseError> {
117        let mut last_error: Option<SynapseError> = None;
118
119        for attempt in 0..self.policy.max_attempts {
120            let input_clone = input.clone();
121            match self.inner.invoke(input_clone, config).await {
122                Ok(output) => return Ok(output),
123                Err(e) => {
124                    let is_last_attempt = attempt + 1 >= self.policy.max_attempts;
125                    if is_last_attempt || !self.policy.should_retry(&e) {
126                        return Err(e);
127                    }
128
129                    let delay = self.policy.delay_for_attempt(attempt);
130                    tokio::time::sleep(delay).await;
131                    last_error = Some(e);
132                }
133            }
134        }
135
136        // This is only reached when max_attempts is 0.
137        Err(last_error.unwrap_or_else(|| {
138            SynapseError::Config("RunnableRetry: max_attempts must be >= 1".into())
139        }))
140    }
141}