tencent_sdk/middleware/
retry_async.rs

1use crate::{core::TencentCloudError, transport::async_impl::AsyncTransport};
2use async_trait::async_trait;
3use fastrand;
4use http::{Method, StatusCode};
5use std::{collections::HashMap, time::Duration};
6use tokio::time::sleep;
7use url::Url;
8
9#[derive(Clone)]
10pub struct RetryAsync<T> {
11    inner: T,
12    max: usize,
13    base_delay: Duration,
14}
15
16impl<T> RetryAsync<T> {
17    pub fn new(inner: T, max: usize, base_delay: Duration) -> Self {
18        Self {
19            inner,
20            max,
21            base_delay,
22        }
23    }
24
25    fn delay_for(&self, attempt: usize) -> Duration {
26        if attempt == 0 {
27            Duration::from_secs(0)
28        } else {
29            let pow = 2f64.powi((attempt - 1) as i32);
30            let base = self.base_delay.mul_f64(pow);
31            let jitter = 0.5 + fastrand::f64();
32            base.mul_f64(jitter)
33        }
34    }
35}
36
37#[async_trait]
38impl<T: AsyncTransport> AsyncTransport for RetryAsync<T> {
39    async fn send(
40        &self,
41        method: Method,
42        url: Url,
43        headers: HashMap<String, String>,
44        body: Option<String>,
45        timeout: Duration,
46    ) -> Result<(StatusCode, String), TencentCloudError> {
47        let mut attempt = 0usize;
48        loop {
49            match self
50                .inner
51                .send(
52                    method.clone(),
53                    url.clone(),
54                    headers.clone(),
55                    body.clone(),
56                    timeout,
57                )
58                .await
59            {
60                Ok((status, payload)) => {
61                    if status.is_server_error() && attempt < self.max {
62                        attempt += 1;
63                        let delay = self.delay_for(attempt);
64                        if !delay.is_zero() {
65                            sleep(delay).await;
66                        }
67                        continue;
68                    }
69                    return Ok((status, payload));
70                }
71                Err(err) => {
72                    let should_retry =
73                        attempt < self.max && matches!(err, TencentCloudError::Transport { .. });
74
75                    if should_retry {
76                        attempt += 1;
77                        let delay = self.delay_for(attempt);
78                        if !delay.is_zero() {
79                            sleep(delay).await;
80                        }
81                        continue;
82                    }
83
84                    return Err(err);
85                }
86            }
87        }
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use async_trait::async_trait;
95    use http::Method;
96    use std::sync::atomic::{AtomicUsize, Ordering};
97    use std::sync::Arc;
98    use tokio::task;
99
100    #[derive(Clone)]
101    struct FlakyAsyncTransport {
102        attempts: Arc<AtomicUsize>,
103        fail_times: usize,
104    }
105
106    impl FlakyAsyncTransport {
107        fn new(fail_times: usize) -> Self {
108            Self {
109                attempts: Arc::new(AtomicUsize::new(0)),
110                fail_times,
111            }
112        }
113    }
114
115    #[async_trait]
116    impl AsyncTransport for FlakyAsyncTransport {
117        async fn send(
118            &self,
119            method: Method,
120            url: Url,
121            _headers: HashMap<String, String>,
122            _body: Option<String>,
123            _timeout: Duration,
124        ) -> Result<(StatusCode, String), TencentCloudError> {
125            let current = self.attempts.fetch_add(1, Ordering::SeqCst);
126            if current < self.fail_times {
127                let error = task::spawn_blocking(move || make_transport_error(method, url))
128                    .await
129                    .expect("spawn blocking for transport error");
130                return Err(error);
131            }
132
133            Ok((StatusCode::OK, "{}".to_string()))
134        }
135    }
136
137    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
138    async fn retries_transport_errors_before_succeeding() {
139        let transport = FlakyAsyncTransport::new(2);
140        let retry = RetryAsync::new(transport.clone(), 3, Duration::from_millis(1));
141        let result = retry
142            .send(
143                Method::POST,
144                Url::parse("https://example.com").unwrap(),
145                HashMap::new(),
146                None,
147                Duration::from_secs(1),
148            )
149            .await;
150
151        assert!(
152            result.is_ok(),
153            "expected retry to eventually succeed: {result:?}"
154        );
155        assert_eq!(
156            transport.attempts.load(Ordering::SeqCst),
157            3,
158            "expected two retries plus final success"
159        );
160    }
161
162    fn make_transport_error(method: Method, url: Url) -> TencentCloudError {
163        let client = reqwest::blocking::Client::builder()
164            .build()
165            .expect("build test reqwest client");
166
167        let error = client
168            .get("http://example.com")
169            .header("\n", "value")
170            .build()
171            .expect_err("invalid header should fail before network IO");
172
173        TencentCloudError::transport(error, method, url)
174    }
175}