1use std::time::Duration;
15
16use crate::error_category::{ErrorCategory, classify_anyhow_error};
17
18#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)]
20pub struct RetryPolicy {
21 pub max_attempts: u32,
23 pub initial_delay: Duration,
24 pub max_delay: Duration,
25 pub multiplier: f64,
26 pub jitter: f64,
27}
28
29impl RetryPolicy {
30 pub fn new(
31 max_attempts: u32,
32 initial_delay: Duration,
33 max_delay: Duration,
34 multiplier: f64,
35 ) -> Self {
36 Self {
37 max_attempts: max_attempts.max(1),
38 initial_delay,
39 max_delay,
40 multiplier: multiplier.max(1.0),
41 jitter: 0.0,
42 }
43 }
44
45 pub fn from_retries(
46 max_retries: u32,
47 initial_delay: Duration,
48 max_delay: Duration,
49 multiplier: f64,
50 ) -> Self {
51 Self::new(
52 max_retries.saturating_add(1),
53 initial_delay,
54 max_delay,
55 multiplier,
56 )
57 }
58
59 pub fn simple(max_retries: u32, base_delay_ms: u64, max_delay_ms: u64) -> Self {
65 Self::from_retries(
66 max_retries,
67 Duration::from_millis(base_delay_ms),
68 Duration::from_millis(max_delay_ms),
69 2.0,
70 )
71 }
72
73 pub fn delay_for_attempt(&self, attempt_index: u32) -> Duration {
74 let multiplier = self.multiplier.powi(attempt_index as i32);
75 let base_delay = Duration::try_from_secs_f64(self.initial_delay.as_secs_f64() * multiplier)
76 .unwrap_or(self.max_delay)
77 .min(self.max_delay);
78
79 if !self.jitter.is_finite() || self.jitter <= 0.0 {
80 return base_delay;
81 }
82
83 #[allow(clippy::cast_sign_loss)]
84 let max_jitter_ms = (base_delay.as_millis() as f64 * self.jitter)
85 .round()
86 .clamp(0.0, u64::MAX as f64) as u64;
87 if max_jitter_ms == 0 {
88 return base_delay;
89 }
90
91 let offset = (u64::from(attempt_index) * 31) % max_jitter_ms.saturating_add(1);
92 base_delay.saturating_add(Duration::from_millis(offset))
93 }
94
95 pub fn decision_for_category(
96 &self,
97 category: ErrorCategory,
98 attempt_index: u32,
99 retry_after: Option<Duration>,
100 ) -> RetryDecision {
101 let has_remaining_attempts = attempt_index.saturating_add(1) < self.max_attempts;
102 if !category.is_retryable() || !has_remaining_attempts {
103 return RetryDecision {
104 category,
105 retryable: false,
106 delay: None,
107 retry_after,
108 };
109 }
110
111 let delay = retry_after.unwrap_or_else(|| self.delay_for_attempt(attempt_index));
112 RetryDecision {
113 category,
114 retryable: true,
115 delay: Some(delay),
116 retry_after,
117 }
118 }
119
120 pub fn classify_anyhow(&self, error: &anyhow::Error) -> RetryDecision {
127 let category = classify_anyhow_error(error);
128 RetryDecision {
129 category,
130 retryable: category.is_retryable(),
131 delay: None,
132 retry_after: None,
133 }
134 }
135
136 pub fn classify_status(&self, status: u16) -> RetryDecision {
140 let category = match status {
141 429 => ErrorCategory::RateLimit,
142 500 | 502 | 504 => ErrorCategory::Network,
143 503 => ErrorCategory::ServiceUnavailable,
144 401 | 403 => ErrorCategory::Authentication,
145 _ => ErrorCategory::ExecutionError,
146 };
147 RetryDecision {
148 category,
149 retryable: category.is_retryable(),
150 delay: None,
151 retry_after: None,
152 }
153 }
154}
155
156impl Default for RetryPolicy {
157 fn default() -> Self {
158 Self::from_retries(2, Duration::from_secs(1), Duration::from_secs(60), 2.0)
159 }
160}
161
162#[derive(Debug, Clone, PartialEq, Eq)]
164pub struct RetryDecision {
165 pub category: ErrorCategory,
166 pub retryable: bool,
167 pub delay: Option<Duration>,
168 pub retry_after: Option<Duration>,
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn default_policy_allows_two_retries() {
177 let policy = RetryPolicy::default();
178 assert_eq!(policy.max_attempts, 3);
179 assert_eq!(policy.initial_delay, Duration::from_secs(1));
180 assert_eq!(policy.max_delay, Duration::from_secs(60));
181 }
182
183 #[test]
184 fn classify_status_rate_limit() {
185 let policy = RetryPolicy::default();
186 let decision = policy.classify_status(429);
187 assert!(decision.retryable);
188 assert_eq!(decision.category, ErrorCategory::RateLimit);
189 }
190
191 #[test]
192 fn classify_status_server_error() {
193 let policy = RetryPolicy::default();
194 let decision = policy.classify_status(503);
195 assert!(decision.retryable);
196 assert_eq!(decision.category, ErrorCategory::ServiceUnavailable);
197 }
198
199 #[test]
200 fn classify_status_auth_not_retryable() {
201 let policy = RetryPolicy::default();
202 let decision = policy.classify_status(401);
203 assert!(!decision.retryable);
204 assert_eq!(decision.category, ErrorCategory::Authentication);
205 }
206
207 #[test]
208 fn classify_anyhow_network_error() {
209 let policy = RetryPolicy::default();
210 let err = anyhow::anyhow!("connection refused");
211 let decision = policy.classify_anyhow(&err);
212 assert!(decision.retryable);
213 }
214
215 #[test]
216 fn simple_policy_matches_bit_shift_doubling() {
217 let policy = RetryPolicy::simple(10, 1000, 5000);
220 let legacy =
221 |attempt: u32| -> u64 { 1000u64.saturating_mul(1u64 << attempt.min(16)).min(5000) };
222 for attempt in 0..6 {
223 assert_eq!(
224 policy.delay_for_attempt(attempt),
225 Duration::from_millis(legacy(attempt)),
226 "delay mismatch at attempt {attempt}"
227 );
228 }
229 }
230
231 #[test]
232 fn delay_for_attempt_clamps_overflowing_backoff_to_max_delay() {
233 let policy =
234 RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), f64::MAX);
235
236 assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(8));
237 }
238
239 #[test]
240 fn delay_for_attempt_ignores_non_finite_jitter() {
241 let mut policy =
242 RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
243 policy.jitter = f64::INFINITY;
244
245 assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(2));
246 }
247
248 #[test]
249 fn delay_for_attempt_handles_huge_finite_jitter() {
250 let mut policy =
251 RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
252 policy.jitter = f64::MAX;
253
254 assert!(policy.delay_for_attempt(1) >= Duration::from_secs(2));
255 }
256
257 #[test]
258 fn decision_for_category_respects_attempt_budget() {
259 let policy =
260 RetryPolicy::from_retries(1, Duration::from_secs(1), Duration::from_secs(8), 2.0);
261
262 let first = policy.decision_for_category(ErrorCategory::Network, 0, None);
263 assert!(first.retryable);
264 assert_eq!(first.delay, Some(Duration::from_secs(1)));
265
266 let exhausted = policy.decision_for_category(ErrorCategory::Network, 1, None);
267 assert!(!exhausted.retryable);
268 assert!(exhausted.delay.is_none());
269 }
270
271 #[test]
272 fn decision_for_category_prefers_retry_after() {
273 let policy =
274 RetryPolicy::from_retries(3, Duration::from_secs(1), Duration::from_secs(8), 2.0);
275
276 let decision =
277 policy.decision_for_category(ErrorCategory::RateLimit, 0, Some(Duration::from_secs(7)));
278 assert!(decision.retryable);
279 assert_eq!(decision.delay, Some(Duration::from_secs(7)));
280 assert_eq!(decision.retry_after, Some(Duration::from_secs(7)));
281 }
282}