1use std::time::Instant;
41use tokio::sync::Mutex;
42use tracing::debug;
43
44use crate::types::{ConfigProvider, RateLimitEntry, RateLimiter};
45
46static LAST_CLEANUP: Mutex<Option<Instant>> = Mutex::const_new(None);
48
49pub async fn check_rate_limit(
86 limiter: &RateLimiter,
87 ip: &str,
88 config: &impl ConfigProvider,
89) -> bool {
90 let rate_config = config.rate_limit_config();
91 let cleanup_config = config.rate_limit_cleanup_config();
92 let mut rate_map = limiter.inner().lock().await;
93 let now = Instant::now();
94
95 if cleanup_config.is_enabled() && rate_map.len() > cleanup_config.threshold {
97 let should_cleanup = {
98 let mut last_cleanup = LAST_CLEANUP.lock().await;
99 match *last_cleanup {
100 None => {
101 *last_cleanup = Some(now);
102 true
103 }
104 Some(last) if now.duration_since(last) >= cleanup_config.interval => {
105 *last_cleanup = Some(now);
106 true
107 }
108 _ => false,
109 }
110 };
111
112 if should_cleanup {
113 let before_count = rate_map.len();
114 let expiry_threshold = rate_config.window_duration * 2;
116 rate_map.retain(|_, entry| now.duration_since(entry.window_start) < expiry_threshold);
117 let removed = before_count - rate_map.len();
118 if removed > 0 {
119 debug!(
120 removed_entries = removed,
121 remaining_entries = rate_map.len(),
122 "Rate limiter cleanup completed"
123 );
124 }
125 }
126 }
127
128 match rate_map.get_mut(ip) {
129 Some(entry) => {
130 if now.duration_since(entry.window_start) >= rate_config.window_duration {
132 entry.window_start = now;
134 entry.request_count = 1;
135 true
136 } else if entry.request_count < rate_config.max_requests {
137 entry.request_count += 1;
139 true
140 } else {
141 false
143 }
144 }
145 None => {
146 rate_map.insert(ip.to_string(), RateLimitEntry::new());
148 true
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::test_utils::TestConfig;
157
158 #[tokio::test]
163 async fn test_first_request_allowed() {
164 let limiter = RateLimiter::new();
165 let config = TestConfig::new().with_rate_limit(5, 60);
166
167 let allowed = check_rate_limit(&limiter, "192.168.1.1", &config).await;
168 assert!(allowed);
169 }
170
171 #[tokio::test]
172 async fn test_requests_within_limit_allowed() {
173 let limiter = RateLimiter::new();
174 let config = TestConfig::new().with_rate_limit(5, 60);
175
176 for i in 0..5 {
177 let allowed = check_rate_limit(&limiter, "192.168.1.1", &config).await;
178 assert!(allowed, "Request {} should be allowed", i + 1);
179 }
180 }
181
182 #[tokio::test]
183 async fn test_request_exceeding_limit_blocked() {
184 let limiter = RateLimiter::new();
185 let config = TestConfig::new().with_rate_limit(3, 60);
186
187 for _ in 0..3 {
189 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
190 }
191
192 let blocked = check_rate_limit(&limiter, "192.168.1.1", &config).await;
194 assert!(!blocked, "Request exceeding limit should be blocked");
195 }
196
197 #[tokio::test]
198 async fn test_different_ips_independent() {
199 let limiter = RateLimiter::new();
200 let config = TestConfig::new().with_rate_limit(2, 60);
201
202 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
204 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
205 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
206
207 assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
209 assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
210 assert!(!check_rate_limit(&limiter, "192.168.1.2", &config).await);
211 }
212
213 #[tokio::test]
214 async fn test_counter_increments_correctly() {
215 let limiter = RateLimiter::new();
216 let config = TestConfig::new().with_rate_limit(5, 60);
217
218 check_rate_limit(&limiter, "192.168.1.1", &config).await;
220 check_rate_limit(&limiter, "192.168.1.1", &config).await;
221 check_rate_limit(&limiter, "192.168.1.1", &config).await;
222
223 let inner = limiter.inner().lock().await;
225 let entry = inner.get("192.168.1.1").unwrap();
226 assert_eq!(entry.request_count, 3);
227 }
228
229 #[tokio::test]
234 async fn test_limit_of_one() {
235 let limiter = RateLimiter::new();
236 let config = TestConfig::new().with_rate_limit(1, 60);
237
238 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
239 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
240 }
241
242 #[tokio::test]
243 async fn test_ipv6_addresses() {
244 let limiter = RateLimiter::new();
245 let config = TestConfig::new().with_rate_limit(2, 60);
246
247 assert!(check_rate_limit(&limiter, "::1", &config).await);
248 assert!(check_rate_limit(&limiter, "::1", &config).await);
249 assert!(!check_rate_limit(&limiter, "::1", &config).await);
250
251 assert!(check_rate_limit(&limiter, "2001:db8::1", &config).await);
253 }
254
255 #[tokio::test]
256 async fn test_multiple_blocked_requests() {
257 let limiter = RateLimiter::new();
258 let config = TestConfig::new().with_rate_limit(1, 60);
259
260 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
261
262 for _ in 0..5 {
264 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
265 }
266
267 let inner = limiter.inner().lock().await;
269 let entry = inner.get("192.168.1.1").unwrap();
270 assert_eq!(entry.request_count, 1);
271 }
272
273 #[tokio::test]
278 async fn test_limiter_clone_shares_state() {
279 let limiter1 = RateLimiter::new();
280 let limiter2 = limiter1.clone();
281 let config = TestConfig::new().with_rate_limit(2, 60);
282
283 assert!(check_rate_limit(&limiter1, "192.168.1.1", &config).await);
285
286 assert!(check_rate_limit(&limiter2, "192.168.1.1", &config).await);
288
289 assert!(!check_rate_limit(&limiter1, "192.168.1.1", &config).await);
291 }
292
293 #[tokio::test]
298 async fn test_cleanup_disabled_when_threshold_zero() {
299 let limiter = RateLimiter::new();
300 let config = TestConfig::new().with_rate_limit(100, 60); for i in 0..100 {
304 check_rate_limit(&limiter, &format!("192.168.1.{}", i), &config).await;
305 }
306
307 let inner = limiter.inner().lock().await;
309 assert_eq!(inner.len(), 100);
310 }
311
312 #[tokio::test]
313 async fn test_entries_tracked_per_ip() {
314 let limiter = RateLimiter::new();
315 let config = TestConfig::new().with_rate_limit(10, 60);
316
317 for i in 0..5 {
319 check_rate_limit(&limiter, &format!("10.0.0.{}", i), &config).await;
320 }
321
322 let inner = limiter.inner().lock().await;
323 assert_eq!(inner.len(), 5);
324 }
325
326 #[tokio::test]
331 async fn test_window_reset_after_expiration() {
332 use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
333 use std::time::Duration;
334
335 let limiter = RateLimiter::new();
336 let config = TestConfig {
338 rate_limit: RateLimitConfig {
339 max_requests: 2,
340 window_duration: Duration::from_millis(1),
341 },
342 cleanup: RateLimitCleanupConfig {
343 threshold: 0,
344 interval: Duration::from_secs(60),
345 },
346 ..TestConfig::default()
347 };
348
349 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
351 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
352
353 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
355
356 tokio::time::sleep(Duration::from_millis(5)).await;
358
359 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
361 }
362
363 #[tokio::test]
364 async fn test_window_reset_resets_counter() {
365 use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
366 use std::time::Duration;
367
368 let limiter = RateLimiter::new();
369 let config = TestConfig {
370 rate_limit: RateLimitConfig {
371 max_requests: 3,
372 window_duration: Duration::from_millis(1),
373 },
374 cleanup: RateLimitCleanupConfig {
375 threshold: 0,
376 interval: Duration::from_secs(60),
377 },
378 ..TestConfig::default()
379 };
380
381 for _ in 0..3 {
383 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
384 }
385 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
386
387 tokio::time::sleep(Duration::from_millis(5)).await;
389
390 for _ in 0..3 {
392 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
393 }
394 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
395 }
396
397 #[tokio::test]
398 async fn test_window_not_expired_keeps_count() {
399 let limiter = RateLimiter::new();
400 let config = TestConfig::new().with_rate_limit(5, 3600); for _ in 0..3 {
404 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
405 }
406
407 {
409 let inner = limiter.inner().lock().await;
410 assert_eq!(inner.get("192.168.1.1").unwrap().request_count, 3);
411 }
412
413 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
415 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
416
417 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
419
420 let inner = limiter.inner().lock().await;
422 assert_eq!(inner.get("192.168.1.1").unwrap().request_count, 5);
423 }
424
425 #[tokio::test]
426 async fn test_different_ips_different_windows() {
427 use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
428 use std::time::Duration;
429
430 let limiter = RateLimiter::new();
431 let config = TestConfig {
432 rate_limit: RateLimitConfig {
433 max_requests: 2,
434 window_duration: Duration::from_millis(50),
435 },
436 cleanup: RateLimitCleanupConfig {
437 threshold: 0,
438 interval: Duration::from_secs(60),
439 },
440 ..TestConfig::default()
441 };
442
443 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
445 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
446 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
447
448 tokio::time::sleep(Duration::from_millis(10)).await;
450
451 assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
453
454 tokio::time::sleep(Duration::from_millis(50)).await;
456
457 assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
459
460 let inner = limiter.inner().lock().await;
462 assert!(inner.contains_key("192.168.1.1"));
465 assert!(inner.contains_key("192.168.1.2"));
466 }
467
468 #[tokio::test]
473 async fn test_cleanup_removes_expired_entries() {
474 use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
475 use std::time::Duration;
476
477 let limiter = RateLimiter::new();
478 let config = TestConfig {
479 rate_limit: RateLimitConfig {
480 max_requests: 100,
481 window_duration: Duration::from_millis(1), },
483 cleanup: RateLimitCleanupConfig {
484 threshold: 1, interval: Duration::from_millis(1), },
487 ..TestConfig::default()
488 };
489
490 check_rate_limit(&limiter, "192.168.1.1", &config).await;
492
493 tokio::time::sleep(Duration::from_millis(10)).await;
495
496 check_rate_limit(&limiter, "192.168.1.2", &config).await;
498
499 tokio::time::sleep(Duration::from_millis(10)).await;
501 check_rate_limit(&limiter, "192.168.1.3", &config).await;
502
503 let inner = limiter.inner().lock().await;
505 assert!(inner.len() <= 3);
508 }
509
510 #[tokio::test]
515 async fn test_concurrent_requests_same_ip() {
516 let limiter = RateLimiter::new();
517 let config = TestConfig::new().with_rate_limit(10, 60);
518
519 let mut handles = vec![];
521 for _ in 0..10 {
522 let limiter_clone = limiter.clone();
523 let handle = tokio::spawn(async move {
524 let config = TestConfig::new().with_rate_limit(10, 60);
525 check_rate_limit(&limiter_clone, "192.168.1.1", &config).await
526 });
527 handles.push(handle);
528 }
529
530 let results: Vec<bool> = futures::future::join_all(handles)
532 .await
533 .into_iter()
534 .map(|r| r.unwrap())
535 .collect();
536
537 assert_eq!(results.iter().filter(|&&r| r).count(), 10);
539
540 assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
542 }
543
544 #[tokio::test]
545 async fn test_concurrent_requests_different_ips() {
546 let limiter = RateLimiter::new();
547
548 let mut handles = vec![];
550 for i in 0..50 {
551 let limiter_clone = limiter.clone();
552 let ip = format!("192.168.1.{}", i);
553 let handle = tokio::spawn(async move {
554 let config = TestConfig::new().with_rate_limit(5, 60);
555 check_rate_limit(&limiter_clone, &ip, &config).await
556 });
557 handles.push(handle);
558 }
559
560 let results: Vec<bool> = futures::future::join_all(handles)
562 .await
563 .into_iter()
564 .map(|r| r.unwrap())
565 .collect();
566
567 assert!(results.iter().all(|&r| r));
569
570 let inner = limiter.inner().lock().await;
572 assert_eq!(inner.len(), 50);
573 }
574}