synapse_pingora/shadow/
rate_limiter.rs1use dashmap::DashMap;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::time::{Duration, Instant};
8
9const MAX_ENTRIES: usize = 50_000;
11
12const CAPACITY_CHECK_INTERVAL: u64 = 128;
14
15pub struct RateLimiter {
20 state: DashMap<String, (u32, Instant)>,
22 limit: u32,
24 window: Duration,
26 allowed: AtomicU64,
28 limited: AtomicU64,
30 ops: AtomicU64,
32 evicted: AtomicU64,
34}
35
36impl RateLimiter {
37 pub fn new(limit_per_minute: u32) -> Self {
39 Self {
40 state: DashMap::new(),
41 limit: limit_per_minute,
42 window: Duration::from_secs(60),
43 allowed: AtomicU64::new(0),
44 limited: AtomicU64::new(0),
45 ops: AtomicU64::new(0),
46 evicted: AtomicU64::new(0),
47 }
48 }
49
50 pub fn with_window(limit: u32, window: Duration) -> Self {
52 Self {
53 state: DashMap::new(),
54 limit,
55 window,
56 allowed: AtomicU64::new(0),
57 limited: AtomicU64::new(0),
58 ops: AtomicU64::new(0),
59 evicted: AtomicU64::new(0),
60 }
61 }
62
63 pub fn check_and_increment(&self, ip: &str) -> bool {
68 let now = Instant::now();
69
70 let ops = self.ops.fetch_add(1, Ordering::Relaxed);
72 if ops.is_multiple_of(CAPACITY_CHECK_INTERVAL) && self.state.len() >= MAX_ENTRIES {
73 self.evict_expired(now);
74 }
75
76 if self.state.len() >= MAX_ENTRIES && !self.state.contains_key(ip) {
78 self.limited.fetch_add(1, Ordering::Relaxed);
79 self.evicted.fetch_add(1, Ordering::Relaxed);
80 return false;
81 }
82
83 let allowed = {
84 let mut entry = self.state.entry(ip.to_string()).or_insert((0, now));
85
86 if now.duration_since(entry.1) >= self.window {
88 entry.0 = 0;
89 entry.1 = now;
90 }
91
92 if entry.0 >= self.limit {
94 false
95 } else {
96 entry.0 += 1;
97 true
98 }
99 };
100
101 if allowed {
103 self.allowed.fetch_add(1, Ordering::Relaxed);
104 } else {
105 self.limited.fetch_add(1, Ordering::Relaxed);
106 }
107
108 allowed
109 }
110
111 fn evict_expired(&self, now: Instant) {
113 self.state
114 .retain(|_, (_, window_start)| now.duration_since(*window_start) < self.window);
115 }
116
117 pub fn check(&self, ip: &str) -> bool {
119 let now = Instant::now();
120
121 if let Some(entry) = self.state.get(ip) {
122 if now.duration_since(entry.1) >= self.window {
124 return true;
125 }
126 entry.0 < self.limit
128 } else {
129 true
131 }
132 }
133
134 pub fn get_count(&self, ip: &str) -> u32 {
136 self.state.get(ip).map(|e| e.0).unwrap_or(0)
137 }
138
139 pub fn cleanup(&self) {
143 let now = Instant::now();
144 let max_age = self.window * 2;
145
146 self.state
147 .retain(|_, (_, window_start)| now.duration_since(*window_start) < max_age);
148 }
149
150 pub fn len(&self) -> usize {
152 self.state.len()
153 }
154
155 pub fn is_empty(&self) -> bool {
157 self.state.is_empty()
158 }
159
160 pub fn stats(&self) -> RateLimiterStats {
162 RateLimiterStats {
163 tracked_ips: self.state.len(),
164 max_entries: MAX_ENTRIES,
165 allowed: self.allowed.load(Ordering::Relaxed),
166 limited: self.limited.load(Ordering::Relaxed),
167 evicted: self.evicted.load(Ordering::Relaxed),
168 limit: self.limit,
169 window_secs: self.window.as_secs(),
170 }
171 }
172
173 pub fn max_entries(&self) -> usize {
175 MAX_ENTRIES
176 }
177
178 pub fn reset(&self) {
180 self.state.clear();
181 self.allowed.store(0, Ordering::Relaxed);
182 self.limited.store(0, Ordering::Relaxed);
183 self.ops.store(0, Ordering::Relaxed);
184 self.evicted.store(0, Ordering::Relaxed);
185 }
186}
187
188#[derive(Debug, Clone, serde::Serialize)]
190pub struct RateLimiterStats {
191 pub tracked_ips: usize,
193 pub max_entries: usize,
195 pub allowed: u64,
197 pub limited: u64,
199 pub evicted: u64,
201 pub limit: u32,
203 pub window_secs: u64,
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use std::thread;
211
212 #[test]
213 fn test_new_ip_allowed() {
214 let limiter = RateLimiter::new(10);
215 assert!(limiter.check_and_increment("192.168.1.1"));
216 }
217
218 #[test]
219 fn test_within_limit() {
220 let limiter = RateLimiter::new(5);
221 let ip = "10.0.0.1";
222
223 for _ in 0..5 {
224 assert!(limiter.check_and_increment(ip));
225 }
226
227 assert_eq!(limiter.get_count(ip), 5);
228 }
229
230 #[test]
231 fn test_exceeds_limit() {
232 let limiter = RateLimiter::new(3);
233 let ip = "10.0.0.2";
234
235 assert!(limiter.check_and_increment(ip));
236 assert!(limiter.check_and_increment(ip));
237 assert!(limiter.check_and_increment(ip));
238 assert!(!limiter.check_and_increment(ip));
240 assert!(!limiter.check_and_increment(ip));
241 }
242
243 #[test]
244 fn test_window_reset() {
245 let limiter = RateLimiter::with_window(2, Duration::from_millis(50));
247 let ip = "10.0.0.3";
248
249 assert!(limiter.check_and_increment(ip));
250 assert!(limiter.check_and_increment(ip));
251 assert!(!limiter.check_and_increment(ip)); thread::sleep(Duration::from_millis(60));
255
256 assert!(limiter.check_and_increment(ip));
258 }
259
260 #[test]
261 fn test_different_ips_independent() {
262 let limiter = RateLimiter::new(2);
263
264 assert!(limiter.check_and_increment("ip1"));
265 assert!(limiter.check_and_increment("ip1"));
266 assert!(!limiter.check_and_increment("ip1")); assert!(limiter.check_and_increment("ip2"));
270 assert!(limiter.check_and_increment("ip2"));
271 assert!(!limiter.check_and_increment("ip2")); }
273
274 #[test]
275 fn test_check_without_increment() {
276 let limiter = RateLimiter::new(2);
277 let ip = "10.0.0.4";
278
279 assert!(limiter.check(ip)); assert_eq!(limiter.get_count(ip), 0); limiter.check_and_increment(ip);
283 limiter.check_and_increment(ip);
284
285 assert!(!limiter.check(ip)); }
287
288 #[test]
289 fn test_cleanup() {
290 let limiter = RateLimiter::with_window(10, Duration::from_millis(25));
291
292 limiter.check_and_increment("ip1");
293 limiter.check_and_increment("ip2");
294 assert_eq!(limiter.len(), 2);
295
296 thread::sleep(Duration::from_millis(60));
298
299 limiter.cleanup();
300 assert_eq!(limiter.len(), 0);
301 }
302
303 #[test]
304 fn test_stats() {
305 let limiter = RateLimiter::new(2);
306
307 limiter.check_and_increment("ip1");
308 limiter.check_and_increment("ip1");
309 limiter.check_and_increment("ip1"); let stats = limiter.stats();
312 assert_eq!(stats.tracked_ips, 1);
313 assert_eq!(stats.max_entries, MAX_ENTRIES);
314 assert_eq!(stats.allowed, 2);
315 assert_eq!(stats.limited, 1);
316 assert_eq!(stats.limit, 2);
317 }
318
319 #[test]
320 fn test_capacity_bound() {
321 let limiter = RateLimiter::with_window(100, Duration::from_secs(60));
325
326 assert_eq!(limiter.max_entries(), MAX_ENTRIES);
328
329 assert_eq!(limiter.stats().evicted, 0);
331 }
332
333 #[test]
334 fn test_reset() {
335 let limiter = RateLimiter::new(10);
336
337 limiter.check_and_increment("ip1");
338 limiter.check_and_increment("ip2");
339
340 limiter.reset();
341
342 assert!(limiter.is_empty());
343 let stats = limiter.stats();
344 assert_eq!(stats.allowed, 0);
345 assert_eq!(stats.limited, 0);
346 }
347
348 #[test]
349 fn test_concurrent_access() {
350 use std::sync::Arc;
351
352 let limiter = Arc::new(RateLimiter::new(100));
353 let mut handles = vec![];
354
355 for i in 0..10 {
356 let limiter = Arc::clone(&limiter);
357 let handle = thread::spawn(move || {
358 for _ in 0..10 {
359 limiter.check_and_increment(&format!("ip{}", i));
360 }
361 });
362 handles.push(handle);
363 }
364
365 for handle in handles {
366 handle.join().unwrap();
367 }
368
369 assert_eq!(limiter.len(), 10);
370 let stats = limiter.stats();
371 assert_eq!(stats.allowed, 100);
372 }
373}