tokio_rate_limit/algorithm/
probabilistic_token_bucket.rs1use crate::algorithm::Algorithm;
25use crate::error::Result;
26use crate::limiter::RateLimitDecision;
27use async_trait::async_trait;
28use flurry::HashMap as FlurryHashMap;
29use std::sync::atomic::{AtomicU64, Ordering};
30use std::sync::Arc;
31use std::time::Duration;
32use tokio::time::Instant;
33
34const SCALE: u64 = 1000;
36
37const MAX_BURST: u64 = u64::MAX / (2 * SCALE);
39
40const MAX_RATE_PER_SEC: u64 = u64::MAX / (2 * SCALE);
42
43const NUM_SHARDS: usize = 256;
45
46thread_local! {
49 static RNG_STATE: std::cell::Cell<u64> = std::cell::Cell::new(
50 std::time::SystemTime::now()
51 .duration_since(std::time::UNIX_EPOCH)
52 .unwrap()
53 .as_nanos() as u64
54 );
55}
56
57#[inline]
60fn fast_random() -> u64 {
61 RNG_STATE.with(|state| {
62 let mut x = state.get();
63 if x == 0 {
64 x = 1;
65 }
66 x ^= x << 13;
67 x ^= x >> 7;
68 x ^= x << 17;
69 state.set(x);
70 x
71 })
72}
73
74struct AtomicProbabilisticState {
76 tokens: AtomicU64,
79
80 last_refill_nanos: AtomicU64,
82
83 last_access_nanos: AtomicU64,
85}
86
87impl AtomicProbabilisticState {
88 fn new(capacity: u64, sample_rate: u32, now_nanos: u64) -> Self {
89 Self {
91 tokens: AtomicU64::new(capacity.saturating_mul(SCALE).saturating_mul(sample_rate as u64)),
92 last_refill_nanos: AtomicU64::new(now_nanos),
93 last_access_nanos: AtomicU64::new(now_nanos),
94 }
95 }
96
97 fn try_consume_probabilistic(
102 &self,
103 capacity: u64,
104 refill_rate_per_second: u64,
105 now_nanos: u64,
106 cost: u64,
107 sample_rate: u32,
108 ) -> (bool, u64) {
109 self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
111
112 let should_sample = (fast_random() % sample_rate as u64) == 0;
114
115 if should_sample {
116 let scaled_capacity = capacity.saturating_mul(SCALE).saturating_mul(sample_rate as u64);
118 let token_cost = cost.saturating_mul(SCALE).saturating_mul(sample_rate as u64);
119
120 loop {
121 let current_tokens = self.tokens.load(Ordering::Relaxed);
122 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
123
124 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
126 let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
127 let tokens_per_sec_scaled = refill_rate_per_second
128 .saturating_mul(SCALE)
129 .saturating_mul(sample_rate as u64);
130 let new_tokens_to_add = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
131
132 let updated_tokens = current_tokens
133 .saturating_add(new_tokens_to_add)
134 .min(scaled_capacity);
135
136 if updated_tokens >= token_cost {
137 let new_tokens = updated_tokens.saturating_sub(token_cost);
139 let new_time = if new_tokens_to_add > 0 {
140 now_nanos
141 } else {
142 last_refill
143 };
144
145 match self.tokens.compare_exchange_weak(
146 current_tokens,
147 new_tokens,
148 Ordering::AcqRel,
149 Ordering::Relaxed,
150 ) {
151 Ok(_) => {
152 if new_tokens_to_add > 0 {
153 let _ = self.last_refill_nanos.compare_exchange_weak(
154 last_refill,
155 new_time,
156 Ordering::AcqRel,
157 Ordering::Relaxed,
158 );
159 }
160 return (true, new_tokens / (SCALE * sample_rate as u64));
162 }
163 Err(_) => continue,
164 }
165 } else {
166 let new_time = if new_tokens_to_add > 0 {
168 now_nanos
169 } else {
170 last_refill
171 };
172
173 match self.tokens.compare_exchange_weak(
174 current_tokens,
175 updated_tokens,
176 Ordering::AcqRel,
177 Ordering::Relaxed,
178 ) {
179 Ok(_) => {
180 if new_tokens_to_add > 0 {
181 let _ = self.last_refill_nanos.compare_exchange_weak(
182 last_refill,
183 new_time,
184 Ordering::AcqRel,
185 Ordering::Relaxed,
186 );
187 }
188 return (false, updated_tokens / (SCALE * sample_rate as u64));
189 }
190 Err(_) => continue,
191 }
192 }
193 }
194 } else {
195 let current_tokens = self.tokens.load(Ordering::Relaxed);
197 let token_cost = cost.saturating_mul(SCALE).saturating_mul(sample_rate as u64);
198
199 let permitted = current_tokens >= token_cost;
201 let remaining = current_tokens / (SCALE * sample_rate as u64);
202
203 (permitted, remaining)
204 }
205 }
206}
207
208pub struct ProbabilisticTokenBucket {
242 capacity: u64,
243 refill_rate_per_second: u64,
244 reference_instant: Instant,
245 idle_ttl: Option<Duration>,
246 shards: Vec<Arc<FlurryHashMap<String, Arc<AtomicProbabilisticState>>>>,
247
248 sample_rate: u32,
253}
254
255impl ProbabilisticTokenBucket {
256 #[inline]
258 fn get_shard_index(key: &str) -> usize {
259 let mut hash: u64 = 0xcbf29ce484222325;
260 for byte in key.bytes() {
261 hash ^= byte as u64;
262 hash = hash.wrapping_mul(0x100000001b3);
263 }
264 (hash as usize) & (NUM_SHARDS - 1)
265 }
266
267 #[inline]
268 fn get_shard(&self, key: &str) -> &Arc<FlurryHashMap<String, Arc<AtomicProbabilisticState>>> {
269 let index = Self::get_shard_index(key);
270 &self.shards[index]
271 }
272
273 pub fn new(capacity: u64, refill_rate_per_second: u64, sample_rate: u32) -> Self {
297 assert!(sample_rate >= 1, "Sample rate must be at least 1");
298
299 let safe_capacity = capacity.min(MAX_BURST);
300 let safe_rate = refill_rate_per_second.min(MAX_RATE_PER_SEC);
301
302 let shards = (0..NUM_SHARDS)
303 .map(|_| Arc::new(FlurryHashMap::new()))
304 .collect();
305
306 Self {
307 capacity: safe_capacity,
308 refill_rate_per_second: safe_rate,
309 reference_instant: Instant::now(),
310 idle_ttl: None,
311 shards,
312 sample_rate,
313 }
314 }
315
316 pub fn with_ttl(
318 capacity: u64,
319 refill_rate_per_second: u64,
320 sample_rate: u32,
321 idle_ttl: Duration,
322 ) -> Self {
323 let mut bucket = Self::new(capacity, refill_rate_per_second, sample_rate);
324 bucket.idle_ttl = Some(idle_ttl);
325 bucket
326 }
327
328 #[inline]
329 fn now_nanos(&self) -> u64 {
330 self.reference_instant.elapsed().as_nanos() as u64
331 }
332
333 fn cleanup_idle(&self, now_nanos: u64) {
334 if let Some(ttl) = self.idle_ttl {
335 let ttl_nanos = ttl.as_nanos() as u64;
336
337 for shard in &self.shards {
338 let guard = shard.guard();
339 let keys_to_remove: Vec<String> = shard
340 .iter(&guard)
341 .filter_map(|(key, state)| {
342 let last_access = state.last_access_nanos.load(Ordering::Relaxed);
343 let age = now_nanos.saturating_sub(last_access);
344 if age >= ttl_nanos {
345 Some(key.clone())
346 } else {
347 None
348 }
349 })
350 .collect();
351
352 for key in keys_to_remove {
353 shard.remove(&key, &guard);
354 }
355 }
356 }
357 }
358
359 pub fn sample_rate(&self) -> u32 {
361 self.sample_rate
362 }
363
364 #[cfg(test)]
366 fn len(&self) -> usize {
367 self.shards.iter().map(|shard| shard.len()).sum()
368 }
369}
370
371impl super::private::Sealed for ProbabilisticTokenBucket {}
372
373#[async_trait]
374impl Algorithm for ProbabilisticTokenBucket {
375 async fn check(&self, key: &str) -> Result<RateLimitDecision> {
376 let now = self.now_nanos();
377
378 if self.idle_ttl.is_some() && (fast_random() % (self.sample_rate as u64 * 100)) == 0 {
380 self.cleanup_idle(now);
381 }
382
383 let shard = self.get_shard(key);
384 let guard = shard.guard();
385 let state = match shard.get(key, &guard) {
386 Some(state) => state.clone(),
387 None => {
388 let new_state = Arc::new(AtomicProbabilisticState::new(
389 self.capacity,
390 self.sample_rate,
391 now,
392 ));
393 let key_string = key.to_string();
394 match shard.try_insert(key_string, new_state.clone(), &guard) {
395 Ok(_) => new_state,
396 Err(current) => current.current.clone(),
397 }
398 }
399 };
400
401 let (permitted, remaining) = state.try_consume_probabilistic(
402 self.capacity,
403 self.refill_rate_per_second,
404 now,
405 1,
406 self.sample_rate,
407 );
408
409 let retry_after = if !permitted {
410 let tokens_needed = 1u64.saturating_sub(remaining);
411 let seconds_to_wait = if self.refill_rate_per_second > 0 {
412 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
413 } else {
414 1.0
415 };
416 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
417 } else {
418 None
419 };
420
421 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
422 let tokens_to_refill = self.capacity.saturating_sub(remaining);
423 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
424 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
425 } else if remaining >= self.capacity {
426 Some(Duration::from_secs(0))
427 } else {
428 None
429 };
430
431 Ok(RateLimitDecision {
432 permitted,
433 retry_after,
434 remaining: Some(remaining),
435 limit: self.capacity,
436 reset,
437 })
438 }
439
440 async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
441 let now = self.now_nanos();
442
443 if self.idle_ttl.is_some() && (fast_random() % (self.sample_rate as u64 * 100)) == 0 {
444 self.cleanup_idle(now);
445 }
446
447 let shard = self.get_shard(key);
448 let guard = shard.guard();
449 let state = match shard.get(key, &guard) {
450 Some(state) => state.clone(),
451 None => {
452 let new_state = Arc::new(AtomicProbabilisticState::new(
453 self.capacity,
454 self.sample_rate,
455 now,
456 ));
457 let key_string = key.to_string();
458 match shard.try_insert(key_string, new_state.clone(), &guard) {
459 Ok(_) => new_state,
460 Err(current) => current.current.clone(),
461 }
462 }
463 };
464
465 let (permitted, remaining) = state.try_consume_probabilistic(
466 self.capacity,
467 self.refill_rate_per_second,
468 now,
469 cost,
470 self.sample_rate,
471 );
472
473 let retry_after = if !permitted {
474 let tokens_needed = cost.saturating_sub(remaining);
475 let seconds_to_wait = if self.refill_rate_per_second > 0 {
476 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
477 } else {
478 1.0
479 };
480 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
481 } else {
482 None
483 };
484
485 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
486 let tokens_to_refill = self.capacity.saturating_sub(remaining);
487 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
488 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
489 } else if remaining >= self.capacity {
490 Some(Duration::from_secs(0))
491 } else {
492 None
493 };
494
495 Ok(RateLimitDecision {
496 permitted,
497 retry_after,
498 remaining: Some(remaining),
499 limit: self.capacity,
500 reset,
501 })
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[tokio::test]
510 async fn test_basic_functionality() {
511 let bucket = ProbabilisticTokenBucket::new(10, 100, 1);
513
514 for _ in 0..10 {
516 let decision = bucket.check("test-key").await.unwrap();
517 assert!(decision.permitted);
518 }
519
520 let decision = bucket.check("test-key").await.unwrap();
522 assert!(!decision.permitted);
523 }
524
525 #[tokio::test]
526 async fn test_multiple_keys() {
527 let bucket = ProbabilisticTokenBucket::new(2, 10, 1);
528
529 bucket.check("key1").await.unwrap();
530 bucket.check("key1").await.unwrap();
531 let decision = bucket.check("key1").await.unwrap();
532 assert!(!decision.permitted);
533
534 let decision = bucket.check("key2").await.unwrap();
535 assert!(decision.permitted);
536 }
537
538 #[tokio::test(start_paused = true)]
539 async fn test_refill() {
540 let bucket = ProbabilisticTokenBucket::new(5, 10, 1);
541
542 for _ in 0..5 {
544 bucket.check("test-key").await.unwrap();
545 }
546
547 let decision = bucket.check("test-key").await.unwrap();
548 assert!(!decision.permitted);
549
550 tokio::time::advance(Duration::from_millis(100)).await;
552
553 let decision = bucket.check("test-key").await.unwrap();
554 assert!(decision.permitted);
555 }
556
557 #[tokio::test]
558 async fn test_probabilistic_sampling() {
559 let bucket = ProbabilisticTokenBucket::new(1_000_000, 1_000_000, 100);
561
562 for i in 0..1000 {
564 let key = format!("key-{}", i % 10);
565 let _ = bucket.check(&key).await.unwrap();
566 }
567 }
568
569 #[tokio::test]
570 async fn test_cost_based() {
571 let bucket = ProbabilisticTokenBucket::new(100, 100, 1);
572
573 let decision = bucket.check_with_cost("test-key", 50).await.unwrap();
575 assert!(decision.permitted);
576 assert!(decision.remaining.unwrap() >= 40 && decision.remaining.unwrap() <= 50);
577
578 let decision = bucket.check_with_cost("test-key", 50).await.unwrap();
580 assert!(decision.permitted);
581
582 let decision = bucket.check_with_cost("test-key", 50).await.unwrap();
584 assert!(!decision.permitted);
585 }
586
587 #[tokio::test(start_paused = true)]
588 async fn test_ttl_eviction() {
589 let bucket = ProbabilisticTokenBucket::with_ttl(10, 100, 1, Duration::from_secs(1));
590
591 bucket.check("key1").await.unwrap();
592 assert_eq!(bucket.len(), 1);
593
594 tokio::time::advance(Duration::from_secs(2)).await;
595
596 for _ in 0..200 {
598 bucket.check("key2").await.unwrap();
599 }
600
601 let count = bucket.len();
603 assert!((1..=2).contains(&count));
604 }
605}