1use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use axum::extract::ConnectInfo;
12use axum::http::StatusCode;
13use axum::response::{IntoResponse, Response};
14use tokio::sync::RwLock;
15
16#[derive(Debug, Clone)]
22pub struct RateLimitConfig {
23 pub enabled: bool,
25 pub requests_per_second: u32,
27 pub burst_size: u32,
29 pub max_tracked_ips: usize,
31}
32
33impl RateLimitConfig {
34 const DEFAULT_MAX_TRACKED_IPS: usize = 10_000;
36
37 pub fn disabled() -> Self {
39 Self {
40 enabled: false,
41 requests_per_second: 0,
42 burst_size: 0,
43 max_tracked_ips: Self::DEFAULT_MAX_TRACKED_IPS,
44 }
45 }
46
47 pub fn new(requests_per_second: u32) -> Self {
49 Self {
50 enabled: true,
51 requests_per_second,
52 burst_size: requests_per_second.saturating_mul(2),
54 max_tracked_ips: Self::DEFAULT_MAX_TRACKED_IPS,
55 }
56 }
57
58 pub fn with_burst(requests_per_second: u32, burst_size: u32) -> Self {
60 Self {
61 enabled: true,
62 requests_per_second,
63 burst_size,
64 max_tracked_ips: Self::DEFAULT_MAX_TRACKED_IPS,
65 }
66 }
67}
68
69impl Default for RateLimitConfig {
70 fn default() -> Self {
71 Self::disabled()
72 }
73}
74
75#[derive(Debug, Clone)]
81struct TokenBucket {
82 tokens: f64,
84 last_update: Instant,
86 max_tokens: f64,
88 refill_rate: f64,
90}
91
92impl TokenBucket {
93 fn new(max_tokens: u32, refill_rate: u32) -> Self {
94 Self {
95 tokens: max_tokens as f64,
96 last_update: Instant::now(),
97 max_tokens: max_tokens as f64,
98 refill_rate: refill_rate as f64,
99 }
100 }
101
102 fn try_consume(&mut self) -> bool {
104 self.refill();
105
106 if self.tokens >= 1.0 {
107 self.tokens -= 1.0;
108 true
109 } else {
110 false
111 }
112 }
113
114 fn refill(&mut self) {
116 let now = Instant::now();
117 let elapsed = now.duration_since(self.last_update);
118 let new_tokens = elapsed.as_secs_f64() * self.refill_rate;
119
120 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
121 self.last_update = now;
122 }
123
124 fn remaining(&self) -> u32 {
126 self.tokens as u32
127 }
128
129 fn reset_after(&self) -> Duration {
131 if self.tokens >= 1.0 {
132 Duration::ZERO
133 } else {
134 let tokens_needed = 1.0 - self.tokens;
135 let seconds = tokens_needed / self.refill_rate;
136 Duration::from_secs_f64(seconds)
137 }
138 }
139}
140
141#[derive(Debug)]
147pub struct RateLimiter {
148 config: RateLimitConfig,
149 buckets: RwLock<HashMap<IpAddr, TokenBucket>>,
150}
151
152impl RateLimiter {
153 pub fn new(config: RateLimitConfig) -> Self {
155 Self {
156 config,
157 buckets: RwLock::new(HashMap::new()),
158 }
159 }
160
161 pub async fn check(&self, ip: IpAddr) -> RateLimitResult {
163 if !self.config.enabled {
164 return RateLimitResult::Allowed {
165 remaining: u32::MAX,
166 reset_after: Duration::ZERO,
167 };
168 }
169
170 let mut buckets = self.buckets.write().await;
171
172 if !buckets.contains_key(&ip) && buckets.len() >= self.config.max_tracked_ips {
174 let oldest_ip = buckets
175 .iter()
176 .min_by_key(|(_, b)| b.last_update)
177 .map(|(ip, _)| *ip);
178 if let Some(ip_to_evict) = oldest_ip {
179 buckets.remove(&ip_to_evict);
180 }
181 }
182
183 let bucket = buckets.entry(ip).or_insert_with(|| {
184 TokenBucket::new(self.config.burst_size, self.config.requests_per_second)
185 });
186
187 if bucket.try_consume() {
188 RateLimitResult::Allowed {
189 remaining: bucket.remaining(),
190 reset_after: bucket.reset_after(),
191 }
192 } else {
193 RateLimitResult::Limited {
194 retry_after: bucket.reset_after(),
195 }
196 }
197 }
198
199 pub async fn cleanup(&self, max_age: Duration) {
201 let now = Instant::now();
202 let mut buckets = self.buckets.write().await;
203 buckets.retain(|_, bucket| now.duration_since(bucket.last_update) < max_age);
204 }
205
206 pub async fn client_count(&self) -> usize {
208 self.buckets.read().await.len()
209 }
210}
211
212#[derive(Debug, Clone)]
214pub enum RateLimitResult {
215 Allowed {
217 remaining: u32,
219 reset_after: Duration,
221 },
222 Limited {
224 retry_after: Duration,
226 },
227}
228
229pub async fn rate_limit_middleware(
238 connect_info: Option<ConnectInfo<std::net::SocketAddr>>,
239 limiter: Option<Arc<RateLimiter>>,
240 req: axum::extract::Request,
241 next: axum::middleware::Next,
242) -> Response {
243 if let Some(ref limiter) = limiter {
244 let ip = connect_info
245 .map(|ci| ci.0.ip())
246 .unwrap_or(IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
247
248 match limiter.check(ip).await {
249 RateLimitResult::Allowed { .. } => {}
250 RateLimitResult::Limited { retry_after } => {
251 let retry_after_secs = retry_after.as_secs().max(1);
252 return rate_limit_error_response(retry_after_secs);
253 }
254 }
255 }
256
257 next.run(req).await
258}
259
260pub fn rate_limit_error_response(retry_after_secs: u64) -> Response {
262 let body = serde_json::json!({
263 "error": "rate_limited",
264 "message": "Too many requests",
265 "retry_after_seconds": retry_after_secs,
266 });
267 (
268 StatusCode::TOO_MANY_REQUESTS,
269 [("retry-after", retry_after_secs.to_string())],
270 axum::Json(body),
271 )
272 .into_response()
273}
274
275#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_config_disabled() {
285 let config = RateLimitConfig::disabled();
286 assert!(!config.enabled);
287 }
288
289 #[test]
290 fn test_config_new() {
291 let config = RateLimitConfig::new(100);
292 assert!(config.enabled);
293 assert_eq!(config.requests_per_second, 100);
294 assert_eq!(config.burst_size, 200);
295 }
296
297 #[test]
298 fn test_config_with_burst() {
299 let config = RateLimitConfig::with_burst(100, 50);
300 assert!(config.enabled);
301 assert_eq!(config.requests_per_second, 100);
302 assert_eq!(config.burst_size, 50);
303 }
304
305 #[test]
306 fn test_token_bucket_basic() {
307 let mut bucket = TokenBucket::new(10, 10);
308 assert_eq!(bucket.remaining(), 10);
309
310 for _ in 0..10 {
312 assert!(bucket.try_consume());
313 }
314
315 assert!(!bucket.try_consume());
317 }
318
319 #[test]
320 fn test_token_bucket_refill() {
321 let mut bucket = TokenBucket::new(10, 100); for _ in 0..10 {
325 bucket.try_consume();
326 }
327
328 bucket.last_update = Instant::now() - Duration::from_millis(100);
330
331 bucket.refill();
333 assert!(bucket.remaining() >= 9); }
335
336 #[tokio::test]
337 async fn test_rate_limiter_disabled() {
338 let config = RateLimitConfig::disabled();
339 let limiter = RateLimiter::new(config);
340
341 let ip = "127.0.0.1".parse().unwrap();
342 match limiter.check(ip).await {
343 RateLimitResult::Allowed { remaining, .. } => {
344 assert_eq!(remaining, u32::MAX);
345 }
346 RateLimitResult::Limited { .. } => panic!("Should not be limited"),
347 }
348 }
349
350 #[tokio::test]
351 async fn test_rate_limiter_allows_burst() {
352 let config = RateLimitConfig::with_burst(10, 5); let limiter = RateLimiter::new(config);
354
355 let ip = "127.0.0.1".parse().unwrap();
356
357 for i in 0..5 {
359 match limiter.check(ip).await {
360 RateLimitResult::Allowed { remaining, .. } => {
361 assert_eq!(remaining, 4 - i);
362 }
363 RateLimitResult::Limited { .. } => panic!("Should not be limited at request {}", i),
364 }
365 }
366
367 match limiter.check(ip).await {
369 RateLimitResult::Allowed { .. } => panic!("Should be limited"),
370 RateLimitResult::Limited { retry_after } => {
371 assert!(retry_after.as_millis() <= 100);
372 }
373 }
374 }
375
376 #[tokio::test]
377 async fn test_rate_limiter_per_ip() {
378 let config = RateLimitConfig::with_burst(10, 2);
379 let limiter = RateLimiter::new(config);
380
381 let ip1: IpAddr = "127.0.0.1".parse().unwrap();
382 let ip2: IpAddr = "127.0.0.2".parse().unwrap();
383
384 for _ in 0..2 {
386 limiter.check(ip1).await;
387 }
388
389 match limiter.check(ip1).await {
391 RateLimitResult::Allowed { .. } => panic!("ip1 should be limited"),
392 RateLimitResult::Limited { .. } => {}
393 }
394
395 match limiter.check(ip2).await {
397 RateLimitResult::Allowed { .. } => {}
398 RateLimitResult::Limited { .. } => panic!("ip2 should not be limited"),
399 }
400 }
401
402 #[tokio::test]
403 async fn test_rate_limiter_cleanup() {
404 let config = RateLimitConfig::new(10);
405 let limiter = RateLimiter::new(config);
406
407 let ip: IpAddr = "127.0.0.1".parse().unwrap();
408 limiter.check(ip).await;
409
410 assert_eq!(limiter.client_count().await, 1);
411
412 limiter.cleanup(Duration::from_nanos(1)).await;
414 assert_eq!(limiter.client_count().await, 0);
415 }
416
417 #[tokio::test]
418 async fn test_rate_limiter_bounded() {
419 let mut config = RateLimitConfig::new(10);
420 config.max_tracked_ips = 3;
421 let limiter = RateLimiter::new(config);
422
423 for i in 1..=3u8 {
425 let ip: IpAddr = format!("10.0.0.{}", i).parse().unwrap();
426 limiter.check(ip).await;
427 }
428 assert_eq!(limiter.client_count().await, 3);
429
430 let ip4: IpAddr = "10.0.0.4".parse().unwrap();
432 limiter.check(ip4).await;
433 assert_eq!(limiter.client_count().await, 3);
434 }
435
436 #[test]
437 fn test_reset_after_calculation() {
438 let mut bucket = TokenBucket::new(10, 10); assert_eq!(bucket.reset_after(), Duration::ZERO);
442
443 for _ in 0..10 {
445 bucket.try_consume();
446 }
447
448 let reset = bucket.reset_after();
450 assert!(reset.as_millis() >= 90 && reset.as_millis() <= 110);
451 }
452}