1use std::{cmp, fmt::Debug, time};
2
3use rand::RngExt;
4
5#[derive(Debug, PartialEq, Eq)]
6pub enum RetryAction {
7 OKAY,
8 WAIT,
9}
10
11pub trait RetryPolicy: Debug + PartialEq + Eq {
12 fn max_tries(&self) -> usize;
13 fn current_tries(&self) -> usize;
14
15 fn fail(&mut self);
16 fn succeed(&mut self);
17
18 fn can_try(&self) -> Option<RetryAction> {
19 if self.current_tries() >= self.max_tries() {
20 None
21 } else {
22 Some(RetryAction::OKAY)
23 }
24 }
25
26 fn is_down(&self) -> bool;
27}
28
29#[derive(Debug, PartialEq, Eq, Clone)]
30pub enum RetryPolicyWrapper {
31 ExponentialBackoff(ExponentialBackoffPolicy),
32}
33
34#[derive(Debug, PartialEq, Eq, Clone)]
35pub struct ExponentialBackoffPolicy {
36 max_tries: usize,
37 current_tries: usize,
38 last_try: time::Instant,
39 wait: time::Duration,
40}
41
42impl ExponentialBackoffPolicy {
43 pub fn new(max_tries: usize) -> Self {
44 ExponentialBackoffPolicy {
45 max_tries,
46 current_tries: 0,
47 last_try: time::Instant::now(),
48 wait: time::Duration::from_secs(0),
49 }
50 }
51}
52
53impl RetryPolicy for ExponentialBackoffPolicy {
54 fn max_tries(&self) -> usize {
55 self.max_tries
56 }
57
58 fn current_tries(&self) -> usize {
59 self.current_tries
60 }
61
62 fn fail(&mut self) {
63 if self.last_try.elapsed().lt(&self.wait) {
64 return;
66 }
67
68 let max_secs = cmp::max(
69 1,
70 1u64.checked_shl(self.current_tries as u32)
71 .unwrap_or(u64::MAX),
72 );
73 let wait = if max_secs == 1 {
74 1
75 } else {
76 let mut rng = rand::rng();
77 rng.random_range(1..max_secs)
78 };
79
80 self.wait = time::Duration::from_secs(wait);
81 self.last_try = time::Instant::now();
82 self.current_tries = cmp::min(self.current_tries + 1, self.max_tries);
83 }
84
85 fn succeed(&mut self) {
86 self.wait = time::Duration::default();
87 self.last_try = time::Instant::now();
88 self.current_tries = 0;
89 }
90
91 fn can_try(&self) -> Option<RetryAction> {
92 let action = if self.last_try.elapsed().ge(&self.wait) {
93 RetryAction::OKAY
94 } else {
95 RetryAction::WAIT
96 };
97
98 Some(action)
99 }
100
101 fn is_down(&self) -> bool {
102 self.current_tries() >= self.max_tries()
103 }
104}
105
106#[cfg(test)]
107impl ExponentialBackoffPolicy {
108 pub(crate) fn force_down(&mut self) {
116 self.current_tries = self.max_tries;
117 }
118}
119
120#[cfg(test)]
121impl RetryPolicyWrapper {
122 pub(crate) fn force_down(&mut self) {
124 match self {
125 RetryPolicyWrapper::ExponentialBackoff(p) => p.force_down(),
126 }
127 }
128}
129
130impl From<ExponentialBackoffPolicy> for RetryPolicyWrapper {
131 fn from(val: ExponentialBackoffPolicy) -> Self {
132 RetryPolicyWrapper::ExponentialBackoff(val)
133 }
134}
135
136impl RetryPolicy for RetryPolicyWrapper {
137 fn max_tries(&self) -> usize {
138 match *self {
139 RetryPolicyWrapper::ExponentialBackoff(ref policy) => policy,
140 }
141 .max_tries()
142 }
143
144 fn current_tries(&self) -> usize {
145 match *self {
146 RetryPolicyWrapper::ExponentialBackoff(ref policy) => policy,
147 }
148 .current_tries()
149 }
150
151 fn fail(&mut self) {
152 match *self {
153 RetryPolicyWrapper::ExponentialBackoff(ref mut policy) => policy,
154 }
155 .fail()
156 }
157
158 fn succeed(&mut self) {
159 match *self {
160 RetryPolicyWrapper::ExponentialBackoff(ref mut policy) => policy,
161 }
162 .succeed()
163 }
164
165 fn can_try(&self) -> Option<RetryAction> {
166 match *self {
167 RetryPolicyWrapper::ExponentialBackoff(ref policy) => policy,
168 }
169 .can_try()
170 }
171
172 fn is_down(&self) -> bool {
173 match *self {
174 RetryPolicyWrapper::ExponentialBackoff(ref policy) => policy,
175 }
176 .is_down()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use serial_test::serial;
183
184 use super::{ExponentialBackoffPolicy, RetryAction, RetryPolicy};
185
186 const MAX_FAILS: usize = 10;
187
188 #[serial]
189 #[test]
190 fn no_fail() {
191 let policy = ExponentialBackoffPolicy::new(MAX_FAILS);
192 let can_try = policy.can_try();
193
194 assert_eq!(Some(RetryAction::OKAY), can_try)
195 }
196
197 #[serial]
198 #[test]
199 fn single_fail() {
200 let mut policy = ExponentialBackoffPolicy::new(MAX_FAILS);
201 policy.fail();
202 let can_try = policy.can_try();
203
204 assert_eq!(Some(RetryAction::WAIT), can_try)
206 }
207
208 #[serial]
209 #[test]
210 fn max_fails() {
211 let mut policy = ExponentialBackoffPolicy::new(MAX_FAILS);
212
213 for _ in 0..MAX_FAILS {
214 policy.fail();
215 }
216
217 let can_try = policy.can_try();
218
219 assert_eq!(Some(RetryAction::WAIT), can_try)
220 }
221
222 #[serial]
223 #[test]
224 fn recover_from_fail() {
225 let mut policy = ExponentialBackoffPolicy::new(MAX_FAILS);
226
227 for _ in 0..(MAX_FAILS - 1) {
229 policy.fail();
230 }
231
232 policy.succeed();
233 policy.fail();
234 policy.fail();
235 policy.fail();
236
237 let can_try = policy.can_try();
238
239 assert_eq!(Some(RetryAction::WAIT), can_try)
240 }
241}