webull_rs/utils/
rate_limit.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::{Duration, Instant};
4use tokio::time::sleep;
5
6/// Rate limiter for API requests.
7pub struct RateLimiter {
8    /// Maximum number of requests per minute
9    requests_per_minute: u32,
10
11    /// Request timestamps
12    timestamps: Arc<Mutex<HashMap<String, Vec<Instant>>>>,
13
14    /// Backoff strategy
15    backoff_strategy: BackoffStrategy,
16}
17
18impl RateLimiter {
19    /// Create a new rate limiter.
20    pub fn new(requests_per_minute: u32) -> Self {
21        Self {
22            requests_per_minute,
23            timestamps: Arc::new(Mutex::new(HashMap::new())),
24            backoff_strategy: BackoffStrategy::default(),
25        }
26    }
27
28    /// Set the backoff strategy.
29    pub fn with_backoff_strategy(mut self, strategy: BackoffStrategy) -> Self {
30        self.backoff_strategy = strategy;
31        self
32    }
33
34    /// Wait for rate limit to allow a request.
35    pub async fn wait(&self, endpoint: &str) {
36        // Get the current time
37        let now = Instant::now();
38
39        // Check if we need to wait
40        let wait_time = {
41            // Get the timestamps for this endpoint
42            let mut timestamps = self.timestamps.lock().unwrap();
43            let endpoint_timestamps = timestamps.entry(endpoint.to_string()).or_insert_with(Vec::new);
44
45            // Remove timestamps older than 1 minute
46            endpoint_timestamps.retain(|t| now.duration_since(*t) < Duration::from_secs(60));
47
48            // Check if we've exceeded the rate limit
49            if endpoint_timestamps.len() >= self.requests_per_minute as usize {
50                // Calculate how long to wait
51                let oldest = endpoint_timestamps[0];
52                Some(Duration::from_secs(60) - now.duration_since(oldest))
53            } else {
54                // Add the current timestamp
55                endpoint_timestamps.push(now);
56                None
57            }
58        };
59
60        // Wait if necessary
61        if let Some(duration) = wait_time {
62            // Wait for the rate limit to reset
63            sleep(duration).await;
64
65            // Add the current timestamp
66            let mut timestamps = self.timestamps.lock().unwrap();
67            let endpoint_timestamps = timestamps.entry(endpoint.to_string()).or_insert_with(Vec::new);
68            endpoint_timestamps.push(Instant::now());
69        }
70    }
71
72    /// Handle a rate limit error.
73    pub async fn handle_rate_limit_error(&self, attempt: u32) -> Duration {
74        self.backoff_strategy.get_backoff_duration(attempt)
75    }
76}
77
78/// Backoff strategy for rate limiting.
79#[derive(Debug, Clone, Copy)]
80pub enum BackoffStrategy {
81    /// Constant backoff
82    Constant(Duration),
83
84    /// Linear backoff
85    Linear {
86        /// Initial backoff duration
87        initial: Duration,
88
89        /// Increment per attempt
90        increment: Duration,
91    },
92
93    /// Exponential backoff
94    Exponential {
95        /// Initial backoff duration
96        initial: Duration,
97
98        /// Multiplier per attempt
99        multiplier: f64,
100
101        /// Maximum backoff duration
102        max: Duration,
103    },
104}
105
106impl BackoffStrategy {
107    /// Get the backoff duration for an attempt.
108    pub fn get_backoff_duration(&self, attempt: u32) -> Duration {
109        match self {
110            Self::Constant(duration) => *duration,
111            Self::Linear { initial, increment } => {
112                *initial + *increment * attempt
113            }
114            Self::Exponential { initial, multiplier, max } => {
115                let duration = initial.as_secs_f64() * multiplier.powf(attempt as f64);
116                Duration::from_secs_f64(duration.min(max.as_secs_f64()))
117            }
118        }
119    }
120}
121
122impl Default for BackoffStrategy {
123    fn default() -> Self {
124        Self::Exponential {
125            initial: Duration::from_secs(1),
126            multiplier: 2.0,
127            max: Duration::from_secs(60),
128        }
129    }
130}