webull_rs/utils/
rate_limit.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::{Duration, Instant};
4use tokio::time::sleep;
5
6/// Rate limiter for API requests.
7pub struct RateLimiter {
8    /// Maximum number of requests per minute
9    requests_per_minute: u32,
10
11    /// Request timestamps
12    timestamps: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
13
14    /// Backoff strategy
15    backoff_strategy: BackoffStrategy,
16}
17
18impl RateLimiter {
19    /// Create a new rate limiter.
20    pub fn new(requests_per_minute: u32) -> Self {
21        Self {
22            requests_per_minute,
23            timestamps: Arc::new(Mutex::new(HashMap::new())),
24            backoff_strategy: BackoffStrategy::default(),
25        }
26    }
27
28    /// Set the backoff strategy.
29    pub fn with_backoff_strategy(mut self, strategy: BackoffStrategy) -> Self {
30        self.backoff_strategy = strategy;
31        self
32    }
33
34    /// Wait for rate limit to allow a request.
35    pub async fn wait(&self, endpoint: &str) {
36        // Get the current time
37        let now = Instant::now();
38
39        // Check if we need to wait
40        let wait_time = {
41            // Get the timestamps for this endpoint
42            let mut timestamps = self.timestamps.lock().unwrap();
43            let endpoint_timestamps = timestamps
44                .entry(endpoint.to_string())
45                .or_insert_with(Vec::new);
46
47            // Remove timestamps older than 1 minute
48            endpoint_timestamps.retain(|t| now.duration_since(*t) < Duration::from_secs(60));
49
50            // Check if we've exceeded the rate limit
51            if endpoint_timestamps.len() >= self.requests_per_minute as usize {
52                // Calculate how long to wait
53                let oldest = endpoint_timestamps[0];
54                Some(Duration::from_secs(60) - now.duration_since(oldest))
55            } else {
56                // Add the current timestamp
57                endpoint_timestamps.push(now);
58                None
59            }
60        };
61
62        // Wait if necessary
63        if let Some(duration) = wait_time {
64            // Wait for the rate limit to reset
65            sleep(duration).await;
66
67            // Add the current timestamp
68            let mut timestamps = self.timestamps.lock().unwrap();
69            let endpoint_timestamps = timestamps
70                .entry(endpoint.to_string())
71                .or_insert_with(Vec::new);
72            endpoint_timestamps.push(Instant::now());
73        }
74    }
75
76    /// Handle a rate limit error.
77    pub async fn handle_rate_limit_error(&self, attempt: u32) -> Duration {
78        self.backoff_strategy.get_backoff_duration(attempt)
79    }
80}
81
82/// Backoff strategy for rate limiting.
83#[derive(Debug, Clone, Copy)]
84pub enum BackoffStrategy {
85    /// Constant backoff
86    Constant(Duration),
87
88    /// Linear backoff
89    Linear {
90        /// Initial backoff duration
91        initial: Duration,
92
93        /// Increment per attempt
94        increment: Duration,
95    },
96
97    /// Exponential backoff
98    Exponential {
99        /// Initial backoff duration
100        initial: Duration,
101
102        /// Multiplier per attempt
103        multiplier: f64,
104
105        /// Maximum backoff duration
106        max: Duration,
107    },
108}
109
110impl BackoffStrategy {
111    /// Get the backoff duration for an attempt.
112    pub fn get_backoff_duration(&self, attempt: u32) -> Duration {
113        match self {
114            Self::Constant(duration) => *duration,
115            Self::Linear { initial, increment } => *initial + *increment * attempt,
116            Self::Exponential {
117                initial,
118                multiplier,
119                max,
120            } => {
121                let duration = initial.as_secs_f64() * multiplier.powf(attempt as f64);
122                Duration::from_secs_f64(duration.min(max.as_secs_f64()))
123            }
124        }
125    }
126}
127
128impl Default for BackoffStrategy {
129    fn default() -> Self {
130        Self::Exponential {
131            initial: Duration::from_secs(1),
132            multiplier: 2.0,
133            max: Duration::from_secs(60),
134        }
135    }
136}