1use std::{cmp, fmt::Debug, time};
2
3use rand::Rng;
4
5#[derive(Debug, PartialEq, Eq)]
6pub enum RetryAction {
7 OKAY,
8 WAIT,
9}
10
11pub trait RetryPolicy: Debug + PartialEq + Eq {
12 fn max_tries(&self) -> usize;
13 fn current_tries(&self) -> usize;
14
15 fn fail(&mut self);
16 fn succeed(&mut self);
17
18 fn can_try(&self) -> Option<RetryAction> {
19 if self.current_tries() >= self.max_tries() {
20 None
21 } else {
22 Some(RetryAction::OKAY)
23 }
24 }
25
26 fn is_down(&self) -> bool;
27}
28
29#[derive(Debug, PartialEq, Eq, Clone)]
30pub enum RetryPolicyWrapper {
31 ExponentialBackoff(ExponentialBackoffPolicy),
32}
33
34#[derive(Debug, PartialEq, Eq, Clone)]
35pub struct ExponentialBackoffPolicy {
36 max_tries: usize,
37 current_tries: usize,
38 last_try: time::Instant,
39 wait: time::Duration,
40}
41
42impl ExponentialBackoffPolicy {
43 pub fn new(max_tries: usize) -> Self {
44 ExponentialBackoffPolicy {
45 max_tries,
46 current_tries: 0,
47 last_try: time::Instant::now(),
48 wait: time::Duration::from_secs(0),
49 }
50 }
51}
52
53impl RetryPolicy for ExponentialBackoffPolicy {
54 fn max_tries(&self) -> usize {
55 self.max_tries
56 }
57
58 fn current_tries(&self) -> usize {
59 self.current_tries
60 }
61
62 fn fail(&mut self) {
63 if self.last_try.elapsed().lt(&self.wait) {
64 return;
66 }
67
68 let max_secs = cmp::max(1, 1u64.wrapping_shl(self.current_tries as u32));
69 let wait = if max_secs == 1 {
70 1
71 } else {
72 let mut rng = rand::thread_rng();
73 rng.gen_range(1..max_secs)
74 };
75
76 self.wait = time::Duration::from_secs(wait);
77 self.last_try = time::Instant::now();
78 self.current_tries = cmp::min(self.current_tries + 1, self.max_tries);
79 }
80
81 fn succeed(&mut self) {
82 self.wait = time::Duration::default();
83 self.last_try = time::Instant::now();
84 self.current_tries = 0;
85 }
86
87 fn can_try(&self) -> Option<RetryAction> {
88 let action = if self.last_try.elapsed().ge(&self.wait) {
89 RetryAction::OKAY
90 } else {
91 RetryAction::WAIT
92 };
93
94 Some(action)
95 }
96
97 fn is_down(&self) -> bool {
98 self.current_tries() >= self.max_tries()
99 }
100}
101
102impl From<ExponentialBackoffPolicy> for RetryPolicyWrapper {
103 fn from(val: ExponentialBackoffPolicy) -> Self {
104 RetryPolicyWrapper::ExponentialBackoff(val)
105 }
106}
107
108impl RetryPolicy for RetryPolicyWrapper {
109 fn max_tries(&self) -> usize {
110 match *self {
111 RetryPolicyWrapper::ExponentialBackoff(ref policy) => policy,
112 }
113 .max_tries()
114 }
115
116 fn current_tries(&self) -> usize {
117 match *self {
118 RetryPolicyWrapper::ExponentialBackoff(ref policy) => policy,
119 }
120 .current_tries()
121 }
122
123 fn fail(&mut self) {
124 match *self {
125 RetryPolicyWrapper::ExponentialBackoff(ref mut policy) => policy,
126 }
127 .fail()
128 }
129
130 fn succeed(&mut self) {
131 match *self {
132 RetryPolicyWrapper::ExponentialBackoff(ref mut policy) => policy,
133 }
134 .succeed()
135 }
136
137 fn can_try(&self) -> Option<RetryAction> {
138 match *self {
139 RetryPolicyWrapper::ExponentialBackoff(ref policy) => policy,
140 }
141 .can_try()
142 }
143
144 fn is_down(&self) -> bool {
145 match *self {
146 RetryPolicyWrapper::ExponentialBackoff(ref policy) => policy,
147 }
148 .is_down()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::{ExponentialBackoffPolicy, RetryAction, RetryPolicy};
155 use serial_test::serial;
156
157 const MAX_FAILS: usize = 10;
158
159 #[serial]
160 #[test]
161 fn no_fail() {
162 let policy = ExponentialBackoffPolicy::new(MAX_FAILS);
163 let can_try = policy.can_try();
164
165 assert_eq!(Some(RetryAction::OKAY), can_try)
166 }
167
168 #[serial]
169 #[test]
170 fn single_fail() {
171 let mut policy = ExponentialBackoffPolicy::new(MAX_FAILS);
172 policy.fail();
173 let can_try = policy.can_try();
174
175 assert_eq!(Some(RetryAction::WAIT), can_try)
177 }
178
179 #[serial]
180 #[test]
181 fn max_fails() {
182 let mut policy = ExponentialBackoffPolicy::new(MAX_FAILS);
183
184 for _ in 0..MAX_FAILS {
185 policy.fail();
186 }
187
188 let can_try = policy.can_try();
189
190 assert_eq!(Some(RetryAction::WAIT), can_try)
191 }
192
193 #[serial]
194 #[test]
195 fn recover_from_fail() {
196 let mut policy = ExponentialBackoffPolicy::new(MAX_FAILS);
197
198 for _ in 0..(MAX_FAILS - 1) {
200 policy.fail();
201 }
202
203 policy.succeed();
204 policy.fail();
205 policy.fail();
206 policy.fail();
207
208 let can_try = policy.can_try();
209
210 assert_eq!(Some(RetryAction::WAIT), can_try)
211 }
212}