1use std::collections::hash_map::RandomState;
4use std::hash::{BuildHasher, Hasher};
5use std::time::Duration;
6
7use http::StatusCode;
8
9#[derive(Debug, Clone)]
24pub struct RetryPolicy {
25 pub max_retries: u32,
27 pub initial_backoff: Duration,
29 pub max_backoff: Duration,
31 pub backoff_multiplier: f64,
33 pub retry_on_status: Vec<StatusCode>,
35 pub retry_on_timeout: bool,
37}
38
39impl Default for RetryPolicy {
40 fn default() -> Self {
41 Self {
42 max_retries: 3,
43 initial_backoff: Duration::from_millis(100),
44 max_backoff: Duration::from_secs(10),
45 backoff_multiplier: 2.0,
46 retry_on_status: vec![
47 StatusCode::TOO_MANY_REQUESTS,
48 StatusCode::BAD_GATEWAY,
49 StatusCode::SERVICE_UNAVAILABLE,
50 StatusCode::GATEWAY_TIMEOUT,
51 ],
52 retry_on_timeout: true,
53 }
54 }
55}
56
57impl RetryPolicy {
58 pub fn none() -> Self {
60 Self {
61 max_retries: 0,
62 initial_backoff: Duration::ZERO,
63 max_backoff: Duration::ZERO,
64 backoff_multiplier: 1.0,
65 retry_on_status: Vec::new(),
66 retry_on_timeout: false,
67 }
68 }
69
70 pub fn max_retries(mut self, n: u32) -> Self {
72 self.max_retries = n;
73 self
74 }
75
76 pub fn initial_backoff(mut self, d: Duration) -> Self {
78 self.initial_backoff = d;
79 self
80 }
81
82 pub fn max_backoff(mut self, d: Duration) -> Self {
84 self.max_backoff = d;
85 self
86 }
87
88 pub fn backoff_multiplier(mut self, f: f64) -> Self {
90 self.backoff_multiplier = f;
91 self
92 }
93
94 pub fn retry_on_status(mut self, codes: Vec<StatusCode>) -> Self {
96 self.retry_on_status = codes;
97 self
98 }
99
100 pub fn retry_on_timeout(mut self, enabled: bool) -> Self {
102 self.retry_on_timeout = enabled;
103 self
104 }
105
106 pub(crate) fn should_retry_status(&self, status: StatusCode) -> bool {
108 self.retry_on_status.contains(&status)
109 }
110
111 pub(crate) fn backoff_for_attempt(&self, attempt: u32) -> Duration {
116 let base =
117 self.initial_backoff.as_secs_f64() * self.backoff_multiplier.powi(attempt as i32);
118 let capped = base.min(self.max_backoff.as_secs_f64());
119
120 let jitter_frac = random_fraction() * 0.25;
122 let with_jitter = capped * (1.0 + jitter_frac);
123
124 Duration::from_secs_f64(with_jitter.min(self.max_backoff.as_secs_f64()))
125 }
126}
127
128fn random_fraction() -> f64 {
132 let mut hasher = RandomState::new().build_hasher();
133 hasher.write_u64(0);
134 let bits = hasher.finish();
135 (bits >> 11) as f64 / (1u64 << 53) as f64
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn default_policy_values() {
144 let p = RetryPolicy::default();
145 assert_eq!(p.max_retries, 3);
146 assert_eq!(p.initial_backoff, Duration::from_millis(100));
147 assert_eq!(p.max_backoff, Duration::from_secs(10));
148 assert!((p.backoff_multiplier - 2.0).abs() < f64::EPSILON);
149 assert!(p.retry_on_timeout);
150 assert!(p.retry_on_status.contains(&StatusCode::TOO_MANY_REQUESTS));
151 assert!(p.retry_on_status.contains(&StatusCode::SERVICE_UNAVAILABLE));
152 }
153
154 #[test]
155 fn none_policy_disables_everything() {
156 let p = RetryPolicy::none();
157 assert_eq!(p.max_retries, 0);
158 assert!(p.retry_on_status.is_empty());
159 assert!(!p.retry_on_timeout);
160 }
161
162 #[test]
163 fn backoff_grows_exponentially() {
164 let p = RetryPolicy::default();
165 let b0 = p.backoff_for_attempt(0);
166 let b1 = p.backoff_for_attempt(1);
167 let b2 = p.backoff_for_attempt(2);
168 assert!(b1 > b0, "b1 ({b1:?}) should be > b0 ({b0:?})");
170 assert!(b2 > b1, "b2 ({b2:?}) should be > b1 ({b1:?})");
171 }
172
173 #[test]
174 fn backoff_capped_at_max() {
175 let p = RetryPolicy::default().max_backoff(Duration::from_millis(500));
176 let b10 = p.backoff_for_attempt(10);
177 assert!(b10 <= Duration::from_millis(500));
178 }
179
180 #[test]
181 fn builder_methods_chain() {
182 let p = RetryPolicy::none()
183 .max_retries(5)
184 .initial_backoff(Duration::from_millis(50))
185 .max_backoff(Duration::from_secs(5))
186 .backoff_multiplier(3.0)
187 .retry_on_status(vec![StatusCode::INTERNAL_SERVER_ERROR])
188 .retry_on_timeout(true);
189
190 assert_eq!(p.max_retries, 5);
191 assert_eq!(p.initial_backoff, Duration::from_millis(50));
192 assert_eq!(p.max_backoff, Duration::from_secs(5));
193 assert!((p.backoff_multiplier - 3.0).abs() < f64::EPSILON);
194 assert_eq!(p.retry_on_status, vec![StatusCode::INTERNAL_SERVER_ERROR]);
195 assert!(p.retry_on_timeout);
196 }
197}