tencent_sdk/middleware/
retry_async.rs1use crate::{core::TencentCloudError, transport::async_impl::AsyncTransport};
2use async_trait::async_trait;
3use fastrand;
4use http::{Method, StatusCode};
5use std::{collections::HashMap, time::Duration};
6use tokio::time::sleep;
7use url::Url;
8
9#[derive(Clone)]
10pub struct RetryAsync<T> {
11 inner: T,
12 max: usize,
13 base_delay: Duration,
14}
15
16impl<T> RetryAsync<T> {
17 pub fn new(inner: T, max: usize, base_delay: Duration) -> Self {
18 Self {
19 inner,
20 max,
21 base_delay,
22 }
23 }
24
25 fn delay_for(&self, attempt: usize) -> Duration {
26 if attempt == 0 {
27 Duration::from_secs(0)
28 } else {
29 let pow = 2f64.powi((attempt - 1) as i32);
30 let base = self.base_delay.mul_f64(pow);
31 let jitter = 0.5 + fastrand::f64();
32 base.mul_f64(jitter)
33 }
34 }
35}
36
37#[async_trait]
38impl<T: AsyncTransport> AsyncTransport for RetryAsync<T> {
39 async fn send(
40 &self,
41 method: Method,
42 url: Url,
43 headers: HashMap<String, String>,
44 body: Option<String>,
45 timeout: Duration,
46 ) -> Result<(StatusCode, String), TencentCloudError> {
47 let mut attempt = 0usize;
48 loop {
49 match self
50 .inner
51 .send(
52 method.clone(),
53 url.clone(),
54 headers.clone(),
55 body.clone(),
56 timeout,
57 )
58 .await
59 {
60 Ok((status, payload)) => {
61 if status.is_server_error() && attempt < self.max {
62 attempt += 1;
63 let delay = self.delay_for(attempt);
64 if !delay.is_zero() {
65 sleep(delay).await;
66 }
67 continue;
68 }
69 return Ok((status, payload));
70 }
71 Err(err) => {
72 let should_retry =
73 attempt < self.max && matches!(err, TencentCloudError::Transport { .. });
74
75 if should_retry {
76 attempt += 1;
77 let delay = self.delay_for(attempt);
78 if !delay.is_zero() {
79 sleep(delay).await;
80 }
81 continue;
82 }
83
84 return Err(err);
85 }
86 }
87 }
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use async_trait::async_trait;
95 use http::Method;
96 use std::sync::atomic::{AtomicUsize, Ordering};
97 use std::sync::Arc;
98 use tokio::task;
99
100 #[derive(Clone)]
101 struct FlakyAsyncTransport {
102 attempts: Arc<AtomicUsize>,
103 fail_times: usize,
104 }
105
106 impl FlakyAsyncTransport {
107 fn new(fail_times: usize) -> Self {
108 Self {
109 attempts: Arc::new(AtomicUsize::new(0)),
110 fail_times,
111 }
112 }
113 }
114
115 #[async_trait]
116 impl AsyncTransport for FlakyAsyncTransport {
117 async fn send(
118 &self,
119 method: Method,
120 url: Url,
121 _headers: HashMap<String, String>,
122 _body: Option<String>,
123 _timeout: Duration,
124 ) -> Result<(StatusCode, String), TencentCloudError> {
125 let current = self.attempts.fetch_add(1, Ordering::SeqCst);
126 if current < self.fail_times {
127 let error = task::spawn_blocking(move || make_transport_error(method, url))
128 .await
129 .expect("spawn blocking for transport error");
130 return Err(error);
131 }
132
133 Ok((StatusCode::OK, "{}".to_string()))
134 }
135 }
136
137 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
138 async fn retries_transport_errors_before_succeeding() {
139 let transport = FlakyAsyncTransport::new(2);
140 let retry = RetryAsync::new(transport.clone(), 3, Duration::from_millis(1));
141 let result = retry
142 .send(
143 Method::POST,
144 Url::parse("https://example.com").unwrap(),
145 HashMap::new(),
146 None,
147 Duration::from_secs(1),
148 )
149 .await;
150
151 assert!(
152 result.is_ok(),
153 "expected retry to eventually succeed: {result:?}"
154 );
155 assert_eq!(
156 transport.attempts.load(Ordering::SeqCst),
157 3,
158 "expected two retries plus final success"
159 );
160 }
161
162 fn make_transport_error(method: Method, url: Url) -> TencentCloudError {
163 let client = reqwest::blocking::Client::builder()
164 .build()
165 .expect("build test reqwest client");
166
167 let error = client
168 .get("http://example.com")
169 .header("\n", "value")
170 .build()
171 .expect_err("invalid header should fail before network IO");
172
173 TencentCloudError::transport(error, method, url)
174 }
175}