Skip to main content

rust_web_server/rate_limit/
mod.rs

1#[cfg(test)]
2mod tests;
3
4use std::collections::{HashMap, VecDeque};
5use std::sync::{Mutex, OnceLock};
6use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8
9/// A sliding-window per-key rate limiter.
10///
11/// Each call to [`RateLimiter::check`] records a timestamp for the given key
12/// and returns `true` when the number of calls within the current window is
13/// still below `max_requests`. Returns `false` once the limit is exceeded.
14///
15/// Thread-safe: the internal state is behind a `Mutex` so it can be shared
16/// across threads via [`global`] or wrapped in an `Arc`.
17///
18/// # Example
19///
20/// ```rust,no_run
21/// use rust_web_server::rate_limit::RateLimiter;
22///
23/// let limiter = RateLimiter::new(100, 60); // 100 req / 60 s
24///
25/// if limiter.check("192.168.1.1") {
26///     // process request
27/// } else {
28///     // return 429 Too Many Requests
29/// }
30/// ```
31pub struct RateLimiter {
32    state: Mutex<HashMap<String, VecDeque<Instant>>>,
33    max_requests: AtomicU32,
34    window_secs: AtomicU64,
35}
36
37impl RateLimiter {
38    /// Create a new limiter allowing `max_requests` per `window_secs`-second window.
39    pub fn new(max_requests: u32, window_secs: u64) -> Self {
40        RateLimiter {
41            state: Mutex::new(HashMap::new()),
42            max_requests: AtomicU32::new(max_requests),
43            window_secs: AtomicU64::new(window_secs),
44        }
45    }
46
47    /// Update the limits on a live limiter without restarting.
48    ///
49    /// Changes take effect on the next call to [`check`] or [`remaining`].
50    /// Called automatically by [`crate::config_reload::reload`] on SIGHUP.
51    pub fn set_limits(&self, max_requests: u32, window_secs: u64) {
52        self.max_requests.store(max_requests, Ordering::Relaxed);
53        self.window_secs.store(window_secs, Ordering::Relaxed);
54    }
55
56    fn window(&self) -> Duration {
57        Duration::from_secs(self.window_secs.load(Ordering::Relaxed))
58    }
59
60    fn max(&self) -> u32 {
61        self.max_requests.load(Ordering::Relaxed)
62    }
63
64    /// Returns `true` if `key` (typically a client IP) is within the rate limit,
65    /// or `false` if the limit has been exceeded.
66    ///
67    /// A permitted call is always recorded so it counts toward future limits.
68    pub fn check(&self, key: &str) -> bool {
69        let now = Instant::now();
70        let window = self.window();
71        let max = self.max();
72        let mut guard = self.state.lock().unwrap();
73        let timestamps = guard.entry(key.to_string()).or_default();
74
75        // Drop timestamps older than the window.
76        while timestamps.front().map(|t| now.duration_since(*t) > window).unwrap_or(false) {
77            timestamps.pop_front();
78        }
79
80        if (timestamps.len() as u32) < max {
81            timestamps.push_back(now);
82            true
83        } else {
84            false
85        }
86    }
87
88    /// Number of remaining requests `key` may make within the current window.
89    pub fn remaining(&self, key: &str) -> u32 {
90        let now = Instant::now();
91        let window = self.window();
92        let max = self.max();
93        let mut guard = self.state.lock().unwrap();
94        let timestamps = guard.entry(key.to_string()).or_default();
95        while timestamps.front().map(|t| now.duration_since(*t) > window).unwrap_or(false) {
96            timestamps.pop_front();
97        }
98        max.saturating_sub(timestamps.len() as u32)
99    }
100
101    /// Remove all tracked state for `key`. Useful in tests.
102    pub fn reset(&self, key: &str) {
103        self.state.lock().unwrap().remove(key);
104    }
105}
106
107static GLOBAL_LIMITER: OnceLock<RateLimiter> = OnceLock::new();
108
109/// Return the process-wide rate limiter, initialized from environment variables.
110///
111/// | Variable | Default | Meaning |
112/// |---|---|---|
113/// | `RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS` | `1000` | Requests allowed per window |
114/// | `RWS_CONFIG_RATE_LIMIT_WINDOW_SECS` | `60` | Window length in seconds |
115///
116/// Returns `None` when rate limiting is disabled (`RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS=0`).
117pub fn global() -> &'static RateLimiter {
118    GLOBAL_LIMITER.get_or_init(|| {
119        let max: u32 = std::env::var("RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS")
120            .ok()
121            .and_then(|v| v.parse().ok())
122            .unwrap_or(1000);
123        let window: u64 = std::env::var("RWS_CONFIG_RATE_LIMIT_WINDOW_SECS")
124            .ok()
125            .and_then(|v| v.parse().ok())
126            .unwrap_or(60);
127        RateLimiter::new(max, window)
128    })
129}