Skip to main content

tower_retry_plus/
lib.rs

1//! Enhanced retry middleware for Tower services.
2//!
3//! This crate provides advanced retry functionality beyond Tower's built-in retry,
4//! with flexible backoff strategies, retry predicates, and comprehensive event system.
5//!
6//! # Features
7//!
8//! - **IntervalFunction abstraction**: Pluggable backoff strategies
9//!   - Fixed interval
10//!   - Exponential backoff with configurable multiplier
11//!   - Exponential random backoff with randomization factor
12//!   - Custom function-based backoff
13//! - **Retry predicates**: Control which errors should be retried
14//! - **Event system**: Observability through retry events
15//! - **Flexible configuration**: Builder API with sensible defaults
16//!
17//! # Examples
18//!
19//! ```
20//! use tower_retry_plus::RetryConfig;
21//! use tower::ServiceBuilder;
22//! use std::time::Duration;
23//!
24//! # #[derive(Debug, Clone)]
25//! # struct MyError;
26//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
27//! // Create retry configuration with exponential backoff
28//! let retry_config: RetryConfig<MyError> = RetryConfig::builder()
29//!     .max_attempts(5)
30//!     .exponential_backoff(Duration::from_millis(100))
31//!     .on_retry(|attempt, delay| {
32//!         println!("Retry attempt {} after {:?}", attempt, delay);
33//!     })
34//!     .build();
35//!
36//! // Apply to a service
37//! let service = ServiceBuilder::new()
38//!     .layer(retry_config.layer())
39//!     .service(tower::service_fn(|req: String| async move {
40//!         Ok::<_, MyError>(format!("Response: {}", req))
41//!     }));
42//! # Ok(())
43//! # }
44//! ```
45
46mod backoff;
47mod config;
48mod events;
49mod layer;
50mod policy;
51
52pub use backoff::{
53    ExponentialBackoff, ExponentialRandomBackoff, FixedInterval, FnInterval, IntervalFunction,
54};
55pub use config::{RetryConfig, RetryConfigBuilder};
56pub use events::RetryEvent;
57pub use layer::RetryLayer;
58pub use policy::{RetryPolicy, RetryPredicate};
59
60use futures::future::BoxFuture;
61use std::sync::Arc;
62use std::task::{Context, Poll};
63use std::time::Instant;
64use tower::Service;
65
66/// A Tower [`Service`] that retries failed requests.
67///
68/// This service wraps an inner service and automatically retries requests
69/// that fail, according to the configured retry policy and backoff strategy.
70pub struct Retry<S, E> {
71    inner: S,
72    config: Arc<RetryConfig<E>>,
73}
74
75impl<S, E> Retry<S, E> {
76    /// Creates a new `Retry` service wrapping the given service.
77    pub fn new(inner: S, config: Arc<RetryConfig<E>>) -> Self {
78        Self { inner, config }
79    }
80}
81
82impl<S, E> Clone for Retry<S, E>
83where
84    S: Clone,
85{
86    fn clone(&self) -> Self {
87        Self {
88            inner: self.inner.clone(),
89            config: Arc::clone(&self.config),
90        }
91    }
92}
93
94impl<S, Req, E> Service<Req> for Retry<S, E>
95where
96    S: Service<Req, Error = E> + Clone + Send + 'static,
97    S::Future: Send + 'static,
98    Req: Clone + Send + 'static,
99    E: Clone + Send + 'static,
100    S::Response: Send + 'static,
101{
102    type Response = S::Response;
103    type Error = E;
104    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
105
106    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107        self.inner.poll_ready(cx)
108    }
109
110    fn call(&mut self, req: Req) -> Self::Future {
111        let mut service = self.inner.clone();
112        let config = Arc::clone(&self.config);
113
114        Box::pin(async move {
115            let mut attempt = 0;
116
117            loop {
118                let result = service.call(req.clone()).await;
119
120                match result {
121                    Ok(response) => {
122                        // Success
123                        let event = RetryEvent::Success {
124                            pattern_name: config.name.clone(),
125                            timestamp: Instant::now(),
126                            attempts: attempt + 1,
127                        };
128                        config.event_listeners.emit(&event);
129                        return Ok(response);
130                    }
131                    Err(error) => {
132                        // Check if we should retry this error
133                        if !config.policy.should_retry(&error) {
134                            let event = RetryEvent::IgnoredError {
135                                pattern_name: config.name.clone(),
136                                timestamp: Instant::now(),
137                            };
138                            config.event_listeners.emit(&event);
139                            return Err(error);
140                        }
141
142                        // Check if we've exhausted retries
143                        if attempt + 1 >= config.policy.max_attempts {
144                            let event = RetryEvent::Error {
145                                pattern_name: config.name.clone(),
146                                timestamp: Instant::now(),
147                                attempts: attempt + 1,
148                            };
149                            config.event_listeners.emit(&event);
150                            return Err(error);
151                        }
152
153                        // Calculate backoff and retry
154                        let delay = config.policy.next_backoff(attempt);
155                        let event = RetryEvent::Retry {
156                            pattern_name: config.name.clone(),
157                            timestamp: Instant::now(),
158                            attempt,
159                            delay,
160                        };
161                        config.event_listeners.emit(&event);
162
163                        tokio::time::sleep(delay).await;
164                        attempt += 1;
165                    }
166                }
167            }
168        })
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use std::sync::atomic::{AtomicUsize, Ordering};
176    use std::time::Duration;
177    use tower::service_fn;
178    use tower::{Layer, ServiceExt};
179
180    #[derive(Debug, Clone)]
181    struct TestError {
182        #[allow(dead_code)]
183        message: String,
184    }
185
186    impl TestError {
187        fn new(message: &str) -> Self {
188            Self {
189                message: message.to_string(),
190            }
191        }
192    }
193
194    #[tokio::test]
195    async fn successful_request_no_retry() {
196        let call_count = Arc::new(AtomicUsize::new(0));
197        let cc = Arc::clone(&call_count);
198
199        let service = service_fn(move |req: String| {
200            let cc = Arc::clone(&cc);
201            async move {
202                cc.fetch_add(1, Ordering::SeqCst);
203                Ok::<_, TestError>(format!("Response: {}", req))
204            }
205        });
206
207        let config: RetryConfig<TestError> = RetryConfig::builder()
208            .max_attempts(3)
209            .fixed_backoff(Duration::from_millis(10))
210            .build();
211
212        let layer = config.layer();
213        let mut service = layer.layer(service);
214
215        let response = service
216            .ready()
217            .await
218            .unwrap()
219            .call("test".to_string())
220            .await
221            .unwrap();
222
223        assert_eq!(response, "Response: test");
224        assert_eq!(call_count.load(Ordering::SeqCst), 1);
225    }
226
227    #[tokio::test]
228    async fn retries_on_failure() {
229        let call_count = Arc::new(AtomicUsize::new(0));
230        let cc = Arc::clone(&call_count);
231
232        let service = service_fn(move |_req: String| {
233            let cc = Arc::clone(&cc);
234            async move {
235                let count = cc.fetch_add(1, Ordering::SeqCst);
236                if count < 2 {
237                    Err(TestError::new("temporary failure"))
238                } else {
239                    Ok::<_, TestError>("success".to_string())
240                }
241            }
242        });
243
244        let config: RetryConfig<TestError> = RetryConfig::builder()
245            .max_attempts(3)
246            .fixed_backoff(Duration::from_millis(10))
247            .build();
248
249        let layer = config.layer();
250        let mut service = layer.layer(service);
251
252        let response = service
253            .ready()
254            .await
255            .unwrap()
256            .call("test".to_string())
257            .await
258            .unwrap();
259
260        assert_eq!(response, "success");
261        assert_eq!(call_count.load(Ordering::SeqCst), 3);
262    }
263
264    #[tokio::test]
265    async fn exhausts_retries() {
266        let call_count = Arc::new(AtomicUsize::new(0));
267        let cc = Arc::clone(&call_count);
268
269        let service = service_fn(move |_req: String| {
270            let cc = Arc::clone(&cc);
271            async move {
272                cc.fetch_add(1, Ordering::SeqCst);
273                Err::<String, _>(TestError::new("permanent failure"))
274            }
275        });
276
277        let config: RetryConfig<TestError> = RetryConfig::builder()
278            .max_attempts(3)
279            .fixed_backoff(Duration::from_millis(10))
280            .build();
281
282        let layer = config.layer();
283        let mut service = layer.layer(service);
284
285        let result = service
286            .ready()
287            .await
288            .unwrap()
289            .call("test".to_string())
290            .await;
291
292        assert!(result.is_err());
293        assert_eq!(call_count.load(Ordering::SeqCst), 3);
294    }
295
296    #[tokio::test]
297    async fn retry_predicate_filters_errors() {
298        let call_count = Arc::new(AtomicUsize::new(0));
299        let cc = Arc::clone(&call_count);
300
301        let service = service_fn(move |_req: String| {
302            let cc = Arc::clone(&cc);
303            async move {
304                cc.fetch_add(1, Ordering::SeqCst);
305                Err::<String, _>(TestError::new("non-retryable"))
306            }
307        });
308
309        let config: RetryConfig<TestError> = RetryConfig::builder()
310            .max_attempts(3)
311            .fixed_backoff(Duration::from_millis(10))
312            .retry_on(|_: &TestError| false) // Never retry
313            .build();
314
315        let layer = config.layer();
316        let mut service = layer.layer(service);
317
318        let result = service
319            .ready()
320            .await
321            .unwrap()
322            .call("test".to_string())
323            .await;
324
325        assert!(result.is_err());
326        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Only called once
327    }
328
329    #[tokio::test]
330    async fn event_listeners_called() {
331        let retry_count = Arc::new(AtomicUsize::new(0));
332        let success_count = Arc::new(AtomicUsize::new(0));
333
334        let rc = Arc::clone(&retry_count);
335        let sc = Arc::clone(&success_count);
336
337        let call_count = Arc::new(AtomicUsize::new(0));
338        let cc = Arc::clone(&call_count);
339
340        let service = service_fn(move |_req: String| {
341            let cc = Arc::clone(&cc);
342            async move {
343                let count = cc.fetch_add(1, Ordering::SeqCst);
344                if count < 2 {
345                    Err(TestError::new("temporary"))
346                } else {
347                    Ok::<_, TestError>("success".to_string())
348                }
349            }
350        });
351
352        let config: RetryConfig<TestError> = RetryConfig::builder()
353            .max_attempts(3)
354            .fixed_backoff(Duration::from_millis(10))
355            .on_retry(move |_, _| {
356                rc.fetch_add(1, Ordering::SeqCst);
357            })
358            .on_success(move |_| {
359                sc.fetch_add(1, Ordering::SeqCst);
360            })
361            .build();
362
363        let layer = config.layer();
364        let mut service = layer.layer(service);
365
366        let _ = service
367            .ready()
368            .await
369            .unwrap()
370            .call("test".to_string())
371            .await;
372
373        assert_eq!(retry_count.load(Ordering::SeqCst), 2); // 2 retries
374        assert_eq!(success_count.load(Ordering::SeqCst), 1); // 1 success
375    }
376
377    #[tokio::test]
378    async fn exponential_backoff_increases_delay() {
379        let config: RetryConfig<TestError> = RetryConfig::builder()
380            .max_attempts(5)
381            .backoff(ExponentialBackoff::new(Duration::from_millis(100)))
382            .build();
383
384        assert_eq!(config.policy.next_backoff(0), Duration::from_millis(100));
385        assert_eq!(config.policy.next_backoff(1), Duration::from_millis(200));
386        assert_eq!(config.policy.next_backoff(2), Duration::from_millis(400));
387    }
388
389    #[tokio::test]
390    async fn custom_interval_function() {
391        let config: RetryConfig<TestError> = RetryConfig::builder()
392            .max_attempts(3)
393            .backoff(FnInterval::new(|attempt| {
394                Duration::from_secs((attempt + 1) as u64)
395            }))
396            .build();
397
398        assert_eq!(config.policy.next_backoff(0), Duration::from_secs(1));
399        assert_eq!(config.policy.next_backoff(1), Duration::from_secs(2));
400        assert_eq!(config.policy.next_backoff(2), Duration::from_secs(3));
401    }
402}