Skip to main content

threads_rs/
rate_limit.rs

1use std::time::{Duration, Instant};
2
3use chrono::{DateTime, Utc};
4use tokio::sync::RwLock;
5
6/// Rate limit information extracted from API response headers.
7#[derive(Debug, Clone)]
8pub struct RateLimitInfo {
9    /// Maximum requests allowed per window.
10    pub limit: u32,
11    /// Requests remaining in the current window.
12    pub remaining: u32,
13    /// When the rate limit window resets.
14    pub reset: DateTime<Utc>,
15    /// Suggested wait time from the API.
16    pub retry_after: Option<Duration>,
17}
18
19/// Snapshot of the current rate limit status.
20#[derive(Debug, Clone)]
21pub struct RateLimitStatus {
22    /// Maximum requests allowed per window.
23    pub limit: u32,
24    /// Requests remaining in the current window.
25    pub remaining: u32,
26    /// When the rate limit window resets.
27    pub reset_time: DateTime<Utc>,
28    /// Duration until the window resets.
29    pub reset_in: Duration,
30}
31
32/// Configuration for the rate limiter.
33#[derive(Debug, Clone)]
34pub struct RateLimiterConfig {
35    /// Starting request limit.
36    pub initial_limit: u32,
37    /// Exponential backoff multiplier.
38    pub backoff_multiplier: f64,
39    /// Maximum backoff duration.
40    pub max_backoff: Duration,
41}
42
43impl Default for RateLimiterConfig {
44    fn default() -> Self {
45        Self {
46            initial_limit: 100,
47            backoff_multiplier: 2.0,
48            max_backoff: Duration::from_secs(300),
49        }
50    }
51}
52
53struct Inner {
54    limit: u32,
55    remaining: u32,
56    reset_time: DateTime<Utc>,
57    last_request_time: Option<Instant>,
58    backoff_multiplier: f64,
59    max_backoff: Duration,
60    rate_limited: bool,
61    last_rate_limit_time: Option<Instant>,
62    consecutive_rate_limits: u32,
63    enabled: bool,
64}
65
66/// Manages API rate limiting with intelligent backoff.
67///
68/// Thread-safe via internal `RwLock`. All methods take `&self`.
69pub struct RateLimiter {
70    inner: RwLock<Inner>,
71}
72
73impl RateLimiter {
74    /// Create a new rate limiter with the given configuration.
75    pub fn new(config: &RateLimiterConfig) -> Self {
76        let limit = if config.initial_limit == 0 {
77            100
78        } else {
79            config.initial_limit
80        };
81        let backoff = if config.backoff_multiplier <= 0.0 {
82            2.0
83        } else {
84            config.backoff_multiplier
85        };
86        let max_backoff = if config.max_backoff.is_zero() {
87            Duration::from_secs(300)
88        } else {
89            config.max_backoff
90        };
91
92        Self {
93            inner: RwLock::new(Inner {
94                limit,
95                remaining: limit,
96                reset_time: Utc::now() + chrono::Duration::hours(1),
97                last_request_time: None,
98                backoff_multiplier: backoff,
99                max_backoff,
100                rate_limited: false,
101                last_rate_limit_time: None,
102                consecutive_rate_limits: 0,
103                enabled: true,
104            }),
105        }
106    }
107
108    /// Disable rate limiting. Requests will not be throttled.
109    pub async fn disable(&self) {
110        let mut inner = self.inner.write().await;
111        inner.enabled = false;
112    }
113
114    /// Enable rate limiting.
115    pub async fn enable(&self) {
116        let mut inner = self.inner.write().await;
117        inner.enabled = true;
118    }
119
120    /// Returns `true` if a request should wait before proceeding.
121    /// Only returns `true` when the API has explicitly rate-limited us
122    /// and the rate limiter is enabled.
123    pub async fn should_wait(&self) -> bool {
124        let inner = self.inner.read().await;
125        inner.enabled && inner.rate_limited && Utc::now() < inner.reset_time
126    }
127
128    /// Blocks until it's safe to make a request.
129    /// Only blocks when actually rate-limited by the API.
130    pub async fn wait(&self) -> crate::Result<()> {
131        // Check if window has reset
132        {
133            let mut inner = self.inner.write().await;
134            if Utc::now() >= inner.reset_time {
135                inner.remaining = inner.limit;
136                inner.reset_time = Utc::now() + chrono::Duration::hours(1);
137                inner.rate_limited = false;
138                inner.consecutive_rate_limits = 0;
139                tracing::debug!(limit = inner.limit, "Rate limit window reset");
140                return Ok(());
141            }
142            if !inner.rate_limited {
143                inner.last_request_time = Some(Instant::now());
144                return Ok(());
145            }
146        }
147
148        // Rate-limited: sleep until reset
149        loop {
150            let (wait_duration, original_reset) = {
151                let inner = self.inner.read().await;
152                let mut wait_time = (inner.reset_time - Utc::now())
153                    .to_std()
154                    .unwrap_or(Duration::from_secs(1));
155
156                // Apply exponential backoff if hitting limits repeatedly
157                if inner.consecutive_rate_limits > 1 {
158                    let base_delay = Duration::from_secs(1);
159                    let exponent = (inner.consecutive_rate_limits - 1).min(10);
160                    let backoff_secs =
161                        base_delay.as_secs_f64() * inner.backoff_multiplier.powi(exponent as i32);
162                    let backoff = Duration::from_secs_f64(backoff_secs);
163                    if backoff > wait_time {
164                        wait_time = backoff;
165                    }
166                    if wait_time > inner.max_backoff {
167                        wait_time = inner.max_backoff;
168                    }
169                }
170
171                tracing::info!(
172                    wait_ms = wait_time.as_millis() as u64,
173                    remaining = inner.remaining,
174                    "API rate limit enforced, waiting"
175                );
176
177                (wait_time, inner.reset_time)
178            };
179
180            tokio::time::sleep(wait_duration).await;
181
182            // Check if rate limit was extended while sleeping
183            let mut inner = self.inner.write().await;
184            if inner.reset_time > original_reset {
185                continue;
186            }
187            inner.rate_limited = false;
188            inner.last_request_time = Some(Instant::now());
189            return Ok(());
190        }
191    }
192
193    /// Updates rate limit state from API response headers.
194    ///
195    /// A successful response (with rate limit headers) confirms the client
196    /// is no longer in a consecutive rate-limit run, so the backoff counter
197    /// is reset.
198    pub async fn update_from_headers(&self, info: &RateLimitInfo) {
199        let mut inner = self.inner.write().await;
200        if info.limit > 0 {
201            inner.limit = info.limit;
202        }
203        inner.remaining = info.remaining;
204        if info.reset > Utc::now() {
205            inner.reset_time = info.reset;
206        }
207        // A successful response means we are no longer in a consecutive rate-limit run.
208        inner.rate_limited = false;
209        inner.consecutive_rate_limits = 0;
210        tracing::debug!(
211            limit = info.limit,
212            remaining = info.remaining,
213            "Rate limit updated from headers"
214        );
215    }
216
217    /// Marks that the API has returned a 429 rate-limit response.
218    pub async fn mark_rate_limited(&self, reset_time: DateTime<Utc>) {
219        let mut inner = self.inner.write().await;
220        inner.rate_limited = true;
221        inner.last_rate_limit_time = Some(Instant::now());
222        inner.consecutive_rate_limits += 1;
223        if reset_time > Utc::now() {
224            inner.reset_time = reset_time;
225        } else {
226            // No valid reset time from the API — use a safe default
227            inner.reset_time = Utc::now() + chrono::Duration::seconds(60);
228        }
229        tracing::info!(
230            consecutive = inner.consecutive_rate_limits,
231            "Marked as rate limited by API"
232        );
233    }
234
235    /// Returns a snapshot of the current rate limit status.
236    pub async fn get_status(&self) -> RateLimitStatus {
237        let inner = self.inner.read().await;
238        let reset_in = (inner.reset_time - Utc::now())
239            .to_std()
240            .unwrap_or(Duration::ZERO);
241        RateLimitStatus {
242            limit: inner.limit,
243            remaining: inner.remaining,
244            reset_time: inner.reset_time,
245            reset_in,
246        }
247    }
248
249    /// Returns `true` if usage has exceeded the given threshold (0.0–1.0).
250    pub async fn is_near_limit(&self, threshold: f64) -> bool {
251        let inner = self.inner.read().await;
252        if inner.limit == 0 {
253            return false;
254        }
255        let used = inner.limit.saturating_sub(inner.remaining) as f64 / inner.limit as f64;
256        used >= threshold
257    }
258
259    /// Returns `true` if the API has rate-limited us and the window hasn't reset.
260    pub async fn is_rate_limited(&self) -> bool {
261        let inner = self.inner.read().await;
262        inner.rate_limited && Utc::now() < inner.reset_time
263    }
264
265    /// Resets the rate limiter to its initial state.
266    pub async fn reset(&self) {
267        let mut inner = self.inner.write().await;
268        inner.remaining = inner.limit;
269        inner.reset_time = Utc::now() + chrono::Duration::hours(1);
270        inner.last_request_time = None;
271        inner.rate_limited = false;
272        inner.consecutive_rate_limits = 0;
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[tokio::test]
281    async fn test_new_default_config() {
282        let rl = RateLimiter::new(&RateLimiterConfig::default());
283        let status = rl.get_status().await;
284        assert_eq!(status.limit, 100);
285        assert_eq!(status.remaining, 100);
286    }
287
288    #[tokio::test]
289    async fn test_should_wait_initially_false() {
290        let rl = RateLimiter::new(&RateLimiterConfig::default());
291        assert!(!rl.should_wait().await);
292    }
293
294    #[tokio::test]
295    async fn test_mark_rate_limited() {
296        let rl = RateLimiter::new(&RateLimiterConfig::default());
297        assert!(!rl.is_rate_limited().await);
298        let reset = Utc::now() + chrono::Duration::minutes(5);
299        rl.mark_rate_limited(reset).await;
300        assert!(rl.is_rate_limited().await);
301        assert!(rl.should_wait().await);
302    }
303
304    #[tokio::test]
305    async fn test_update_from_headers() {
306        let rl = RateLimiter::new(&RateLimiterConfig::default());
307        let info = RateLimitInfo {
308            limit: 200,
309            remaining: 150,
310            reset: Utc::now() + chrono::Duration::hours(1),
311            retry_after: None,
312        };
313        rl.update_from_headers(&info).await;
314        let status = rl.get_status().await;
315        assert_eq!(status.limit, 200);
316        assert_eq!(status.remaining, 150);
317    }
318
319    #[tokio::test]
320    async fn test_is_near_limit() {
321        let rl = RateLimiter::new(&RateLimiterConfig {
322            initial_limit: 100,
323            ..Default::default()
324        });
325        assert!(!rl.is_near_limit(0.8).await);
326        let info = RateLimitInfo {
327            limit: 100,
328            remaining: 10,
329            reset: Utc::now() + chrono::Duration::hours(1),
330            retry_after: None,
331        };
332        rl.update_from_headers(&info).await;
333        assert!(rl.is_near_limit(0.8).await);
334    }
335
336    #[tokio::test]
337    async fn test_reset() {
338        let rl = RateLimiter::new(&RateLimiterConfig::default());
339        rl.mark_rate_limited(Utc::now() + chrono::Duration::minutes(5))
340            .await;
341        assert!(rl.is_rate_limited().await);
342        rl.reset().await;
343        assert!(!rl.is_rate_limited().await);
344        let status = rl.get_status().await;
345        assert_eq!(status.remaining, status.limit);
346    }
347
348    #[tokio::test]
349    async fn test_wait_not_rate_limited() {
350        let rl = RateLimiter::new(&RateLimiterConfig::default());
351        // Should return immediately when not rate-limited
352        rl.wait().await.unwrap();
353    }
354
355    #[tokio::test]
356    async fn test_disable_enable() {
357        let rl = RateLimiter::new(&RateLimiterConfig::default());
358        let reset = Utc::now() + chrono::Duration::minutes(5);
359        rl.mark_rate_limited(reset).await;
360        assert!(rl.should_wait().await);
361
362        // Disable bypasses rate limiting
363        rl.disable().await;
364        assert!(!rl.should_wait().await);
365
366        // Re-enable restores rate limiting
367        rl.enable().await;
368        assert!(rl.should_wait().await);
369    }
370}