1use crate::error::{TushareError, TushareResult};
2use crate::types::{TushareEntityList, TushareRequest, TushareResponse};
3use crate::{Api, TushareClient};
4use rand::Rng;
5use std::collections::HashMap;
6use std::time::{Duration, Instant};
7use tokio::sync::Mutex;
8use tokio::time::sleep;
9
10#[derive(Debug, Clone)]
19pub struct RetryConfig {
20 pub max_retries: usize,
21 pub base_delay: Duration,
22 pub max_delay: Duration,
23}
24
25impl Default for RetryConfig {
26 fn default() -> Self {
27 Self {
28 max_retries: 3,
29 base_delay: Duration::from_millis(200),
30 max_delay: Duration::from_secs(5),
31 }
32 }
33}
34
35#[derive(Debug)]
50pub struct TushareClientEx {
51 inner: TushareClient,
52 api_min_intervals: HashMap<String, Duration>,
53 api_next_allowed_at: Mutex<HashMap<String, Instant>>,
54 retry: Option<RetryConfig>,
55}
56
57impl TushareClientEx {
58 pub fn new(inner: TushareClient) -> Self {
62 Self {
63 inner,
64 api_min_intervals: HashMap::new(),
65 api_next_allowed_at: Mutex::new(HashMap::new()),
66 retry: None,
67 }
68 }
69
70 pub fn with_api_min_interval(mut self, api: Api, min_interval: Duration) -> Self {
87 self.api_min_intervals.insert(api.name(), min_interval);
88 self
89 }
90
91 pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
100 self.retry = Some(config);
101 self
102 }
103
104 pub fn inner(&self) -> &TushareClient {
106 &self.inner
107 }
108
109 pub fn into_inner(self) -> TushareClient {
111 self.inner
112 }
113
114 pub async fn call_api<T>(&self, request: &T) -> TushareResult<TushareResponse>
116 where
117 for<'a> &'a T: TryInto<TushareRequest>,
118 for<'a> <&'a T as TryInto<TushareRequest>>::Error: Into<TushareError>,
119 {
120 let request = request.try_into().map_err(Into::into)?;
121
122 self.apply_api_min_interval_rate_limit(&request.api_name.name()).await;
123
124 self.call_api_with_retry(request).await
125 }
126
127 pub async fn call_api_as<T, R>(&self, request: &R) -> TushareResult<TushareEntityList<T>>
128 where
129 T: crate::traits::FromTushareData,
130 for<'a> &'a R: TryInto<TushareRequest>,
131 for<'a> <&'a R as TryInto<TushareRequest>>::Error: Into<TushareError>,
132 {
133 let response = self.call_api(request).await?;
134 TushareEntityList::try_from(response).map_err(Into::into)
135 }
136
137 async fn call_api_with_retry(&self, request: TushareRequest) -> TushareResult<TushareResponse> {
138 let Some(cfg) = self.retry.clone() else {
139 return self.inner.call_api_request(&request).await;
140 };
141
142 let mut attempt = 0usize;
143 let request_id = crate::client::generate_request_id();
144 let api_name = request.api_name.name();
145
146 loop {
147 match self
148 .inner
149 .call_api_request_with_request_id(&request_id, &request)
150 .await
151 {
152 Ok(resp) => return Ok(resp),
153 Err(err) => {
154 let should_retry = attempt < cfg.max_retries && is_retryable_error(&err);
155 if !should_retry {
156 self.inner.logger().log_safe(
157 crate::logging::LogLevel::Error,
158 || {
159 format!(
160 "[{}] tushare_api retry exhausted or non-retryable error; api={}, attempts={}, max_retries={}, err={}",
161 request_id, api_name, attempt, cfg.max_retries, err
162 )
163 },
164 None,
165 );
166 return Err(err);
167 }
168
169 let delay = compute_backoff_delay(&cfg, attempt);
170 self.inner.logger().log_safe(
171 crate::logging::LogLevel::Warn,
172 || {
173 format!(
174 "[{}] tushare_api retrying; api={}, retry={}/{}, delay={:?}, err={}",
175 request_id,
176 api_name,
177 attempt + 1,
178 cfg.max_retries,
179 delay,
180 err
181 )
182 },
183 None,
184 );
185 sleep(delay).await;
186 attempt += 1;
187 }
188 }
189 }
190 }
191
192 async fn apply_api_min_interval_rate_limit(&self, api_name: &str) {
193 let Some(min_interval) = self.api_min_intervals.get(api_name).copied() else {
194 return;
195 };
196
197 let now = Instant::now();
198 let wait = {
199 let mut guard = self.api_next_allowed_at.lock().await;
200 let next_allowed_at = guard.get(api_name).copied().unwrap_or(now);
201 let base = if next_allowed_at > now { next_allowed_at } else { now };
202 guard.insert(api_name.to_string(), base + min_interval);
203 if base > now {
204 base - now
205 } else {
206 Duration::from_secs(0)
207 }
208 };
209
210 if !wait.is_zero() {
211 sleep(wait).await;
212 }
213 }
214}
215
216fn is_retryable_error(err: &TushareError) -> bool {
217 matches!(
218 err,
219 TushareError::HttpError(_) | TushareError::TimeoutError
220 )
221}
222
223fn compute_backoff_delay(cfg: &RetryConfig, attempt: usize) -> Duration {
224 let shift = attempt.min(31) as u32;
225 let factor = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
226 let base = cfg.base_delay.saturating_mul(factor as u32);
227 let capped = if base > cfg.max_delay { cfg.max_delay } else { base };
228
229 let capped_ms = capped.as_millis().min(u64::MAX as u128) as u64;
232 if capped_ms == 0 {
233 return Duration::from_millis(0);
234 }
235
236 let half = capped_ms / 2;
237 let jitter_ms = rand::thread_rng().gen_range(0..=half);
238 Duration::from_millis(half + jitter_ms)
239}
240