smarty_rust_sdk/sdk/
retry_strategy.rs

1use std::time::Duration;
2
3use hyper::header::RETRY_AFTER;
4use log::warn;
5use reqwest::{Request, Response, StatusCode};
6use reqwest_middleware::{Error, Middleware};
7
8const MAX_RETRY_DURATION: u64 = 10;
9
10enum RetryResult {
11    Transient,
12    RateLimit(Duration),
13    Fatal,
14    Success,
15}
16
17pub struct SmartyRetryMiddleware {
18    pub retry_count: u64,
19}
20
21impl SmartyRetryMiddleware {
22    pub fn new(max_retries: u64) -> Self {
23        Self {
24            retry_count: max_retries,
25        }
26    }
27}
28
29impl Default for SmartyRetryMiddleware {
30    fn default() -> Self {
31        Self::new(10)
32    }
33}
34
35#[async_trait::async_trait]
36impl Middleware for SmartyRetryMiddleware {
37    async fn handle(
38        &self,
39        req: Request,
40        extensions: &mut task_local_extensions::Extensions,
41        next: reqwest_middleware::Next<'_>,
42    ) -> reqwest_middleware::Result<Response> {
43        self.handle_retry(req, extensions, next).await
44    }
45}
46
47impl SmartyRetryMiddleware {
48    async fn handle_retry<'a>(
49        &'a self,
50        req: Request,
51        extensions: &'a mut task_local_extensions::Extensions,
52        next: reqwest_middleware::Next<'a>,
53    ) -> reqwest_middleware::Result<Response> {
54        let mut cur_retries = 0;
55        loop {
56            let duplicate_request = req.try_clone().ok_or_else(|| {
57                Error::Middleware(anyhow!(
58                    "Request object is not cloneable. Are you passing a streaming body?"
59                        .to_string()
60                ))
61            })?;
62
63            let res = next.clone().run(duplicate_request, extensions).await;
64
65            let retry = match &res {
66                Ok(res) => retry_success(res),
67                Err(err) => retry_failure(err),
68            };
69
70            if cur_retries >= self.retry_count {
71                return res;
72            }
73
74            break match retry {
75                RetryResult::Transient => {
76                    cur_retries += 1;
77
78                    warn!(
79                        "Retry Attempt #{}, Sleeping {} seconds before the next attempt",
80                        cur_retries,
81                        cur_retries.min(MAX_RETRY_DURATION)
82                    );
83                    tokio::time::sleep(Duration::from_secs(cur_retries.min(MAX_RETRY_DURATION)))
84                        .await;
85
86                    continue;
87                }
88                RetryResult::RateLimit(time) => {
89                    cur_retries += 1;
90                    warn!(
91                        "Retry Attempt #{} resulted in rate limit. Waiting for {}",
92                        cur_retries,
93                        time.as_secs()
94                    );
95
96                    tokio::time::sleep(time).await;
97
98                    continue;
99                }
100                _ => res,
101            };
102        }
103    }
104}
105
106fn retry_success(res: &Response) -> RetryResult {
107    let status = res.status();
108
109    if status.is_success() {
110        return RetryResult::Success;
111    }
112
113    match status {
114        StatusCode::REQUEST_TIMEOUT
115        | StatusCode::INTERNAL_SERVER_ERROR
116        | StatusCode::BAD_GATEWAY
117        | StatusCode::SERVICE_UNAVAILABLE
118        | StatusCode::GATEWAY_TIMEOUT => RetryResult::Transient,
119        StatusCode::TOO_MANY_REQUESTS => {
120            return match res.headers().get(RETRY_AFTER) {
121                Some(time) => {
122                    if let Ok(time) = time.to_str() {
123                        if let Ok(time) = time.parse::<u64>() {
124                            RetryResult::RateLimit(Duration::from_secs(time))
125                        } else {
126                            warn!(
127                                "Server Returned Too Many Requests Status Code, but the RETRY_AFTER header was unable to be parsed"
128                            );
129                            RetryResult::Transient
130                        }
131                    } else {
132                        warn!("Server Returned Too Many Requests Status Code, but the RETRY_AFTER header was unable to be turned into a valid utf-8 string");
133                        RetryResult::Transient
134                    }
135                }
136                _ => {
137                    warn!("Server Returned Too Many Requests Status Code, but the RETRY_AFTER header was non-existent");
138                    RetryResult::Transient
139                }
140            }
141        }
142        _ => {
143            // Fatal
144            RetryResult::Fatal
145        }
146    }
147}
148
149fn retry_failure(error: &reqwest_middleware::Error) -> RetryResult {
150    match error {
151        // If something fails in the middleware we're screwed.
152        Error::Middleware(_) => RetryResult::Fatal,
153        Error::Reqwest(error) => {
154            #[cfg(not(target_arch = "wasm32"))]
155            let is_connect = error.is_connect();
156            #[cfg(target_arch = "wasm32")]
157            let is_connect = false;
158            if error.is_body()
159                || error.is_decode()
160                || error.is_builder()
161                || error.is_redirect()
162                || error.is_timeout()
163                || is_connect
164            {
165                RetryResult::Fatal
166            } else if error.is_request() {
167                // It seems that hyper::Error(IncompleteMessage) is not correctly handled by reqwest.
168                // Here we check if the Reqwest error was originated by hyper and map it consistently.
169                #[cfg(not(target_arch = "wasm32"))]
170                if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
171                    // The hyper::Error(IncompleteMessage) is raised if the HTTP response is well formatted but does not contain all the bytes.
172                    // This can happen when the server has started sending back the response but the connection is cut halfway thorugh.
173                    // We can safely retry the call, hence marking this error as [`Retryable::Transient`].
174                    // Instead hyper::Error(Canceled) is raised when the connection is
175                    // gracefully closed on the server side.
176                    if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
177                        RetryResult::Transient
178
179                    // Try and downcast the hyper error to io::Error if that is the
180                    // underlying error, and try and classify it.
181                    } else if let Some(io_error) =
182                        get_source_error_type::<std::io::Error>(hyper_error)
183                    {
184                        classify_io_error(io_error)
185                    } else {
186                        RetryResult::Fatal
187                    }
188                } else {
189                    RetryResult::Fatal
190                }
191                #[cfg(target_arch = "wasm32")]
192                RetryResult::Fatal
193            } else {
194                // We omit checking if error.is_status() since we check that already.
195                // However, if Response::error_for_status is used the status will still
196                // remain in the response object.
197                RetryResult::Success
198            }
199        }
200    }
201}
202
203#[cfg(not(target_arch = "wasm32"))]
204fn classify_io_error(error: &std::io::Error) -> RetryResult {
205    match error.kind() {
206        std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => {
207            RetryResult::Transient
208        }
209        _ => RetryResult::Fatal,
210    }
211}
212
213/// Downcasts the given err source into T.
214#[cfg(not(target_arch = "wasm32"))]
215fn get_source_error_type<T: std::error::Error + 'static>(
216    err: &dyn std::error::Error,
217) -> Option<&T> {
218    let mut source = err.source();
219
220    while let Some(err) = source {
221        if let Some(err) = err.downcast_ref::<T>() {
222            return Some(err);
223        }
224
225        source = err.source();
226    }
227    None
228}