surf_retry/
lib.rs

1#![forbid(unsafe_code, future_incompatible)]
2#![deny(
3    missing_docs,
4    missing_debug_implementations,
5    missing_copy_implementations,
6    nonstandard_style,
7    unused_qualifications,
8    unused_import_braces,
9    unused_extern_crates,
10    trivial_casts,
11    trivial_numeric_casts
12)]
13#![cfg_attr(docsrs, feature(doc_cfg))]
14//! A [surf] middleware that handles request retry logic
15//! # Example
16//! ```no_run
17//! use surf_retry::{ExponentialBackoff, RetryMiddleware};
18//! use surf_governor::GovernorMiddleware;
19//! use surf::{Client, Request, http::Method};
20//! use url::Url;
21//!
22//! #[async_std::main]
23//! async fn main() -> surf::Result<()> {
24//!     let req = Request::new(Method::Get, Url::parse("https://example.api")?);
25//!     // Construct the retry middleware with max retries set to 3, exponential backoff also set to a max of 3, and a fallback interval of 1 second
26//!     let retry = RetryMiddleware::new(
27//!        3,
28//!        ExponentialBackoff::builder().build_with_max_retries(3),
29//!        1,
30//!        );
31//!     // Construct Surf client with the retry middleware and a limit of 1 request per second to force a retry
32//!     let client = Client::new().with(retry).with(GovernorMiddleware::per_second(1)?);
33//!     let res = client.send(req).await?;
34//!     Ok(())
35//! }
36//! ```
37use async_std::task;
38use chrono::Utc;
39use httpdate::parse_http_date;
40pub use retry_policies::{policies::ExponentialBackoff, RetryPolicy};
41use std::time::{Duration, SystemTime};
42use surf::{
43    http::{headers, StatusCode},
44    middleware::{Middleware, Next},
45    Client, Request, Response, Result,
46};
47
48/// The middleware is constructed with settings to handle a few different situations.
49///
50/// `max_retries` specifies the total number of attempts that will be made given a [`Retry-After`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After) header has been provided.
51///
52/// If no `Retry-After` header has been provided the configured [policy](https://docs.rs/retry-policies) will be used.
53///
54/// Should conditions for a retry be met but a retry interval failed to be determined the provided `fallback_interval` will be used.
55#[derive(Debug)]
56pub struct RetryMiddleware<T: RetryPolicy + Send + Sync + 'static> {
57    max_retries: u32,
58    policy: T,
59    fallback_interval: u64,
60}
61
62impl Default for RetryMiddleware<ExponentialBackoff> {
63    fn default() -> Self {
64        Self::new(
65            3,
66            ExponentialBackoff::builder().build_with_max_retries(3),
67            1,
68        )
69    }
70}
71
72impl<T: RetryPolicy + Send + Sync + 'static> RetryMiddleware<T> {
73    /// Construct the retry middleware with provided options.
74    pub fn new(max_retries: u32, policy: T, fallback_interval: u64) -> Self {
75        Self {
76            max_retries,
77            policy,
78            fallback_interval,
79        }
80    }
81
82    fn use_policy(&self, retry_count: u32) -> u64 {
83        let should_retry = self.policy.should_retry(retry_count);
84        if let retry_policies::RetryDecision::Retry { execute_after } = should_retry {
85            match (execute_after - Utc::now()).to_std() {
86                Ok(duration) => duration.as_secs(),
87                Err(_) => self.fallback_interval,
88            }
89        } else {
90            self.fallback_interval
91        }
92    }
93}
94
95const RETRY_CODES: &[StatusCode] = &[StatusCode::TooManyRequests, StatusCode::RequestTimeout];
96
97fn retry_to_seconds(header: &headers::HeaderValue) -> Result<u64> {
98    let mut secs = match header.as_str().parse::<u64>() {
99        Ok(s) => s,
100        Err(_) => {
101            let date = parse_http_date(header.as_str())?;
102            let sys_time = SystemTime::now();
103            let difference = date.duration_since(sys_time)?;
104            difference.as_secs()
105        }
106    };
107    if secs < 1 {
108        secs = 1;
109    }
110    Ok(secs)
111}
112
113#[surf::utils::async_trait]
114impl<T: RetryPolicy + Send + Sync + 'static> Middleware for RetryMiddleware<T> {
115    async fn handle(&self, mut req: Request, client: Client, next: Next<'_>) -> Result<Response> {
116        let mut retries: u32 = 0;
117
118        let mut r: Request = req.clone();
119        let request_body = req.take_body().into_bytes().await?;
120        r.set_body(request_body.clone());
121
122        let res = next.run(r, client.clone()).await?;
123        if RETRY_CODES.contains(&res.status()) {
124            while retries < self.max_retries {
125                retries += 1;
126
127                let secs: u64;
128                if let Some(retry_after) = res.header(headers::RETRY_AFTER) {
129                    match retry_to_seconds(retry_after) {
130                        Ok(s) => {
131                            secs = s;
132                        }
133                        Err(_e) => {
134                            secs = self.use_policy(retries);
135                        }
136                    };
137                } else {
138                    secs = self.use_policy(retries);
139                };
140
141                task::sleep(Duration::from_secs(secs)).await;
142
143                let mut r: Request = req.clone();
144                r.set_body(request_body.clone());
145
146                let res = next.run(r, client.clone()).await?;
147                if !RETRY_CODES.contains(&res.status()) {
148                    return Ok(res);
149                }
150            }
151        }
152        Ok(res)
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use crate::*;
159    use surf::{http::Method, Client, Request};
160    use surf_governor::GovernorMiddleware;
161    use url::Url;
162    use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
163
164    #[async_std::test]
165    async fn will_retry_request() -> Result<()> {
166        let mock_server = MockServer::start().await;
167        let m = Mock::given(method("GET"))
168            .respond_with(ResponseTemplate::new(200).set_body_string("Hello!".to_string()))
169            .expect(2);
170        let _mock_guard = mock_server.register_as_scoped(m).await;
171        let url = format!("{}/", &mock_server.uri());
172        let req = Request::new(Method::Get, Url::parse(&url).unwrap());
173        let retry = RetryMiddleware::default();
174        let client = Client::new()
175            .with(retry)
176            .with(GovernorMiddleware::per_second(1)?);
177        let good_res = client.send(req.clone()).await?;
178        assert_eq!(good_res.status(), 200);
179        let wait_res = client.send(req).await?;
180        assert_eq!(wait_res.status(), 200);
181        Ok(())
182    }
183}