1use crate::error::{TushareError, TushareResult};
2use crate::types::{TushareEntityList, TushareRequest, TushareResponse};
3use crate::{Api, TushareClient};
4use rand::Rng;
5use std::collections::HashMap;
6use std::time::{Duration, Instant};
7use tokio::sync::Mutex;
8use tokio::time::sleep;
9
10#[derive(Debug, Clone)]
19pub struct RetryConfig {
20 pub max_retries: usize,
21 pub base_delay: Duration,
22 pub max_delay: Duration,
23}
24
25impl Default for RetryConfig {
26 fn default() -> Self {
27 Self {
28 max_retries: 3,
29 base_delay: Duration::from_millis(200),
30 max_delay: Duration::from_secs(5),
31 }
32 }
33}
34
35#[derive(Debug)]
50pub struct TushareClientEx {
51 inner: TushareClient,
52 api_min_intervals: HashMap<String, Duration>,
53 api_next_allowed_at: Mutex<HashMap<String, Instant>>,
54 retry: Option<RetryConfig>,
55}
56
57impl TushareClientEx {
58 pub fn new(inner: TushareClient) -> Self {
62 Self {
63 inner,
64 api_min_intervals: HashMap::new(),
65 api_next_allowed_at: Mutex::new(HashMap::new()),
66 retry: None,
67 }
68 }
69
70 pub fn with_api_min_interval(mut self, api: Api, min_interval: Duration) -> Self {
87 self.api_min_intervals.insert(api.name(), min_interval);
88 self
89 }
90
91 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
100 self.retry = Some(config);
101 self
102 }
103
104 pub fn inner(&self) -> &TushareClient {
106 &self.inner
107 }
108
109 pub fn into_inner(self) -> TushareClient {
111 self.inner
112 }
113
114 pub async fn call_api<T>(&self, request: &T) -> TushareResult<TushareResponse>
116 where
117 for<'a> &'a T: TryInto<TushareRequest>,
118 for<'a> <&'a T as TryInto<TushareRequest>>::Error: Into<TushareError>,
119 {
120 let request = request.try_into().map_err(Into::into)?;
121
122 self.apply_api_min_interval_rate_limit(&request.api_name.name()).await;
123
124 self.call_api_with_retry(request).await
125 }
126
127 pub async fn call_api_as<T, R>(&self, request: &R) -> TushareResult<TushareEntityList<T>>
128 where
129 T: crate::traits::FromTushareData,
130 for<'a> &'a R: TryInto<TushareRequest>,
131 for<'a> <&'a R as TryInto<TushareRequest>>::Error: Into<TushareError>,
132 {
133 let response = self.call_api(request).await?;
134 TushareEntityList::try_from(response).map_err(Into::into)
135 }
136
137 async fn call_api_with_retry(&self, request: TushareRequest) -> TushareResult<TushareResponse> {
138 let Some(cfg) = self.retry.clone() else {
139 return self.inner.call_api_request(&request).await;
140 };
141
142 let mut attempt = 0usize;
143
144 loop {
145 match self.inner.call_api_request(&request).await {
146 Ok(resp) => return Ok(resp),
147 Err(err) => {
148 let should_retry = attempt < cfg.max_retries && is_retryable_error(&err);
149 if !should_retry {
150 return Err(err);
151 }
152
153 let delay = compute_backoff_delay(&cfg, attempt);
154 sleep(delay).await;
155 attempt += 1;
156 }
157 }
158 }
159 }
160
161 async fn apply_api_min_interval_rate_limit(&self, api_name: &str) {
162 let Some(min_interval) = self.api_min_intervals.get(api_name).copied() else {
163 return;
164 };
165
166 let now = Instant::now();
167 let wait = {
168 let mut guard = self.api_next_allowed_at.lock().await;
169 let next_allowed_at = guard.get(api_name).copied().unwrap_or(now);
170 let base = if next_allowed_at > now { next_allowed_at } else { now };
171 guard.insert(api_name.to_string(), base + min_interval);
172 if base > now {
173 base - now
174 } else {
175 Duration::from_secs(0)
176 }
177 };
178
179 if !wait.is_zero() {
180 sleep(wait).await;
181 }
182 }
183}
184
185fn is_retryable_error(err: &TushareError) -> bool {
186 matches!(
187 err,
188 TushareError::HttpError(_) | TushareError::TimeoutError
189 )
190}
191
192fn compute_backoff_delay(cfg: &RetryConfig, attempt: usize) -> Duration {
193 let shift = attempt.min(31) as u32;
194 let factor = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
195 let base = cfg.base_delay.saturating_mul(factor as u32);
196 let capped = if base > cfg.max_delay { cfg.max_delay } else { base };
197
198 let capped_ms = capped.as_millis().min(u64::MAX as u128) as u64;
201 if capped_ms == 0 {
202 return Duration::from_millis(0);
203 }
204
205 let half = capped_ms / 2;
206 let jitter_ms = rand::thread_rng().gen_range(0..=half);
207 Duration::from_millis(half + jitter_ms)
208}
209