Skip to main content

rustyclaw_core/retry/
mod.rs

1mod policy;
2
3pub use policy::RetryPolicy;
4
5use std::future::Future;
6use std::time::Duration;
7
8/// Classification of transient retry causes.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum RetryReason {
11    Connect,
12    Timeout,
13    RateLimited,
14    ServerError,
15    RequestTimeout,
16}
17
18impl RetryReason {
19    pub fn as_str(self) -> &'static str {
20        match self {
21            Self::Connect => "connect",
22            Self::Timeout => "timeout",
23            Self::RateLimited => "rate_limited",
24            Self::ServerError => "server_error",
25            Self::RequestTimeout => "request_timeout",
26        }
27    }
28}
29
30/// Retry decision for one attempt result.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum RetryDecision {
33    Retry {
34        reason: RetryReason,
35        retry_after: Option<Duration>,
36    },
37    DoNotRetry,
38}
39
40/// Metadata for one scheduled retry.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct RetryAttempt {
43    pub attempt: u32,
44    pub delay: Duration,
45    pub reason: RetryReason,
46}
47
48/// Parse `Retry-After` header value as a delay.
49///
50/// Supports:
51/// - Delta-seconds (`Retry-After: 5`)
52/// - HTTP-date (`Retry-After: Wed, 21 Oct 2015 07:28:00 GMT`)
53pub fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
54    let raw = headers
55        .get(reqwest::header::RETRY_AFTER)?
56        .to_str()
57        .ok()?
58        .trim();
59
60    if let Ok(secs) = raw.parse::<u64>() {
61        return Some(Duration::from_secs(secs));
62    }
63
64    if let Ok(when) = httpdate::parse_http_date(raw) {
65        let now = std::time::SystemTime::now();
66        if let Ok(delay) = when.duration_since(now) {
67            return Some(delay);
68        }
69        return Some(Duration::from_secs(0));
70    }
71
72    None
73}
74
75/// Classify reqwest result into retry/no-retry.
76pub fn classify_reqwest_result(
77    result: &std::result::Result<reqwest::Response, reqwest::Error>,
78) -> RetryDecision {
79    match result {
80        Ok(resp) => {
81            let status = resp.status();
82            if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
83                return RetryDecision::Retry {
84                    reason: RetryReason::RateLimited,
85                    retry_after: parse_retry_after(resp.headers()),
86                };
87            }
88            if status == reqwest::StatusCode::REQUEST_TIMEOUT {
89                return RetryDecision::Retry {
90                    reason: RetryReason::RequestTimeout,
91                    retry_after: parse_retry_after(resp.headers()),
92                };
93            }
94            if status.is_server_error() {
95                return RetryDecision::Retry {
96                    reason: RetryReason::ServerError,
97                    retry_after: parse_retry_after(resp.headers()),
98                };
99            }
100            RetryDecision::DoNotRetry
101        }
102        Err(err) => {
103            if err.is_timeout() {
104                return RetryDecision::Retry {
105                    reason: RetryReason::Timeout,
106                    retry_after: None,
107                };
108            }
109            if err.is_connect() || err.is_request() {
110                return RetryDecision::Retry {
111                    reason: RetryReason::Connect,
112                    retry_after: None,
113                };
114            }
115            RetryDecision::DoNotRetry
116        }
117    }
118}
119
120/// Retry an async operation with backoff according to `policy`.
121///
122/// - `operation(attempt)` is called with a 1-based attempt number.
123/// - `classify(result)` decides whether to retry.
124/// - `on_retry(info)` is called right before sleeping.
125pub async fn retry_with_backoff<T, E, Op, Fut, Classify, OnRetry>(
126    policy: &RetryPolicy,
127    mut operation: Op,
128    mut classify: Classify,
129    mut on_retry: OnRetry,
130) -> std::result::Result<T, E>
131where
132    Op: FnMut(u32) -> Fut,
133    Fut: Future<Output = std::result::Result<T, E>>,
134    Classify: FnMut(&std::result::Result<T, E>) -> RetryDecision,
135    OnRetry: FnMut(RetryAttempt),
136{
137    let max_attempts = policy.max_attempts.max(1);
138
139    for attempt in 1..=max_attempts {
140        let result = operation(attempt).await;
141        let decision = if attempt < max_attempts {
142            classify(&result)
143        } else {
144            RetryDecision::DoNotRetry
145        };
146
147        match (decision, result) {
148            (
149                RetryDecision::Retry {
150                    reason,
151                    retry_after,
152                },
153                Err(err),
154            ) => {
155                let backoff = policy.backoff_delay(attempt);
156                let base_delay = retry_after.unwrap_or(backoff);
157                let delay = policy.with_jitter(base_delay);
158                on_retry(RetryAttempt {
159                    attempt,
160                    delay,
161                    reason,
162                });
163                tokio::time::sleep(delay).await;
164                let _ = err;
165            }
166            (
167                RetryDecision::Retry {
168                    reason,
169                    retry_after,
170                },
171                Ok(_),
172            ) => {
173                let backoff = policy.backoff_delay(attempt);
174                let base_delay = retry_after.unwrap_or(backoff);
175                let delay = policy.with_jitter(base_delay);
176                on_retry(RetryAttempt {
177                    attempt,
178                    delay,
179                    reason,
180                });
181                tokio::time::sleep(delay).await;
182            }
183            (_, final_result) => return final_result,
184        }
185    }
186
187    unreachable!("retry loop always returns");
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
194    use std::sync::Arc;
195    use std::sync::atomic::{AtomicU32, Ordering};
196
197    #[test]
198    fn parse_retry_after_delta_seconds() {
199        let mut headers = HeaderMap::new();
200        headers.insert(RETRY_AFTER, HeaderValue::from_static("7"));
201        assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(7)));
202    }
203
204    #[tokio::test]
205    async fn retry_helper_retries_until_success() {
206        let policy = RetryPolicy {
207            max_attempts: 3,
208            base_delay: Duration::from_millis(0),
209            max_delay: Duration::from_millis(0),
210            jitter_ratio: 0.0,
211        };
212        let attempts = Arc::new(AtomicU32::new(0));
213        let seen = attempts.clone();
214
215        let result = retry_with_backoff(
216            &policy,
217            move |_attempt| {
218                let seen = seen.clone();
219                async move {
220                    let n = seen.fetch_add(1, Ordering::SeqCst) + 1;
221                    if n < 3 { Err("transient") } else { Ok("ok") }
222                }
223            },
224            |r: &std::result::Result<&str, &str>| match r {
225                Err(_) => RetryDecision::Retry {
226                    reason: RetryReason::Connect,
227                    retry_after: None,
228                },
229                Ok(_) => RetryDecision::DoNotRetry,
230            },
231            |_info| {},
232        )
233        .await;
234
235        assert_eq!(result, Ok("ok"));
236        assert_eq!(attempts.load(Ordering::SeqCst), 3);
237    }
238}