tencent_sdk/middleware/
retry_blocking.rs

1use crate::{core::TencentCloudError, transport::blocking_impl::BlockingTransport};
2use fastrand;
3use http::{Method, StatusCode};
4use std::{collections::HashMap, thread::sleep, time::Duration};
5use url::Url;
6
7#[derive(Clone)]
8pub struct RetryBlocking<T> {
9    inner: T,
10    max: usize,
11    base_delay: Duration,
12}
13
14impl<T> RetryBlocking<T> {
15    pub fn new(inner: T, max: usize, base_delay: Duration) -> Self {
16        Self {
17            inner,
18            max,
19            base_delay,
20        }
21    }
22
23    fn delay_for(&self, attempt: usize) -> Duration {
24        if attempt == 0 {
25            Duration::from_secs(0)
26        } else {
27            let pow = 2f64.powi((attempt - 1) as i32);
28            let base = self.base_delay.mul_f64(pow);
29            let jitter = 0.5 + fastrand::f64();
30            base.mul_f64(jitter)
31        }
32    }
33}
34
35impl<T: BlockingTransport> BlockingTransport for RetryBlocking<T> {
36    fn send(
37        &self,
38        method: Method,
39        url: Url,
40        headers: HashMap<String, String>,
41        body: Option<String>,
42        timeout: Duration,
43    ) -> Result<(StatusCode, String), TencentCloudError> {
44        let mut attempt = 0usize;
45        loop {
46            match self.inner.send(
47                method.clone(),
48                url.clone(),
49                headers.clone(),
50                body.clone(),
51                timeout,
52            ) {
53                Ok((status, payload)) => {
54                    if status.is_server_error() && attempt < self.max {
55                        attempt += 1;
56                        let delay = self.delay_for(attempt);
57                        if !delay.is_zero() {
58                            sleep(delay);
59                        }
60                        continue;
61                    }
62
63                    return Ok((status, payload));
64                }
65                Err(err) => {
66                    let should_retry =
67                        attempt < self.max && matches!(err, TencentCloudError::Transport { .. });
68
69                    if should_retry {
70                        attempt += 1;
71                        let delay = self.delay_for(attempt);
72                        if !delay.is_zero() {
73                            sleep(delay);
74                        }
75                        continue;
76                    }
77
78                    return Err(err);
79                }
80            }
81        }
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88    use http::Method;
89    use std::sync::atomic::{AtomicUsize, Ordering};
90    use std::sync::Arc;
91
92    #[derive(Clone)]
93    struct FlakyBlockingTransport {
94        attempts: Arc<AtomicUsize>,
95        fail_times: usize,
96    }
97
98    impl FlakyBlockingTransport {
99        fn new(fail_times: usize) -> Self {
100            Self {
101                attempts: Arc::new(AtomicUsize::new(0)),
102                fail_times,
103            }
104        }
105    }
106
107    impl BlockingTransport for FlakyBlockingTransport {
108        fn send(
109            &self,
110            method: Method,
111            url: Url,
112            _headers: HashMap<String, String>,
113            _body: Option<String>,
114            _timeout: Duration,
115        ) -> Result<(StatusCode, String), TencentCloudError> {
116            let current = self.attempts.fetch_add(1, Ordering::SeqCst);
117            if current < self.fail_times {
118                return Err(make_transport_error(method, url));
119            }
120
121            Ok((StatusCode::OK, "{}".to_string()))
122        }
123    }
124
125    #[test]
126    fn retries_transport_errors_before_succeeding() {
127        let transport = FlakyBlockingTransport::new(1);
128        let retry = RetryBlocking::new(transport.clone(), 2, Duration::from_millis(1));
129        let result = retry.send(
130            Method::POST,
131            Url::parse("https://example.com").unwrap(),
132            HashMap::new(),
133            None,
134            Duration::from_secs(1),
135        );
136
137        assert!(
138            result.is_ok(),
139            "expected retry to eventually succeed: {result:?}"
140        );
141        assert_eq!(
142            transport.attempts.load(Ordering::SeqCst),
143            2,
144            "expected one retry plus final success"
145        );
146    }
147
148    fn make_transport_error(method: Method, url: Url) -> TencentCloudError {
149        let client = reqwest::blocking::Client::builder()
150            .build()
151            .expect("build test reqwest client");
152
153        let error = client
154            .get("http://example.com")
155            .header("\n", "value")
156            .build()
157            .expect_err("invalid header should fail before network IO");
158
159        TencentCloudError::transport(error, method, url)
160    }
161}