Skip to main content

rustio_admin/middleware/
rate_limit.rs

1//! Per-IP rate limiting using a token bucket. Kept in memory via
2//! DashMap — good for single-node deployments. For multi-node, plug
3//! the same shape against Redis or Postgres.
4//!
5//! Default: 120 requests / 60s. Override by constructing a
6//! `RateLimiter` yourself and calling `rate_limit(limiter)`.
7
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12
13use crate::error::{Error, Result};
14use crate::http::{Request, Response};
15use crate::router::Next;
16
17// public:
18#[derive(Clone)]
19pub struct RateLimiter {
20    inner: Arc<Inner>,
21}
22
23struct Inner {
24    capacity: u32,
25    refill_per_second: f64,
26    buckets: DashMap<String, Bucket>,
27}
28
29#[derive(Debug)]
30struct Bucket {
31    tokens: f64,
32    updated: Instant,
33}
34
35impl RateLimiter {
36    // public:
37    pub fn new(capacity: u32, window: Duration) -> Self {
38        let refill = capacity as f64 / window.as_secs_f64().max(0.001);
39        Self {
40            inner: Arc::new(Inner {
41                capacity,
42                refill_per_second: refill,
43                buckets: DashMap::new(),
44            }),
45        }
46    }
47
48    // public:
49    pub fn default_limits() -> Self {
50        Self::new(120, Duration::from_secs(60))
51    }
52
53    /// Try to consume one token from the bucket keyed by `key`.
54    /// Returns `true` when the request is allowed (token consumed)
55    /// and `false` when the bucket is empty.
56    ///
57    /// `pub(crate)` so the recovery module can drive its own
58    /// scoped buckets (per-IP request + consume limits) without
59    /// going through the global middleware path. The middleware
60    /// closure in [`rate_limit`] continues to be the only public
61    /// way to plug a limiter into the router.
62    pub(crate) fn allow(&self, key: &str) -> bool {
63        let now = Instant::now();
64        let mut entry = self.inner.buckets.entry(key.to_string()).or_insert(Bucket {
65            tokens: self.inner.capacity as f64,
66            updated: now,
67        });
68        let elapsed = now.duration_since(entry.updated).as_secs_f64();
69        let refill = elapsed * self.inner.refill_per_second;
70        entry.tokens = (entry.tokens + refill).min(self.inner.capacity as f64);
71        entry.updated = now;
72        if entry.tokens >= 1.0 {
73            entry.tokens -= 1.0;
74            true
75        } else {
76            false
77        }
78    }
79}
80
81// public:
82/// The middleware function. Wrap a limiter into a closure and hand
83/// it to `Router::middleware(...)`.
84pub fn rate_limit(
85    limiter: RateLimiter,
86) -> impl Fn(
87    Request,
88    Next,
89)
90    -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response>> + Send + 'static>>
91       + Clone
92       + Send
93       + Sync
94       + 'static {
95    move |req: Request, next: Next| {
96        let limiter = limiter.clone();
97        Box::pin(async move {
98            // Prefer X-Forwarded-For (one-layer-deep) for deployments
99            // behind a reverse proxy. Falls back to a fixed "anon" key
100            // when no client identifier is available — we still rate
101            // limit in that case, just globally.
102            let key = req
103                .header("x-forwarded-for")
104                .and_then(|v| v.split(',').next())
105                .map(|s| s.trim().to_string())
106                .unwrap_or_else(|| "anon".to_string());
107
108            if !limiter.allow(&key) {
109                return Err(Error::BadRequest("rate limit exceeded".into()));
110            }
111            next.run(req).await
112        })
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn bucket_allows_burst_then_blocks() {
122        let limiter = RateLimiter::new(3, Duration::from_secs(60));
123        assert!(limiter.allow("1.2.3.4"));
124        assert!(limiter.allow("1.2.3.4"));
125        assert!(limiter.allow("1.2.3.4"));
126        assert!(!limiter.allow("1.2.3.4"));
127    }
128
129    #[test]
130    fn different_keys_tracked_separately() {
131        let limiter = RateLimiter::new(1, Duration::from_secs(60));
132        assert!(limiter.allow("a"));
133        assert!(limiter.allow("b"));
134        assert!(!limiter.allow("a"));
135    }
136}