Skip to main content

rustio_admin/middleware/
rate_limit.rs

1//! Per-IP rate limiting using a token bucket. Kept in memory via
2//! DashMap — good for single-node deployments. For multi-node, plug
3//! the same shape against Redis or Postgres.
4//!
5//! Default: 120 requests / 60s. Override by constructing a
6//! `RateLimiter` yourself and calling `rate_limit(limiter)`.
7
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12
13use crate::error::{Error, Result};
14use crate::http::{Request, Response};
15use crate::router::Next;
16
17#[derive(Clone)]
18pub struct RateLimiter {
19    inner: Arc<Inner>,
20}
21
22struct Inner {
23    capacity: u32,
24    refill_per_second: f64,
25    buckets: DashMap<String, Bucket>,
26}
27
28#[derive(Debug)]
29struct Bucket {
30    tokens: f64,
31    updated: Instant,
32}
33
34impl RateLimiter {
35    pub fn new(capacity: u32, window: Duration) -> Self {
36        let refill = capacity as f64 / window.as_secs_f64().max(0.001);
37        Self {
38            inner: Arc::new(Inner {
39                capacity,
40                refill_per_second: refill,
41                buckets: DashMap::new(),
42            }),
43        }
44    }
45
46    pub fn default_limits() -> Self {
47        Self::new(120, Duration::from_secs(60))
48    }
49
50    /// Try to consume one token from the bucket keyed by `key`.
51    /// Returns `true` when the request is allowed (token consumed)
52    /// and `false` when the bucket is empty.
53    ///
54    /// `pub(crate)` so the recovery module can drive its own
55    /// scoped buckets (per-IP request + consume limits) without
56    /// going through the global middleware path. The middleware
57    /// closure in [`rate_limit`] continues to be the only public
58    /// way to plug a limiter into the router.
59    pub(crate) fn allow(&self, key: &str) -> bool {
60        let now = Instant::now();
61        let mut entry = self.inner.buckets.entry(key.to_string()).or_insert(Bucket {
62            tokens: self.inner.capacity as f64,
63            updated: now,
64        });
65        let elapsed = now.duration_since(entry.updated).as_secs_f64();
66        let refill = elapsed * self.inner.refill_per_second;
67        entry.tokens = (entry.tokens + refill).min(self.inner.capacity as f64);
68        entry.updated = now;
69        if entry.tokens >= 1.0 {
70            entry.tokens -= 1.0;
71            true
72        } else {
73            false
74        }
75    }
76}
77
78/// The middleware function. Wrap a limiter into a closure and hand
79/// it to `Router::middleware(...)`.
80pub fn rate_limit(
81    limiter: RateLimiter,
82) -> impl Fn(
83    Request,
84    Next,
85)
86    -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response>> + Send + 'static>>
87       + Clone
88       + Send
89       + Sync
90       + 'static {
91    move |req: Request, next: Next| {
92        let limiter = limiter.clone();
93        Box::pin(async move {
94            // Prefer X-Forwarded-For (one-layer-deep) for deployments
95            // behind a reverse proxy. Falls back to a fixed "anon" key
96            // when no client identifier is available — we still rate
97            // limit in that case, just globally.
98            let key = req
99                .header("x-forwarded-for")
100                .and_then(|v| v.split(',').next())
101                .map(|s| s.trim().to_string())
102                .unwrap_or_else(|| "anon".to_string());
103
104            if !limiter.allow(&key) {
105                return Err(Error::BadRequest("rate limit exceeded".into()));
106            }
107            next.run(req).await
108        })
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn bucket_allows_burst_then_blocks() {
118        let limiter = RateLimiter::new(3, Duration::from_secs(60));
119        assert!(limiter.allow("1.2.3.4"));
120        assert!(limiter.allow("1.2.3.4"));
121        assert!(limiter.allow("1.2.3.4"));
122        assert!(!limiter.allow("1.2.3.4"));
123    }
124
125    #[test]
126    fn different_keys_tracked_separately() {
127        let limiter = RateLimiter::new(1, Duration::from_secs(60));
128        assert!(limiter.allow("a"));
129        assert!(limiter.allow("b"));
130        assert!(!limiter.allow("a"));
131    }
132}