tower_async/util/backoff/
exponential.rs1use std::sync::Arc;
2use std::time::Duration;
3use std::{fmt::Display, sync::Mutex};
4use tokio::time;
5
6use crate::util::rng::{HasherRng, Rng};
7
8use super::{Backoff, MakeBackoff};
9
10#[derive(Debug, Clone)]
12pub struct ExponentialBackoffMaker<R = HasherRng> {
13 min: time::Duration,
15 max: time::Duration,
17 jitter: f64,
21 rng: R,
22}
23
24#[derive(Debug, Clone)]
33pub struct ExponentialBackoff<R = HasherRng> {
34 min: time::Duration,
35 max: time::Duration,
36 jitter: f64,
37 state: Arc<Mutex<ExponentialBackoffState<R>>>,
38}
39
40#[derive(Debug, Clone)]
41struct ExponentialBackoffState<R = HasherRng> {
42 rng: R,
43 iterations: u32,
44}
45
46impl<R> ExponentialBackoffMaker<R>
47where
48 R: Rng,
49{
50 pub fn new(
61 min: time::Duration,
62 max: time::Duration,
63 jitter: f64,
64 rng: R,
65 ) -> Result<Self, InvalidBackoff> {
66 if min > max {
67 return Err(InvalidBackoff("maximum must not be less than minimum"));
68 }
69 if max == time::Duration::from_millis(0) {
70 return Err(InvalidBackoff("maximum must be non-zero"));
71 }
72 if jitter < 0.0 {
73 return Err(InvalidBackoff("jitter must not be negative"));
74 }
75 if jitter > 100.0 {
76 return Err(InvalidBackoff("jitter must not be greater than 100"));
77 }
78 if !jitter.is_finite() {
79 return Err(InvalidBackoff("jitter must be finite"));
80 }
81
82 Ok(ExponentialBackoffMaker {
83 min,
84 max,
85 jitter,
86 rng,
87 })
88 }
89}
90
91impl<R> MakeBackoff for ExponentialBackoffMaker<R>
92where
93 R: Rng + Clone,
94{
95 type Backoff = ExponentialBackoff<R>;
96
97 fn make_backoff(&self) -> Self::Backoff {
98 ExponentialBackoff {
99 max: self.max,
100 min: self.min,
101 jitter: self.jitter,
102 state: Arc::new(Mutex::new(ExponentialBackoffState {
103 rng: self.rng.clone(),
104 iterations: 0,
105 })),
106 }
107 }
108}
109
110impl<R: Rng> ExponentialBackoff<R> {
111 fn base(&self) -> time::Duration {
112 debug_assert!(
113 self.min <= self.max,
114 "maximum backoff must not be less than minimum backoff"
115 );
116 debug_assert!(
117 self.max > time::Duration::from_millis(0),
118 "Maximum backoff must be non-zero"
119 );
120 self.min
121 .checked_mul(2_u32.saturating_pow(self.state.lock().unwrap().iterations))
122 .unwrap_or(self.max)
123 .min(self.max)
124 }
125
126 fn jitter(&self, base: time::Duration) -> time::Duration {
129 if self.jitter == 0.0 {
130 time::Duration::default()
131 } else {
132 let jitter_factor = self.state.lock().unwrap().rng.next_f64();
133 debug_assert!(
134 jitter_factor > 0.0,
135 "rng returns values between 0.0 and 1.0"
136 );
137 let rand_jitter = jitter_factor * self.jitter;
138 let secs = (base.as_secs() as f64) * rand_jitter;
139 let nanos = (base.subsec_nanos() as f64) * rand_jitter;
140 let remaining = self.max - base;
141 time::Duration::new(secs as u64, nanos as u32).min(remaining)
142 }
143 }
144}
145
146impl<R> Backoff for ExponentialBackoff<R>
147where
148 R: Rng,
149{
150 async fn next_backoff(&self) {
151 let base = self.base();
152 let next = base + self.jitter(base);
153
154 self.state.lock().unwrap().iterations += 1;
155
156 tokio::time::sleep(next).await
157 }
158}
159
160impl Default for ExponentialBackoffMaker {
161 fn default() -> Self {
162 ExponentialBackoffMaker::new(
163 Duration::from_millis(50),
164 Duration::from_millis(u64::MAX),
165 0.99,
166 HasherRng::default(),
167 )
168 .expect("Unable to create ExponentialBackoff")
169 }
170}
171
172#[derive(Debug)]
174pub struct InvalidBackoff(&'static str);
175
176impl Display for InvalidBackoff {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 write!(f, "invalid backoff: {}", self.0)
179 }
180}
181
182impl std::error::Error for InvalidBackoff {}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use quickcheck::*;
188
189 quickcheck! {
190 fn backoff_base_first(min_ms: u64, max_ms: u64) -> TestResult {
191 let min = time::Duration::from_millis(min_ms);
192 let max = time::Duration::from_millis(max_ms);
193 let rng = HasherRng::default();
194 let backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
195 Err(_) => return TestResult::discard(),
196 Ok(backoff) => backoff,
197 };
198 let backoff = backoff.make_backoff();
199
200 let delay = backoff.base();
201 TestResult::from_bool(min == delay)
202 }
203
204 fn backoff_base(min_ms: u64, max_ms: u64, iterations: u32) -> TestResult {
205 let min = time::Duration::from_millis(min_ms);
206 let max = time::Duration::from_millis(max_ms);
207 let rng = HasherRng::default();
208 let backoff = match ExponentialBackoffMaker::new(min, max, 0.0, rng) {
209 Err(_) => return TestResult::discard(),
210 Ok(backoff) => backoff,
211 };
212 let backoff = backoff.make_backoff();
213
214 backoff.state.lock().unwrap().iterations = iterations;
215 let delay = backoff.base();
216 TestResult::from_bool(min <= delay && delay <= max)
217 }
218
219 fn backoff_jitter(base_ms: u64, max_ms: u64, jitter: f64) -> TestResult {
220 let base = time::Duration::from_millis(base_ms);
221 let max = time::Duration::from_millis(max_ms);
222 let rng = HasherRng::default();
223 let backoff = match ExponentialBackoffMaker::new(base, max, jitter, rng) {
224 Err(_) => return TestResult::discard(),
225 Ok(backoff) => backoff,
226 };
227 let backoff = backoff.make_backoff();
228
229 let j = backoff.jitter(base);
230 if jitter == 0.0 || base_ms == 0 || max_ms == base_ms {
231 TestResult::from_bool(j == time::Duration::default())
232 } else {
233 TestResult::from_bool(j > time::Duration::default())
234 }
235 }
236 }
237}