rustio_admin/middleware/
rate_limit.rs1use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12
13use crate::error::{Error, Result};
14use crate::http::{Request, Response};
15use crate::router::Next;
16
17#[derive(Clone)]
18pub struct RateLimiter {
19 inner: Arc<Inner>,
20}
21
22struct Inner {
23 capacity: u32,
24 refill_per_second: f64,
25 buckets: DashMap<String, Bucket>,
26}
27
28#[derive(Debug)]
29struct Bucket {
30 tokens: f64,
31 updated: Instant,
32}
33
34impl RateLimiter {
35 pub fn new(capacity: u32, window: Duration) -> Self {
36 let refill = capacity as f64 / window.as_secs_f64().max(0.001);
37 Self {
38 inner: Arc::new(Inner {
39 capacity,
40 refill_per_second: refill,
41 buckets: DashMap::new(),
42 }),
43 }
44 }
45
46 pub fn default_limits() -> Self {
47 Self::new(120, Duration::from_secs(60))
48 }
49
50 fn allow(&self, key: &str) -> bool {
51 let now = Instant::now();
52 let mut entry = self.inner.buckets.entry(key.to_string()).or_insert(Bucket {
53 tokens: self.inner.capacity as f64,
54 updated: now,
55 });
56 let elapsed = now.duration_since(entry.updated).as_secs_f64();
57 let refill = elapsed * self.inner.refill_per_second;
58 entry.tokens = (entry.tokens + refill).min(self.inner.capacity as f64);
59 entry.updated = now;
60 if entry.tokens >= 1.0 {
61 entry.tokens -= 1.0;
62 true
63 } else {
64 false
65 }
66 }
67}
68
69pub fn rate_limit(
72 limiter: RateLimiter,
73) -> impl Fn(
74 Request,
75 Next,
76)
77 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response>> + Send + 'static>>
78 + Clone
79 + Send
80 + Sync
81 + 'static {
82 move |req: Request, next: Next| {
83 let limiter = limiter.clone();
84 Box::pin(async move {
85 let key = req
90 .header("x-forwarded-for")
91 .and_then(|v| v.split(',').next())
92 .map(|s| s.trim().to_string())
93 .unwrap_or_else(|| "anon".to_string());
94
95 if !limiter.allow(&key) {
96 return Err(Error::BadRequest("rate limit exceeded".into()));
97 }
98 next.run(req).await
99 })
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn bucket_allows_burst_then_blocks() {
109 let limiter = RateLimiter::new(3, Duration::from_secs(60));
110 assert!(limiter.allow("1.2.3.4"));
111 assert!(limiter.allow("1.2.3.4"));
112 assert!(limiter.allow("1.2.3.4"));
113 assert!(!limiter.allow("1.2.3.4"));
114 }
115
116 #[test]
117 fn different_keys_tracked_separately() {
118 let limiter = RateLimiter::new(1, Duration::from_secs(60));
119 assert!(limiter.allow("a"));
120 assert!(limiter.allow("b"));
121 assert!(!limiter.allow("a"));
122 }
123}