rustio_admin/middleware/
rate_limit.rs1use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12
13use crate::error::{Error, Result};
14use crate::http::{Request, Response};
15use crate::router::Next;
16
17#[derive(Clone)]
18pub struct RateLimiter {
19 inner: Arc<Inner>,
20}
21
22struct Inner {
23 capacity: u32,
24 refill_per_second: f64,
25 buckets: DashMap<String, Bucket>,
26}
27
28#[derive(Debug)]
29struct Bucket {
30 tokens: f64,
31 updated: Instant,
32}
33
34impl RateLimiter {
35 pub fn new(capacity: u32, window: Duration) -> Self {
36 let refill = capacity as f64 / window.as_secs_f64().max(0.001);
37 Self {
38 inner: Arc::new(Inner {
39 capacity,
40 refill_per_second: refill,
41 buckets: DashMap::new(),
42 }),
43 }
44 }
45
46 pub fn default_limits() -> Self {
47 Self::new(120, Duration::from_secs(60))
48 }
49
50 pub(crate) fn allow(&self, key: &str) -> bool {
60 let now = Instant::now();
61 let mut entry = self.inner.buckets.entry(key.to_string()).or_insert(Bucket {
62 tokens: self.inner.capacity as f64,
63 updated: now,
64 });
65 let elapsed = now.duration_since(entry.updated).as_secs_f64();
66 let refill = elapsed * self.inner.refill_per_second;
67 entry.tokens = (entry.tokens + refill).min(self.inner.capacity as f64);
68 entry.updated = now;
69 if entry.tokens >= 1.0 {
70 entry.tokens -= 1.0;
71 true
72 } else {
73 false
74 }
75 }
76}
77
78pub fn rate_limit(
81 limiter: RateLimiter,
82) -> impl Fn(
83 Request,
84 Next,
85)
86 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response>> + Send + 'static>>
87 + Clone
88 + Send
89 + Sync
90 + 'static {
91 move |req: Request, next: Next| {
92 let limiter = limiter.clone();
93 Box::pin(async move {
94 let key = req
99 .header("x-forwarded-for")
100 .and_then(|v| v.split(',').next())
101 .map(|s| s.trim().to_string())
102 .unwrap_or_else(|| "anon".to_string());
103
104 if !limiter.allow(&key) {
105 return Err(Error::BadRequest("rate limit exceeded".into()));
106 }
107 next.run(req).await
108 })
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn bucket_allows_burst_then_blocks() {
118 let limiter = RateLimiter::new(3, Duration::from_secs(60));
119 assert!(limiter.allow("1.2.3.4"));
120 assert!(limiter.allow("1.2.3.4"));
121 assert!(limiter.allow("1.2.3.4"));
122 assert!(!limiter.allow("1.2.3.4"));
123 }
124
125 #[test]
126 fn different_keys_tracked_separately() {
127 let limiter = RateLimiter::new(1, Duration::from_secs(60));
128 assert!(limiter.allow("a"));
129 assert!(limiter.allow("b"));
130 assert!(!limiter.allow("a"));
131 }
132}