volga_rate_limiter/rate_limiter/
token_bucket.rs1use super::{
4 MICROS_PER_SEC, RateLimiter, SystemTimeSource, TimeSource,
5 store::{TokenBucketParams, TokenBucketStore},
6};
7use dashmap::DashMap;
8use std::{
9 sync::{
10 Arc,
11 atomic::{AtomicU64, Ordering::*},
12 },
13 time::Duration,
14};
15
16#[derive(Debug)]
22struct Entry {
23 available_tokens: AtomicU64,
25
26 last_refill_us: AtomicU64,
28
29 last_seen_us: AtomicU64,
31}
32
33#[derive(Debug, Clone)]
38pub struct InMemoryTokenBucketStore {
39 storage: Arc<DashMap<u64, Entry>>,
40}
41
42impl InMemoryTokenBucketStore {
43 pub fn new() -> Self {
45 Self {
46 storage: Arc::new(DashMap::new()),
47 }
48 }
49}
50
51impl Default for InMemoryTokenBucketStore {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57impl TokenBucketStore for InMemoryTokenBucketStore {
58 #[inline]
59 fn try_consume(&self, params: TokenBucketParams) -> bool {
60 let TokenBucketParams {
61 key,
62 now_us,
63 capacity_scaled,
64 refill_rate_scaled_per_sec,
65 scale,
66 eviction_grace_us,
67 } = params;
68
69 if let Some(entry) = self.storage.get(&key) {
71 let last_seen = entry.last_seen_us.load(Acquire);
72 if now_us.saturating_sub(last_seen) > eviction_grace_us {
73 drop(entry);
74 self.storage.remove(&key);
75 }
76 }
77
78 let entry = self.storage.entry(key).or_insert_with(|| Entry {
79 available_tokens: AtomicU64::new(capacity_scaled),
80 last_refill_us: AtomicU64::new(now_us),
81 last_seen_us: AtomicU64::new(now_us),
82 });
83
84 entry.last_seen_us.store(now_us, Release);
86
87 Self::refill(
88 entry.value(),
89 now_us,
90 refill_rate_scaled_per_sec,
91 capacity_scaled,
92 );
93 Self::consume(entry.value(), scale)
94 }
95}
96
97impl InMemoryTokenBucketStore {
98 fn refill(entry: &Entry, now_us: u64, refill_rate_scaled_per_sec: u64, capacity_scaled: u64) {
99 if refill_rate_scaled_per_sec == 0 {
100 return;
101 }
102
103 let mut last = entry.last_refill_us.load(Acquire);
105 loop {
106 if now_us <= last {
107 return;
108 }
109
110 match entry
111 .last_refill_us
112 .compare_exchange(last, now_us, AcqRel, Acquire)
113 {
114 Ok(_) => break,
115 Err(next) => last = next,
116 }
117 }
118
119 let elapsed_us = now_us - last;
120 let num = (elapsed_us as u128) * (refill_rate_scaled_per_sec as u128);
122 let add_u128 = num / (MICROS_PER_SEC as u128);
123 let add = u64::try_from(add_u128).unwrap_or(u64::MAX);
124
125 if add == 0 {
126 return;
127 }
128
129 let mut current = entry.available_tokens.load(Relaxed);
130 loop {
131 let updated = current.saturating_add(add).min(capacity_scaled);
132 match entry
133 .available_tokens
134 .compare_exchange(current, updated, AcqRel, Relaxed)
135 {
136 Ok(_) => return,
137 Err(next) => current = next,
138 }
139 }
140 }
141
142 fn consume(entry: &Entry, scale: u64) -> bool {
143 let mut current = entry.available_tokens.load(Relaxed);
144 loop {
145 if current < scale {
146 return false;
147 }
148 let updated = current - scale;
149 match entry
150 .available_tokens
151 .compare_exchange(current, updated, AcqRel, Relaxed)
152 {
153 Ok(_) => return true,
154 Err(next) => current = next,
155 }
156 }
157 }
158}
159
160#[derive(Debug)]
196pub struct TokenBucketRateLimiter<
197 T: TimeSource = SystemTimeSource,
198 S: TokenBucketStore = InMemoryTokenBucketStore,
199> {
200 store: S,
201 capacity: u64,
202 refill_rate_scaled_per_sec: u64,
203 scale: u64,
204 capacity_scaled: u64,
205 eviction_grace_us: u64,
206 time_source: T,
207}
208
209impl<T: TimeSource, S: TokenBucketStore> RateLimiter for TokenBucketRateLimiter<T, S> {
210 #[inline]
215 fn check(&self, key: u64) -> bool {
216 self.store.try_consume(TokenBucketParams {
217 key,
218 now_us: self.time_source.now_micros(),
219 capacity_scaled: self.capacity_scaled,
220 refill_rate_scaled_per_sec: self.refill_rate_scaled_per_sec,
221 scale: self.scale,
222 eviction_grace_us: self.eviction_grace_us,
223 })
224 }
225}
226
227const DEFAULT_SCALE: u64 = MICROS_PER_SEC;
228const DEFAULT_EVICTION: u64 = 60 * MICROS_PER_SEC; impl TokenBucketRateLimiter {
231 #[inline]
251 pub fn new(capacity: u64, refill_rate: f64) -> Self {
252 Self::with_time_source(capacity, refill_rate, SystemTimeSource)
253 }
254}
255
256impl<T: TimeSource> TokenBucketRateLimiter<T> {
257 #[inline]
265 pub fn with_time_source(capacity: u64, refill_rate: f64, time_source: T) -> Self {
266 Self::with_time_source_and_store(
267 capacity,
268 refill_rate,
269 time_source,
270 InMemoryTokenBucketStore::new(),
271 )
272 }
273}
274
275impl<S: TokenBucketStore> TokenBucketRateLimiter<SystemTimeSource, S> {
276 #[inline]
282 pub fn with_store(capacity: u64, refill_rate: f64, store: S) -> Self {
283 Self::with_time_source_and_store(capacity, refill_rate, SystemTimeSource, store)
284 }
285}
286
287impl<T: TimeSource, S: TokenBucketStore> TokenBucketRateLimiter<T, S> {
288 #[inline]
294 pub fn with_time_source_and_store(
295 capacity: u64,
296 refill_rate: f64,
297 time_source: T,
298 store: S,
299 ) -> Self {
300 let scale: u64 = DEFAULT_SCALE;
301
302 assert!(refill_rate.is_finite(), "refill_rate must be finite");
303 assert!(refill_rate >= 0.0, "refill_rate must be >= 0");
304
305 let scaled_f = refill_rate * scale as f64;
306 assert!(scaled_f <= u64::MAX as f64, "refill_rate too large");
307
308 let refill_rate_scaled_per_sec = scaled_f.round() as u64;
309
310 let capacity_scaled = capacity
311 .checked_mul(scale)
312 .expect("capacity * scale overflow");
313
314 Self {
315 store,
316 capacity,
317 refill_rate_scaled_per_sec,
318 scale,
319 capacity_scaled,
320 eviction_grace_us: DEFAULT_EVICTION,
321 time_source,
322 }
323 }
324
325 #[inline]
330 pub fn set_eviction(&mut self, eviction: Duration) {
331 self.eviction_grace_us = eviction.as_micros().try_into().unwrap_or(u64::MAX);
332 }
333
334 #[inline(always)]
336 pub fn capacity(&self) -> u64 {
337 self.capacity
338 }
339
340 #[inline(always)]
342 pub fn refill_rate(&self) -> f64 {
343 self.refill_rate_scaled_per_sec as f64 / self.scale as f64
344 }
345
346 #[inline(always)]
348 pub fn eviction_grace_secs(&self) -> u64 {
349 self.eviction_grace_us / MICROS_PER_SEC
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::super::test_utils::MockTimeSource;
356 use super::*;
357
358 #[test]
359 fn token_bucket_allows_burst_up_to_capacity() {
360 let limiter = TokenBucketRateLimiter::new(3, 1.0);
361 let key = 99;
362
363 assert!(limiter.check(key));
364 assert!(limiter.check(key));
365 assert!(limiter.check(key));
366 assert!(!limiter.check(key));
367 }
368
369 #[test]
370 fn token_bucket_refills_over_time() {
371 let time = MockTimeSource::new(100);
372 let limiter = TokenBucketRateLimiter::with_time_source(2, 1.0, time.clone());
373 let key = 7;
374
375 assert!(limiter.check(key));
376 assert!(limiter.check(key));
377 assert!(!limiter.check(key));
378
379 time.advance(1);
380 assert!(limiter.check(key));
381 assert!(!limiter.check(key));
382
383 time.advance(1);
384 assert!(limiter.check(key));
385 }
386
387 #[test]
388 fn token_bucket_isolated_per_key() {
389 let limiter = TokenBucketRateLimiter::new(1, 1.0);
390
391 assert!(limiter.check(1));
392 assert!(!limiter.check(1));
393 assert!(limiter.check(2));
394 }
395
396 #[test]
397 fn token_bucket_with_custom_store_delegates_to_store() {
398 use crate::rate_limiter::store::{TokenBucketParams, TokenBucketStore};
399 use std::sync::Arc;
400 use std::sync::atomic::{AtomicU32, Ordering::Relaxed};
401
402 struct CountingStore {
403 inner: InMemoryTokenBucketStore,
404 calls: Arc<AtomicU32>,
405 }
406 impl TokenBucketStore for CountingStore {
407 fn try_consume(&self, params: TokenBucketParams) -> bool {
408 self.calls.fetch_add(1, Relaxed);
409 self.inner.try_consume(params)
410 }
411 }
412
413 let calls = Arc::new(AtomicU32::new(0));
414 let store = CountingStore {
415 inner: InMemoryTokenBucketStore::new(),
416 calls: calls.clone(),
417 };
418 let limiter = TokenBucketRateLimiter::with_store(3, 1.0, store);
419
420 assert!(limiter.check(99));
421 assert_eq!(calls.load(Relaxed), 1);
422 }
423
424 #[test]
425 fn token_bucket_zero_refill_rate_is_valid() {
426 let limiter = TokenBucketRateLimiter::new(2, 0.0);
433 assert!(limiter.check(1));
434 assert!(limiter.check(1));
435 assert!(!limiter.check(1)); }
437
438 #[test]
439 fn token_bucket_tiny_refill_rate_rounds_to_zero_scaled() {
440 let limiter = TokenBucketRateLimiter::new(1, 1e-10);
443 assert!(limiter.check(1));
444 assert!(!limiter.check(1));
445 }
446
447 #[test]
448 #[should_panic(expected = "capacity * scale overflow")]
449 fn panics_when_capacity_scaled_overflows() {
450 let scale = 1_000_000u64;
453 let capacity = (u64::MAX / scale) + 1;
454
455 let _ = TokenBucketRateLimiter::with_time_source_and_store(
456 capacity,
457 1.0,
458 SystemTimeSource,
459 InMemoryTokenBucketStore::new(),
460 );
461 }
462
463 #[test]
464 #[should_panic(expected = "refill_rate must be finite")]
465 fn panics_when_refill_rate_is_nan() {
466 let _ = TokenBucketRateLimiter::with_time_source_and_store(
467 1,
468 f64::NAN,
469 SystemTimeSource,
470 InMemoryTokenBucketStore::new(),
471 );
472 }
473
474 #[test]
475 #[should_panic(expected = "refill_rate must be finite")]
476 fn panics_when_refill_rate_is_infinite() {
477 let _ = TokenBucketRateLimiter::with_time_source_and_store(
478 1,
479 f64::INFINITY,
480 SystemTimeSource,
481 InMemoryTokenBucketStore::new(),
482 );
483 }
484
485 #[test]
486 #[should_panic(expected = "refill_rate must be >= 0")]
487 fn panics_when_refill_rate_is_negative() {
488 let _ = TokenBucketRateLimiter::with_time_source_and_store(
489 1,
490 -0.1,
491 SystemTimeSource,
492 InMemoryTokenBucketStore::new(),
493 );
494 }
495
496 #[test]
497 #[should_panic(expected = "refill_rate too large")]
498 fn panics_when_refill_rate_scaled_exceeds_u64_max() {
499 let _ = TokenBucketRateLimiter::with_time_source_and_store(
502 1,
503 1e30,
504 SystemTimeSource,
505 InMemoryTokenBucketStore::new(),
506 );
507 }
508}