umbral_core/ratelimit.rs
1//! A dependency-light, in-memory **sliding-window rate limiter**.
2//!
3//! The single sliding-window limiter in the tree: it backs umbral-rest's
4//! API throttles ([`umbral_rest::throttle`]) AND umbral-auth's
5//! login/register brute-force throttle (`plugins/umbral-auth/src/throttle.rs`,
6//! consolidated onto this primitive — see the note below). It
7//! tracks per-key timestamps in a `Mutex<HashMap<String, VecDeque<Instant>>>`
8//! and answers one question: *is this key under its rate right now?*
9//!
10//! ```ignore
11//! use std::time::Duration;
12//! use umbral::ratelimit::{Rate, RateLimiter};
13//!
14//! let limiter = RateLimiter::new(Rate::parse("100/hour").unwrap());
15//! let decision = limiter.check("203.0.113.7");
16//! if !decision.allowed {
17//! // 429; tell the client when to come back
18//! let secs = decision.retry_after.map(|d| d.as_secs()).unwrap_or(0);
19//! }
20//! ```
21//!
22//! ## The window
23//!
24//! "Sliding window" means each `check` first prunes every recorded
25//! timestamp older than `rate.period` from now, then counts what's left.
26//! If the count is below `rate.num`, the call is allowed *and recorded*;
27//! otherwise it's denied and the limiter computes `retry_after` as the
28//! time until the oldest still-in-window entry ages out (the moment a
29//! slot frees up). There's no fixed-window edge burst: the window moves
30//! continuously with the clock.
31//!
32//! ## Scope and limits
33//!
34//! - **In-memory, single-process.** State lives in this process's heap.
35//! A multi-instance deployment behind a load balancer gives each
36//! replica its own counters; the effective limit is `num × replicas`.
37//! A Redis-backed store is the multi-instance follow-up (mirrors the
38//! same gap `umbral-auth`'s throttle has).
39//! - **Unbounded key set.** The `HashMap` grows one entry per distinct
40//! key and entries are pruned lazily on next `check` of that key, never
41//! swept globally. For IP/user keys on a normal app this is bounded by
42//! the active client set; an adversarial key explosion is a known edge
43//! (the same shape `umbral-auth`'s throttle has) — a periodic sweep is a
44//! future hardening.
45//!
46//! ## Consolidated: `umbral-auth::throttle` adopts this primitive
47//!
48//! `umbral-auth` once shipped its own bespoke login/register throttle
49//! (`plugins/umbral-auth/src/throttle.rs`) written before this primitive
50//! existed, with a hand-rolled copy of the same sliding-window-per-key idea.
51//! That duplicate is gone: `umbral-auth::throttle::Throttle` is now a thin
52//! wrapper over [`RateLimiter`], so there's a single limiter implementation
53//! in the tree. The "success forgives" path (clear a login counter after a
54//! successful login) drove the [`RateLimiter::clear`] method added here.
55//! Done in `planning/gaps2.md` (#90).
56
57use std::collections::{HashMap, VecDeque};
58use std::sync::Mutex;
59use std::time::{Duration, Instant};
60
61/// A rate: `num` events per `period`. Build by hand or parse the
62/// `"<num>/<period>"` string with [`Rate::parse`].
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct Rate {
65 /// Maximum number of events allowed within one `period`.
66 pub num: u32,
67 /// The sliding window length.
68 pub period: Duration,
69}
70
71impl Rate {
72 /// Construct directly from a count and a window.
73 pub fn new(num: u32, period: Duration) -> Self {
74 Self { num, period }
75 }
76
77 /// Parse a rate string: `"<num>/<period>"`.
78 ///
79 /// `num` is a positive integer; `period` is one of (case-insensitive):
80 ///
81 /// | period token | window |
82 /// |---|---|
83 /// | `sec`, `s`, `second` | 1 second |
84 /// | `min`, `m`, `minute` | 60 seconds |
85 /// | `hour`, `h` | 3600 seconds |
86 /// | `day`, `d` | 86400 seconds |
87 ///
88 /// A bare number with no separator is also accepted as a per-second
89 /// rate (the `"<num>"` shorthand), e.g. `"5"` ≡ `"5/sec"`. Anything
90 /// else — empty string, non-numeric count, zero count, unknown period
91 /// — returns `Err` with a short message.
92 ///
93 /// ```
94 /// # use std::time::Duration;
95 /// # use umbral_core::ratelimit::Rate;
96 /// assert_eq!(Rate::parse("100/hour").unwrap().num, 100);
97 /// assert_eq!(Rate::parse("10/min").unwrap().period, Duration::from_secs(60));
98 /// assert!(Rate::parse("oops").is_err());
99 /// ```
100 pub fn parse(s: &str) -> Result<Self, String> {
101 let s = s.trim();
102 if s.is_empty() {
103 return Err("empty rate string".to_string());
104 }
105 let (num_part, period_part) = match s.split_once('/') {
106 Some((n, p)) => (n.trim(), p.trim()),
107 // Bare number → per-second (shorthand).
108 None => (s, "sec"),
109 };
110 let num: u32 = num_part
111 .parse()
112 .map_err(|_| format!("invalid rate count `{num_part}` in `{s}`"))?;
113 if num == 0 {
114 return Err(format!("rate count must be positive in `{s}`"));
115 }
116 let period = match period_part.to_ascii_lowercase().as_str() {
117 "sec" | "s" | "second" => Duration::from_secs(1),
118 "min" | "m" | "minute" => Duration::from_secs(60),
119 "hour" | "h" => Duration::from_secs(3600),
120 "day" | "d" => Duration::from_secs(86_400),
121 other => return Err(format!("unknown rate period `{other}` in `{s}`")),
122 };
123 Ok(Self { num, period })
124 }
125}
126
127/// The verdict for one [`RateLimiter::check`].
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub struct RateDecision {
130 /// `true` when the request is under the limit (and was recorded);
131 /// `false` when it's over (and was NOT recorded).
132 pub allowed: bool,
133 /// On a denial, how long until a slot frees up — the time until the
134 /// oldest in-window entry ages out. `None` when `allowed` is `true`.
135 pub retry_after: Option<Duration>,
136 /// The configured ceiling (`Rate::num`). Useful for an
137 /// `X-RateLimit-Limit` header.
138 pub limit: u32,
139 /// How many requests remain in the current window AFTER this one.
140 /// `0` on a denial.
141 pub remaining: u32,
142}
143
144/// An in-memory sliding-window rate limiter, keyed by an arbitrary
145/// string (IP, user id, scope-qualified key — the caller decides).
146///
147/// Cheap to clone the configured [`Rate`]; the shared counter map sits
148/// behind a `Mutex` so a single `RateLimiter` can back many concurrent
149/// requests. Wrap in an `Arc` to share across handlers.
150#[derive(Debug)]
151pub struct RateLimiter {
152 rate: Rate,
153 buckets: Mutex<HashMap<String, VecDeque<Instant>>>,
154}
155
156impl RateLimiter {
157 /// Build a limiter enforcing `rate`.
158 pub fn new(rate: Rate) -> Self {
159 Self {
160 rate,
161 buckets: Mutex::new(HashMap::new()),
162 }
163 }
164
165 /// The configured rate.
166 pub fn rate(&self) -> Rate {
167 self.rate
168 }
169
170 /// Check (and, if allowed, record) one request for `key` against the
171 /// configured rate, using the real wall clock.
172 ///
173 /// See [`Self::check_at`] for the deterministic, clock-injectable
174 /// variant the tests drive.
175 pub fn check(&self, key: &str) -> RateDecision {
176 self.check_at(key, Instant::now())
177 }
178
179 /// Clock-injectable core: identical to [`Self::check`] but the caller
180 /// supplies `now`. Private-ish (crate-visible) so deterministic tests
181 /// can advance time without sleeping; production always routes through
182 /// [`Self::check`] with `Instant::now()`.
183 pub fn check_at(&self, key: &str, now: Instant) -> RateDecision {
184 let window = self.rate.period;
185 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
186 let entries = buckets.entry(key.to_string()).or_default();
187
188 // Prune everything older than the window — the "sliding" step.
189 // `now.checked_duration_since` guards against a clock that didn't
190 // advance (or a stamp in the future); treat un-orderable stamps
191 // as in-window (conservative: never silently drop a recent hit).
192 while let Some(front) = entries.front() {
193 match now.checked_duration_since(*front) {
194 Some(age) if age >= window => {
195 entries.pop_front();
196 }
197 _ => break,
198 }
199 }
200
201 let count = entries.len() as u32;
202 if count < self.rate.num {
203 entries.push_back(now);
204 RateDecision {
205 allowed: true,
206 retry_after: None,
207 limit: self.rate.num,
208 remaining: self.rate.num - count - 1,
209 }
210 } else {
211 // Over the limit. A slot frees when the OLDEST in-window entry
212 // ages out: that's `window - (now - oldest)`. The prune above
213 // guarantees the front is still within the window, so the
214 // subtraction is non-negative; saturate to be safe.
215 let retry_after = entries
216 .front()
217 .and_then(|oldest| now.checked_duration_since(*oldest))
218 .map(|age| window.saturating_sub(age))
219 .unwrap_or(window);
220 RateDecision {
221 allowed: false,
222 retry_after: Some(retry_after),
223 limit: self.rate.num,
224 remaining: 0,
225 }
226 }
227 }
228
229 /// Forget every recorded request for `key`, resetting its window so the
230 /// next [`check`](Self::check) starts from a clean budget.
231 ///
232 /// The "success forgives" primitive: a caller that wants a prior burst of
233 /// denied attempts to stop counting after some positive outcome (e.g.
234 /// umbral-auth clears the login counter on a SUCCESSFUL login so a user who
235 /// fat-fingered their password isn't locked out) calls this to drop the
236 /// key's history. A no-op if the key was never seen.
237 pub fn clear(&self, key: &str) {
238 let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
239 buckets.remove(key);
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn parse_each_period() {
249 assert_eq!(
250 Rate::parse("1/sec").unwrap().period,
251 Duration::from_secs(1)
252 );
253 assert_eq!(Rate::parse("1/s").unwrap().period, Duration::from_secs(1));
254 assert_eq!(
255 Rate::parse("1/second").unwrap().period,
256 Duration::from_secs(1)
257 );
258 assert_eq!(
259 Rate::parse("1/min").unwrap().period,
260 Duration::from_secs(60)
261 );
262 assert_eq!(
263 Rate::parse("1/hour").unwrap().period,
264 Duration::from_secs(3600)
265 );
266 assert_eq!(
267 Rate::parse("1/day").unwrap().period,
268 Duration::from_secs(86_400)
269 );
270 }
271
272 #[test]
273 fn parse_rejects_garbage() {
274 assert!(Rate::parse("").is_err());
275 assert!(Rate::parse("oops").is_err());
276 assert!(Rate::parse("10/fortnight").is_err());
277 assert!(Rate::parse("0/sec").is_err());
278 assert!(Rate::parse("abc/min").is_err());
279 }
280
281 #[test]
282 fn third_request_in_window_denied() {
283 let limiter = RateLimiter::new(Rate::parse("2/min").unwrap());
284 let t0 = Instant::now();
285 let d1 = limiter.check_at("a", t0);
286 assert!(d1.allowed);
287 assert_eq!(d1.remaining, 1);
288 let d2 = limiter.check_at("a", t0 + Duration::from_secs(1));
289 assert!(d2.allowed);
290 assert_eq!(d2.remaining, 0);
291 let d3 = limiter.check_at("a", t0 + Duration::from_secs(2));
292 assert!(!d3.allowed);
293 assert!(d3.retry_after.is_some());
294 // Slot frees 60s after the FIRST hit, i.e. 58s from t0+2s.
295 assert_eq!(d3.retry_after.unwrap(), Duration::from_secs(58));
296 }
297
298 #[test]
299 fn distinct_keys_are_independent() {
300 let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
301 let t0 = Instant::now();
302 assert!(limiter.check_at("a", t0).allowed);
303 // Key "b" has its own bucket — not affected by "a" being full.
304 assert!(limiter.check_at("b", t0).allowed);
305 // "a" is now over its 1/min.
306 assert!(!limiter.check_at("a", t0).allowed);
307 }
308
309 #[test]
310 fn allowed_again_after_window_elapses() {
311 let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
312 let t0 = Instant::now();
313 assert!(limiter.check_at("a", t0).allowed);
314 assert!(!limiter.check_at("a", t0 + Duration::from_secs(30)).allowed);
315 // 61s later the original hit has aged out of the 60s window.
316 assert!(limiter.check_at("a", t0 + Duration::from_secs(61)).allowed);
317 }
318
319 #[test]
320 fn clear_forgets_a_key() {
321 let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
322 let t0 = Instant::now();
323 assert!(limiter.check_at("a", t0).allowed);
324 // Over budget within the window.
325 assert!(!limiter.check_at("a", t0).allowed);
326 // Clearing the key drops its history, so the next check is allowed.
327 limiter.clear("a");
328 assert!(limiter.check_at("a", t0).allowed);
329 // A clear on an unknown key is a harmless no-op.
330 limiter.clear("never-seen");
331 }
332}