Skip to main content

tower_resilience_retry/
lib.rs

1//! Enhanced retry middleware for Tower services.
2//!
3//! This crate provides advanced retry functionality beyond Tower's built-in retry,
4//! with flexible backoff strategies, retry predicates, and comprehensive event system.
5//!
6//! # Features
7//!
8//! - **IntervalFunction abstraction**: Pluggable backoff strategies
9//!   - Fixed interval
10//!   - Exponential backoff with configurable multiplier
11//!   - Exponential random backoff with randomization factor
12//!   - Custom function-based backoff
13//! - **Per-request configuration**: Extract max attempts from the request
14//! - **Retry predicates**: Control which errors should be retried
15//! - **Event system**: Observability through retry events
16//! - **Flexible configuration**: Builder API with sensible defaults
17//!
18//! # Examples
19//!
20//! ## Basic Retry with Exponential Backoff
21//!
22//! ```
23//! use tower_resilience_retry::RetryLayer;
24//! use tower::ServiceBuilder;
25//! use std::time::Duration;
26//!
27//! # #[derive(Debug, Clone)]
28//! # struct MyError;
29//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
30//! // Create retry layer with exponential backoff
31//! let retry_layer = RetryLayer::<String, String, MyError>::builder()
32//!     .max_attempts(5)
33//!     .exponential_backoff(Duration::from_millis(100))
34//!     .on_retry(|attempt, delay| {
35//!         println!("Retry attempt {} after {:?}", attempt, delay);
36//!     })
37//!     .build();
38//!
39//! // Apply to a service
40//! let service = ServiceBuilder::new()
41//!     .layer(retry_layer)
42//!     .service(tower::service_fn(|req: String| async move {
43//!         Ok::<_, MyError>(format!("Response: {}", req))
44//!     }));
45//! # Ok(())
46//! # }
47//! ```
48//!
49//! ## Per-Request Max Attempts
50//!
51//! Extract retry configuration from the request itself:
52//!
53//! ```
54//! use tower_resilience_retry::RetryLayer;
55//! use tower::ServiceBuilder;
56//! use std::time::Duration;
57//!
58//! #[derive(Clone)]
59//! struct MyRequest {
60//!     is_idempotent: bool,
61//!     data: String,
62//! }
63//!
64//! # #[derive(Debug, Clone)]
65//! # struct MyError;
66//! # async fn example() {
67//! // Idempotent requests can retry more aggressively
68//! let retry_layer = RetryLayer::<MyRequest, (), MyError>::builder()
69//!     .max_attempts_fn(|req: &MyRequest| {
70//!         if req.is_idempotent { 5 } else { 1 }
71//!     })
72//!     .exponential_backoff(Duration::from_millis(100))
73//!     .build();
74//! # }
75//! ```
76//!
77//! ## Fallback After Retry Exhaustion
78//!
79//! When retries are exhausted, you can provide a fallback response using standard error handling:
80//!
81//! ```
82//! use tower_resilience_retry::RetryLayer;
83//! use tower::{Service, ServiceBuilder, ServiceExt};
84//! use std::time::Duration;
85//!
86//! # #[derive(Debug, Clone)]
87//! # struct MyError;
88//! # impl std::fmt::Display for MyError {
89//! #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90//! #         write!(f, "MyError")
91//! #     }
92//! # }
93//! # impl std::error::Error for MyError {}
94//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
95//! let retry_layer = RetryLayer::<String, String, MyError>::builder()
96//!     .max_attempts(3)
97//!     .exponential_backoff(Duration::from_millis(100))
98//!     .build();
99//!
100//! let mut service = ServiceBuilder::new()
101//!     .layer(retry_layer)
102//!     .service(tower::service_fn(|req: String| async move {
103//!         Err::<String, MyError>(MyError) // Always fails
104//!     }));
105//!
106//! // Handle retry exhaustion with fallback
107//! let result = service.ready().await?.call("request".to_string()).await
108//!     .unwrap_or_else(|_| "Fallback: Service unavailable".to_string());
109//! # Ok(())
110//! # }
111//! ```
112//!
113//! ### Fallback with Cached Data
114//!
115//! ```
116//! use tower_resilience_retry::RetryLayer;
117//! use tower::{Service, ServiceBuilder, ServiceExt};
118//! use std::time::Duration;
119//! use std::sync::Arc;
120//! use std::collections::HashMap;
121//!
122//! # #[derive(Debug, Clone)]
123//! # struct MyError;
124//! # impl std::fmt::Display for MyError {
125//! #     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126//! #         write!(f, "MyError")
127//! #     }
128//! # }
129//! # impl std::error::Error for MyError {}
130//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
131//! let cache = Arc::new(std::sync::RwLock::new(HashMap::new()));
132//! cache.write().unwrap().insert("key", "cached value");
133//!
134//! let retry_layer = RetryLayer::<String, String, MyError>::builder()
135//!     .max_attempts(3)
136//!     .exponential_backoff(Duration::from_millis(50))
137//!     .build();
138//!
139//! let mut service = ServiceBuilder::new()
140//!     .layer(retry_layer)
141//!     .service(tower::service_fn(|req: String| async move {
142//!         Err::<String, MyError>(MyError)
143//!     }));
144//!
145//! let cache_clone = Arc::clone(&cache);
146//! let result = service.ready().await?.call("key".to_string()).await
147//!     .unwrap_or_else(|_| {
148//!         cache_clone.read().unwrap()
149//!             .get("key")
150//!             .map(|s| s.to_string())
151//!             .unwrap_or_else(|| "Default value".to_string())
152//!     });
153//! # Ok(())
154//! # }
155//! ```
156
157mod backoff;
158mod budget;
159mod config;
160mod events;
161mod layer;
162mod policy;
163
164pub use backoff::{
165    ExponentialBackoff, ExponentialRandomBackoff, FixedInterval, FnInterval, IntervalFunction,
166};
167pub use budget::{AimdBudget, RetryBudget, RetryBudgetBuilder, TokenBucketBudget};
168pub use config::{MaxAttemptsSource, RetryConfig, RetryConfigBuilder};
169pub use events::RetryEvent;
170pub use layer::RetryLayer;
171pub use policy::{ResponsePredicate, RetryPolicy, RetryPredicate};
172
173use futures::future::BoxFuture;
174use std::marker::PhantomData;
175use std::sync::Arc;
176use std::task::{Context, Poll};
177use std::time::Instant;
178use tower::Service;
179
180#[cfg(feature = "metrics")]
181use metrics::{counter, describe_counter, describe_histogram, histogram};
182
183#[cfg(feature = "tracing")]
184use tracing::{debug, info, warn};
185
186/// A Tower [`Service`] that retries failed requests.
187///
188/// This service wraps an inner service and automatically retries requests
189/// that fail, according to the configured retry policy and backoff strategy.
190pub struct Retry<S, Req, Res, E> {
191    inner: S,
192    config: Arc<RetryConfig<Req, Res, E>>,
193    _phantom: PhantomData<Req>,
194}
195
196impl<S, Req, Res, E> Retry<S, Req, Res, E> {
197    /// Creates a new `Retry` service wrapping the given service.
198    pub fn new(
199        inner: S,
200        config: Arc<RetryConfig<Req, Res, E>>,
201        _phantom: PhantomData<Req>,
202    ) -> Self {
203        #[cfg(feature = "metrics")]
204        {
205            describe_counter!(
206                "retry_calls_total",
207                "Total number of retry operations (success or exhausted)"
208            );
209            describe_counter!(
210                "retry_attempts_total",
211                "Total number of retry attempts across all calls"
212            );
213            describe_histogram!("retry_attempts", "Number of attempts per successful call");
214        }
215
216        Self {
217            inner,
218            config,
219            _phantom,
220        }
221    }
222}
223
224impl<S, Req, Res, E> Clone for Retry<S, Req, Res, E>
225where
226    S: Clone,
227{
228    fn clone(&self) -> Self {
229        Self {
230            inner: self.inner.clone(),
231            config: Arc::clone(&self.config),
232            _phantom: PhantomData,
233        }
234    }
235}
236
237impl<S, Req, Res, E> Service<Req> for Retry<S, Req, Res, E>
238where
239    S: Service<Req, Response = Res, Error = E> + Clone + Send + 'static,
240    S::Future: Send + 'static,
241    Req: Clone + Send + 'static,
242    Res: Send + 'static,
243    E: Clone + Send + 'static,
244{
245    type Response = Res;
246    type Error = E;
247    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
248
249    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
250        self.inner.poll_ready(cx)
251    }
252
253    fn call(&mut self, req: Req) -> Self::Future {
254        let mut service = self.inner.clone();
255        let config = Arc::clone(&self.config);
256
257        // Extract max_attempts from request before moving it
258        let max_attempts = config.max_attempts_source.get_max_attempts(&req);
259
260        Box::pin(async move {
261            let mut attempt = 0;
262
263            loop {
264                let result = service.call(req.clone()).await;
265
266                match result {
267                    Ok(response) => {
268                        // Check if the response should be retried
269                        if config.policy.should_retry_response(&response) {
270                            // Treat as a retryable failure
271                            if attempt + 1 >= max_attempts {
272                                #[cfg(feature = "metrics")]
273                                {
274                                    counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "exhausted").increment(1);
275                                }
276
277                                #[cfg(feature = "tracing")]
278                                warn!(retry = %config.name, attempts = attempt + 1, max_attempts = max_attempts, "Retry attempts exhausted (response predicate)");
279
280                                let event = RetryEvent::Error {
281                                    pattern_name: config.name.clone(),
282                                    timestamp: Instant::now(),
283                                    attempts: attempt + 1,
284                                };
285                                config.event_listeners.emit(&event);
286                                return Ok(response);
287                            }
288
289                            // Check retry budget if configured
290                            if let Some(ref budget) = config.budget {
291                                if !budget.try_withdraw() {
292                                    #[cfg(feature = "metrics")]
293                                    {
294                                        counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "budget_exhausted").increment(1);
295                                    }
296
297                                    #[cfg(feature = "tracing")]
298                                    warn!(retry = %config.name, attempt = attempt + 1, "Retry budget exhausted (response predicate)");
299
300                                    let event = RetryEvent::BudgetExhausted {
301                                        pattern_name: config.name.clone(),
302                                        timestamp: Instant::now(),
303                                        attempt: attempt + 1,
304                                    };
305                                    config.event_listeners.emit(&event);
306                                    return Ok(response);
307                                }
308                            }
309
310                            // Calculate backoff and retry
311                            let delay = config.policy.next_backoff(attempt);
312
313                            #[cfg(feature = "metrics")]
314                            {
315                                counter!("retry_attempts_total", "retry" => config.name.clone())
316                                    .increment(1);
317                            }
318
319                            #[cfg(feature = "tracing")]
320                            debug!(retry = %config.name, attempt = attempt + 1, delay_ms = delay.as_millis(), "Retrying after response predicate match");
321
322                            let event = RetryEvent::Retry {
323                                pattern_name: config.name.clone(),
324                                timestamp: Instant::now(),
325                                attempt,
326                                delay,
327                            };
328                            config.event_listeners.emit(&event);
329
330                            tokio::time::sleep(delay).await;
331                            attempt += 1;
332                            continue;
333                        }
334
335                        // Success - deposit to budget if configured
336                        if let Some(ref budget) = config.budget {
337                            budget.deposit();
338                        }
339
340                        #[cfg(feature = "metrics")]
341                        {
342                            counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "success").increment(1);
343                            histogram!("retry_attempts", "retry" => config.name.clone())
344                                .record((attempt + 1) as f64);
345                        }
346
347                        #[cfg(feature = "tracing")]
348                        {
349                            if attempt > 0 {
350                                info!(retry = %config.name, attempts = attempt + 1, "Request succeeded after retries");
351                            } else {
352                                debug!(retry = %config.name, "Request succeeded on first attempt");
353                            }
354                        }
355
356                        let event = RetryEvent::Success {
357                            pattern_name: config.name.clone(),
358                            timestamp: Instant::now(),
359                            attempts: attempt + 1,
360                        };
361                        config.event_listeners.emit(&event);
362                        return Ok(response);
363                    }
364                    Err(error) => {
365                        // Check if we should retry this error
366                        if !config.policy.should_retry(&error) {
367                            #[cfg(feature = "tracing")]
368                            debug!(retry = %config.name, "Error not retryable, failing immediately");
369
370                            let event = RetryEvent::IgnoredError {
371                                pattern_name: config.name.clone(),
372                                timestamp: Instant::now(),
373                            };
374                            config.event_listeners.emit(&event);
375                            return Err(error);
376                        }
377
378                        // Check if we've exhausted retries (use per-request max_attempts)
379                        if attempt + 1 >= max_attempts {
380                            #[cfg(feature = "metrics")]
381                            {
382                                counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "exhausted").increment(1);
383                            }
384
385                            #[cfg(feature = "tracing")]
386                            warn!(retry = %config.name, attempts = attempt + 1, max_attempts = max_attempts, "Retry attempts exhausted");
387
388                            let event = RetryEvent::Error {
389                                pattern_name: config.name.clone(),
390                                timestamp: Instant::now(),
391                                attempts: attempt + 1,
392                            };
393                            config.event_listeners.emit(&event);
394                            return Err(error);
395                        }
396
397                        // Check retry budget if configured
398                        if let Some(ref budget) = config.budget {
399                            if !budget.try_withdraw() {
400                                #[cfg(feature = "metrics")]
401                                {
402                                    counter!("retry_calls_total", "retry" => config.name.clone(), "result" => "budget_exhausted").increment(1);
403                                }
404
405                                #[cfg(feature = "tracing")]
406                                warn!(retry = %config.name, attempt = attempt + 1, "Retry budget exhausted, failing immediately");
407
408                                let event = RetryEvent::BudgetExhausted {
409                                    pattern_name: config.name.clone(),
410                                    timestamp: Instant::now(),
411                                    attempt: attempt + 1,
412                                };
413                                config.event_listeners.emit(&event);
414                                return Err(error);
415                            }
416                        }
417
418                        // Calculate backoff and retry
419                        let delay = config.policy.next_backoff(attempt);
420
421                        #[cfg(feature = "metrics")]
422                        {
423                            counter!("retry_attempts_total", "retry" => config.name.clone())
424                                .increment(1);
425                        }
426
427                        #[cfg(feature = "tracing")]
428                        debug!(retry = %config.name, attempt = attempt + 1, delay_ms = delay.as_millis(), "Retrying after delay");
429
430                        let event = RetryEvent::Retry {
431                            pattern_name: config.name.clone(),
432                            timestamp: Instant::now(),
433                            attempt,
434                            delay,
435                        };
436                        config.event_listeners.emit(&event);
437
438                        tokio::time::sleep(delay).await;
439                        attempt += 1;
440                    }
441                }
442            }
443        })
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use std::sync::atomic::{AtomicUsize, Ordering};
451    use std::time::Duration;
452    use tower::service_fn;
453    use tower::{Layer, ServiceExt};
454
455    #[derive(Debug, Clone)]
456    struct TestError {
457        #[allow(dead_code)]
458        message: String,
459    }
460
461    impl TestError {
462        fn new(message: &str) -> Self {
463            Self {
464                message: message.to_string(),
465            }
466        }
467    }
468
469    #[tokio::test]
470    async fn successful_request_no_retry() {
471        let call_count = Arc::new(AtomicUsize::new(0));
472        let cc = Arc::clone(&call_count);
473
474        let service = service_fn(move |req: String| {
475            let cc = Arc::clone(&cc);
476            async move {
477                cc.fetch_add(1, Ordering::SeqCst);
478                Ok::<_, TestError>(format!("Response: {}", req))
479            }
480        });
481
482        let layer = RetryLayer::<String, String, TestError>::builder()
483            .max_attempts(3)
484            .fixed_backoff(Duration::from_millis(10))
485            .build();
486
487        let mut service = layer.layer(service);
488
489        let response = service
490            .ready()
491            .await
492            .unwrap()
493            .call("test".to_string())
494            .await
495            .unwrap();
496
497        assert_eq!(response, "Response: test");
498        assert_eq!(call_count.load(Ordering::SeqCst), 1);
499    }
500
501    #[tokio::test]
502    async fn retries_on_failure() {
503        let call_count = Arc::new(AtomicUsize::new(0));
504        let cc = Arc::clone(&call_count);
505
506        let service = service_fn(move |_req: String| {
507            let cc = Arc::clone(&cc);
508            async move {
509                let count = cc.fetch_add(1, Ordering::SeqCst);
510                if count < 2 {
511                    Err(TestError::new("temporary failure"))
512                } else {
513                    Ok::<_, TestError>("success".to_string())
514                }
515            }
516        });
517
518        let layer = RetryLayer::<String, String, TestError>::builder()
519            .max_attempts(3)
520            .fixed_backoff(Duration::from_millis(10))
521            .build();
522
523        let mut service = layer.layer(service);
524
525        let response = service
526            .ready()
527            .await
528            .unwrap()
529            .call("test".to_string())
530            .await
531            .unwrap();
532
533        assert_eq!(response, "success");
534        assert_eq!(call_count.load(Ordering::SeqCst), 3);
535    }
536
537    #[tokio::test]
538    async fn exhausts_retries() {
539        let call_count = Arc::new(AtomicUsize::new(0));
540        let cc = Arc::clone(&call_count);
541
542        let service = service_fn(move |_req: String| {
543            let cc = Arc::clone(&cc);
544            async move {
545                cc.fetch_add(1, Ordering::SeqCst);
546                Err::<String, _>(TestError::new("permanent failure"))
547            }
548        });
549
550        let layer = RetryLayer::<String, String, TestError>::builder()
551            .max_attempts(3)
552            .fixed_backoff(Duration::from_millis(10))
553            .build();
554
555        let mut service = layer.layer(service);
556
557        let result = service
558            .ready()
559            .await
560            .unwrap()
561            .call("test".to_string())
562            .await;
563
564        assert!(result.is_err());
565        assert_eq!(call_count.load(Ordering::SeqCst), 3);
566    }
567
568    #[tokio::test]
569    async fn retry_predicate_filters_errors() {
570        let call_count = Arc::new(AtomicUsize::new(0));
571        let cc = Arc::clone(&call_count);
572
573        let service = service_fn(move |_req: String| {
574            let cc = Arc::clone(&cc);
575            async move {
576                cc.fetch_add(1, Ordering::SeqCst);
577                Err::<String, _>(TestError::new("non-retryable"))
578            }
579        });
580
581        let layer = RetryLayer::<String, String, TestError>::builder()
582            .max_attempts(3)
583            .fixed_backoff(Duration::from_millis(10))
584            .retry_on(|_: &TestError| false) // Never retry
585            .build();
586
587        let mut service = layer.layer(service);
588
589        let result = service
590            .ready()
591            .await
592            .unwrap()
593            .call("test".to_string())
594            .await;
595
596        assert!(result.is_err());
597        assert_eq!(call_count.load(Ordering::SeqCst), 1); // Only called once
598    }
599
600    #[tokio::test]
601    async fn event_listeners_called() {
602        let retry_count = Arc::new(AtomicUsize::new(0));
603        let success_count = Arc::new(AtomicUsize::new(0));
604
605        let rc = Arc::clone(&retry_count);
606        let sc = Arc::clone(&success_count);
607
608        let call_count = Arc::new(AtomicUsize::new(0));
609        let cc = Arc::clone(&call_count);
610
611        let service = service_fn(move |_req: String| {
612            let cc = Arc::clone(&cc);
613            async move {
614                let count = cc.fetch_add(1, Ordering::SeqCst);
615                if count < 2 {
616                    Err(TestError::new("temporary"))
617                } else {
618                    Ok::<_, TestError>("success".to_string())
619                }
620            }
621        });
622
623        let layer = RetryLayer::<String, String, TestError>::builder()
624            .max_attempts(3)
625            .fixed_backoff(Duration::from_millis(10))
626            .on_retry(move |_, _| {
627                rc.fetch_add(1, Ordering::SeqCst);
628            })
629            .on_success(move |_| {
630                sc.fetch_add(1, Ordering::SeqCst);
631            })
632            .build();
633
634        let mut service = layer.layer(service);
635
636        let _ = service
637            .ready()
638            .await
639            .unwrap()
640            .call("test".to_string())
641            .await;
642
643        assert_eq!(retry_count.load(Ordering::SeqCst), 2); // 2 retries
644        assert_eq!(success_count.load(Ordering::SeqCst), 1); // 1 success
645    }
646
647    #[tokio::test]
648    async fn budget_limits_retries() {
649        let call_count = Arc::new(AtomicUsize::new(0));
650        let budget_exhausted_count = Arc::new(AtomicUsize::new(0));
651
652        let cc = Arc::clone(&call_count);
653        let bec = Arc::clone(&budget_exhausted_count);
654
655        // Create a budget with only 1 token
656        let budget = RetryBudgetBuilder::new()
657            .token_bucket()
658            .max_tokens(1)
659            .initial_tokens(1)
660            .build();
661
662        let service = service_fn(move |_req: String| {
663            let cc = Arc::clone(&cc);
664            async move {
665                cc.fetch_add(1, Ordering::SeqCst);
666                Err::<String, _>(TestError::new("always fails"))
667            }
668        });
669
670        let layer = RetryLayer::<String, String, TestError>::builder()
671            .max_attempts(5)
672            .fixed_backoff(Duration::from_millis(1))
673            .budget(budget)
674            .on_budget_exhausted(move |_| {
675                bec.fetch_add(1, Ordering::SeqCst);
676            })
677            .build();
678
679        let mut service = layer.layer(service);
680
681        let result = service
682            .ready()
683            .await
684            .unwrap()
685            .call("test".to_string())
686            .await;
687
688        assert!(result.is_err());
689        // Should have called twice: 1 initial + 1 retry (budget allows 1 retry)
690        assert_eq!(call_count.load(Ordering::SeqCst), 2);
691        // Budget exhausted should be called once (when 2nd retry was blocked)
692        assert_eq!(budget_exhausted_count.load(Ordering::SeqCst), 1);
693    }
694
695    #[tokio::test]
696    async fn budget_replenishes_on_success() {
697        let budget = RetryBudgetBuilder::new()
698            .token_bucket()
699            .max_tokens(10)
700            .initial_tokens(0) // Start empty
701            .build();
702
703        // Budget starts empty
704        assert_eq!(budget.balance(), 0);
705        assert!(!budget.try_withdraw());
706
707        // Deposit (simulating successful request)
708        budget.deposit();
709        assert_eq!(budget.balance(), 1);
710
711        // Now withdrawal should work
712        assert!(budget.try_withdraw());
713        assert_eq!(budget.balance(), 0);
714    }
715
716    #[tokio::test]
717    async fn per_request_max_attempts() {
718        #[derive(Clone)]
719        struct Request {
720            is_idempotent: bool,
721        }
722
723        let call_count = Arc::new(AtomicUsize::new(0));
724        let cc = Arc::clone(&call_count);
725
726        let service = service_fn(move |_req: Request| {
727            let cc = Arc::clone(&cc);
728            async move {
729                cc.fetch_add(1, Ordering::SeqCst);
730                Err::<String, _>(TestError::new("always fails"))
731            }
732        });
733
734        let layer = RetryLayer::<Request, String, TestError>::builder()
735            .max_attempts_fn(|req: &Request| if req.is_idempotent { 5 } else { 1 })
736            .fixed_backoff(Duration::from_millis(1))
737            .build();
738
739        let mut service = layer.layer(service);
740
741        // Non-idempotent request - should only try once
742        call_count.store(0, Ordering::SeqCst);
743        let _ = service
744            .ready()
745            .await
746            .unwrap()
747            .call(Request {
748                is_idempotent: false,
749            })
750            .await;
751        assert_eq!(call_count.load(Ordering::SeqCst), 1);
752
753        // Idempotent request - should try 5 times
754        call_count.store(0, Ordering::SeqCst);
755        let _ = service
756            .ready()
757            .await
758            .unwrap()
759            .call(Request {
760                is_idempotent: true,
761            })
762            .await;
763        assert_eq!(call_count.load(Ordering::SeqCst), 5);
764    }
765
766    #[tokio::test]
767    async fn per_request_max_attempts_with_success() {
768        #[derive(Clone)]
769        struct Request {
770            max_retries: usize,
771            succeed_on_attempt: usize,
772        }
773
774        let call_count = Arc::new(AtomicUsize::new(0));
775        let cc = Arc::clone(&call_count);
776
777        let service = service_fn(move |req: Request| {
778            let cc = Arc::clone(&cc);
779            async move {
780                let attempt = cc.fetch_add(1, Ordering::SeqCst);
781                if attempt >= req.succeed_on_attempt {
782                    Ok::<_, TestError>("success".to_string())
783                } else {
784                    Err(TestError::new("not yet"))
785                }
786            }
787        });
788
789        let layer = RetryLayer::<Request, String, TestError>::builder()
790            .max_attempts_fn(|req: &Request| req.max_retries)
791            .fixed_backoff(Duration::from_millis(1))
792            .build();
793
794        let mut service = layer.layer(service);
795
796        // Request that succeeds on 3rd attempt with 5 max retries
797        call_count.store(0, Ordering::SeqCst);
798        let result = service
799            .ready()
800            .await
801            .unwrap()
802            .call(Request {
803                max_retries: 5,
804                succeed_on_attempt: 2,
805            })
806            .await;
807        assert!(result.is_ok());
808        assert_eq!(call_count.load(Ordering::SeqCst), 3);
809
810        // Request that would need 3 attempts but only has 2 max
811        call_count.store(0, Ordering::SeqCst);
812        let result = service
813            .ready()
814            .await
815            .unwrap()
816            .call(Request {
817                max_retries: 2,
818                succeed_on_attempt: 2,
819            })
820            .await;
821        assert!(result.is_err());
822        assert_eq!(call_count.load(Ordering::SeqCst), 2);
823    }
824
825    // Note: Backoff behavior is tested in tests/retry/retry_backoff.rs
826}