wisegate_core/
rate_limiter.rs

1//! Rate limiting implementation for WiseGate.
2//!
3//! Provides per-IP rate limiting using a sliding window algorithm with
4//! automatic cleanup of expired entries to prevent memory exhaustion.
5//!
6//! # Algorithm
7//!
8//! Uses a simple sliding window approach:
9//! - Each IP has a counter and a timestamp of the last request
10//! - If the window has expired, the counter resets
11//! - If under the limit, the counter increments and the request is allowed
12//! - If over the limit, the request is denied
13//!
14//! # Memory Management
15//!
16//! To prevent memory exhaustion from tracking many unique IPs, the rate limiter
17//! performs automatic cleanup when:
18//! - Entry count exceeds the configured threshold
19//! - Minimum interval since last cleanup has passed
20//!
21//! # Thread Safety
22//!
23//! Uses `tokio::sync::Mutex` for async-friendly locking that won't block
24//! the Tokio thread pool.
25//!
26//! # Example
27//!
28//! ```ignore
29//! use wisegate_core::{rate_limiter, RateLimiter};
30//!
31//! let limiter = RateLimiter::new();
32//!
33//! if rate_limiter::check_rate_limit(&limiter, "192.168.1.1", &config).await {
34//!     // Request allowed
35//! } else {
36//!     // Rate limit exceeded
37//! }
38//! ```
39
40use std::time::Instant;
41use tokio::sync::Mutex;
42use tracing::debug;
43
44use crate::types::{ConfigProvider, RateLimitEntry, RateLimiter};
45
46/// Tracks the last cleanup time to enforce minimum interval between cleanups.
47static LAST_CLEANUP: Mutex<Option<Instant>> = Mutex::const_new(None);
48
49/// Checks if a request from the given IP should be allowed based on rate limits.
50///
51/// Returns `true` if the request is allowed, `false` if rate limited.
52///
53/// # Algorithm
54///
55/// 1. If the time window has expired for this IP, reset the counter
56/// 2. If the request count is under the limit, increment and allow
57/// 3. If the request count exceeds the limit, deny
58///
59/// # Cleanup
60///
61/// Automatically cleans up expired entries when:
62/// - Entry count exceeds `RATE_LIMIT_CLEANUP_THRESHOLD`
63/// - At least `RATE_LIMIT_CLEANUP_INTERVAL_SECS` since last cleanup
64///
65/// # Arguments
66///
67/// * `limiter` - Shared rate limiter state
68/// * `ip` - Client IP address to check
69/// * `config` - Configuration provider for rate limit settings
70///
71/// # Returns
72///
73/// - `true` - Request is allowed
74/// - `false` - Request is rate limited (should return 429)
75///
76/// # Example
77///
78/// ```ignore
79/// use wisegate_core::rate_limiter::check_rate_limit;
80///
81/// if !check_rate_limit(&limiter, &client_ip, &config).await {
82///     return Err(StatusCode::TOO_MANY_REQUESTS);
83/// }
84/// ```
85pub async fn check_rate_limit(
86    limiter: &RateLimiter,
87    ip: &str,
88    config: &impl ConfigProvider,
89) -> bool {
90    let rate_config = config.rate_limit_config();
91    let cleanup_config = config.rate_limit_cleanup_config();
92    let mut rate_map = limiter.inner().lock().await;
93    let now = Instant::now();
94
95    // Perform cleanup if needed (threshold exceeded and interval passed)
96    if cleanup_config.is_enabled() && rate_map.len() > cleanup_config.threshold {
97        let should_cleanup = {
98            let mut last_cleanup = LAST_CLEANUP.lock().await;
99            match *last_cleanup {
100                None => {
101                    *last_cleanup = Some(now);
102                    true
103                }
104                Some(last) if now.duration_since(last) >= cleanup_config.interval => {
105                    *last_cleanup = Some(now);
106                    true
107                }
108                _ => false,
109            }
110        };
111
112        if should_cleanup {
113            let before_count = rate_map.len();
114            // Remove entries that have expired (older than 2x window duration for safety margin)
115            let expiry_threshold = rate_config.window_duration * 2;
116            rate_map.retain(|_, entry| now.duration_since(entry.window_start) < expiry_threshold);
117            let removed = before_count - rate_map.len();
118            if removed > 0 {
119                debug!(
120                    removed_entries = removed,
121                    remaining_entries = rate_map.len(),
122                    "Rate limiter cleanup completed"
123                );
124            }
125        }
126    }
127
128    match rate_map.get_mut(ip) {
129        Some(entry) => {
130            // Check if we're in a new time window
131            if now.duration_since(entry.window_start) >= rate_config.window_duration {
132                // Reset window
133                entry.window_start = now;
134                entry.request_count = 1;
135                true
136            } else if entry.request_count < rate_config.max_requests {
137                // Within limit, increment counter
138                entry.request_count += 1;
139                true
140            } else {
141                // Rate limit exceeded
142                false
143            }
144        }
145        None => {
146            // First request from this IP
147            rate_map.insert(ip.to_string(), RateLimitEntry::new());
148            true
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use crate::test_utils::TestConfig;
157
158    // ===========================================
159    // Basic rate limiting tests
160    // ===========================================
161
162    #[tokio::test]
163    async fn test_first_request_allowed() {
164        let limiter = RateLimiter::new();
165        let config = TestConfig::new().with_rate_limit(5, 60);
166
167        let allowed = check_rate_limit(&limiter, "192.168.1.1", &config).await;
168        assert!(allowed);
169    }
170
171    #[tokio::test]
172    async fn test_requests_within_limit_allowed() {
173        let limiter = RateLimiter::new();
174        let config = TestConfig::new().with_rate_limit(5, 60);
175
176        for i in 0..5 {
177            let allowed = check_rate_limit(&limiter, "192.168.1.1", &config).await;
178            assert!(allowed, "Request {} should be allowed", i + 1);
179        }
180    }
181
182    #[tokio::test]
183    async fn test_request_exceeding_limit_blocked() {
184        let limiter = RateLimiter::new();
185        let config = TestConfig::new().with_rate_limit(3, 60);
186
187        // First 3 requests should be allowed
188        for _ in 0..3 {
189            assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
190        }
191
192        // 4th request should be blocked
193        let blocked = check_rate_limit(&limiter, "192.168.1.1", &config).await;
194        assert!(!blocked, "Request exceeding limit should be blocked");
195    }
196
197    #[tokio::test]
198    async fn test_different_ips_independent() {
199        let limiter = RateLimiter::new();
200        let config = TestConfig::new().with_rate_limit(2, 60);
201
202        // IP 1 makes 2 requests
203        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
204        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
205        assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
206
207        // IP 2 should still have its full quota
208        assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
209        assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
210        assert!(!check_rate_limit(&limiter, "192.168.1.2", &config).await);
211    }
212
213    #[tokio::test]
214    async fn test_counter_increments_correctly() {
215        let limiter = RateLimiter::new();
216        let config = TestConfig::new().with_rate_limit(5, 60);
217
218        // Make some requests
219        check_rate_limit(&limiter, "192.168.1.1", &config).await;
220        check_rate_limit(&limiter, "192.168.1.1", &config).await;
221        check_rate_limit(&limiter, "192.168.1.1", &config).await;
222
223        // Check the counter
224        let inner = limiter.inner().lock().await;
225        let entry = inner.get("192.168.1.1").unwrap();
226        assert_eq!(entry.request_count, 3);
227    }
228
229    // ===========================================
230    // Edge case tests
231    // ===========================================
232
233    #[tokio::test]
234    async fn test_limit_of_one() {
235        let limiter = RateLimiter::new();
236        let config = TestConfig::new().with_rate_limit(1, 60);
237
238        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
239        assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
240    }
241
242    #[tokio::test]
243    async fn test_ipv6_addresses() {
244        let limiter = RateLimiter::new();
245        let config = TestConfig::new().with_rate_limit(2, 60);
246
247        assert!(check_rate_limit(&limiter, "::1", &config).await);
248        assert!(check_rate_limit(&limiter, "::1", &config).await);
249        assert!(!check_rate_limit(&limiter, "::1", &config).await);
250
251        // Different IPv6 should be independent
252        assert!(check_rate_limit(&limiter, "2001:db8::1", &config).await);
253    }
254
255    #[tokio::test]
256    async fn test_multiple_blocked_requests() {
257        let limiter = RateLimiter::new();
258        let config = TestConfig::new().with_rate_limit(1, 60);
259
260        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
261
262        // Multiple blocked requests should all return false
263        for _ in 0..5 {
264            assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
265        }
266
267        // Counter should not increase beyond limit
268        let inner = limiter.inner().lock().await;
269        let entry = inner.get("192.168.1.1").unwrap();
270        assert_eq!(entry.request_count, 1);
271    }
272
273    // ===========================================
274    // Concurrent access tests
275    // ===========================================
276
277    #[tokio::test]
278    async fn test_limiter_clone_shares_state() {
279        let limiter1 = RateLimiter::new();
280        let limiter2 = limiter1.clone();
281        let config = TestConfig::new().with_rate_limit(2, 60);
282
283        // Use limiter1 for first request
284        assert!(check_rate_limit(&limiter1, "192.168.1.1", &config).await);
285
286        // Use limiter2 for second request - should share state
287        assert!(check_rate_limit(&limiter2, "192.168.1.1", &config).await);
288
289        // Third request on either should be blocked
290        assert!(!check_rate_limit(&limiter1, "192.168.1.1", &config).await);
291    }
292
293    // ===========================================
294    // Cleanup tests
295    // ===========================================
296
297    #[tokio::test]
298    async fn test_cleanup_disabled_when_threshold_zero() {
299        let limiter = RateLimiter::new();
300        let config = TestConfig::new().with_rate_limit(100, 60); // Cleanup disabled (threshold = 0)
301
302        // Add many entries
303        for i in 0..100 {
304            check_rate_limit(&limiter, &format!("192.168.1.{}", i), &config).await;
305        }
306
307        // All entries should remain (no cleanup)
308        let inner = limiter.inner().lock().await;
309        assert_eq!(inner.len(), 100);
310    }
311
312    #[tokio::test]
313    async fn test_entries_tracked_per_ip() {
314        let limiter = RateLimiter::new();
315        let config = TestConfig::new().with_rate_limit(10, 60);
316
317        // Make requests from 5 different IPs
318        for i in 0..5 {
319            check_rate_limit(&limiter, &format!("10.0.0.{}", i), &config).await;
320        }
321
322        let inner = limiter.inner().lock().await;
323        assert_eq!(inner.len(), 5);
324    }
325
326    // ===========================================
327    // Time-based / Window expiration tests
328    // ===========================================
329
330    #[tokio::test]
331    async fn test_window_reset_after_expiration() {
332        use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
333        use std::time::Duration;
334
335        let limiter = RateLimiter::new();
336        // Use very short window for testing - need direct struct for milliseconds
337        let config = TestConfig {
338            rate_limit: RateLimitConfig {
339                max_requests: 2,
340                window_duration: Duration::from_millis(1),
341            },
342            cleanup: RateLimitCleanupConfig {
343                threshold: 0,
344                interval: Duration::from_secs(60),
345            },
346            ..TestConfig::default()
347        };
348
349        // First two requests allowed
350        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
351        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
352
353        // Third blocked
354        assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
355
356        // Wait for window to expire
357        tokio::time::sleep(Duration::from_millis(5)).await;
358
359        // Should be allowed again after window expires
360        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
361    }
362
363    #[tokio::test]
364    async fn test_window_reset_resets_counter() {
365        use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
366        use std::time::Duration;
367
368        let limiter = RateLimiter::new();
369        let config = TestConfig {
370            rate_limit: RateLimitConfig {
371                max_requests: 3,
372                window_duration: Duration::from_millis(1),
373            },
374            cleanup: RateLimitCleanupConfig {
375                threshold: 0,
376                interval: Duration::from_secs(60),
377            },
378            ..TestConfig::default()
379        };
380
381        // Use full quota
382        for _ in 0..3 {
383            assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
384        }
385        assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
386
387        // Wait for window to expire
388        tokio::time::sleep(Duration::from_millis(5)).await;
389
390        // Counter should reset, full quota available again
391        for _ in 0..3 {
392            assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
393        }
394        assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
395    }
396
397    #[tokio::test]
398    async fn test_window_not_expired_keeps_count() {
399        let limiter = RateLimiter::new();
400        let config = TestConfig::new().with_rate_limit(5, 3600); // 1 hour window
401
402        // Make 3 requests
403        for _ in 0..3 {
404            assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
405        }
406
407        // Verify counter is 3
408        {
409            let inner = limiter.inner().lock().await;
410            assert_eq!(inner.get("192.168.1.1").unwrap().request_count, 3);
411        }
412
413        // Make 2 more requests (still within limit)
414        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
415        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
416
417        // Now should be blocked (5 requests made)
418        assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
419
420        // Counter should still be 5 (not increased when blocked)
421        let inner = limiter.inner().lock().await;
422        assert_eq!(inner.get("192.168.1.1").unwrap().request_count, 5);
423    }
424
425    #[tokio::test]
426    async fn test_different_ips_different_windows() {
427        use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
428        use std::time::Duration;
429
430        let limiter = RateLimiter::new();
431        let config = TestConfig {
432            rate_limit: RateLimitConfig {
433                max_requests: 2,
434                window_duration: Duration::from_millis(50),
435            },
436            cleanup: RateLimitCleanupConfig {
437                threshold: 0,
438                interval: Duration::from_secs(60),
439            },
440            ..TestConfig::default()
441        };
442
443        // IP1: exhaust quota
444        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
445        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
446        assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
447
448        // Wait a bit (not enough for window to expire)
449        tokio::time::sleep(Duration::from_millis(10)).await;
450
451        // IP2: start fresh
452        assert!(check_rate_limit(&limiter, "192.168.1.2", &config).await);
453
454        // Wait for IP1's window to expire
455        tokio::time::sleep(Duration::from_millis(50)).await;
456
457        // IP1 should be allowed again
458        assert!(check_rate_limit(&limiter, "192.168.1.1", &config).await);
459
460        // IP2 still within its window, should have 1 request counted
461        let inner = limiter.inner().lock().await;
462        // IP2 might have had its window expire too, depends on timing
463        // Just verify both IPs are tracked
464        assert!(inner.contains_key("192.168.1.1"));
465        assert!(inner.contains_key("192.168.1.2"));
466    }
467
468    // ===========================================
469    // Cleanup with expiration tests
470    // ===========================================
471
472    #[tokio::test]
473    async fn test_cleanup_removes_expired_entries() {
474        use crate::types::{RateLimitCleanupConfig, RateLimitConfig};
475        use std::time::Duration;
476
477        let limiter = RateLimiter::new();
478        let config = TestConfig {
479            rate_limit: RateLimitConfig {
480                max_requests: 100,
481                window_duration: Duration::from_millis(1), // Very short window
482            },
483            cleanup: RateLimitCleanupConfig {
484                threshold: 1,                       // Trigger cleanup when > 1 entry
485                interval: Duration::from_millis(1), // Allow frequent cleanup
486            },
487            ..TestConfig::default()
488        };
489
490        // Add first entry
491        check_rate_limit(&limiter, "192.168.1.1", &config).await;
492
493        // Wait for it to expire (2x window duration for cleanup)
494        tokio::time::sleep(Duration::from_millis(10)).await;
495
496        // Add second entry - this should trigger cleanup
497        check_rate_limit(&limiter, "192.168.1.2", &config).await;
498
499        // Wait a bit more and add third to trigger another cleanup check
500        tokio::time::sleep(Duration::from_millis(10)).await;
501        check_rate_limit(&limiter, "192.168.1.3", &config).await;
502
503        // Only recent entries should remain (older ones cleaned up)
504        let inner = limiter.inner().lock().await;
505        // Due to timing, we can't predict exactly which entries remain
506        // but we can verify cleanup mechanism works by checking count is <= 3
507        assert!(inner.len() <= 3);
508    }
509
510    // ===========================================
511    // Concurrent request tests
512    // ===========================================
513
514    #[tokio::test]
515    async fn test_concurrent_requests_same_ip() {
516        let limiter = RateLimiter::new();
517        let config = TestConfig::new().with_rate_limit(10, 60);
518
519        // Spawn multiple concurrent requests
520        let mut handles = vec![];
521        for _ in 0..10 {
522            let limiter_clone = limiter.clone();
523            let handle = tokio::spawn(async move {
524                let config = TestConfig::new().with_rate_limit(10, 60);
525                check_rate_limit(&limiter_clone, "192.168.1.1", &config).await
526            });
527            handles.push(handle);
528        }
529
530        // Wait for all to complete
531        let results: Vec<bool> = futures::future::join_all(handles)
532            .await
533            .into_iter()
534            .map(|r| r.unwrap())
535            .collect();
536
537        // All 10 should be allowed
538        assert_eq!(results.iter().filter(|&&r| r).count(), 10);
539
540        // 11th request should be blocked
541        assert!(!check_rate_limit(&limiter, "192.168.1.1", &config).await);
542    }
543
544    #[tokio::test]
545    async fn test_concurrent_requests_different_ips() {
546        let limiter = RateLimiter::new();
547
548        // Spawn requests from different IPs concurrently
549        let mut handles = vec![];
550        for i in 0..50 {
551            let limiter_clone = limiter.clone();
552            let ip = format!("192.168.1.{}", i);
553            let handle = tokio::spawn(async move {
554                let config = TestConfig::new().with_rate_limit(5, 60);
555                check_rate_limit(&limiter_clone, &ip, &config).await
556            });
557            handles.push(handle);
558        }
559
560        // Wait for all to complete
561        let results: Vec<bool> = futures::future::join_all(handles)
562            .await
563            .into_iter()
564            .map(|r| r.unwrap())
565            .collect();
566
567        // All should be allowed (first request from each IP)
568        assert!(results.iter().all(|&r| r));
569
570        // Verify all IPs are tracked
571        let inner = limiter.inner().lock().await;
572        assert_eq!(inner.len(), 50);
573    }
574}