tower_async/util/backoff/
exponential.rs

1use std::sync::Arc;
2use std::time::Duration;
3use std::{fmt::Display, sync::Mutex};
4use tokio::time;
5
6use crate::util::rng::{HasherRng, Rng};
7
8use super::{Backoff, MakeBackoff};
9
10/// A maker type for [`ExponentialBackoff`].
11#[derive(Debug, Clone)]
12pub struct ExponentialBackoffMaker<R = HasherRng> {
13    /// The minimum amount of time to wait before resuming an operation.
14    min: time::Duration,
15    /// The maximum amount of time to wait before resuming an operation.
16    max: time::Duration,
17    /// The ratio of the base timeout that may be randomly added to a backoff.
18    ///
19    /// Must be greater than or equal to 0.0.
20    jitter: f64,
21    rng: R,
22}
23
24/// A jittered [exponential backoff] strategy.
25///
26/// The backoff duration will increase exponentially for every subsequent
27/// backoff, up to a maximum duration. A small amount of [random jitter] is
28/// added to each backoff duration, in order to avoid retry spikes.
29///
30/// [exponential backoff]: https://en.wikipedia.org/wiki/Exponential_backoff
31/// [random jitter]: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
32#[derive(Debug, Clone)]
33pub struct ExponentialBackoff<R = HasherRng> {
34    min: time::Duration,
35    max: time::Duration,
36    jitter: f64,
37    state: Arc<Mutex<ExponentialBackoffState<R>>>,
38}
39
40#[derive(Debug, Clone)]
41struct ExponentialBackoffState<R = HasherRng> {
42    rng: R,
43    iterations: u32,
44}
45
46impl<R> ExponentialBackoffMaker<R>
47where
48    R: Rng,
49{
50    /// Create a new `ExponentialBackoff`.
51    ///
52    /// # Error
53    ///
54    /// Returns a config validation error if:
55    /// - `min` > `max`
56    /// - `max` > 0
57    /// - `jitter` >= `0.0`
58    /// - `jitter` < `100.0`
59    /// - `jitter` is finite
60    pub fn new(
61        min: time::Duration,
62        max: time::Duration,
63        jitter: f64,
64        rng: R,
65    ) -> Result<Self, InvalidBackoff> {
66        if min > max {
67            return Err(InvalidBackoff("maximum must not be less than minimum"));
68        }
69        if max == time::Duration::from_millis(0) {
70            return Err(InvalidBackoff("maximum must be non-zero"));
71        }
72        if jitter < 0.0 {
73            return Err(InvalidBackoff("jitter must not be negative"));
74        }
75        if jitter > 100.0 {
76            return Err(InvalidBackoff("jitter must not be greater than 100"));
77        }
78        if !jitter.is_finite() {
79            return Err(InvalidBackoff("jitter must be finite"));
80        }
81
82        Ok(ExponentialBackoffMaker {
83            min,
84            max,
85            jitter,
86            rng,
87        })
88    }
89}
90
91impl<R> MakeBackoff for ExponentialBackoffMaker<R>
92where
93    R: Rng + Clone,
94{
95    type Backoff = ExponentialBackoff<R>;
96
97    fn make_backoff(&self) -> Self::Backoff {
98        ExponentialBackoff {
99            max: self.max,
100            min: self.min,
101            jitter: self.jitter,
102            state: Arc::new(Mutex::new(ExponentialBackoffState {
103                rng: self.rng.clone(),
104                iterations: 0,
105            })),
106        }
107    }
108}
109
110impl<R: Rng> ExponentialBackoff<R> {
111    fn base(&self) -> time::Duration {
112        debug_assert!(
113            self.min <= self.max,
114            "maximum backoff must not be less than minimum backoff"
115        );
116        debug_assert!(
117            self.max > time::Duration::from_millis(0),
118            "Maximum backoff must be non-zero"
119        );
120        self.min
121            .checked_mul(2_u32.saturating_pow(self.state.lock().unwrap().iterations))
122            .unwrap_or(self.max)
123            .min(self.max)
124    }
125
126    /// Returns a random, uniform duration on `[0, base*self.jitter]` no greater
127    /// than `self.max`.
128    fn jitter(&self, base: time::Duration) -> time::Duration {
129        if self.jitter == 0.0 {
130            time::Duration::default()
131        } else {
132            let jitter_factor = self.state.lock().unwrap().rng.next_f64();
133            debug_assert!(
134                jitter_factor > 0.0,
135                "rng returns values between 0.0 and 1.0"
136            );
137            let rand_jitter = jitter_factor * self.jitter;
138            let secs = (base.as_secs() as f64) * rand_jitter;
139            let nanos = (base.subsec_nanos() as f64) * rand_jitter;
140            let remaining = self.max - base;
141            time::Duration::new(secs as u64, nanos as u32).min(remaining)
142        }
143    }
144}
145
146impl<R> Backoff for ExponentialBackoff<R>
147where
148    R: Rng,
149{
150    async fn next_backoff(&self) {
151        let base = self.base();
152        let next = base + self.jitter(base);
153
154        self.state.lock().unwrap().iterations += 1;
155
156        tokio::time::sleep(next).await
157    }
158}
159
160impl Default for ExponentialBackoffMaker {
161    fn default() -> Self {
162        ExponentialBackoffMaker::new(
163            Duration::from_millis(50),
164            Duration::from_millis(u64::MAX),
165            0.99,
166            HasherRng::default(),
167        )
168        .expect("Unable to create ExponentialBackoff")
169    }
170}
171
172/// Backoff validation error.
173#[derive(Debug)]
174pub struct InvalidBackoff(&'static str);
175
176impl Display for InvalidBackoff {
177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178        write!(f, "invalid backoff: {}", self.0)
179    }
180}
181
182impl std::error::Error for InvalidBackoff {}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use quickcheck::*;
188
189    quickcheck! {
190        fn backoff_base_first(min_ms: u64, max_ms: u64) -> TestResult {
191            let min = time::Duration::from_millis(min_ms);
192            let max = time::Duration::from_millis(max_ms);
193            let rng = HasherRng::default();
194            let backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
195                Err(_) => return TestResult::discard(),
196                Ok(backoff) => backoff,
197            };
198            let backoff = backoff.make_backoff();
199
200            let delay = backoff.base();
201            TestResult::from_bool(min == delay)
202        }
203
204        fn backoff_base(min_ms: u64, max_ms: u64, iterations: u32) -> TestResult {
205            let min = time::Duration::from_millis(min_ms);
206            let max = time::Duration::from_millis(max_ms);
207            let rng = HasherRng::default();
208            let backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
209                Err(_) => return TestResult::discard(),
210                Ok(backoff) => backoff,
211            };
212            let backoff = backoff.make_backoff();
213
214            backoff.state.lock().unwrap().iterations = iterations;
215            let delay = backoff.base();
216            TestResult::from_bool(min <= delay && delay <= max)
217        }
218
219        fn backoff_jitter(base_ms: u64, max_ms: u64, jitter: f64) -> TestResult {
220            let base = time::Duration::from_millis(base_ms);
221            let max = time::Duration::from_millis(max_ms);
222            let rng = HasherRng::default();
223            let backoff = match ExponentialBackoffMaker::new(base, max, jitter, rng) {
224                Err(_) => return TestResult::discard(),
225                Ok(backoff) => backoff,
226            };
227            let backoff = backoff.make_backoff();
228
229            let j = backoff.jitter(base);
230            if jitter == 0.0 || base_ms == 0 || max_ms == base_ms {
231                TestResult::from_bool(j == time::Duration::default())
232            } else {
233                TestResult::from_bool(j > time::Duration::default())
234            }
235        }
236    }
237}