tushare_api/
client_ex.rs

1use crate::error::{TushareError, TushareResult};
2use crate::types::{TushareEntityList, TushareRequest, TushareResponse};
3use crate::{Api, TushareClient};
4use rand::Rng;
5use std::collections::HashMap;
6use std::time::{Duration, Instant};
7use tokio::sync::Mutex;
8use tokio::time::sleep;
9
10/// Retry configuration for [`TushareClientEx`].
11///
12/// The retry logic is implemented at the wrapper layer so that [`TushareClient`]
13/// can stay focused on a single HTTP request + response parsing.
14///
15/// Notes:
16/// - Only retryable errors will be retried (currently network/timeout errors).
17/// - The delay uses exponential backoff: `base_delay * 2^attempt`, capped by `max_delay`.
18#[derive(Debug, Clone)]
19pub struct RetryConfig {
20    pub max_retries: usize,
21    pub base_delay: Duration,
22    pub max_delay: Duration,
23}
24
25impl Default for RetryConfig {
26    fn default() -> Self {
27        Self {
28            max_retries: 3,
29            base_delay: Duration::from_millis(200),
30            max_delay: Duration::from_secs(5),
31        }
32    }
33}
34
35/// Extended client wrapper that adds advanced behaviors on top of [`TushareClient`].
36///
37/// Currently supported:
38/// - **Per-API minimum interval rate limiting (default: sleep)**
39///   If an API is configured with a minimum interval (e.g. 10 seconds), repeated
40///   calls to the same API will be automatically delayed so that two calls are at
41///   least `min_interval` apart. Callers do not need to implement any sleep logic.
42///
43/// - **Retry with exponential backoff (optional)**
44///   When enabled via [`Self::with_retry_config`], network/timeout failures will be
45///   retried with exponential backoff.
46///
47/// This wrapper is designed to keep the core client stable while allowing you to
48/// opt into additional behaviors.
49#[derive(Debug)]
50pub struct TushareClientEx {
51    inner: TushareClient,
52    api_min_intervals: HashMap<String, Duration>,
53    api_next_allowed_at: Mutex<HashMap<String, Instant>>,
54    retry: Option<RetryConfig>,
55}
56
57impl TushareClientEx {
58    /// Create a new wrapper client.
59    ///
60    /// By default, no per-API interval limit is applied and retry is disabled.
61    pub fn new(inner: TushareClient) -> Self {
62        Self {
63            inner,
64            api_min_intervals: HashMap::new(),
65            api_next_allowed_at: Mutex::new(HashMap::new()),
66            retry: None,
67        }
68    }
69
70    /// Configure a minimum interval between two calls of the same API.
71    ///
72    /// If the interval is not met, the wrapper will `sleep` until it becomes
73    /// eligible to call.
74    ///
75    /// Example:
76    ///
77    /// ```rust,no_run
78    /// use std::time::Duration;
79    /// use tushare_api::{Api, TushareClient, TushareClientEx};
80    ///
81    /// # fn build(inner: TushareClient) -> TushareClientEx {
82    /// TushareClientEx::new(inner)
83    ///     .with_api_min_interval(Api::Daily, Duration::from_secs(10))
84    /// # }
85    /// ```
86    pub fn with_api_min_interval(mut self, api: Api, min_interval: Duration) -> Self {
87        self.api_min_intervals.insert(api.name(), min_interval);
88        self
89    }
90
91    /// Enable retry with exponential backoff.
92    ///
93    /// Retryable errors:
94    /// - [`TushareError::HttpError`]
95    /// - [`TushareError::TimeoutError`]
96    ///
97    /// Non-retryable errors (by design):
98    /// - [`TushareError::ApiError`] (business-level errors returned by Tushare)
99    pub fn with_retry_config(mut self, config: RetryConfig) -> Self {
100        self.retry = Some(config);
101        self
102    }
103
104    /// Borrow the underlying [`TushareClient`].
105    pub fn inner(&self) -> &TushareClient {
106        &self.inner
107    }
108
109    /// Consume the wrapper and return the underlying [`TushareClient`].
110    pub fn into_inner(self) -> TushareClient {
111        self.inner
112    }
113
114    /// Call API with configured rate limiting (sleep) and optional retry.
115    pub async fn call_api<T>(&self, request: &T) -> TushareResult<TushareResponse>
116    where
117        for<'a> &'a T: TryInto<TushareRequest>,
118        for<'a> <&'a T as TryInto<TushareRequest>>::Error: Into<TushareError>,
119    {
120        let request = request.try_into().map_err(Into::into)?;
121
122        self.apply_api_min_interval_rate_limit(&request.api_name.name()).await;
123
124        self.call_api_with_retry(request).await
125    }
126
127    pub async fn call_api_as<T, R>(&self, request: &R) -> TushareResult<TushareEntityList<T>>
128    where
129        T: crate::traits::FromTushareData,
130        for<'a> &'a R: TryInto<TushareRequest>,
131        for<'a> <&'a R as TryInto<TushareRequest>>::Error: Into<TushareError>,
132    {
133        let response = self.call_api(request).await?;
134        TushareEntityList::try_from(response).map_err(Into::into)
135    }
136
137    async fn call_api_with_retry(&self, request: TushareRequest) -> TushareResult<TushareResponse> {
138        let Some(cfg) = self.retry.clone() else {
139            return self.inner.call_api_request(&request).await;
140        };
141
142        let mut attempt = 0usize;
143
144        loop {
145            match self.inner.call_api_request(&request).await {
146                Ok(resp) => return Ok(resp),
147                Err(err) => {
148                    let should_retry = attempt < cfg.max_retries && is_retryable_error(&err);
149                    if !should_retry {
150                        return Err(err);
151                    }
152
153                    let delay = compute_backoff_delay(&cfg, attempt);
154                    sleep(delay).await;
155                    attempt += 1;
156                }
157            }
158        }
159    }
160
161    async fn apply_api_min_interval_rate_limit(&self, api_name: &str) {
162        let Some(min_interval) = self.api_min_intervals.get(api_name).copied() else {
163            return;
164        };
165
166        let now = Instant::now();
167        let wait = {
168            let mut guard = self.api_next_allowed_at.lock().await;
169            let next_allowed_at = guard.get(api_name).copied().unwrap_or(now);
170            let base = if next_allowed_at > now { next_allowed_at } else { now };
171            guard.insert(api_name.to_string(), base + min_interval);
172            if base > now {
173                base - now
174            } else {
175                Duration::from_secs(0)
176            }
177        };
178
179        if !wait.is_zero() {
180            sleep(wait).await;
181        }
182    }
183}
184
185fn is_retryable_error(err: &TushareError) -> bool {
186    matches!(
187        err,
188        TushareError::HttpError(_) | TushareError::TimeoutError
189    )
190}
191
192fn compute_backoff_delay(cfg: &RetryConfig, attempt: usize) -> Duration {
193    let shift = attempt.min(31) as u32;
194    let factor = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
195    let base = cfg.base_delay.saturating_mul(factor as u32);
196    let capped = if base > cfg.max_delay { cfg.max_delay } else { base };
197
198    // Equal jitter: capped/2 + random(0..=capped/2)
199    // Compared to full jitter, this is less volatile while still spreading retries.
200    let capped_ms = capped.as_millis().min(u64::MAX as u128) as u64;
201    if capped_ms == 0 {
202        return Duration::from_millis(0);
203    }
204
205    let half = capped_ms / 2;
206    let jitter_ms = rand::thread_rng().gen_range(0..=half);
207    Duration::from_millis(half + jitter_ms)
208}
209