Skip to main content

rust_web_server/rate_limit/
mod.rs

1#[cfg(test)]
2mod tests;
3
4use std::collections::{HashMap, VecDeque};
5use std::sync::{Mutex, OnceLock};
6use std::time::{Duration, Instant};
7
8/// A sliding-window per-key rate limiter.
9///
10/// Each call to [`RateLimiter::check`] records a timestamp for the given key
11/// and returns `true` when the number of calls within the current window is
12/// still below `max_requests`. Returns `false` once the limit is exceeded.
13///
14/// Thread-safe: the internal state is behind a `Mutex` so it can be shared
15/// across threads via [`global`] or wrapped in an `Arc`.
16///
17/// # Example
18///
19/// ```rust,no_run
20/// use rust_web_server::rate_limit::RateLimiter;
21///
22/// let limiter = RateLimiter::new(100, 60); // 100 req / 60 s
23///
24/// if limiter.check("192.168.1.1") {
25///     // process request
26/// } else {
27///     // return 429 Too Many Requests
28/// }
29/// ```
30pub struct RateLimiter {
31    state: Mutex<HashMap<String, VecDeque<Instant>>>,
32    max_requests: u32,
33    window: Duration,
34}
35
36impl RateLimiter {
37    /// Create a new limiter allowing `max_requests` per `window_secs`-second window.
38    pub fn new(max_requests: u32, window_secs: u64) -> Self {
39        RateLimiter {
40            state: Mutex::new(HashMap::new()),
41            max_requests,
42            window: Duration::from_secs(window_secs),
43        }
44    }
45
46    /// Returns `true` if `key` (typically a client IP) is within the rate limit,
47    /// or `false` if the limit has been exceeded.
48    ///
49    /// A permitted call is always recorded so it counts toward future limits.
50    pub fn check(&self, key: &str) -> bool {
51        let now = Instant::now();
52        let mut guard = self.state.lock().unwrap();
53        let window = self.window;
54        let timestamps = guard.entry(key.to_string()).or_default();
55
56        // Drop timestamps older than the window.
57        while timestamps.front().map(|t| now.duration_since(*t) > window).unwrap_or(false) {
58            timestamps.pop_front();
59        }
60
61        if (timestamps.len() as u32) < self.max_requests {
62            timestamps.push_back(now);
63            true
64        } else {
65            false
66        }
67    }
68
69    /// Number of remaining requests `key` may make within the current window.
70    pub fn remaining(&self, key: &str) -> u32 {
71        let now = Instant::now();
72        let mut guard = self.state.lock().unwrap();
73        let window = self.window;
74        let timestamps = guard.entry(key.to_string()).or_default();
75        while timestamps.front().map(|t| now.duration_since(*t) > window).unwrap_or(false) {
76            timestamps.pop_front();
77        }
78        self.max_requests.saturating_sub(timestamps.len() as u32)
79    }
80
81    /// Remove all tracked state for `key`. Useful in tests.
82    pub fn reset(&self, key: &str) {
83        self.state.lock().unwrap().remove(key);
84    }
85}
86
87static GLOBAL_LIMITER: OnceLock<RateLimiter> = OnceLock::new();
88
89/// Return the process-wide rate limiter, initialized from environment variables.
90///
91/// | Variable | Default | Meaning |
92/// |---|---|---|
93/// | `RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS` | `1000` | Requests allowed per window |
94/// | `RWS_CONFIG_RATE_LIMIT_WINDOW_SECS` | `60` | Window length in seconds |
95///
96/// Returns `None` when rate limiting is disabled (`RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS=0`).
97pub fn global() -> &'static RateLimiter {
98    GLOBAL_LIMITER.get_or_init(|| {
99        let max: u32 = std::env::var("RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS")
100            .ok()
101            .and_then(|v| v.parse().ok())
102            .unwrap_or(1000);
103        let window: u64 = std::env::var("RWS_CONFIG_RATE_LIMIT_WINDOW_SECS")
104            .ok()
105            .and_then(|v| v.parse().ok())
106            .unwrap_or(60);
107        RateLimiter::new(max, window)
108    })
109}