1use std::time::Duration;
6
7use rand::thread_rng;
8use rand::Rng;
9
10pub const DEFAULT_REGION_BACKOFF: Backoff = Backoff::no_jitter_backoff(2, 500, 10);
11pub const DEFAULT_STORE_BACKOFF: Backoff = Backoff::no_jitter_backoff(2, 1000, 10);
12pub const OPTIMISTIC_BACKOFF: Backoff = Backoff::no_jitter_backoff(2, 500, 10);
13pub const PESSIMISTIC_BACKOFF: Backoff = Backoff::no_jitter_backoff(2, 500, 10);
14
15#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct Backoff {
20 kind: BackoffKind,
21 current_attempts: u32,
22 max_attempts: u32,
23 base_delay_ms: u64,
24 current_delay_ms: u64,
25 max_delay_ms: u64,
26}
27
28impl Backoff {
29 pub fn next_delay_duration(&mut self) -> Option<Duration> {
31 if self.current_attempts >= self.max_attempts {
32 return None;
33 }
34 self.current_attempts += 1;
35
36 match self.kind {
37 BackoffKind::None => None,
38 BackoffKind::NoJitter => {
39 let delay_ms = self.max_delay_ms.min(self.current_delay_ms);
40 self.current_delay_ms <<= 1;
41
42 Some(Duration::from_millis(delay_ms))
43 }
44 BackoffKind::FullJitter => {
45 let delay_ms = self.max_delay_ms.min(self.current_delay_ms);
46
47 let mut rng = thread_rng();
48 let delay_ms: u64 = rng.gen_range(0..delay_ms);
49 self.current_delay_ms <<= 1;
50
51 Some(Duration::from_millis(delay_ms))
52 }
53 BackoffKind::EqualJitter => {
54 let delay_ms = self.max_delay_ms.min(self.current_delay_ms);
55 let half_delay_ms = delay_ms >> 1;
56
57 let mut rng = thread_rng();
58 let delay_ms: u64 = rng.gen_range(0..half_delay_ms) + half_delay_ms;
59 self.current_delay_ms <<= 1;
60
61 Some(Duration::from_millis(delay_ms))
62 }
63 BackoffKind::DecorrelatedJitter => {
64 let mut rng = thread_rng();
65 let delay_ms: u64 = rng
66 .gen_range(0..self.current_delay_ms * 3 - self.base_delay_ms)
67 + self.base_delay_ms;
68
69 let delay_ms = delay_ms.min(self.max_delay_ms);
70 self.current_delay_ms = delay_ms;
71
72 Some(Duration::from_millis(delay_ms))
73 }
74 }
75 }
76
77 pub fn is_none(&self) -> bool {
79 self.kind == BackoffKind::None
80 }
81
82 pub fn current_attempts(&self) -> u32 {
84 self.current_attempts
85 }
86
87 pub const fn no_backoff() -> Backoff {
89 Backoff {
90 kind: BackoffKind::None,
91 current_attempts: 0,
92 max_attempts: 0,
93 base_delay_ms: 0,
94 current_delay_ms: 0,
95 max_delay_ms: 0,
96 }
97 }
98
99 pub const fn no_jitter_backoff(
105 base_delay_ms: u64,
106 max_delay_ms: u64,
107 max_attempts: u32,
108 ) -> Backoff {
109 Backoff {
110 kind: BackoffKind::NoJitter,
111 current_attempts: 0,
112 max_attempts,
113 base_delay_ms,
114 current_delay_ms: base_delay_ms,
115 max_delay_ms,
116 }
117 }
118
119 pub fn full_jitter_backoff(
125 base_delay_ms: u64,
126 max_delay_ms: u64,
127 max_attempts: u32,
128 ) -> Backoff {
129 assert!(
130 base_delay_ms > 0 && max_delay_ms > 0,
131 "Both base_delay_ms and max_delay_ms must be positive"
132 );
133
134 Backoff {
135 kind: BackoffKind::FullJitter,
136 current_attempts: 0,
137 max_attempts,
138 base_delay_ms,
139 current_delay_ms: base_delay_ms,
140 max_delay_ms,
141 }
142 }
143
144 pub fn equal_jitter_backoff(
150 base_delay_ms: u64,
151 max_delay_ms: u64,
152 max_attempts: u32,
153 ) -> Backoff {
154 assert!(
155 base_delay_ms > 1 && max_delay_ms > 1,
156 "Both base_delay_ms and max_delay_ms must be greater than 1"
157 );
158
159 Backoff {
160 kind: BackoffKind::EqualJitter,
161 current_attempts: 0,
162 max_attempts,
163 base_delay_ms,
164 current_delay_ms: base_delay_ms,
165 max_delay_ms,
166 }
167 }
168
169 pub fn decorrelated_jitter_backoff(
175 base_delay_ms: u64,
176 max_delay_ms: u64,
177 max_attempts: u32,
178 ) -> Backoff {
179 assert!(base_delay_ms > 0, "base_delay_ms must be positive");
180
181 Backoff {
182 kind: BackoffKind::DecorrelatedJitter,
183 current_attempts: 0,
184 max_attempts,
185 base_delay_ms,
186 current_delay_ms: base_delay_ms,
187 max_delay_ms,
188 }
189 }
190}
191
192#[derive(Debug, Clone, PartialEq, Eq)]
194enum BackoffKind {
195 None,
196 NoJitter,
197 FullJitter,
198 EqualJitter,
199 DecorrelatedJitter,
200}
201
202#[cfg(test)]
203mod test {
204 use std::convert::TryInto;
205
206 use super::*;
207
208 #[test]
209 fn test_no_jitter_backoff() {
210 let mut backoff = Backoff::no_jitter_backoff(0, 0, 0);
212 assert_eq!(backoff.next_delay_duration(), None);
213
214 let mut backoff = Backoff::no_jitter_backoff(2, 7, 3);
215 assert_eq!(
216 backoff.next_delay_duration(),
217 Some(Duration::from_millis(2))
218 );
219 assert_eq!(
220 backoff.next_delay_duration(),
221 Some(Duration::from_millis(4))
222 );
223 assert_eq!(
224 backoff.next_delay_duration(),
225 Some(Duration::from_millis(7))
226 );
227 assert_eq!(backoff.next_delay_duration(), None);
228 }
229
230 #[test]
231 fn test_full_jitter_backoff() {
232 let mut backoff = Backoff::full_jitter_backoff(2, 7, 3);
233 assert!(backoff.next_delay_duration().unwrap() <= Duration::from_millis(2));
234 assert!(backoff.next_delay_duration().unwrap() <= Duration::from_millis(4));
235 assert!(backoff.next_delay_duration().unwrap() <= Duration::from_millis(7));
236 assert_eq!(backoff.next_delay_duration(), None);
237 }
238
239 #[test]
240 #[should_panic(expected = "Both base_delay_ms and max_delay_ms must be positive")]
241 fn test_full_jitter_backoff_with_invalid_base_delay_ms() {
242 Backoff::full_jitter_backoff(0, 7, 3);
243 }
244
245 #[test]
246 #[should_panic(expected = "Both base_delay_ms and max_delay_ms must be positive")]
247 fn test_full_jitter_backoff_with_invalid_max_delay_ms() {
248 Backoff::full_jitter_backoff(2, 0, 3);
249 }
250
251 #[test]
252 fn test_equal_jitter_backoff() {
253 let mut backoff = Backoff::equal_jitter_backoff(2, 7, 3);
254
255 let first_delay_dur = backoff.next_delay_duration().unwrap();
256 assert!(first_delay_dur >= Duration::from_millis(1));
257 assert!(first_delay_dur <= Duration::from_millis(2));
258
259 let second_delay_dur = backoff.next_delay_duration().unwrap();
260 assert!(second_delay_dur >= Duration::from_millis(2));
261 assert!(second_delay_dur <= Duration::from_millis(4));
262
263 let third_delay_dur = backoff.next_delay_duration().unwrap();
264 assert!(third_delay_dur >= Duration::from_millis(3));
265 assert!(third_delay_dur <= Duration::from_millis(6));
266
267 assert_eq!(backoff.next_delay_duration(), None);
268 }
269
270 #[test]
271 #[should_panic(expected = "Both base_delay_ms and max_delay_ms must be greater than 1")]
272 fn test_equal_jitter_backoff_with_invalid_base_delay_ms() {
273 Backoff::equal_jitter_backoff(1, 7, 3);
274 }
275
276 #[test]
277 #[should_panic(expected = "Both base_delay_ms and max_delay_ms must be greater than 1")]
278 fn test_equal_jitter_backoff_with_invalid_max_delay_ms() {
279 Backoff::equal_jitter_backoff(2, 1, 3);
280 }
281
282 #[test]
283 fn test_decorrelated_jitter_backoff() {
284 let mut backoff = Backoff::decorrelated_jitter_backoff(2, 7, 3);
285
286 let first_delay_dur = backoff.next_delay_duration().unwrap();
287 assert!(first_delay_dur >= Duration::from_millis(2));
288 assert!(first_delay_dur <= Duration::from_millis(6));
289
290 let second_delay_dur = backoff.next_delay_duration().unwrap();
291 assert!(second_delay_dur >= Duration::from_millis(2));
292 let cap_ms = 7u64.min((first_delay_dur.as_millis() * 3).try_into().unwrap());
293 assert!(second_delay_dur <= Duration::from_millis(cap_ms));
294
295 let third_delay_dur = backoff.next_delay_duration().unwrap();
296 assert!(third_delay_dur >= Duration::from_millis(2));
297 let cap_ms = 7u64.min((second_delay_dur.as_millis() * 3).try_into().unwrap());
298 assert!(second_delay_dur <= Duration::from_millis(cap_ms));
299
300 assert_eq!(backoff.next_delay_duration(), None);
301 }
302
303 #[test]
304 #[should_panic(expected = "base_delay_ms must be positive")]
305 fn test_decorrelated_jitter_backoff_with_invalid_base_delay_ms() {
306 Backoff::decorrelated_jitter_backoff(0, 7, 3);
307 }
308}