tokio_rate_limit/algorithm/
zerocopy_token_bucket.rs1use crate::algorithm::Algorithm;
38use crate::error::Result;
39use crate::limiter::RateLimitDecision;
40use async_trait::async_trait;
41use flurry::HashMap as FlurryHashMap;
42use std::sync::atomic::{AtomicU64, Ordering};
43use std::sync::Arc;
44use std::time::Duration;
45use tokio::time::Instant;
46
47const SCALE: u64 = 1000;
48const MAX_BURST: u64 = u64::MAX / (2 * SCALE);
49const MAX_RATE_PER_SEC: u64 = u64::MAX / (2 * SCALE);
50
51struct AtomicTokenState {
53 tokens: AtomicU64,
54 last_refill_nanos: AtomicU64,
55 last_access_nanos: AtomicU64,
56}
57
58impl AtomicTokenState {
59 fn new(capacity: u64, now_nanos: u64) -> Self {
60 Self {
61 tokens: AtomicU64::new(capacity.saturating_mul(SCALE)),
62 last_refill_nanos: AtomicU64::new(now_nanos),
63 last_access_nanos: AtomicU64::new(now_nanos),
64 }
65 }
66
67 fn try_consume(
68 &self,
69 capacity: u64,
70 refill_rate_per_second: u64,
71 now_nanos: u64,
72 cost: u64,
73 ) -> (bool, u64) {
74 self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
75
76 let scaled_capacity = capacity.saturating_mul(SCALE);
77 let token_cost = cost.saturating_mul(SCALE);
78
79 loop {
80 let current_tokens = self.tokens.load(Ordering::Relaxed);
81 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
82
83 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
84 let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
85 let tokens_per_sec_scaled = refill_rate_per_second.saturating_mul(SCALE);
86 let new_tokens_to_add = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
87
88 let updated_tokens = current_tokens
89 .saturating_add(new_tokens_to_add)
90 .min(scaled_capacity);
91
92 if updated_tokens >= token_cost {
93 let new_tokens = updated_tokens.saturating_sub(token_cost);
94 let new_time = if new_tokens_to_add > 0 {
95 now_nanos
96 } else {
97 last_refill
98 };
99
100 match self.tokens.compare_exchange_weak(
101 current_tokens,
102 new_tokens,
103 Ordering::AcqRel,
104 Ordering::Relaxed,
105 ) {
106 Ok(_) => {
107 if new_tokens_to_add > 0 {
108 let _ = self.last_refill_nanos.compare_exchange_weak(
109 last_refill,
110 new_time,
111 Ordering::AcqRel,
112 Ordering::Relaxed,
113 );
114 }
115 return (true, new_tokens / SCALE);
116 }
117 Err(_) => continue,
118 }
119 } else {
120 let new_time = if new_tokens_to_add > 0 {
121 now_nanos
122 } else {
123 last_refill
124 };
125
126 match self.tokens.compare_exchange_weak(
127 current_tokens,
128 updated_tokens,
129 Ordering::AcqRel,
130 Ordering::Relaxed,
131 ) {
132 Ok(_) => {
133 if new_tokens_to_add > 0 {
134 let _ = self.last_refill_nanos.compare_exchange_weak(
135 last_refill,
136 new_time,
137 Ordering::AcqRel,
138 Ordering::Relaxed,
139 );
140 }
141 return (false, updated_tokens / SCALE);
142 }
143 Err(_) => continue,
144 }
145 }
146 }
147 }
148}
149
150pub struct ZeroCopyTokenBucket {
161 capacity: u64,
162 refill_rate_per_second: u64,
163 reference_instant: Instant,
164 idle_ttl: Option<Duration>,
165 tokens: Arc<FlurryHashMap<String, Arc<AtomicTokenState>>>,
166}
167
168impl ZeroCopyTokenBucket {
169 pub fn new(capacity: u64, refill_rate_per_second: u64) -> Self {
171 let safe_capacity = capacity.min(MAX_BURST);
172 let safe_rate = refill_rate_per_second.min(MAX_RATE_PER_SEC);
173
174 Self {
175 capacity: safe_capacity,
176 refill_rate_per_second: safe_rate,
177 reference_instant: Instant::now(),
178 idle_ttl: None,
179 tokens: Arc::new(FlurryHashMap::new()),
180 }
181 }
182
183 pub fn with_ttl(capacity: u64, refill_rate_per_second: u64, idle_ttl: Duration) -> Self {
185 let mut bucket = Self::new(capacity, refill_rate_per_second);
186 bucket.idle_ttl = Some(idle_ttl);
187 bucket
188 }
189
190 #[inline]
191 fn now_nanos(&self) -> u64 {
192 self.reference_instant.elapsed().as_nanos() as u64
193 }
194
195 #[inline]
200 fn get_or_create_state(
201 &self,
202 key: &str,
203 guard: &flurry::Guard<'_>,
204 now_nanos: u64,
205 ) -> Arc<AtomicTokenState> {
206 if let Some(state) = self.tokens.get(key, guard) {
209 return state.clone();
210 }
211
212 let key_string = key.to_string();
215 let new_state = Arc::new(AtomicTokenState::new(self.capacity, now_nanos));
216
217 match self.tokens.try_insert(key_string, new_state.clone(), guard) {
218 Ok(_) => new_state,
219 Err(current) => current.current.clone(),
220 }
221 }
222
223 fn cleanup_idle(&self, now_nanos: u64) {
224 if let Some(ttl) = self.idle_ttl {
225 let ttl_nanos = ttl.as_nanos() as u64;
226 let guard = self.tokens.guard();
227 let keys_to_remove: Vec<String> = self
228 .tokens
229 .iter(&guard)
230 .filter_map(|(key, state)| {
231 let last_access = state.last_access_nanos.load(Ordering::Relaxed);
232 let age = now_nanos.saturating_sub(last_access);
233 if age >= ttl_nanos {
234 Some(key.clone())
235 } else {
236 None
237 }
238 })
239 .collect();
240
241 for key in keys_to_remove {
242 self.tokens.remove(&key, &guard);
243 }
244 }
245 }
246}
247
248impl super::private::Sealed for ZeroCopyTokenBucket {}
249
250#[async_trait]
251impl Algorithm for ZeroCopyTokenBucket {
252 async fn check(&self, key: &str) -> Result<RateLimitDecision> {
253 let now = self.now_nanos();
254
255 if self.idle_ttl.is_some() && (now % 100) == 0 {
256 self.cleanup_idle(now);
257 }
258
259 let guard = self.tokens.guard();
261 let state = self.get_or_create_state(key, &guard, now);
262
263 let (permitted, remaining) =
264 state.try_consume(self.capacity, self.refill_rate_per_second, now, 1);
265
266 let retry_after = if !permitted {
267 let tokens_needed = 1u64.saturating_sub(remaining);
268 let seconds_to_wait = if self.refill_rate_per_second > 0 {
269 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
270 } else {
271 1.0
272 };
273 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
274 } else {
275 None
276 };
277
278 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
279 let tokens_to_refill = self.capacity.saturating_sub(remaining);
280 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
281 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
282 } else if remaining >= self.capacity {
283 Some(Duration::from_secs(0))
284 } else {
285 None
286 };
287
288 Ok(RateLimitDecision {
289 permitted,
290 retry_after,
291 remaining: Some(remaining),
292 limit: self.capacity,
293 reset,
294 })
295 }
296
297 async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
298 let now = self.now_nanos();
299
300 if self.idle_ttl.is_some() && (now % 100) == 0 {
301 self.cleanup_idle(now);
302 }
303
304 let guard = self.tokens.guard();
306 let state = self.get_or_create_state(key, &guard, now);
307
308 let (permitted, remaining) =
309 state.try_consume(self.capacity, self.refill_rate_per_second, now, cost);
310
311 let retry_after = if !permitted {
312 let tokens_needed = cost.saturating_sub(remaining);
313 let seconds_to_wait = if self.refill_rate_per_second > 0 {
314 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
315 } else {
316 1.0
317 };
318 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
319 } else {
320 None
321 };
322
323 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
324 let tokens_to_refill = self.capacity.saturating_sub(remaining);
325 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
326 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
327 } else if remaining >= self.capacity {
328 Some(Duration::from_secs(0))
329 } else {
330 None
331 };
332
333 Ok(RateLimitDecision {
334 permitted,
335 retry_after,
336 remaining: Some(remaining),
337 limit: self.capacity,
338 reset,
339 })
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[tokio::test]
348 async fn test_zerocopy_token_bucket_basic() {
349 let bucket = ZeroCopyTokenBucket::new(10, 100);
350
351 for _ in 0..10 {
352 let decision = bucket.check("test-key").await.unwrap();
353 assert!(decision.permitted);
354 }
355
356 let decision = bucket.check("test-key").await.unwrap();
357 assert!(!decision.permitted);
358 }
359
360 #[tokio::test(start_paused = true)]
361 async fn test_zerocopy_token_bucket_refill() {
362 let bucket = ZeroCopyTokenBucket::new(10, 100);
363
364 for _ in 0..10 {
365 bucket.check("test-key").await.unwrap();
366 }
367
368 let decision = bucket.check("test-key").await.unwrap();
369 assert!(!decision.permitted);
370
371 tokio::time::advance(Duration::from_millis(100)).await;
372
373 for _ in 0..10 {
374 let decision = bucket.check("test-key").await.unwrap();
375 assert!(decision.permitted);
376 }
377 }
378
379 #[tokio::test]
380 async fn test_zerocopy_no_allocation_on_second_access() {
381 let bucket = ZeroCopyTokenBucket::new(1000, 1000);
382
383 bucket.check("test-key").await.unwrap();
385
386 for _ in 0..100 {
389 bucket.check("test-key").await.unwrap();
390 }
391 }
392}