Skip to main content

rustyclaw_core/retry/
mod.rs

1mod policy;
2
3pub use policy::RetryPolicy;
4
5use std::future::Future;
6use std::time::Duration;
7
8/// Classification of transient retry causes.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum RetryReason {
11    Connect,
12    Timeout,
13    RateLimited,
14    ServerError,
15    RequestTimeout,
16}
17
18impl RetryReason {
19    pub fn as_str(self) -> &'static str {
20        match self {
21            Self::Connect => "connect",
22            Self::Timeout => "timeout",
23            Self::RateLimited => "rate_limited",
24            Self::ServerError => "server_error",
25            Self::RequestTimeout => "request_timeout",
26        }
27    }
28}
29
30/// Retry decision for one attempt result.
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum RetryDecision {
33    Retry {
34        reason: RetryReason,
35        retry_after: Option<Duration>,
36    },
37    DoNotRetry,
38}
39
40/// Metadata for one scheduled retry.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub struct RetryAttempt {
43    pub attempt: u32,
44    pub delay: Duration,
45    pub reason: RetryReason,
46}
47
48/// Parse `Retry-After` header value as a delay.
49///
50/// Supports:
51/// - Delta-seconds (`Retry-After: 5`)
52/// - HTTP-date (`Retry-After: Wed, 21 Oct 2015 07:28:00 GMT`)
53pub fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
54    let raw = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?.trim();
55
56    if let Ok(secs) = raw.parse::<u64>() {
57        return Some(Duration::from_secs(secs));
58    }
59
60    if let Ok(when) = httpdate::parse_http_date(raw) {
61        let now = std::time::SystemTime::now();
62        if let Ok(delay) = when.duration_since(now) {
63            return Some(delay);
64        }
65        return Some(Duration::from_secs(0));
66    }
67
68    None
69}
70
71/// Classify reqwest result into retry/no-retry.
72pub fn classify_reqwest_result(
73    result: &std::result::Result<reqwest::Response, reqwest::Error>,
74) -> RetryDecision {
75    match result {
76        Ok(resp) => {
77            let status = resp.status();
78            if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
79                return RetryDecision::Retry {
80                    reason: RetryReason::RateLimited,
81                    retry_after: parse_retry_after(resp.headers()),
82                };
83            }
84            if status == reqwest::StatusCode::REQUEST_TIMEOUT {
85                return RetryDecision::Retry {
86                    reason: RetryReason::RequestTimeout,
87                    retry_after: parse_retry_after(resp.headers()),
88                };
89            }
90            if status.is_server_error() {
91                return RetryDecision::Retry {
92                    reason: RetryReason::ServerError,
93                    retry_after: parse_retry_after(resp.headers()),
94                };
95            }
96            RetryDecision::DoNotRetry
97        }
98        Err(err) => {
99            if err.is_timeout() {
100                return RetryDecision::Retry {
101                    reason: RetryReason::Timeout,
102                    retry_after: None,
103                };
104            }
105            if err.is_connect() || err.is_request() {
106                return RetryDecision::Retry {
107                    reason: RetryReason::Connect,
108                    retry_after: None,
109                };
110            }
111            RetryDecision::DoNotRetry
112        }
113    }
114}
115
116/// Retry an async operation with backoff according to `policy`.
117///
118/// - `operation(attempt)` is called with a 1-based attempt number.
119/// - `classify(result)` decides whether to retry.
120/// - `on_retry(info)` is called right before sleeping.
121pub async fn retry_with_backoff<T, E, Op, Fut, Classify, OnRetry>(
122    policy: &RetryPolicy,
123    mut operation: Op,
124    mut classify: Classify,
125    mut on_retry: OnRetry,
126) -> std::result::Result<T, E>
127where
128    Op: FnMut(u32) -> Fut,
129    Fut: Future<Output = std::result::Result<T, E>>,
130    Classify: FnMut(&std::result::Result<T, E>) -> RetryDecision,
131    OnRetry: FnMut(RetryAttempt),
132{
133    let max_attempts = policy.max_attempts.max(1);
134
135    for attempt in 1..=max_attempts {
136        let result = operation(attempt).await;
137        let decision = if attempt < max_attempts {
138            classify(&result)
139        } else {
140            RetryDecision::DoNotRetry
141        };
142
143        match (decision, result) {
144            (RetryDecision::Retry { reason, retry_after }, Err(err)) => {
145                let backoff = policy.backoff_delay(attempt);
146                let base_delay = retry_after.unwrap_or(backoff);
147                let delay = policy.with_jitter(base_delay);
148                on_retry(RetryAttempt {
149                    attempt,
150                    delay,
151                    reason,
152                });
153                tokio::time::sleep(delay).await;
154                let _ = err;
155            }
156            (RetryDecision::Retry { reason, retry_after }, Ok(_)) => {
157                let backoff = policy.backoff_delay(attempt);
158                let base_delay = retry_after.unwrap_or(backoff);
159                let delay = policy.with_jitter(base_delay);
160                on_retry(RetryAttempt {
161                    attempt,
162                    delay,
163                    reason,
164                });
165                tokio::time::sleep(delay).await;
166            }
167            (_, final_result) => return final_result,
168        }
169    }
170
171    unreachable!("retry loop always returns");
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use reqwest::header::{HeaderMap, HeaderValue, RETRY_AFTER};
178    use std::sync::Arc;
179    use std::sync::atomic::{AtomicU32, Ordering};
180
181    #[test]
182    fn parse_retry_after_delta_seconds() {
183        let mut headers = HeaderMap::new();
184        headers.insert(RETRY_AFTER, HeaderValue::from_static("7"));
185        assert_eq!(parse_retry_after(&headers), Some(Duration::from_secs(7)));
186    }
187
188    #[tokio::test]
189    async fn retry_helper_retries_until_success() {
190        let policy = RetryPolicy {
191            max_attempts: 3,
192            base_delay: Duration::from_millis(0),
193            max_delay: Duration::from_millis(0),
194            jitter_ratio: 0.0,
195        };
196        let attempts = Arc::new(AtomicU32::new(0));
197        let seen = attempts.clone();
198
199        let result = retry_with_backoff(
200            &policy,
201            move |_attempt| {
202                let seen = seen.clone();
203                async move {
204                    let n = seen.fetch_add(1, Ordering::SeqCst) + 1;
205                    if n < 3 {
206                        Err("transient")
207                    } else {
208                        Ok("ok")
209                    }
210                }
211            },
212            |r: &std::result::Result<&str, &str>| match r {
213                Err(_) => RetryDecision::Retry {
214                    reason: RetryReason::Connect,
215                    retry_after: None,
216                },
217                Ok(_) => RetryDecision::DoNotRetry,
218            },
219            |_info| {},
220        )
221        .await;
222
223        assert_eq!(result, Ok("ok"));
224        assert_eq!(attempts.load(Ordering::SeqCst), 3);
225    }
226}