rust_web_server/rate_limit/mod.rs
1#[cfg(test)]
2mod tests;
3
4use std::collections::{HashMap, VecDeque};
5use std::sync::{Mutex, OnceLock};
6use std::time::{Duration, Instant};
7
8/// A sliding-window per-key rate limiter.
9///
10/// Each call to [`RateLimiter::check`] records a timestamp for the given key
11/// and returns `true` when the number of calls within the current window is
12/// still below `max_requests`. Returns `false` once the limit is exceeded.
13///
14/// Thread-safe: the internal state is behind a `Mutex` so it can be shared
15/// across threads via [`global`] or wrapped in an `Arc`.
16///
17/// # Example
18///
19/// ```rust,no_run
20/// use rust_web_server::rate_limit::RateLimiter;
21///
22/// let limiter = RateLimiter::new(100, 60); // 100 req / 60 s
23///
24/// if limiter.check("192.168.1.1") {
25/// // process request
26/// } else {
27/// // return 429 Too Many Requests
28/// }
29/// ```
30pub struct RateLimiter {
31 state: Mutex<HashMap<String, VecDeque<Instant>>>,
32 max_requests: u32,
33 window: Duration,
34}
35
36impl RateLimiter {
37 /// Create a new limiter allowing `max_requests` per `window_secs`-second window.
38 pub fn new(max_requests: u32, window_secs: u64) -> Self {
39 RateLimiter {
40 state: Mutex::new(HashMap::new()),
41 max_requests,
42 window: Duration::from_secs(window_secs),
43 }
44 }
45
46 /// Returns `true` if `key` (typically a client IP) is within the rate limit,
47 /// or `false` if the limit has been exceeded.
48 ///
49 /// A permitted call is always recorded so it counts toward future limits.
50 pub fn check(&self, key: &str) -> bool {
51 let now = Instant::now();
52 let mut guard = self.state.lock().unwrap();
53 let window = self.window;
54 let timestamps = guard.entry(key.to_string()).or_default();
55
56 // Drop timestamps older than the window.
57 while timestamps.front().map(|t| now.duration_since(*t) > window).unwrap_or(false) {
58 timestamps.pop_front();
59 }
60
61 if (timestamps.len() as u32) < self.max_requests {
62 timestamps.push_back(now);
63 true
64 } else {
65 false
66 }
67 }
68
69 /// Number of remaining requests `key` may make within the current window.
70 pub fn remaining(&self, key: &str) -> u32 {
71 let now = Instant::now();
72 let mut guard = self.state.lock().unwrap();
73 let window = self.window;
74 let timestamps = guard.entry(key.to_string()).or_default();
75 while timestamps.front().map(|t| now.duration_since(*t) > window).unwrap_or(false) {
76 timestamps.pop_front();
77 }
78 self.max_requests.saturating_sub(timestamps.len() as u32)
79 }
80
81 /// Remove all tracked state for `key`. Useful in tests.
82 pub fn reset(&self, key: &str) {
83 self.state.lock().unwrap().remove(key);
84 }
85}
86
87static GLOBAL_LIMITER: OnceLock<RateLimiter> = OnceLock::new();
88
89/// Return the process-wide rate limiter, initialized from environment variables.
90///
91/// | Variable | Default | Meaning |
92/// |---|---|---|
93/// | `RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS` | `1000` | Requests allowed per window |
94/// | `RWS_CONFIG_RATE_LIMIT_WINDOW_SECS` | `60` | Window length in seconds |
95///
96/// Returns `None` when rate limiting is disabled (`RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS=0`).
97pub fn global() -> &'static RateLimiter {
98 GLOBAL_LIMITER.get_or_init(|| {
99 let max: u32 = std::env::var("RWS_CONFIG_RATE_LIMIT_MAX_REQUESTS")
100 .ok()
101 .and_then(|v| v.parse().ok())
102 .unwrap_or(1000);
103 let window: u64 = std::env::var("RWS_CONFIG_RATE_LIMIT_WINDOW_SECS")
104 .ok()
105 .and_then(|v| v.parse().ok())
106 .unwrap_or(60);
107 RateLimiter::new(max, window)
108 })
109}