Skip to main content

sift_rs/
retry.rs

1//! Generic retry extension for gRPC wrapper services.
2//!
3//! This module provides a retry mechanism that can be applied to any wrapper service
4//! without modifying the wrapper traits themselves. The retry logic intelligently
5//! extracts `tonic::Status` from `sift_error::Error` types to make retry decisions.
6//!
7//! ## Usage
8//!
9//! The retry mechanism uses a closure pattern to work around Rust's borrow checker
10//! and the `&mut self` requirement of tonic clients. Each retry attempt clones the
11//! wrapper and calls the closure, allowing the closure to use `&mut self` internally.
12//!
13//! ```no_run
14//! use sift_rs::retry::{RetryExt, RetryConfig};
15//! use sift_rs::wrappers::assets::new_asset_service;
16//! use sift_rs::wrappers::assets::AssetServiceWrapper;
17//! use std::time::Duration;
18//!
19//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
20//! # let channel = todo!();
21//! let wrapper = new_asset_service(channel);
22//! let cfg = RetryConfig {
23//!     max_attempts: 3,
24//!     base_delay: Duration::from_millis(100),
25//!     max_delay: Duration::from_secs(5),
26//!     backoff_multiplier: 2.0,
27//! };
28//!
29//! let svc = wrapper.retrying(cfg);
30//! let asset = svc.call(|mut w| async move {
31//!     w.try_get_asset_by_id("asset-123").await
32//! }).await?;
33//! # Ok(())
34//! # }
35//! ```
36//!
37//! ## Important Notes
38//!
39//! - **Idempotency**: Only use retries for idempotent operations. Non-idempotent
40//!   operations may be executed multiple times if retries occur.
41//! - **Streaming RPCs**: This retry mechanism does not support streaming RPCs.
42//!   Streaming calls require recreating the stream and may have side effects.
43
44use std::error::Error as StdError;
45use std::future::Future;
46use std::result::Result as StdResult;
47use std::time::Duration;
48
49use sift_error::prelude::*;
50use tonic;
51
52/// Configuration for retry behavior.
53#[derive(Debug, Clone)]
54pub struct RetryConfig {
55    /// Maximum number of retry attempts (including the initial attempt).
56    pub max_attempts: usize,
57    /// Base delay for exponential backoff.
58    pub base_delay: Duration,
59    /// Maximum delay cap for exponential backoff.
60    pub max_delay: Duration,
61    /// Multiplier for exponential backoff (e.g., 2.0 for doubling each attempt).
62    pub backoff_multiplier: f64,
63}
64
65impl Default for RetryConfig {
66    /// Creates a default retry configuration with conservative settings:
67    /// - 3 attempts total
68    /// - 100ms base delay
69    /// - 5s maximum delay
70    /// - 2.0 backoff multiplier (exponential)
71    fn default() -> Self {
72        Self {
73            max_attempts: 3,
74            base_delay: Duration::from_millis(100),
75            max_delay: Duration::from_secs(5),
76            backoff_multiplier: 2.0,
77        }
78    }
79}
80
81impl RetryConfig {
82    /// Calculates the backoff delay for a given attempt number.
83    ///
84    /// The delay is calculated as: `base_delay * (backoff_multiplier ^ (attempt - 1))`
85    /// and is capped at `max_delay`.
86    ///
87    /// # Arguments
88    ///
89    /// * `attempt` - The attempt number (1-indexed). For attempt 1, returns `base_delay`.
90    pub fn backoff(&self, attempt: usize) -> Duration {
91        if attempt <= 1 {
92            return self.base_delay;
93        }
94
95        let exponent = (attempt - 1) as f64;
96        let delay_ms = self.base_delay.as_millis() as f64 * self.backoff_multiplier.powf(exponent);
97        let delay = Duration::from_millis(delay_ms as u64);
98
99        delay.min(self.max_delay)
100    }
101}
102
103/// Trait for determining whether an error should trigger a retry.
104pub trait RetryDecider<E> {
105    /// Returns `true` if the error should trigger a retry attempt.
106    fn should_retry(&self, err: &E) -> bool;
107}
108
109/// Default retry decider for gRPC errors wrapped in `sift_error::Error`.
110///
111/// This decider uses a two-strategy approach:
112/// 1. First, attempts to extract `tonic::Status` from the error's source chain
113/// 2. Falls back to `ErrorKind`-based heuristics if no `tonic::Status` is found
114pub struct DefaultGrpcRetry;
115
116impl RetryDecider<sift_error::Error> for DefaultGrpcRetry {
117    fn should_retry(&self, err: &sift_error::Error) -> bool {
118        // Strategy 1: Try to extract tonic::Status from error source chain
119        let mut source = err.source();
120        while let Some(err_ref) = source {
121            if let Some(status) = err_ref.downcast_ref::<tonic::Status>() {
122                return matches!(
123                    status.code(),
124                    tonic::Code::Unavailable
125                        | tonic::Code::ResourceExhausted
126                        | tonic::Code::DeadlineExceeded
127                );
128            }
129            source = err_ref.source();
130        }
131
132        // Strategy 2: Fallback to ErrorKind-based heuristics
133        matches!(
134            err.kind(),
135            ErrorKind::GrpcConnectError
136                | ErrorKind::RetrieveAssetError
137                | ErrorKind::RetrieveIngestionConfigError
138                | ErrorKind::RetrieveRunError
139        )
140    }
141}
142
143/// Adapter that wraps a type and provides retry functionality.
144#[derive(Clone, Debug)]
145pub struct Retrying<T, D = DefaultGrpcRetry> {
146    inner: T,
147    cfg: RetryConfig,
148    decider: D,
149}
150
151impl<T> Retrying<T> {
152    /// Creates a new `Retrying` adapter with the default gRPC retry decider.
153    pub fn new(inner: T, cfg: RetryConfig) -> Self {
154        Self {
155            inner,
156            cfg,
157            decider: DefaultGrpcRetry,
158        }
159    }
160}
161
162impl<T, D> Retrying<T, D> {
163    /// Replaces the retry decider with a custom one.
164    pub fn with_decider<D2>(self, decider: D2) -> Retrying<T, D2> {
165        Retrying {
166            inner: self.inner,
167            cfg: self.cfg,
168            decider,
169        }
170    }
171
172    /// Returns a reference to the inner wrapped value.
173    pub fn inner(&self) -> &T {
174        &self.inner
175    }
176}
177
178impl<T, D> Retrying<T, D>
179where
180    T: Clone,
181{
182    /// Executes a closure with retry logic.
183    ///
184    /// The closure is called up to `max_attempts` times. If it returns an error
185    /// and the decider indicates the error is retryable, the function waits for
186    /// the calculated backoff delay before retrying.
187    ///
188    /// # Arguments
189    ///
190    /// * `f` - A closure that takes a cloned wrapper and returns a future
191    ///   that produces a `Result`. The closure can use `&mut self` internally
192    ///   since each attempt gets a fresh clone.
193    ///
194    /// # Returns
195    ///
196    /// Returns `Ok(result)` if any attempt succeeds, or `Err(error)` if all
197    /// attempts fail or the error is not retryable.
198    pub async fn call<F, Fut, R, E>(&self, mut f: F) -> StdResult<R, E>
199    where
200        F: FnMut(T) -> Fut,
201        Fut: Future<Output = StdResult<R, E>>,
202        D: RetryDecider<E>,
203    {
204        let mut last_err = None;
205
206        for attempt in 1..=self.cfg.max_attempts {
207            let wrapper = self.inner.clone();
208            match f(wrapper).await {
209                Ok(result) => return Ok(result),
210                Err(e) => {
211                    last_err = Some(e);
212                    if attempt < self.cfg.max_attempts
213                        && self.decider.should_retry(last_err.as_ref().unwrap())
214                    {
215                        let delay = self.cfg.backoff(attempt);
216                        tokio::time::sleep(delay).await;
217                        continue;
218                    }
219                    break;
220                }
221            }
222        }
223
224        Err(last_err.expect("retry loop invariant violated"))
225    }
226}
227
228/// Extension trait that provides the `.retrying()` method for any type.
229pub trait RetryExt: Sized {
230    /// Wraps `self` in a `Retrying` adapter with the given configuration.
231    fn retrying(self, cfg: RetryConfig) -> Retrying<Self> {
232        Retrying::new(self, cfg)
233    }
234}
235
236impl<T> RetryExt for T {}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use std::sync::Arc;
242    use std::sync::atomic::{AtomicUsize, Ordering};
243
244    #[test]
245    fn test_backoff_calculation() {
246        let cfg = RetryConfig::default();
247
248        // First attempt should return base delay
249        assert_eq!(cfg.backoff(1), Duration::from_millis(100));
250
251        // Second attempt: 100ms * 2^1 = 200ms
252        assert_eq!(cfg.backoff(2), Duration::from_millis(200));
253
254        // Third attempt: 100ms * 2^2 = 400ms
255        assert_eq!(cfg.backoff(3), Duration::from_millis(400));
256
257        // Fourth attempt: 100ms * 2^3 = 800ms
258        assert_eq!(cfg.backoff(4), Duration::from_millis(800));
259    }
260
261    #[test]
262    fn test_backoff_caps_at_max() {
263        let cfg = RetryConfig {
264            max_attempts: 10,
265            base_delay: Duration::from_millis(100),
266            max_delay: Duration::from_millis(500),
267            backoff_multiplier: 2.0,
268        };
269
270        // Should cap at max_delay
271        let delay = cfg.backoff(10);
272        assert_eq!(delay, Duration::from_millis(500));
273    }
274
275    #[tokio::test]
276    async fn test_retry_loop_succeeds_after_failures() {
277        let counter = Arc::new(AtomicUsize::new(0));
278
279        let cfg = RetryConfig {
280            max_attempts: 3,
281            base_delay: Duration::from_millis(10),
282            max_delay: Duration::from_secs(5),
283            backoff_multiplier: 2.0,
284        };
285
286        let retrying = Retrying::new((), cfg);
287
288        let result = retrying
289            .call(|_| {
290                let counter = counter.clone();
291                async move {
292                    let attempts = counter.fetch_add(1, Ordering::SeqCst) + 1;
293                    if attempts < 3 {
294                        Err::<(), sift_error::Error>(Error::new_msg(
295                            ErrorKind::RetrieveAssetError,
296                            "temporary failure",
297                        ))
298                    } else {
299                        Ok(())
300                    }
301                }
302            })
303            .await;
304
305        assert!(result.is_ok());
306        assert_eq!(counter.load(Ordering::SeqCst), 3);
307    }
308
309    #[tokio::test]
310    async fn test_retry_loop_exhausts_attempts() {
311        let counter = Arc::new(AtomicUsize::new(0));
312
313        let cfg = RetryConfig {
314            max_attempts: 3,
315            base_delay: Duration::from_millis(10),
316            max_delay: Duration::from_secs(5),
317            backoff_multiplier: 2.0,
318        };
319
320        let retrying = Retrying::new((), cfg);
321
322        let result = retrying
323            .call(|_| {
324                let counter = counter.clone();
325                async move {
326                    counter.fetch_add(1, Ordering::SeqCst);
327                    Err::<(), sift_error::Error>(Error::new_msg(
328                        ErrorKind::RetrieveAssetError,
329                        "persistent failure",
330                    ))
331                }
332            })
333            .await;
334
335        assert!(result.is_err());
336        assert_eq!(counter.load(Ordering::SeqCst), 3);
337    }
338
339    #[test]
340    fn test_default_grpc_retry_with_tonic_status() {
341        let decider = DefaultGrpcRetry;
342
343        // Test retryable status codes
344        let unavailable = Error::new(
345            ErrorKind::RetrieveAssetError,
346            tonic::Status::unavailable("service unavailable"),
347        );
348        assert!(decider.should_retry(&unavailable));
349
350        let resource_exhausted = Error::new(
351            ErrorKind::RetrieveAssetError,
352            tonic::Status::resource_exhausted("resource exhausted"),
353        );
354        assert!(decider.should_retry(&resource_exhausted));
355
356        let deadline_exceeded = Error::new(
357            ErrorKind::RetrieveAssetError,
358            tonic::Status::deadline_exceeded("deadline exceeded"),
359        );
360        assert!(decider.should_retry(&deadline_exceeded));
361
362        // Test non-retryable status codes
363        let invalid_argument = Error::new(
364            ErrorKind::ArgumentValidationError,
365            tonic::Status::invalid_argument("invalid argument"),
366        );
367        assert!(!decider.should_retry(&invalid_argument));
368
369        let not_found = Error::new(
370            ErrorKind::NotFoundError,
371            tonic::Status::not_found("not found"),
372        );
373        assert!(!decider.should_retry(&not_found));
374    }
375
376    #[test]
377    fn test_default_grpc_retry_with_error_kind_fallback() {
378        let decider = DefaultGrpcRetry;
379
380        // Test retryable error kinds (without tonic::Status)
381        let grpc_connect_error = Error::new_msg(ErrorKind::GrpcConnectError, "connection failed");
382        assert!(decider.should_retry(&grpc_connect_error));
383
384        let retrieve_asset_error =
385            Error::new_msg(ErrorKind::RetrieveAssetError, "retrieval failed");
386        assert!(decider.should_retry(&retrieve_asset_error));
387
388        let retrieve_ingestion_config_error =
389            Error::new_msg(ErrorKind::RetrieveIngestionConfigError, "retrieval failed");
390        assert!(decider.should_retry(&retrieve_ingestion_config_error));
391
392        let retrieve_run_error = Error::new_msg(ErrorKind::RetrieveRunError, "retrieval failed");
393        assert!(decider.should_retry(&retrieve_run_error));
394
395        // Test non-retryable error kinds
396        let argument_error = Error::new_msg(ErrorKind::ArgumentValidationError, "bad argument");
397        assert!(!decider.should_retry(&argument_error));
398
399        let not_found_error = Error::new_msg(ErrorKind::NotFoundError, "not found");
400        assert!(!decider.should_retry(&not_found_error));
401    }
402
403    #[tokio::test]
404    async fn test_no_retry_on_non_retryable_error() {
405        let counter = Arc::new(AtomicUsize::new(0));
406
407        let cfg = RetryConfig {
408            max_attempts: 3,
409            base_delay: Duration::from_millis(10),
410            max_delay: Duration::from_secs(5),
411            backoff_multiplier: 2.0,
412        };
413
414        let retrying = Retrying::new((), cfg);
415
416        let result = retrying
417            .call(|_| {
418                let counter = counter.clone();
419                async move {
420                    counter.fetch_add(1, Ordering::SeqCst);
421                    // InvalidArgument is not retryable
422                    Err::<(), sift_error::Error>(Error::new(
423                        ErrorKind::ArgumentValidationError,
424                        tonic::Status::invalid_argument("invalid argument"),
425                    ))
426                }
427            })
428            .await;
429
430        assert!(result.is_err());
431        // Should only attempt once since error is not retryable
432        assert_eq!(counter.load(Ordering::SeqCst), 1);
433    }
434}