volga_rate_limiter/rate_limiter/
sliding_window.rs1use super::{
4 RateLimiter, SystemTimeSource, TimeSource,
5 store::{SlidingWindowParams, SlidingWindowStore},
6};
7use dashmap::DashMap;
8use std::sync::{
9 Arc,
10 atomic::{AtomicU32, AtomicU64, Ordering::*},
11};
12use std::time::Duration;
13
14#[derive(Debug)]
25struct Entry {
26 previous_count: AtomicU32,
28
29 current_count: AtomicU32,
31
32 window_start: AtomicU64,
34}
35
36#[derive(Debug, Clone)]
40pub struct InMemorySlidingWindowStore {
41 storage: Arc<DashMap<u64, Entry>>,
42}
43
44impl InMemorySlidingWindowStore {
45 pub fn new() -> Self {
47 Self {
48 storage: Arc::new(DashMap::new()),
49 }
50 }
51}
52
53impl Default for InMemorySlidingWindowStore {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl SlidingWindowStore for InMemorySlidingWindowStore {
60 #[inline]
61 fn check_and_count(&self, params: SlidingWindowParams) -> bool {
62 let SlidingWindowParams {
64 key,
65 window,
66 window_size_secs,
67 max_requests,
68 now,
69 grace_secs,
70 } = params;
71
72 if let Some(entry) = self.storage.get(&key) {
74 let window_start = entry.window_start.load(Acquire);
75 if now.saturating_sub(window_start) > grace_secs {
76 drop(entry);
77 self.storage.remove(&key);
78 }
79 }
80
81 let entry = self.storage.entry(key).or_insert_with(|| {
82 let window_start = now / window_size_secs * window_size_secs;
83 Entry {
84 previous_count: AtomicU32::new(0),
85 current_count: AtomicU32::new(0),
86 window_start: AtomicU64::new(window_start),
87 }
88 });
89
90 let window_start = entry.window_start.load(Acquire);
91
92 if window > window_start {
93 let windows_passed = (window - window_start) / window_size_secs;
94
95 if windows_passed >= 2 {
96 entry.previous_count.store(0, Release);
97 entry.current_count.store(0, Release);
98 entry.window_start.store(window, Release);
99 } else {
100 let old_current = entry.current_count.swap(0, AcqRel);
101 entry.previous_count.store(old_current, Release);
102 entry.window_start.store(window, Release);
103 }
104 }
105
106 let previous = entry.previous_count.load(Acquire);
107 let current = entry.current_count.load(Acquire);
108
109 let elapsed_in_window = now - entry.window_start.load(Acquire);
110 let progress = (elapsed_in_window as f64 / window_size_secs as f64).min(1.0);
111 let previous_weight = 1.0 - progress;
112 let effective = previous as f64 * previous_weight + current as f64;
113
114 if effective >= f64::from(max_requests) {
115 return false;
116 }
117
118 entry.current_count.fetch_add(1, Release);
119 true
120 }
121}
122
123#[derive(Debug)]
171pub struct SlidingWindowRateLimiter<
172 T: TimeSource = SystemTimeSource,
173 S: SlidingWindowStore = InMemorySlidingWindowStore,
174> {
175 store: S,
176 max_requests: u32,
177 window_size_secs: u64,
178 eviction_grace_secs: u64,
179 time_source: T,
180}
181
182impl<T: TimeSource, S: SlidingWindowStore> RateLimiter for SlidingWindowRateLimiter<T, S> {
183 #[inline]
190 fn check(&self, key: u64) -> bool {
191 let now = self.time_source.now_secs();
192 let window = now / self.window_size_secs * self.window_size_secs;
193 self.store.check_and_count(SlidingWindowParams {
194 key,
195 window,
196 window_size_secs: self.window_size_secs,
197 max_requests: self.max_requests,
198 now,
199 grace_secs: self.eviction_grace_secs,
200 })
201 }
202}
203
204impl SlidingWindowRateLimiter {
205 #[inline]
217 pub fn new(max_requests: u32, window_size: Duration) -> Self {
218 Self::with_time_source(max_requests, window_size, SystemTimeSource)
219 }
220}
221
222impl<T: TimeSource> SlidingWindowRateLimiter<T> {
223 #[inline]
234 pub fn with_time_source(max_requests: u32, window_size: Duration, time_source: T) -> Self {
235 Self::with_time_source_and_store(
236 max_requests,
237 window_size,
238 time_source,
239 InMemorySlidingWindowStore::new(),
240 )
241 }
242}
243
244impl<S: SlidingWindowStore> SlidingWindowRateLimiter<SystemTimeSource, S> {
245 #[inline]
251 pub fn with_store(max_requests: u32, window_size: Duration, store: S) -> Self {
252 Self::with_time_source_and_store(max_requests, window_size, SystemTimeSource, store)
253 }
254}
255
256impl<T: TimeSource, S: SlidingWindowStore> SlidingWindowRateLimiter<T, S> {
257 #[inline]
263 pub fn with_time_source_and_store(
264 max_requests: u32,
265 window_size: Duration,
266 time_source: T,
267 store: S,
268 ) -> Self {
269 let window_size_secs = window_size.as_secs();
270 assert!(
271 window_size_secs > 0,
272 "window_size must be at least 1 second"
273 );
274 Self {
275 store,
276 max_requests,
277 window_size_secs,
278 eviction_grace_secs: window_size_secs.saturating_mul(2),
279 time_source,
280 }
281 }
282
283 #[inline]
288 pub fn set_eviction(&mut self, eviction: Duration) {
289 self.eviction_grace_secs = eviction.as_secs();
290 }
291
292 #[inline(always)]
294 pub fn max_requests(&self) -> u32 {
295 self.max_requests
296 }
297
298 #[inline(always)]
300 pub fn window_size_secs(&self) -> u64 {
301 self.window_size_secs
302 }
303
304 #[inline(always)]
308 pub fn eviction_grace_secs(&self) -> u64 {
309 self.eviction_grace_secs
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::super::test_utils::MockTimeSource;
316 use super::*;
317
318 #[test]
319 fn sliding_window_allows_within_limit() {
320 let limiter = SlidingWindowRateLimiter::new(3, Duration::from_secs(10));
321
322 let key = 7;
323
324 assert!(limiter.check(key));
325 assert!(limiter.check(key));
326 assert!(limiter.check(key));
327 assert!(!limiter.check(key));
328 }
329
330 #[test]
331 fn it_tests_window_sliding() {
332 let time = MockTimeSource::new(1000);
333 let limiter =
334 SlidingWindowRateLimiter::with_time_source(10, Duration::from_secs(10), time.clone());
335
336 for i in 0..10 {
337 assert!(limiter.check(1), "Request {} should pass", i + 1);
338 }
339 assert!(!limiter.check(1), "Request 11 should be denied");
340
341 time.advance(5);
342
343 assert!(!limiter.check(1), "Should be denied at 50% of window");
344
345 time.advance(6);
346
347 assert!(limiter.check(1), "Should allow in new window");
348
349 time.advance(10);
350
351 for i in 0..10 {
352 assert!(
353 limiter.check(1),
354 "Request {} should pass after reset",
355 i + 1
356 );
357 }
358 assert!(!limiter.check(1), "Request 11 should be denied");
359 }
360
361 #[test]
362 fn it_tests_window_transition() {
363 let time = MockTimeSource::new(2000);
364 let limiter =
365 SlidingWindowRateLimiter::with_time_source(3, Duration::from_secs(10), time.clone());
366
367 assert!(limiter.check(1));
368 assert!(limiter.check(1));
369 assert!(limiter.check(1));
370 assert!(!limiter.check(1), "4th request should be denied");
371
372 time.advance(5);
373
374 assert!(!limiter.check(1), "Should be denied at 50%");
377
378 time.advance(6);
379
380 assert!(limiter.check(1), "Should allow 1st request in new window");
385
386 assert!(
389 !limiter.check(1),
390 "Should be denied - effective = 3*0.9 + 1 = 3.7"
391 );
392
393 time.advance(2);
394
395 assert!(
399 !limiter.check(1),
400 "Still denied - effective = 3*0.7 + 1 = 3.1"
401 );
402
403 time.advance(4);
404
405 assert!(
409 limiter.check(1),
410 "Should allow - effective = 3*0.3 + 1 = 1.9"
411 );
412 assert!(
413 limiter.check(1),
414 "Should allow - effective = 3*0.3 + 2 = 2.9"
415 );
416 }
417
418 #[test]
419 fn sliding_window_isolated_per_key() {
420 let limiter = SlidingWindowRateLimiter::new(1, Duration::from_secs(5));
421
422 assert!(limiter.check(1));
423 assert!(!limiter.check(1));
424
425 assert!(limiter.check(2));
426 }
427
428 #[test]
429 fn sliding_window_with_custom_store_delegates_to_store() {
430 use crate::rate_limiter::store::{SlidingWindowParams, SlidingWindowStore};
431 use std::sync::Arc;
432 use std::sync::atomic::{AtomicU32, Ordering::Relaxed};
433
434 struct CountingStore {
435 inner: InMemorySlidingWindowStore,
436 calls: Arc<AtomicU32>,
437 }
438 impl SlidingWindowStore for CountingStore {
439 fn check_and_count(&self, params: SlidingWindowParams) -> bool {
440 self.calls.fetch_add(1, Relaxed);
441 self.inner.check_and_count(params)
442 }
443 }
444
445 let calls = Arc::new(AtomicU32::new(0));
446 let store = CountingStore {
447 inner: InMemorySlidingWindowStore::new(),
448 calls: calls.clone(),
449 };
450 let limiter = SlidingWindowRateLimiter::with_store(3, Duration::from_secs(10), store);
451
452 assert!(limiter.check(1));
453 assert_eq!(calls.load(Relaxed), 1);
454 }
455
456 #[test]
457 #[should_panic(expected = "window_size must be at least 1 second")]
458 fn sliding_window_panics_on_zero_window_size() {
459 let _ = SlidingWindowRateLimiter::new(10, Duration::ZERO);
460 }
461
462 #[test]
463 fn sliding_window_is_thread_safe() {
464 use std::sync::Arc;
465 use std::thread;
466
467 let limiter = Arc::new(SlidingWindowRateLimiter::new(1000, Duration::from_secs(10)));
468
469 let key = 123;
470
471 let mut handles = vec![];
472
473 for _ in 0..8 {
474 let limiter = limiter.clone();
475 handles.push(thread::spawn(move || {
476 let mut allowed = 0;
477 for _ in 0..200 {
478 if limiter.check(key) {
479 allowed += 1;
480 }
481 }
482 allowed
483 }));
484 }
485
486 let total: u32 = handles.into_iter().map(|h| h.join().unwrap()).sum();
487
488 assert!(total <= 1000 + 8);
490 }
491}