Skip to main content

synapse_pingora/shadow/
rate_limiter.rs

1//! Per-IP rate limiter for shadow mirroring.
2//!
3//! Prevents flooding honeypots with too many requests from the same source.
4
5use dashmap::DashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8
9/// Maximum number of tracked IPs to prevent unbounded memory growth.
10const MAX_ENTRIES: usize = 50_000;
11
12/// Number of operations between capacity checks (amortized overhead).
13const CAPACITY_CHECK_INTERVAL: u64 = 128;
14
15/// Per-IP rate limiter using sliding window algorithm.
16///
17/// Uses DashMap for lock-free concurrent access, critical for high-RPS WAF scenarios.
18/// Bounded to [`MAX_ENTRIES`] to prevent memory exhaustion under DDoS.
19pub struct RateLimiter {
20    /// IP -> (count, window_start)
21    state: DashMap<String, (u32, Instant)>,
22    /// Maximum requests per window
23    limit: u32,
24    /// Window duration
25    window: Duration,
26    /// Total requests allowed (for stats)
27    allowed: AtomicU64,
28    /// Total requests rate-limited (for stats)
29    limited: AtomicU64,
30    /// Operations counter for periodic capacity enforcement
31    ops: AtomicU64,
32    /// Entries evicted due to capacity limits (for stats)
33    evicted: AtomicU64,
34}
35
36impl RateLimiter {
37    /// Creates a new rate limiter with the specified limit per minute.
38    pub fn new(limit_per_minute: u32) -> Self {
39        Self {
40            state: DashMap::new(),
41            limit: limit_per_minute,
42            window: Duration::from_secs(60),
43            allowed: AtomicU64::new(0),
44            limited: AtomicU64::new(0),
45            ops: AtomicU64::new(0),
46            evicted: AtomicU64::new(0),
47        }
48    }
49
50    /// Creates a new rate limiter with custom window duration.
51    pub fn with_window(limit: u32, window: Duration) -> Self {
52        Self {
53            state: DashMap::new(),
54            limit,
55            window,
56            allowed: AtomicU64::new(0),
57            limited: AtomicU64::new(0),
58            ops: AtomicU64::new(0),
59            evicted: AtomicU64::new(0),
60        }
61    }
62
63    /// Checks if the IP is within rate limit and increments counter.
64    ///
65    /// Returns `true` if the request is allowed, `false` if rate-limited.
66    /// Enforces [`MAX_ENTRIES`] capacity bound, evicting expired entries when full.
67    pub fn check_and_increment(&self, ip: &str) -> bool {
68        let now = Instant::now();
69
70        // Amortized capacity enforcement: check every CAPACITY_CHECK_INTERVAL ops
71        let ops = self.ops.fetch_add(1, Ordering::Relaxed);
72        if ops.is_multiple_of(CAPACITY_CHECK_INTERVAL) && self.state.len() >= MAX_ENTRIES {
73            self.evict_expired(now);
74        }
75
76        // If still over capacity after eviction, reject new IPs (existing IPs can still be tracked)
77        if self.state.len() >= MAX_ENTRIES && !self.state.contains_key(ip) {
78            self.limited.fetch_add(1, Ordering::Relaxed);
79            self.evicted.fetch_add(1, Ordering::Relaxed);
80            return false;
81        }
82
83        let allowed = {
84            let mut entry = self.state.entry(ip.to_string()).or_insert((0, now));
85
86            // Reset window if expired
87            if now.duration_since(entry.1) >= self.window {
88                entry.0 = 0;
89                entry.1 = now;
90            }
91
92            // Check limit
93            if entry.0 >= self.limit {
94                false
95            } else {
96                entry.0 += 1;
97                true
98            }
99        };
100
101        // Update stats
102        if allowed {
103            self.allowed.fetch_add(1, Ordering::Relaxed);
104        } else {
105            self.limited.fetch_add(1, Ordering::Relaxed);
106        }
107
108        allowed
109    }
110
111    /// Evicts expired entries to reclaim capacity.
112    fn evict_expired(&self, now: Instant) {
113        self.state
114            .retain(|_, (_, window_start)| now.duration_since(*window_start) < self.window);
115    }
116
117    /// Checks if the IP would be allowed without incrementing.
118    pub fn check(&self, ip: &str) -> bool {
119        let now = Instant::now();
120
121        if let Some(entry) = self.state.get(ip) {
122            // Window expired - would be allowed
123            if now.duration_since(entry.1) >= self.window {
124                return true;
125            }
126            // Check if under limit
127            entry.0 < self.limit
128        } else {
129            // New IP - would be allowed
130            true
131        }
132    }
133
134    /// Gets the current count for an IP.
135    pub fn get_count(&self, ip: &str) -> u32 {
136        self.state.get(ip).map(|e| e.0).unwrap_or(0)
137    }
138
139    /// Cleans up stale entries older than 2x the window duration.
140    ///
141    /// Call this periodically from a background task to prevent unbounded memory growth.
142    pub fn cleanup(&self) {
143        let now = Instant::now();
144        let max_age = self.window * 2;
145
146        self.state
147            .retain(|_, (_, window_start)| now.duration_since(*window_start) < max_age);
148    }
149
150    /// Returns the number of tracked IPs.
151    pub fn len(&self) -> usize {
152        self.state.len()
153    }
154
155    /// Returns true if no IPs are being tracked.
156    pub fn is_empty(&self) -> bool {
157        self.state.is_empty()
158    }
159
160    /// Returns statistics about the rate limiter.
161    pub fn stats(&self) -> RateLimiterStats {
162        RateLimiterStats {
163            tracked_ips: self.state.len(),
164            max_entries: MAX_ENTRIES,
165            allowed: self.allowed.load(Ordering::Relaxed),
166            limited: self.limited.load(Ordering::Relaxed),
167            evicted: self.evicted.load(Ordering::Relaxed),
168            limit: self.limit,
169            window_secs: self.window.as_secs(),
170        }
171    }
172
173    /// Returns the maximum number of tracked IPs.
174    pub fn max_entries(&self) -> usize {
175        MAX_ENTRIES
176    }
177
178    /// Resets all statistics and clears tracked IPs.
179    pub fn reset(&self) {
180        self.state.clear();
181        self.allowed.store(0, Ordering::Relaxed);
182        self.limited.store(0, Ordering::Relaxed);
183        self.ops.store(0, Ordering::Relaxed);
184        self.evicted.store(0, Ordering::Relaxed);
185    }
186}
187
188/// Statistics from the rate limiter.
189#[derive(Debug, Clone, serde::Serialize)]
190pub struct RateLimiterStats {
191    /// Number of IPs currently being tracked
192    pub tracked_ips: usize,
193    /// Maximum capacity
194    pub max_entries: usize,
195    /// Total requests allowed
196    pub allowed: u64,
197    /// Total requests rate-limited
198    pub limited: u64,
199    /// Entries rejected/evicted due to capacity limits
200    pub evicted: u64,
201    /// Configured limit per window
202    pub limit: u32,
203    /// Window duration in seconds
204    pub window_secs: u64,
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use std::thread;
211
212    #[test]
213    fn test_new_ip_allowed() {
214        let limiter = RateLimiter::new(10);
215        assert!(limiter.check_and_increment("192.168.1.1"));
216    }
217
218    #[test]
219    fn test_within_limit() {
220        let limiter = RateLimiter::new(5);
221        let ip = "10.0.0.1";
222
223        for _ in 0..5 {
224            assert!(limiter.check_and_increment(ip));
225        }
226
227        assert_eq!(limiter.get_count(ip), 5);
228    }
229
230    #[test]
231    fn test_exceeds_limit() {
232        let limiter = RateLimiter::new(3);
233        let ip = "10.0.0.2";
234
235        assert!(limiter.check_and_increment(ip));
236        assert!(limiter.check_and_increment(ip));
237        assert!(limiter.check_and_increment(ip));
238        // Fourth request should be limited
239        assert!(!limiter.check_and_increment(ip));
240        assert!(!limiter.check_and_increment(ip));
241    }
242
243    #[test]
244    fn test_window_reset() {
245        // Use a very short window for testing
246        let limiter = RateLimiter::with_window(2, Duration::from_millis(50));
247        let ip = "10.0.0.3";
248
249        assert!(limiter.check_and_increment(ip));
250        assert!(limiter.check_and_increment(ip));
251        assert!(!limiter.check_and_increment(ip)); // Limited
252
253        // Wait for window to expire
254        thread::sleep(Duration::from_millis(60));
255
256        // Should be allowed again
257        assert!(limiter.check_and_increment(ip));
258    }
259
260    #[test]
261    fn test_different_ips_independent() {
262        let limiter = RateLimiter::new(2);
263
264        assert!(limiter.check_and_increment("ip1"));
265        assert!(limiter.check_and_increment("ip1"));
266        assert!(!limiter.check_and_increment("ip1")); // Limited
267
268        // Different IP should be independent
269        assert!(limiter.check_and_increment("ip2"));
270        assert!(limiter.check_and_increment("ip2"));
271        assert!(!limiter.check_and_increment("ip2")); // Limited
272    }
273
274    #[test]
275    fn test_check_without_increment() {
276        let limiter = RateLimiter::new(2);
277        let ip = "10.0.0.4";
278
279        assert!(limiter.check(ip)); // Would be allowed
280        assert_eq!(limiter.get_count(ip), 0); // Not incremented
281
282        limiter.check_and_increment(ip);
283        limiter.check_and_increment(ip);
284
285        assert!(!limiter.check(ip)); // Would be limited
286    }
287
288    #[test]
289    fn test_cleanup() {
290        let limiter = RateLimiter::with_window(10, Duration::from_millis(25));
291
292        limiter.check_and_increment("ip1");
293        limiter.check_and_increment("ip2");
294        assert_eq!(limiter.len(), 2);
295
296        // Wait for entries to become stale
297        thread::sleep(Duration::from_millis(60));
298
299        limiter.cleanup();
300        assert_eq!(limiter.len(), 0);
301    }
302
303    #[test]
304    fn test_stats() {
305        let limiter = RateLimiter::new(2);
306
307        limiter.check_and_increment("ip1");
308        limiter.check_and_increment("ip1");
309        limiter.check_and_increment("ip1"); // Limited
310
311        let stats = limiter.stats();
312        assert_eq!(stats.tracked_ips, 1);
313        assert_eq!(stats.max_entries, MAX_ENTRIES);
314        assert_eq!(stats.allowed, 2);
315        assert_eq!(stats.limited, 1);
316        assert_eq!(stats.limit, 2);
317    }
318
319    #[test]
320    fn test_capacity_bound() {
321        // Use a small capacity to test the bound behavior
322        // We can't easily override MAX_ENTRIES, so we test the eviction path
323        // by filling to MAX_ENTRIES. Instead, test that the capacity check runs.
324        let limiter = RateLimiter::with_window(100, Duration::from_secs(60));
325
326        // Verify max_entries accessor
327        assert_eq!(limiter.max_entries(), MAX_ENTRIES);
328
329        // Verify evicted counter starts at 0
330        assert_eq!(limiter.stats().evicted, 0);
331    }
332
333    #[test]
334    fn test_reset() {
335        let limiter = RateLimiter::new(10);
336
337        limiter.check_and_increment("ip1");
338        limiter.check_and_increment("ip2");
339
340        limiter.reset();
341
342        assert!(limiter.is_empty());
343        let stats = limiter.stats();
344        assert_eq!(stats.allowed, 0);
345        assert_eq!(stats.limited, 0);
346    }
347
348    #[test]
349    fn test_concurrent_access() {
350        use std::sync::Arc;
351
352        let limiter = Arc::new(RateLimiter::new(100));
353        let mut handles = vec![];
354
355        for i in 0..10 {
356            let limiter = Arc::clone(&limiter);
357            let handle = thread::spawn(move || {
358                for _ in 0..10 {
359                    limiter.check_and_increment(&format!("ip{}", i));
360                }
361            });
362            handles.push(handle);
363        }
364
365        for handle in handles {
366            handle.join().unwrap();
367        }
368
369        assert_eq!(limiter.len(), 10);
370        let stats = limiter.stats();
371        assert_eq!(stats.allowed, 100);
372    }
373}