tokio_rate_limit/algorithm/
cached_token_bucket.rs1use crate::algorithm::Algorithm;
42use crate::error::Result;
43use crate::limiter::RateLimitDecision;
44use async_trait::async_trait;
45use flurry::HashMap as FlurryHashMap;
46use std::sync::atomic::{AtomicU64, Ordering};
47use std::sync::Arc;
48use std::time::Duration;
49use tokio::time::Instant;
50
51const SCALE: u64 = 1000;
52const MAX_BURST: u64 = u64::MAX / (2 * SCALE);
53const MAX_RATE_PER_SEC: u64 = u64::MAX / (2 * SCALE);
54
55struct AtomicTokenState {
57 tokens: AtomicU64,
58 last_refill_nanos: AtomicU64,
59 last_access_nanos: AtomicU64,
60 access_count: AtomicU64, }
62
63impl AtomicTokenState {
64 fn new(capacity: u64, now_nanos: u64) -> Self {
65 Self {
66 tokens: AtomicU64::new(capacity.saturating_mul(SCALE)),
67 last_refill_nanos: AtomicU64::new(now_nanos),
68 last_access_nanos: AtomicU64::new(now_nanos),
69 access_count: AtomicU64::new(0),
70 }
71 }
72
73 fn try_consume(
74 &self,
75 capacity: u64,
76 refill_rate_per_second: u64,
77 now_nanos: u64,
78 cost: u64,
79 ) -> (bool, u64) {
80 self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
81 self.access_count.fetch_add(1, Ordering::Relaxed);
82
83 let scaled_capacity = capacity.saturating_mul(SCALE);
84 let token_cost = cost.saturating_mul(SCALE);
85
86 loop {
87 let current_tokens = self.tokens.load(Ordering::Relaxed);
88 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
89
90 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
91 let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
92 let tokens_per_sec_scaled = refill_rate_per_second.saturating_mul(SCALE);
93 let new_tokens_to_add = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
94
95 let updated_tokens = current_tokens
96 .saturating_add(new_tokens_to_add)
97 .min(scaled_capacity);
98
99 if updated_tokens >= token_cost {
100 let new_tokens = updated_tokens.saturating_sub(token_cost);
101 let new_time = if new_tokens_to_add > 0 {
102 now_nanos
103 } else {
104 last_refill
105 };
106
107 match self.tokens.compare_exchange_weak(
108 current_tokens,
109 new_tokens,
110 Ordering::AcqRel,
111 Ordering::Relaxed,
112 ) {
113 Ok(_) => {
114 if new_tokens_to_add > 0 {
115 let _ = self.last_refill_nanos.compare_exchange_weak(
116 last_refill,
117 new_time,
118 Ordering::AcqRel,
119 Ordering::Relaxed,
120 );
121 }
122 return (true, new_tokens / SCALE);
123 }
124 Err(_) => continue,
125 }
126 } else {
127 let new_time = if new_tokens_to_add > 0 {
128 now_nanos
129 } else {
130 last_refill
131 };
132
133 match self.tokens.compare_exchange_weak(
134 current_tokens,
135 updated_tokens,
136 Ordering::AcqRel,
137 Ordering::Relaxed,
138 ) {
139 Ok(_) => {
140 if new_tokens_to_add > 0 {
141 let _ = self.last_refill_nanos.compare_exchange_weak(
142 last_refill,
143 new_time,
144 Ordering::AcqRel,
145 Ordering::Relaxed,
146 );
147 }
148 return (false, updated_tokens / SCALE);
149 }
150 Err(_) => continue,
151 }
152 }
153 }
154 }
155
156 #[inline]
158 fn is_hot_key(&self) -> bool {
159 self.access_count.load(Ordering::Relaxed) > 10
162 }
163}
164
165use std::cell::RefCell;
166
167struct CacheEntry {
172 key: Option<String>,
173 state: Option<Arc<AtomicTokenState>>,
174 hits: u64,
175 misses: u64,
176}
177
178impl CacheEntry {
179 fn new() -> Self {
180 Self {
181 key: None,
182 state: None,
183 hits: 0,
184 misses: 0,
185 }
186 }
187
188 #[inline]
190 fn get(&mut self, key: &str) -> Option<Arc<AtomicTokenState>> {
191 if let Some(cached_key) = &self.key {
192 if cached_key == key {
193 self.hits += 1;
194 return self.state.clone();
195 }
196 }
197 self.misses += 1;
198 None
199 }
200
201 #[inline]
203 fn set(&mut self, key: String, state: Arc<AtomicTokenState>) {
204 self.key = Some(key);
205 self.state = Some(state);
206 }
207
208 #[allow(dead_code)]
210 fn hit_rate(&self) -> f64 {
211 let total = self.hits + self.misses;
212 if total == 0 {
213 0.0
214 } else {
215 self.hits as f64 / total as f64
216 }
217 }
218}
219
220thread_local! {
221 static CACHE: RefCell<CacheEntry> = RefCell::new(CacheEntry::new());
226}
227
228pub struct CachedTokenBucket {
230 capacity: u64,
231 refill_rate_per_second: u64,
232 reference_instant: Instant,
233 idle_ttl: Option<Duration>,
234 tokens: Arc<FlurryHashMap<String, Arc<AtomicTokenState>>>,
235}
236
237impl CachedTokenBucket {
238 pub fn new(capacity: u64, refill_rate_per_second: u64) -> Self {
240 let safe_capacity = capacity.min(MAX_BURST);
241 let safe_rate = refill_rate_per_second.min(MAX_RATE_PER_SEC);
242
243 Self {
244 capacity: safe_capacity,
245 refill_rate_per_second: safe_rate,
246 reference_instant: Instant::now(),
247 idle_ttl: None,
248 tokens: Arc::new(FlurryHashMap::new()),
249 }
250 }
251
252 pub fn with_ttl(capacity: u64, refill_rate_per_second: u64, idle_ttl: Duration) -> Self {
254 let mut bucket = Self::new(capacity, refill_rate_per_second);
255 bucket.idle_ttl = Some(idle_ttl);
256 bucket
257 }
258
259 #[inline]
260 fn now_nanos(&self) -> u64 {
261 self.reference_instant.elapsed().as_nanos() as u64
262 }
263
264 #[inline]
266 fn get_or_create_state_cached(
267 &self,
268 key: &str,
269 guard: &flurry::Guard<'_>,
270 now_nanos: u64,
271 ) -> Arc<AtomicTokenState> {
272 if let Some(state) = CACHE.with(|cache| cache.borrow_mut().get(key)) {
274 return state;
275 }
276
277 let state = if let Some(state) = self.tokens.get(key, guard) {
279 state.clone()
280 } else {
281 let key_string = key.to_string();
283 let new_state = Arc::new(AtomicTokenState::new(self.capacity, now_nanos));
284
285 match self
286 .tokens
287 .try_insert(key_string.clone(), new_state.clone(), guard)
288 {
289 Ok(_) => new_state,
290 Err(current) => current.current.clone(),
291 }
292 };
293
294 if state.is_hot_key() {
296 CACHE.with(|cache| cache.borrow_mut().set(key.to_string(), state.clone()));
297 }
298
299 state
300 }
301
302 fn cleanup_idle(&self, now_nanos: u64) {
303 if let Some(ttl) = self.idle_ttl {
304 let ttl_nanos = ttl.as_nanos() as u64;
305 let guard = self.tokens.guard();
306 let keys_to_remove: Vec<String> = self
307 .tokens
308 .iter(&guard)
309 .filter_map(|(key, state)| {
310 let last_access = state.last_access_nanos.load(Ordering::Relaxed);
311 let age = now_nanos.saturating_sub(last_access);
312 if age >= ttl_nanos {
313 Some(key.clone())
314 } else {
315 None
316 }
317 })
318 .collect();
319
320 for key in keys_to_remove {
321 self.tokens.remove(&key, &guard);
322 }
323 }
324 }
325}
326
327impl super::private::Sealed for CachedTokenBucket {}
328
329#[async_trait]
330impl Algorithm for CachedTokenBucket {
331 async fn check(&self, key: &str) -> Result<RateLimitDecision> {
332 let now = self.now_nanos();
333
334 if self.idle_ttl.is_some() && (now % 100) == 0 {
335 self.cleanup_idle(now);
336 }
337
338 let guard = self.tokens.guard();
339 let state = self.get_or_create_state_cached(key, &guard, now);
340
341 let (permitted, remaining) =
342 state.try_consume(self.capacity, self.refill_rate_per_second, now, 1);
343
344 let retry_after = if !permitted {
345 let tokens_needed = 1u64.saturating_sub(remaining);
346 let seconds_to_wait = if self.refill_rate_per_second > 0 {
347 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
348 } else {
349 1.0
350 };
351 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
352 } else {
353 None
354 };
355
356 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
357 let tokens_to_refill = self.capacity.saturating_sub(remaining);
358 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
359 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
360 } else if remaining >= self.capacity {
361 Some(Duration::from_secs(0))
362 } else {
363 None
364 };
365
366 Ok(RateLimitDecision {
367 permitted,
368 retry_after,
369 remaining: Some(remaining),
370 limit: self.capacity,
371 reset,
372 })
373 }
374
375 async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
376 let now = self.now_nanos();
377
378 if self.idle_ttl.is_some() && (now % 100) == 0 {
379 self.cleanup_idle(now);
380 }
381
382 let guard = self.tokens.guard();
383 let state = self.get_or_create_state_cached(key, &guard, now);
384
385 let (permitted, remaining) =
386 state.try_consume(self.capacity, self.refill_rate_per_second, now, cost);
387
388 let retry_after = if !permitted {
389 let tokens_needed = cost.saturating_sub(remaining);
390 let seconds_to_wait = if self.refill_rate_per_second > 0 {
391 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
392 } else {
393 1.0
394 };
395 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
396 } else {
397 None
398 };
399
400 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
401 let tokens_to_refill = self.capacity.saturating_sub(remaining);
402 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
403 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
404 } else if remaining >= self.capacity {
405 Some(Duration::from_secs(0))
406 } else {
407 None
408 };
409
410 Ok(RateLimitDecision {
411 permitted,
412 retry_after,
413 remaining: Some(remaining),
414 limit: self.capacity,
415 reset,
416 })
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[tokio::test]
425 async fn test_cached_token_bucket_basic() {
426 let bucket = CachedTokenBucket::new(10, 100);
427
428 for _ in 0..10 {
429 let decision = bucket.check("test-key").await.unwrap();
430 assert!(decision.permitted);
431 }
432
433 let decision = bucket.check("test-key").await.unwrap();
434 assert!(!decision.permitted);
435 }
436
437 #[tokio::test(start_paused = true)]
438 async fn test_cached_token_bucket_refill() {
439 let bucket = CachedTokenBucket::new(10, 100);
440
441 for _ in 0..10 {
442 bucket.check("test-key").await.unwrap();
443 }
444
445 let decision = bucket.check("test-key").await.unwrap();
446 assert!(!decision.permitted);
447
448 tokio::time::advance(Duration::from_millis(100)).await;
449
450 for _ in 0..10 {
451 let decision = bucket.check("test-key").await.unwrap();
452 assert!(decision.permitted);
453 }
454 }
455
456 #[tokio::test]
457 async fn test_cached_token_bucket_hot_keys() {
458 let bucket = CachedTokenBucket::new(1000, 1000);
459
460 for _ in 0..20 {
462 bucket.check("hot-key").await.unwrap();
463 }
464
465 for _ in 0..100 {
468 bucket.check("hot-key").await.unwrap();
469 }
470 }
471}