1use std::time::{Duration, Instant};
2
3use chrono::{DateTime, Utc};
4use tokio::sync::RwLock;
5
6#[derive(Debug, Clone)]
8pub struct RateLimitInfo {
9 pub limit: u32,
11 pub remaining: u32,
13 pub reset: DateTime<Utc>,
15 pub retry_after: Option<Duration>,
17}
18
19#[derive(Debug, Clone)]
21pub struct RateLimitStatus {
22 pub limit: u32,
24 pub remaining: u32,
26 pub reset_time: DateTime<Utc>,
28 pub reset_in: Duration,
30}
31
32#[derive(Debug, Clone)]
34pub struct RateLimiterConfig {
35 pub initial_limit: u32,
37 pub backoff_multiplier: f64,
39 pub max_backoff: Duration,
41}
42
43impl Default for RateLimiterConfig {
44 fn default() -> Self {
45 Self {
46 initial_limit: 100,
47 backoff_multiplier: 2.0,
48 max_backoff: Duration::from_secs(300),
49 }
50 }
51}
52
53struct Inner {
54 limit: u32,
55 remaining: u32,
56 reset_time: DateTime<Utc>,
57 last_request_time: Option<Instant>,
58 backoff_multiplier: f64,
59 max_backoff: Duration,
60 rate_limited: bool,
61 last_rate_limit_time: Option<Instant>,
62 consecutive_rate_limits: u32,
63 enabled: bool,
64}
65
66pub struct RateLimiter {
70 inner: RwLock<Inner>,
71}
72
73impl RateLimiter {
74 pub fn new(config: &RateLimiterConfig) -> Self {
76 let limit = if config.initial_limit == 0 {
77 100
78 } else {
79 config.initial_limit
80 };
81 let backoff = if config.backoff_multiplier <= 0.0 {
82 2.0
83 } else {
84 config.backoff_multiplier
85 };
86 let max_backoff = if config.max_backoff.is_zero() {
87 Duration::from_secs(300)
88 } else {
89 config.max_backoff
90 };
91
92 Self {
93 inner: RwLock::new(Inner {
94 limit,
95 remaining: limit,
96 reset_time: Utc::now() + chrono::Duration::hours(1),
97 last_request_time: None,
98 backoff_multiplier: backoff,
99 max_backoff,
100 rate_limited: false,
101 last_rate_limit_time: None,
102 consecutive_rate_limits: 0,
103 enabled: true,
104 }),
105 }
106 }
107
108 pub async fn disable(&self) {
110 let mut inner = self.inner.write().await;
111 inner.enabled = false;
112 }
113
114 pub async fn enable(&self) {
116 let mut inner = self.inner.write().await;
117 inner.enabled = true;
118 }
119
120 pub async fn should_wait(&self) -> bool {
124 let inner = self.inner.read().await;
125 inner.enabled && inner.rate_limited && Utc::now() < inner.reset_time
126 }
127
128 pub async fn wait(&self) -> crate::Result<()> {
131 {
133 let mut inner = self.inner.write().await;
134 if Utc::now() >= inner.reset_time {
135 inner.remaining = inner.limit;
136 inner.reset_time = Utc::now() + chrono::Duration::hours(1);
137 inner.rate_limited = false;
138 inner.consecutive_rate_limits = 0;
139 tracing::debug!(limit = inner.limit, "Rate limit window reset");
140 return Ok(());
141 }
142 if !inner.rate_limited {
143 inner.last_request_time = Some(Instant::now());
144 return Ok(());
145 }
146 }
147
148 loop {
150 let (wait_duration, original_reset) = {
151 let inner = self.inner.read().await;
152 let mut wait_time = (inner.reset_time - Utc::now())
153 .to_std()
154 .unwrap_or(Duration::from_secs(1));
155
156 if inner.consecutive_rate_limits > 1 {
158 let base_delay = Duration::from_secs(1);
159 let exponent = (inner.consecutive_rate_limits - 1).min(10);
160 let backoff_secs =
161 base_delay.as_secs_f64() * inner.backoff_multiplier.powi(exponent as i32);
162 let backoff = Duration::from_secs_f64(backoff_secs);
163 if backoff > wait_time {
164 wait_time = backoff;
165 }
166 if wait_time > inner.max_backoff {
167 wait_time = inner.max_backoff;
168 }
169 }
170
171 tracing::info!(
172 wait_ms = wait_time.as_millis() as u64,
173 remaining = inner.remaining,
174 "API rate limit enforced, waiting"
175 );
176
177 (wait_time, inner.reset_time)
178 };
179
180 tokio::time::sleep(wait_duration).await;
181
182 let mut inner = self.inner.write().await;
184 if inner.reset_time > original_reset {
185 continue;
186 }
187 inner.rate_limited = false;
188 inner.last_request_time = Some(Instant::now());
189 return Ok(());
190 }
191 }
192
193 pub async fn update_from_headers(&self, info: &RateLimitInfo) {
199 let mut inner = self.inner.write().await;
200 if info.limit > 0 {
201 inner.limit = info.limit;
202 }
203 inner.remaining = info.remaining;
204 if info.reset > Utc::now() {
205 inner.reset_time = info.reset;
206 }
207 inner.rate_limited = false;
209 inner.consecutive_rate_limits = 0;
210 tracing::debug!(
211 limit = info.limit,
212 remaining = info.remaining,
213 "Rate limit updated from headers"
214 );
215 }
216
217 pub async fn mark_rate_limited(&self, reset_time: DateTime<Utc>) {
219 let mut inner = self.inner.write().await;
220 inner.rate_limited = true;
221 inner.last_rate_limit_time = Some(Instant::now());
222 inner.consecutive_rate_limits += 1;
223 if reset_time > Utc::now() {
224 inner.reset_time = reset_time;
225 } else {
226 inner.reset_time = Utc::now() + chrono::Duration::seconds(60);
228 }
229 tracing::info!(
230 consecutive = inner.consecutive_rate_limits,
231 "Marked as rate limited by API"
232 );
233 }
234
235 pub async fn get_status(&self) -> RateLimitStatus {
237 let inner = self.inner.read().await;
238 let reset_in = (inner.reset_time - Utc::now())
239 .to_std()
240 .unwrap_or(Duration::ZERO);
241 RateLimitStatus {
242 limit: inner.limit,
243 remaining: inner.remaining,
244 reset_time: inner.reset_time,
245 reset_in,
246 }
247 }
248
249 pub async fn is_near_limit(&self, threshold: f64) -> bool {
251 let inner = self.inner.read().await;
252 if inner.limit == 0 {
253 return false;
254 }
255 let used = inner.limit.saturating_sub(inner.remaining) as f64 / inner.limit as f64;
256 used >= threshold
257 }
258
259 pub async fn is_rate_limited(&self) -> bool {
261 let inner = self.inner.read().await;
262 inner.rate_limited && Utc::now() < inner.reset_time
263 }
264
265 pub async fn reset(&self) {
267 let mut inner = self.inner.write().await;
268 inner.remaining = inner.limit;
269 inner.reset_time = Utc::now() + chrono::Duration::hours(1);
270 inner.last_request_time = None;
271 inner.rate_limited = false;
272 inner.consecutive_rate_limits = 0;
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[tokio::test]
281 async fn test_new_default_config() {
282 let rl = RateLimiter::new(&RateLimiterConfig::default());
283 let status = rl.get_status().await;
284 assert_eq!(status.limit, 100);
285 assert_eq!(status.remaining, 100);
286 }
287
288 #[tokio::test]
289 async fn test_should_wait_initially_false() {
290 let rl = RateLimiter::new(&RateLimiterConfig::default());
291 assert!(!rl.should_wait().await);
292 }
293
294 #[tokio::test]
295 async fn test_mark_rate_limited() {
296 let rl = RateLimiter::new(&RateLimiterConfig::default());
297 assert!(!rl.is_rate_limited().await);
298 let reset = Utc::now() + chrono::Duration::minutes(5);
299 rl.mark_rate_limited(reset).await;
300 assert!(rl.is_rate_limited().await);
301 assert!(rl.should_wait().await);
302 }
303
304 #[tokio::test]
305 async fn test_update_from_headers() {
306 let rl = RateLimiter::new(&RateLimiterConfig::default());
307 let info = RateLimitInfo {
308 limit: 200,
309 remaining: 150,
310 reset: Utc::now() + chrono::Duration::hours(1),
311 retry_after: None,
312 };
313 rl.update_from_headers(&info).await;
314 let status = rl.get_status().await;
315 assert_eq!(status.limit, 200);
316 assert_eq!(status.remaining, 150);
317 }
318
319 #[tokio::test]
320 async fn test_is_near_limit() {
321 let rl = RateLimiter::new(&RateLimiterConfig {
322 initial_limit: 100,
323 ..Default::default()
324 });
325 assert!(!rl.is_near_limit(0.8).await);
326 let info = RateLimitInfo {
327 limit: 100,
328 remaining: 10,
329 reset: Utc::now() + chrono::Duration::hours(1),
330 retry_after: None,
331 };
332 rl.update_from_headers(&info).await;
333 assert!(rl.is_near_limit(0.8).await);
334 }
335
336 #[tokio::test]
337 async fn test_reset() {
338 let rl = RateLimiter::new(&RateLimiterConfig::default());
339 rl.mark_rate_limited(Utc::now() + chrono::Duration::minutes(5))
340 .await;
341 assert!(rl.is_rate_limited().await);
342 rl.reset().await;
343 assert!(!rl.is_rate_limited().await);
344 let status = rl.get_status().await;
345 assert_eq!(status.remaining, status.limit);
346 }
347
348 #[tokio::test]
349 async fn test_wait_not_rate_limited() {
350 let rl = RateLimiter::new(&RateLimiterConfig::default());
351 rl.wait().await.unwrap();
353 }
354
355 #[tokio::test]
356 async fn test_disable_enable() {
357 let rl = RateLimiter::new(&RateLimiterConfig::default());
358 let reset = Utc::now() + chrono::Duration::minutes(5);
359 rl.mark_rate_limited(reset).await;
360 assert!(rl.should_wait().await);
361
362 rl.disable().await;
364 assert!(!rl.should_wait().await);
365
366 rl.enable().await;
368 assert!(rl.should_wait().await);
369 }
370}