rustio_admin/middleware/
rate_limit.rs1use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use dashmap::DashMap;
12
13use crate::error::{Error, Result};
14use crate::http::{Request, Response};
15use crate::router::Next;
16
17#[derive(Clone)]
19pub struct RateLimiter {
20 inner: Arc<Inner>,
21}
22
23struct Inner {
24 capacity: u32,
25 refill_per_second: f64,
26 buckets: DashMap<String, Bucket>,
27}
28
29#[derive(Debug)]
30struct Bucket {
31 tokens: f64,
32 updated: Instant,
33}
34
35impl RateLimiter {
36 pub fn new(capacity: u32, window: Duration) -> Self {
38 let refill = capacity as f64 / window.as_secs_f64().max(0.001);
39 Self {
40 inner: Arc::new(Inner {
41 capacity,
42 refill_per_second: refill,
43 buckets: DashMap::new(),
44 }),
45 }
46 }
47
48 pub fn default_limits() -> Self {
50 Self::new(120, Duration::from_secs(60))
51 }
52
53 pub(crate) fn allow(&self, key: &str) -> bool {
63 let now = Instant::now();
64 let mut entry = self.inner.buckets.entry(key.to_string()).or_insert(Bucket {
65 tokens: self.inner.capacity as f64,
66 updated: now,
67 });
68 let elapsed = now.duration_since(entry.updated).as_secs_f64();
69 let refill = elapsed * self.inner.refill_per_second;
70 entry.tokens = (entry.tokens + refill).min(self.inner.capacity as f64);
71 entry.updated = now;
72 if entry.tokens >= 1.0 {
73 entry.tokens -= 1.0;
74 true
75 } else {
76 false
77 }
78 }
79}
80
81pub fn rate_limit(
85 limiter: RateLimiter,
86) -> impl Fn(
87 Request,
88 Next,
89)
90 -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Response>> + Send + 'static>>
91 + Clone
92 + Send
93 + Sync
94 + 'static {
95 move |req: Request, next: Next| {
96 let limiter = limiter.clone();
97 Box::pin(async move {
98 let key = req
103 .header("x-forwarded-for")
104 .and_then(|v| v.split(',').next())
105 .map(|s| s.trim().to_string())
106 .unwrap_or_else(|| "anon".to_string());
107
108 if !limiter.allow(&key) {
109 return Err(Error::BadRequest("rate limit exceeded".into()));
110 }
111 next.run(req).await
112 })
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
121 fn bucket_allows_burst_then_blocks() {
122 let limiter = RateLimiter::new(3, Duration::from_secs(60));
123 assert!(limiter.allow("1.2.3.4"));
124 assert!(limiter.allow("1.2.3.4"));
125 assert!(limiter.allow("1.2.3.4"));
126 assert!(!limiter.allow("1.2.3.4"));
127 }
128
129 #[test]
130 fn different_keys_tracked_separately() {
131 let limiter = RateLimiter::new(1, Duration::from_secs(60));
132 assert!(limiter.allow("a"));
133 assert!(limiter.allow("b"));
134 assert!(!limiter.allow("a"));
135 }
136}