smarty_rust_sdk/sdk/
retry_strategy.rs1use std::time::Duration;
2
3use hyper::header::RETRY_AFTER;
4use log::warn;
5use reqwest::{Request, Response, StatusCode};
6use reqwest_middleware::{Error, Middleware};
7
8const MAX_RETRY_DURATION: u64 = 10;
9
10enum RetryResult {
11 Transient,
12 RateLimit(Duration),
13 Fatal,
14 Success,
15}
16
17pub struct SmartyRetryMiddleware {
18 pub retry_count: u64,
19}
20
21impl SmartyRetryMiddleware {
22 pub fn new(max_retries: u64) -> Self {
23 Self {
24 retry_count: max_retries,
25 }
26 }
27}
28
29impl Default for SmartyRetryMiddleware {
30 fn default() -> Self {
31 Self::new(10)
32 }
33}
34
35#[async_trait::async_trait]
36impl Middleware for SmartyRetryMiddleware {
37 async fn handle(
38 &self,
39 req: Request,
40 extensions: &mut task_local_extensions::Extensions,
41 next: reqwest_middleware::Next<'_>,
42 ) -> reqwest_middleware::Result<Response> {
43 self.handle_retry(req, extensions, next).await
44 }
45}
46
47impl SmartyRetryMiddleware {
48 async fn handle_retry<'a>(
49 &'a self,
50 req: Request,
51 extensions: &'a mut task_local_extensions::Extensions,
52 next: reqwest_middleware::Next<'a>,
53 ) -> reqwest_middleware::Result<Response> {
54 let mut cur_retries = 0;
55 loop {
56 let duplicate_request = req.try_clone().ok_or_else(|| {
57 Error::Middleware(anyhow!(
58 "Request object is not cloneable. Are you passing a streaming body?"
59 .to_string()
60 ))
61 })?;
62
63 let res = next.clone().run(duplicate_request, extensions).await;
64
65 let retry = match &res {
66 Ok(res) => retry_success(res),
67 Err(err) => retry_failure(err),
68 };
69
70 if cur_retries >= self.retry_count {
71 return res;
72 }
73
74 break match retry {
75 RetryResult::Transient => {
76 cur_retries += 1;
77
78 warn!(
79 "Retry Attempt #{}, Sleeping {} seconds before the next attempt",
80 cur_retries,
81 cur_retries.min(MAX_RETRY_DURATION)
82 );
83 tokio::time::sleep(Duration::from_secs(cur_retries.min(MAX_RETRY_DURATION)))
84 .await;
85
86 continue;
87 }
88 RetryResult::RateLimit(time) => {
89 cur_retries += 1;
90 warn!(
91 "Retry Attempt #{} resulted in rate limit. Waiting for {}",
92 cur_retries,
93 time.as_secs()
94 );
95
96 tokio::time::sleep(time).await;
97
98 continue;
99 }
100 _ => res,
101 };
102 }
103 }
104}
105
106fn retry_success(res: &Response) -> RetryResult {
107 let status = res.status();
108
109 if status.is_success() {
110 return RetryResult::Success;
111 }
112
113 match status {
114 StatusCode::REQUEST_TIMEOUT
115 | StatusCode::INTERNAL_SERVER_ERROR
116 | StatusCode::BAD_GATEWAY
117 | StatusCode::SERVICE_UNAVAILABLE
118 | StatusCode::GATEWAY_TIMEOUT => RetryResult::Transient,
119 StatusCode::TOO_MANY_REQUESTS => {
120 match res.headers().get(RETRY_AFTER) {
121 Some(time) => {
122 if let Ok(time) = time.to_str() {
123 if let Ok(time) = time.parse::<u64>() {
124 RetryResult::RateLimit(Duration::from_secs(time))
125 } else {
126 warn!(
127 "Server Returned Too Many Requests Status Code, but the RETRY_AFTER header was unable to be parsed"
128 );
129 RetryResult::Transient
130 }
131 } else {
132 warn!("Server Returned Too Many Requests Status Code, but the RETRY_AFTER header was unable to be turned into a valid utf-8 string");
133 RetryResult::Transient
134 }
135 }
136 _ => {
137 warn!("Server Returned Too Many Requests Status Code, but the RETRY_AFTER header was non-existent");
138 RetryResult::Transient
139 }
140 }
141 }
142 _ => {
143 RetryResult::Fatal
145 }
146 }
147}
148
149fn retry_failure(error: &reqwest_middleware::Error) -> RetryResult {
150 match error {
151 Error::Middleware(_) => RetryResult::Fatal,
153 Error::Reqwest(error) => {
154 #[cfg(not(target_arch = "wasm32"))]
155 let is_connect = error.is_connect();
156 #[cfg(target_arch = "wasm32")]
157 let is_connect = false;
158 if error.is_body()
159 || error.is_decode()
160 || error.is_builder()
161 || error.is_redirect()
162 || error.is_timeout()
163 || is_connect
164 {
165 RetryResult::Fatal
166 } else if error.is_request() {
167 #[cfg(not(target_arch = "wasm32"))]
170 if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
171 if hyper_error.is_incomplete_message() || hyper_error.is_canceled() {
177 RetryResult::Transient
178
179 } else if let Some(io_error) =
182 get_source_error_type::<std::io::Error>(hyper_error)
183 {
184 classify_io_error(io_error)
185 } else {
186 RetryResult::Fatal
187 }
188 } else {
189 RetryResult::Fatal
190 }
191 #[cfg(target_arch = "wasm32")]
192 RetryResult::Fatal
193 } else {
194 RetryResult::Success
198 }
199 }
200 }
201}
202
203#[cfg(not(target_arch = "wasm32"))]
204fn classify_io_error(error: &std::io::Error) -> RetryResult {
205 match error.kind() {
206 std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::ConnectionAborted => {
207 RetryResult::Transient
208 }
209 _ => RetryResult::Fatal,
210 }
211}
212
213#[cfg(not(target_arch = "wasm32"))]
215fn get_source_error_type<T: std::error::Error + 'static>(
216 err: &dyn std::error::Error,
217) -> Option<&T> {
218 let mut source = err.source();
219
220 while let Some(err) = source {
221 if let Some(err) = err.downcast_ref::<T>() {
222 return Some(err);
223 }
224
225 source = err.source();
226 }
227 None
228}