1mod policy;
2
3pub use policy::RetryPolicy;
4
5use std::future::Future;
6use std::time::Duration;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum RetryReason {
11 Connect,
12 Timeout,
13 RateLimited,
14 ServerError,
15 RequestTimeout,
16}
17
18impl RetryReason {
19 pub fn as_str(self) -> &'static str {
20 match self {
21 Self::Connect => "connect",
22 Self::Timeout => "timeout",
23 Self::RateLimited => "rate_limited",
24 Self::ServerError => "server_error",
25 Self::RequestTimeout => "request_timeout",
26 }
27 }
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum RetryDecision {
33 Retry {
34 reason: RetryReason,
35 retry_after: Option<Duration>,
36 },
37 DoNotRetry,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct RetryAttempt {
43 pub attempt: u32,
44 pub delay: Duration,
45 pub reason: RetryReason,
46}
47
48pub fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
54 let raw = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?.trim();
55
56 if let Ok(secs) = raw.parse::<u64>() {
57 return Some(Duration::from_secs(secs));
58 }
59
60 if let Ok(when) = httpdate::parse_http_date(raw) {
61 let now = std::time::SystemTime::now();
62 if let Ok(delay) = when.duration_since(now) {
63 return Some(delay);
64 }
65 return Some(Duration::from_secs(0));
66 }
67
68 None
69}
70
71pub fn classify_reqwest_result(
73 result: &std::result::Result<reqwest::Response, reqwest::Error>,
74) -> RetryDecision {
75 match result {
76 Ok(resp) => {
77 let status = resp.status();
78 if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
79 return RetryDecision::Retry {
80 reason: RetryReason::RateLimited,
81 retry_after: parse_retry_after(resp.headers()),
82 };
83 }
84 if status == reqwest::StatusCode::REQUEST_TIMEOUT {
85 return RetryDecision::Retry {
86 reason: RetryReason::RequestTimeout,
87 retry_after: parse_retry_after(resp.headers()),
88 };
89 }
90 if status.is_server_error() {
91 return RetryDecision::Retry {
92 reason: RetryReason::ServerError,
93 retry_after: parse_retry_after(resp.headers()),
94 };
95 }
96 RetryDecision::DoNotRetry
97 }
98 Err(err) => {
99 if err.is_timeout() {
100 return RetryDecision::Retry {
101 reason: RetryReason::Timeout,
102 retry_after: None,
103 };
104 }
105 if err.is_connect() || err.is_request() {
106 return RetryDecision::Retry {
107 reason: RetryReason::Connect,
108 retry_after: None,
109 };
110 }
111 RetryDecision::DoNotRetry
112 }
113 }
114}
115
116pub async fn retry_with_backoff<T, E, Op, Fut, Classify, OnRetry>(
122 policy: &RetryPolicy,
123 mut operation: Op,
124 mut classify: Classify,
125 mut on_retry: OnRetry,
126) -> std::result::Result<T, E>
127where
128 Op: FnMut(u32) -> Fut,
129 Fut: Future<Output = std::result::Result<T, E>>,
130 Classify: FnMut(&std::result::Result<T, E>) -> RetryDecision,
131 OnRetry: FnMut(RetryAttempt),
132{
133 let max_attempts = policy.max_attempts.max(1);
134
135 for attempt in 1..=max_attempts {
136 let result = operation(attempt).await;
137 let decision = if attempt < max_attempts {
138 classify(&result)
139 } else {
140 RetryDecision::DoNotRetry
141 };
142
143 match (decision, result) {
144 (RetryDecision::Retry { reason, retry_after }, Err(err)) => {
145 let backoff = policy.backoff_delay(attempt);
146 let base_delay = retry_after.unwrap_or(backoff);
147 let delay = policy.with_jitter(base_delay);
148 on_retry(RetryAttempt {
149 attempt,
150 delay,
151 reason,
152 });
153 tokio::time::sleep(delay).await;
154 let _ = err;
155 }
156 (RetryDecision::Retry { reason, retry_after }, Ok(_)) => {
157 let backoff = policy.backoff_delay(attempt);
158 let base_delay = retry_after.unwrap_or(backoff);
159 let delay = policy.with_jitter(base_delay);
160 on_retry(RetryAttempt {
161 attempt,
162 delay,
163 reason,
164 });
165 tokio::time::sleep(delay).await;
166 }
167 (_, final_result) => return final_result,
168 }
169 }
170
171 unreachable!("retry loop always returns");
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
178 use std::sync::Arc;
179 use std::sync::atomic::{AtomicU32, Ordering};
180
181 #[test]
182 fn parse_retry_after_delta_seconds() {
183 let mut headers = HeaderMap::new();
184 headers.insert(RETRY_AFTER, HeaderValue::from_static("7"));
185 assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(7)));
186 }
187
188 #[tokio::test]
189 async fn retry_helper_retries_until_success() {
190 let policy = RetryPolicy {
191 max_attempts: 3,
192 base_delay: Duration::from_millis(0),
193 max_delay: Duration::from_millis(0),
194 jitter_ratio: 0.0,
195 };
196 let attempts = Arc::new(AtomicU32::new(0));
197 let seen = attempts.clone();
198
199 let result = retry_with_backoff(
200 &policy,
201 move |_attempt| {
202 let seen = seen.clone();
203 async move {
204 let n = seen.fetch_add(1, Ordering::SeqCst) + 1;
205 if n < 3 {
206 Err("transient")
207 } else {
208 Ok("ok")
209 }
210 }
211 },
212 |r: &std::result::Result<&str, &str>| match r {
213 Err(_) => RetryDecision::Retry {
214 reason: RetryReason::Connect,
215 retry_after: None,
216 },
217 Ok(_) => RetryDecision::DoNotRetry,
218 },
219 |_info| {},
220 )
221 .await;
222
223 assert_eq!(result, Ok("ok"));
224 assert_eq!(attempts.load(Ordering::SeqCst), 3);
225 }
226}