Skip to main content

umbral_core/
ratelimit.rs

1//! A dependency-light, in-memory **sliding-window rate limiter**.
2//!
3//! The single sliding-window limiter in the tree: it backs umbral-rest's
4//! API throttles ([`umbral_rest::throttle`]) AND umbral-auth's
5//! login/register brute-force throttle (`plugins/umbral-auth/src/throttle.rs`,
6//! consolidated onto this primitive — see the note below). It
7//! tracks per-key timestamps in a `Mutex<HashMap<String, VecDeque<Instant>>>`
8//! and answers one question: *is this key under its rate right now?*
9//!
10//! ```ignore
11//! use std::time::Duration;
12//! use umbral::ratelimit::{Rate, RateLimiter};
13//!
14//! let limiter = RateLimiter::new(Rate::parse("100/hour").unwrap());
15//! let decision = limiter.check("203.0.113.7");
16//! if !decision.allowed {
17//!     // 429; tell the client when to come back
18//!     let secs = decision.retry_after.map(|d| d.as_secs()).unwrap_or(0);
19//! }
20//! ```
21//!
22//! ## The window
23//!
24//! "Sliding window" means each `check` first prunes every recorded
25//! timestamp older than `rate.period` from now, then counts what's left.
26//! If the count is below `rate.num`, the call is allowed *and recorded*;
27//! otherwise it's denied and the limiter computes `retry_after` as the
28//! time until the oldest still-in-window entry ages out (the moment a
29//! slot frees up). There's no fixed-window edge burst: the window moves
30//! continuously with the clock.
31//!
32//! ## Scope and limits
33//!
34//! - **In-memory, single-process.** State lives in this process's heap.
35//!   A multi-instance deployment behind a load balancer gives each
36//!   replica its own counters; the effective limit is `num × replicas`.
37//!   A Redis-backed store is the multi-instance follow-up (mirrors the
38//!   same gap `umbral-auth`'s throttle has).
39//! - **Unbounded key set.** The `HashMap` grows one entry per distinct
40//!   key and entries are pruned lazily on next `check` of that key, never
41//!   swept globally. For IP/user keys on a normal app this is bounded by
42//!   the active client set; an adversarial key explosion is a known edge
43//!   (the same shape `umbral-auth`'s throttle has) — a periodic sweep is a
44//!   future hardening.
45//!
46//! ## Consolidated: `umbral-auth::throttle` adopts this primitive
47//!
48//! `umbral-auth` once shipped its own bespoke login/register throttle
49//! (`plugins/umbral-auth/src/throttle.rs`) written before this primitive
50//! existed, with a hand-rolled copy of the same sliding-window-per-key idea.
51//! That duplicate is gone: `umbral-auth::throttle::Throttle` is now a thin
52//! wrapper over [`RateLimiter`], so there's a single limiter implementation
53//! in the tree. The "success forgives" path (clear a login counter after a
54//! successful login) drove the [`RateLimiter::clear`] method added here.
55//! Done in `planning/gaps2.md` (#90).
56
57use std::collections::{HashMap, VecDeque};
58use std::sync::Mutex;
59use std::time::{Duration, Instant};
60
61/// A rate: `num` events per `period`. Build by hand or parse the
62/// `"<num>/<period>"` string with [`Rate::parse`].
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct Rate {
65    /// Maximum number of events allowed within one `period`.
66    pub num: u32,
67    /// The sliding window length.
68    pub period: Duration,
69}
70
71impl Rate {
72    /// Construct directly from a count and a window.
73    pub fn new(num: u32, period: Duration) -> Self {
74        Self { num, period }
75    }
76
77    /// Parse a rate string: `"<num>/<period>"`.
78    ///
79    /// `num` is a positive integer; `period` is one of (case-insensitive):
80    ///
81    /// | period token | window |
82    /// |---|---|
83    /// | `sec`, `s`, `second` | 1 second |
84    /// | `min`, `m`, `minute` | 60 seconds |
85    /// | `hour`, `h` | 3600 seconds |
86    /// | `day`, `d` | 86400 seconds |
87    ///
88    /// A bare number with no separator is also accepted as a per-second
89    /// rate (the `"<num>"` shorthand), e.g. `"5"` ≡ `"5/sec"`. Anything
90    /// else — empty string, non-numeric count, zero count, unknown period
91    /// — returns `Err` with a short message.
92    ///
93    /// ```
94    /// # use std::time::Duration;
95    /// # use umbral_core::ratelimit::Rate;
96    /// assert_eq!(Rate::parse("100/hour").unwrap().num, 100);
97    /// assert_eq!(Rate::parse("10/min").unwrap().period, Duration::from_secs(60));
98    /// assert!(Rate::parse("oops").is_err());
99    /// ```
100    pub fn parse(s: &str) -> Result<Self, String> {
101        let s = s.trim();
102        if s.is_empty() {
103            return Err("empty rate string".to_string());
104        }
105        let (num_part, period_part) = match s.split_once('/') {
106            Some((n, p)) => (n.trim(), p.trim()),
107            // Bare number → per-second (shorthand).
108            None => (s, "sec"),
109        };
110        let num: u32 = num_part
111            .parse()
112            .map_err(|_| format!("invalid rate count `{num_part}` in `{s}`"))?;
113        if num == 0 {
114            return Err(format!("rate count must be positive in `{s}`"));
115        }
116        let period = match period_part.to_ascii_lowercase().as_str() {
117            "sec" | "s" | "second" => Duration::from_secs(1),
118            "min" | "m" | "minute" => Duration::from_secs(60),
119            "hour" | "h" => Duration::from_secs(3600),
120            "day" | "d" => Duration::from_secs(86_400),
121            other => return Err(format!("unknown rate period `{other}` in `{s}`")),
122        };
123        Ok(Self { num, period })
124    }
125}
126
127/// The verdict for one [`RateLimiter::check`].
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub struct RateDecision {
130    /// `true` when the request is under the limit (and was recorded);
131    /// `false` when it's over (and was NOT recorded).
132    pub allowed: bool,
133    /// On a denial, how long until a slot frees up — the time until the
134    /// oldest in-window entry ages out. `None` when `allowed` is `true`.
135    pub retry_after: Option<Duration>,
136    /// The configured ceiling (`Rate::num`). Useful for an
137    /// `X-RateLimit-Limit` header.
138    pub limit: u32,
139    /// How many requests remain in the current window AFTER this one.
140    /// `0` on a denial.
141    pub remaining: u32,
142}
143
144/// An in-memory sliding-window rate limiter, keyed by an arbitrary
145/// string (IP, user id, scope-qualified key — the caller decides).
146///
147/// Cheap to clone the configured [`Rate`]; the shared counter map sits
148/// behind a `Mutex` so a single `RateLimiter` can back many concurrent
149/// requests. Wrap in an `Arc` to share across handlers.
150#[derive(Debug)]
151pub struct RateLimiter {
152    rate: Rate,
153    buckets: Mutex<HashMap<String, VecDeque<Instant>>>,
154}
155
156impl RateLimiter {
157    /// Build a limiter enforcing `rate`.
158    pub fn new(rate: Rate) -> Self {
159        Self {
160            rate,
161            buckets: Mutex::new(HashMap::new()),
162        }
163    }
164
165    /// The configured rate.
166    pub fn rate(&self) -> Rate {
167        self.rate
168    }
169
170    /// Check (and, if allowed, record) one request for `key` against the
171    /// configured rate, using the real wall clock.
172    ///
173    /// See [`Self::check_at`] for the deterministic, clock-injectable
174    /// variant the tests drive.
175    pub fn check(&self, key: &str) -> RateDecision {
176        self.check_at(key, Instant::now())
177    }
178
179    /// Clock-injectable core: identical to [`Self::check`] but the caller
180    /// supplies `now`. Private-ish (crate-visible) so deterministic tests
181    /// can advance time without sleeping; production always routes through
182    /// [`Self::check`] with `Instant::now()`.
183    pub fn check_at(&self, key: &str, now: Instant) -> RateDecision {
184        let window = self.rate.period;
185        let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
186        let entries = buckets.entry(key.to_string()).or_default();
187
188        // Prune everything older than the window — the "sliding" step.
189        // `now.checked_duration_since` guards against a clock that didn't
190        // advance (or a stamp in the future); treat un-orderable stamps
191        // as in-window (conservative: never silently drop a recent hit).
192        while let Some(front) = entries.front() {
193            match now.checked_duration_since(*front) {
194                Some(age) if age >= window => {
195                    entries.pop_front();
196                }
197                _ => break,
198            }
199        }
200
201        let count = entries.len() as u32;
202        if count < self.rate.num {
203            entries.push_back(now);
204            RateDecision {
205                allowed: true,
206                retry_after: None,
207                limit: self.rate.num,
208                remaining: self.rate.num - count - 1,
209            }
210        } else {
211            // Over the limit. A slot frees when the OLDEST in-window entry
212            // ages out: that's `window - (now - oldest)`. The prune above
213            // guarantees the front is still within the window, so the
214            // subtraction is non-negative; saturate to be safe.
215            let retry_after = entries
216                .front()
217                .and_then(|oldest| now.checked_duration_since(*oldest))
218                .map(|age| window.saturating_sub(age))
219                .unwrap_or(window);
220            RateDecision {
221                allowed: false,
222                retry_after: Some(retry_after),
223                limit: self.rate.num,
224                remaining: 0,
225            }
226        }
227    }
228
229    /// Forget every recorded request for `key`, resetting its window so the
230    /// next [`check`](Self::check) starts from a clean budget.
231    ///
232    /// The "success forgives" primitive: a caller that wants a prior burst of
233    /// denied attempts to stop counting after some positive outcome (e.g.
234    /// umbral-auth clears the login counter on a SUCCESSFUL login so a user who
235    /// fat-fingered their password isn't locked out) calls this to drop the
236    /// key's history. A no-op if the key was never seen.
237    pub fn clear(&self, key: &str) {
238        let mut buckets = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
239        buckets.remove(key);
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn parse_each_period() {
249        assert_eq!(
250            Rate::parse("1/sec").unwrap().period,
251            Duration::from_secs(1)
252        );
253        assert_eq!(Rate::parse("1/s").unwrap().period, Duration::from_secs(1));
254        assert_eq!(
255            Rate::parse("1/second").unwrap().period,
256            Duration::from_secs(1)
257        );
258        assert_eq!(
259            Rate::parse("1/min").unwrap().period,
260            Duration::from_secs(60)
261        );
262        assert_eq!(
263            Rate::parse("1/hour").unwrap().period,
264            Duration::from_secs(3600)
265        );
266        assert_eq!(
267            Rate::parse("1/day").unwrap().period,
268            Duration::from_secs(86_400)
269        );
270    }
271
272    #[test]
273    fn parse_rejects_garbage() {
274        assert!(Rate::parse("").is_err());
275        assert!(Rate::parse("oops").is_err());
276        assert!(Rate::parse("10/fortnight").is_err());
277        assert!(Rate::parse("0/sec").is_err());
278        assert!(Rate::parse("abc/min").is_err());
279    }
280
281    #[test]
282    fn third_request_in_window_denied() {
283        let limiter = RateLimiter::new(Rate::parse("2/min").unwrap());
284        let t0 = Instant::now();
285        let d1 = limiter.check_at("a", t0);
286        assert!(d1.allowed);
287        assert_eq!(d1.remaining, 1);
288        let d2 = limiter.check_at("a", t0 + Duration::from_secs(1));
289        assert!(d2.allowed);
290        assert_eq!(d2.remaining, 0);
291        let d3 = limiter.check_at("a", t0 + Duration::from_secs(2));
292        assert!(!d3.allowed);
293        assert!(d3.retry_after.is_some());
294        // Slot frees 60s after the FIRST hit, i.e. 58s from t0+2s.
295        assert_eq!(d3.retry_after.unwrap(), Duration::from_secs(58));
296    }
297
298    #[test]
299    fn distinct_keys_are_independent() {
300        let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
301        let t0 = Instant::now();
302        assert!(limiter.check_at("a", t0).allowed);
303        // Key "b" has its own bucket — not affected by "a" being full.
304        assert!(limiter.check_at("b", t0).allowed);
305        // "a" is now over its 1/min.
306        assert!(!limiter.check_at("a", t0).allowed);
307    }
308
309    #[test]
310    fn allowed_again_after_window_elapses() {
311        let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
312        let t0 = Instant::now();
313        assert!(limiter.check_at("a", t0).allowed);
314        assert!(!limiter.check_at("a", t0 + Duration::from_secs(30)).allowed);
315        // 61s later the original hit has aged out of the 60s window.
316        assert!(limiter.check_at("a", t0 + Duration::from_secs(61)).allowed);
317    }
318
319    #[test]
320    fn clear_forgets_a_key() {
321        let limiter = RateLimiter::new(Rate::parse("1/min").unwrap());
322        let t0 = Instant::now();
323        assert!(limiter.check_at("a", t0).allowed);
324        // Over budget within the window.
325        assert!(!limiter.check_at("a", t0).allowed);
326        // Clearing the key drops its history, so the next check is allowed.
327        limiter.clear("a");
328        assert!(limiter.check_at("a", t0).allowed);
329        // A clear on an unknown key is a harmless no-op.
330        limiter.clear("never-seen");
331    }
332}