1mod policy;
2
3pub use policy::RetryPolicy;
4
5use std::future::Future;
6use std::time::Duration;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum RetryReason {
11 Connect,
12 Timeout,
13 RateLimited,
14 ServerError,
15 RequestTimeout,
16}
17
18impl RetryReason {
19 pub fn as_str(self) -> &'static str {
20 match self {
21 Self::Connect => "connect",
22 Self::Timeout => "timeout",
23 Self::RateLimited => "rate_limited",
24 Self::ServerError => "server_error",
25 Self::RequestTimeout => "request_timeout",
26 }
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum RetryDecision {
33 Retry {
34 reason: RetryReason,
35 retry_after: Option<Duration>,
36 },
37 DoNotRetry,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct RetryAttempt {
43 pub attempt: u32,
44 pub delay: Duration,
45 pub reason: RetryReason,
46}
47
48pub fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
54 let raw = headers
55 .get(reqwest::header::RETRY_AFTER)?
56 .to_str()
57 .ok()?
58 .trim();
59
60 if let Ok(secs) = raw.parse::<u64>() {
61 return Some(Duration::from_secs(secs));
62 }
63
64 if let Ok(when) = httpdate::parse_http_date(raw) {
65 let now = std::time::SystemTime::now();
66 if let Ok(delay) = when.duration_since(now) {
67 return Some(delay);
68 }
69 return Some(Duration::from_secs(0));
70 }
71
72 None
73}
74
75pub fn classify_reqwest_result(
77 result: &std::result::Result<reqwest::Response, reqwest::Error>,
78) -> RetryDecision {
79 match result {
80 Ok(resp) => {
81 let status = resp.status();
82 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
83 return RetryDecision::Retry {
84 reason: RetryReason::RateLimited,
85 retry_after: parse_retry_after(resp.headers()),
86 };
87 }
88 if status == reqwest::StatusCode::REQUEST_TIMEOUT {
89 return RetryDecision::Retry {
90 reason: RetryReason::RequestTimeout,
91 retry_after: parse_retry_after(resp.headers()),
92 };
93 }
94 if status.is_server_error() {
95 return RetryDecision::Retry {
96 reason: RetryReason::ServerError,
97 retry_after: parse_retry_after(resp.headers()),
98 };
99 }
100 RetryDecision::DoNotRetry
101 }
102 Err(err) => {
103 if err.is_timeout() {
104 return RetryDecision::Retry {
105 reason: RetryReason::Timeout,
106 retry_after: None,
107 };
108 }
109 if err.is_connect() || err.is_request() {
110 return RetryDecision::Retry {
111 reason: RetryReason::Connect,
112 retry_after: None,
113 };
114 }
115 RetryDecision::DoNotRetry
116 }
117 }
118}
119
120pub async fn retry_with_backoff<T, E, Op, Fut, Classify, OnRetry>(
126 policy: &RetryPolicy,
127 mut operation: Op,
128 mut classify: Classify,
129 mut on_retry: OnRetry,
130) -> std::result::Result<T, E>
131where
132 Op: FnMut(u32) -> Fut,
133 Fut: Future<Output = std::result::Result<T, E>>,
134 Classify: FnMut(&std::result::Result<T, E>) -> RetryDecision,
135 OnRetry: FnMut(RetryAttempt),
136{
137 let max_attempts = policy.max_attempts.max(1);
138
139 for attempt in 1..=max_attempts {
140 let result = operation(attempt).await;
141 let decision = if attempt < max_attempts {
142 classify(&result)
143 } else {
144 RetryDecision::DoNotRetry
145 };
146
147 match (decision, result) {
148 (
149 RetryDecision::Retry {
150 reason,
151 retry_after,
152 },
153 Err(err),
154 ) => {
155 let backoff = policy.backoff_delay(attempt);
156 let base_delay = retry_after.unwrap_or(backoff);
157 let delay = policy.with_jitter(base_delay);
158 on_retry(RetryAttempt {
159 attempt,
160 delay,
161 reason,
162 });
163 tokio::time::sleep(delay).await;
164 let _ = err;
165 }
166 (
167 RetryDecision::Retry {
168 reason,
169 retry_after,
170 },
171 Ok(_),
172 ) => {
173 let backoff = policy.backoff_delay(attempt);
174 let base_delay = retry_after.unwrap_or(backoff);
175 let delay = policy.with_jitter(base_delay);
176 on_retry(RetryAttempt {
177 attempt,
178 delay,
179 reason,
180 });
181 tokio::time::sleep(delay).await;
182 }
183 (_, final_result) => return final_result,
184 }
185 }
186
187 unreachable!("retry loop always returns");
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
194 use std::sync::Arc;
195 use std::sync::atomic::{AtomicU32, Ordering};
196
197 #[test]
198 fn parse_retry_after_delta_seconds() {
199 let mut headers = HeaderMap::new();
200 headers.insert(RETRY_AFTER, HeaderValue::from_static("7"));
201 assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(7)));
202 }
203
204 #[tokio::test]
205 async fn retry_helper_retries_until_success() {
206 let policy = RetryPolicy {
207 max_attempts: 3,
208 base_delay: Duration::from_millis(0),
209 max_delay: Duration::from_millis(0),
210 jitter_ratio: 0.0,
211 };
212 let attempts = Arc::new(AtomicU32::new(0));
213 let seen = attempts.clone();
214
215 let result = retry_with_backoff(
216 &policy,
217 move |_attempt| {
218 let seen = seen.clone();
219 async move {
220 let n = seen.fetch_add(1, Ordering::SeqCst) + 1;
221 if n < 3 { Err("transient") } else { Ok("ok") }
222 }
223 },
224 |r: &std::result::Result<&str, &str>| match r {
225 Err(_) => RetryDecision::Retry {
226 reason: RetryReason::Connect,
227 retry_after: None,
228 },
229 Ok(_) => RetryDecision::DoNotRetry,
230 },
231 |_info| {},
232 )
233 .await;
234
235 assert_eq!(result, Ok("ok"));
236 assert_eq!(attempts.load(Ordering::SeqCst), 3);
237 }
238}