Skip to main content

rustio_admin/middleware/
rate_limit.rs

1//! Per-IP rate limiting using a token bucket. Kept in memory via
2//! DashMap — good for single-node deployments. For multi-node, plug
3//! the same shape against Redis or Postgres.
4//!
5//! Default: 120 requests / 60s. Override by constructing a
6//! `RateLimiter` yourself and calling `rate_limit(limiter)`.
7
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12
13use crate::error::{Error, Result};
14use crate::http::{Request, Response};
15use crate::router::Next;
16
17#[derive(Clone)]
18pub struct RateLimiter {
19    inner: Arc<Inner>,
20}
21
22struct Inner {
23    capacity: u32,
24    refill_per_second: f64,
25    buckets: DashMap<String, Bucket>,
26}
27
28#[derive(Debug)]
29struct Bucket {
30    tokens: f64,
31    updated: Instant,
32}
33
34impl RateLimiter {
35    pub fn new(capacity: u32, window: Duration) -> Self {
36        let refill = capacity as f64 / window.as_secs_f64().max(0.001);
37        Self {
38            inner: Arc::new(Inner {
39                capacity,
40                refill_per_second: refill,
41                buckets: DashMap::new(),
42            }),
43        }
44    }
45
46    pub fn default_limits() -> Self {
47        Self::new(120, Duration::from_secs(60))
48    }
49
50    fn allow(&self, key: &str) -> bool {
51        let now = Instant::now();
52        let mut entry = self.inner.buckets.entry(key.to_string()).or_insert(Bucket {
53            tokens: self.inner.capacity as f64,
54            updated: now,
55        });
56        let elapsed = now.duration_since(entry.updated).as_secs_f64();
57        let refill = elapsed * self.inner.refill_per_second;
58        entry.tokens = (entry.tokens + refill).min(self.inner.capacity as f64);
59        entry.updated = now;
60        if entry.tokens >= 1.0 {
61            entry.tokens -= 1.0;
62            true
63        } else {
64            false
65        }
66    }
67}
68
69/// The middleware function. Wrap a limiter into a closure and hand
70/// it to `Router::middleware(...)`.
71pub fn rate_limit(
72    limiter: RateLimiter,
73) -> impl Fn(
74    Request,
75    Next,
76)
77    -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response>> + Send + 'static>>
78       + Clone
79       + Send
80       + Sync
81       + 'static {
82    move |req: Request, next: Next| {
83        let limiter = limiter.clone();
84        Box::pin(async move {
85            // Prefer X-Forwarded-For (one-layer-deep) for deployments
86            // behind a reverse proxy. Falls back to a fixed "anon" key
87            // when no client identifier is available — we still rate
88            // limit in that case, just globally.
89            let key = req
90                .header("x-forwarded-for")
91                .and_then(|v| v.split(',').next())
92                .map(|s| s.trim().to_string())
93                .unwrap_or_else(|| "anon".to_string());
94
95            if !limiter.allow(&key) {
96                return Err(Error::BadRequest("rate limit exceeded".into()));
97            }
98            next.run(req).await
99        })
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn bucket_allows_burst_then_blocks() {
109        let limiter = RateLimiter::new(3, Duration::from_secs(60));
110        assert!(limiter.allow("1.2.3.4"));
111        assert!(limiter.allow("1.2.3.4"));
112        assert!(limiter.allow("1.2.3.4"));
113        assert!(!limiter.allow("1.2.3.4"));
114    }
115
116    #[test]
117    fn different_keys_tracked_separately() {
118        let limiter = RateLimiter::new(1, Duration::from_secs(60));
119        assert!(limiter.allow("a"));
120        assert!(limiter.allow("b"));
121        assert!(!limiter.allow("a"));
122    }
123}