Skip to main content

tower_retry_plus/
policy.rs

1use crate::backoff::IntervalFunction;
2use std::sync::Arc;
3use std::time::Duration;
4
5/// Determines whether an error should be retried.
6pub type RetryPredicate<E> = Arc<dyn Fn(&E) -> bool + Send + Sync>;
7
8/// Policy for retry behavior.
9///
10/// This policy combines the interval function (backoff strategy),
11/// maximum attempts, and retry predicate (which errors to retry).
12pub struct RetryPolicy<E> {
13    pub(crate) max_attempts: usize,
14    pub(crate) interval_fn: Arc<dyn IntervalFunction>,
15    pub(crate) retry_predicate: Option<RetryPredicate<E>>,
16}
17
18impl<E> RetryPolicy<E> {
19    /// Creates a new retry policy.
20    pub fn new(max_attempts: usize, interval_fn: Arc<dyn IntervalFunction>) -> Self {
21        Self {
22            max_attempts,
23            interval_fn,
24            retry_predicate: None,
25        }
26    }
27
28    /// Sets a predicate to determine which errors should be retried.
29    pub fn with_retry_predicate<F>(mut self, predicate: F) -> Self
30    where
31        F: Fn(&E) -> bool + Send + Sync + 'static,
32    {
33        self.retry_predicate = Some(Arc::new(predicate));
34        self
35    }
36
37    /// Checks if the given error should be retried.
38    pub fn should_retry(&self, error: &E) -> bool {
39        if let Some(predicate) = &self.retry_predicate {
40            predicate(error)
41        } else {
42            true // Retry all errors by default
43        }
44    }
45
46    /// Computes the delay before the next retry attempt.
47    pub fn next_backoff(&self, attempt: usize) -> Duration {
48        self.interval_fn.next_interval(attempt)
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55    use crate::backoff::FixedInterval;
56
57    #[derive(Debug)]
58    struct TestError {
59        retryable: bool,
60    }
61
62    #[test]
63    fn test_retry_all_by_default() {
64        let policy = RetryPolicy::new(3, Arc::new(FixedInterval::new(Duration::from_secs(1))));
65
66        let error = TestError { retryable: false };
67        assert!(policy.should_retry(&error));
68    }
69
70    #[test]
71    fn test_retry_predicate() {
72        let policy = RetryPolicy::new(3, Arc::new(FixedInterval::new(Duration::from_secs(1))))
73            .with_retry_predicate(|e: &TestError| e.retryable);
74
75        assert!(policy.should_retry(&TestError { retryable: true }));
76        assert!(!policy.should_retry(&TestError { retryable: false }));
77    }
78
79    #[test]
80    fn test_backoff_computation() {
81        let policy: RetryPolicy<TestError> =
82            RetryPolicy::new(3, Arc::new(FixedInterval::new(Duration::from_secs(2))));
83
84        assert_eq!(policy.next_backoff(0), Duration::from_secs(2));
85        assert_eq!(policy.next_backoff(1), Duration::from_secs(2));
86    }
87}