tokio_rate_limit/algorithm/
zerocopy_token_bucket.rs1use crate::algorithm::Algorithm;
38use crate::error::Result;
39use crate::limiter::RateLimitDecision;
40use async_trait::async_trait;
41use flurry::HashMap as FlurryHashMap;
42use std::sync::atomic::{AtomicU64, Ordering};
43use std::sync::Arc;
44use std::time::Duration;
45use tokio::time::Instant;
46
47const SCALE: u64 = 1000;
48const MAX_BURST: u64 = u64::MAX / (2 * SCALE);
49const MAX_RATE_PER_SEC: u64 = u64::MAX / (2 * SCALE);
50
51struct AtomicTokenState {
53 tokens: AtomicU64,
54 last_refill_nanos: AtomicU64,
55 last_access_nanos: AtomicU64,
56}
57
58impl AtomicTokenState {
59 fn new(capacity: u64, now_nanos: u64) -> Self {
60 Self {
61 tokens: AtomicU64::new(capacity.saturating_mul(SCALE)),
62 last_refill_nanos: AtomicU64::new(now_nanos),
63 last_access_nanos: AtomicU64::new(now_nanos),
64 }
65 }
66
67 fn try_consume(
68 &self,
69 capacity: u64,
70 refill_rate_per_second: u64,
71 now_nanos: u64,
72 cost: u64,
73 ) -> (bool, u64) {
74 self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
75
76 let scaled_capacity = capacity.saturating_mul(SCALE);
77 let token_cost = cost.saturating_mul(SCALE);
78
79 loop {
80 let current_tokens = self.tokens.load(Ordering::Relaxed);
81 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
82
83 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
84 let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
85 let tokens_per_sec_scaled = refill_rate_per_second.saturating_mul(SCALE);
86 let new_tokens_to_add = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
87
88 let updated_tokens = current_tokens
89 .saturating_add(new_tokens_to_add)
90 .min(scaled_capacity);
91
92 if updated_tokens >= token_cost {
93 let new_tokens = updated_tokens.saturating_sub(token_cost);
94 let new_time = if new_tokens_to_add > 0 {
95 now_nanos
96 } else {
97 last_refill
98 };
99
100 match self.tokens.compare_exchange_weak(
101 current_tokens,
102 new_tokens,
103 Ordering::AcqRel,
104 Ordering::Relaxed,
105 ) {
106 Ok(_) => {
107 if new_tokens_to_add > 0 {
108 let _ = self.last_refill_nanos.compare_exchange_weak(
109 last_refill,
110 new_time,
111 Ordering::AcqRel,
112 Ordering::Relaxed,
113 );
114 }
115 return (true, new_tokens / SCALE);
116 }
117 Err(_) => continue,
118 }
119 } else {
120 let new_time = if new_tokens_to_add > 0 {
121 now_nanos
122 } else {
123 last_refill
124 };
125
126 match self.tokens.compare_exchange_weak(
127 current_tokens,
128 updated_tokens,
129 Ordering::AcqRel,
130 Ordering::Relaxed,
131 ) {
132 Ok(_) => {
133 if new_tokens_to_add > 0 {
134 let _ = self.last_refill_nanos.compare_exchange_weak(
135 last_refill,
136 new_time,
137 Ordering::AcqRel,
138 Ordering::Relaxed,
139 );
140 }
141 return (false, updated_tokens / SCALE);
142 }
143 Err(_) => continue,
144 }
145 }
146 }
147 }
148}
149
150#[deprecated(since = "0.8.1", note = "Zero-copy optimization is now integrated into TokenBucket since v0.4.0.")]
161pub struct ZeroCopyTokenBucket {
162 capacity: u64,
163 refill_rate_per_second: u64,
164 reference_instant: Instant,
165 idle_ttl: Option<Duration>,
166 tokens: Arc<FlurryHashMap<String, Arc<AtomicTokenState>>>,
167}
168
169impl ZeroCopyTokenBucket {
170 pub fn new(capacity: u64, refill_rate_per_second: u64) -> Self {
172 let safe_capacity = capacity.min(MAX_BURST);
173 let safe_rate = refill_rate_per_second.min(MAX_RATE_PER_SEC);
174
175 Self {
176 capacity: safe_capacity,
177 refill_rate_per_second: safe_rate,
178 reference_instant: Instant::now(),
179 idle_ttl: None,
180 tokens: Arc::new(FlurryHashMap::new()),
181 }
182 }
183
184 pub fn with_ttl(capacity: u64, refill_rate_per_second: u64, idle_ttl: Duration) -> Self {
186 let mut bucket = Self::new(capacity, refill_rate_per_second);
187 bucket.idle_ttl = Some(idle_ttl);
188 bucket
189 }
190
191 #[inline]
192 fn now_nanos(&self) -> u64 {
193 self.reference_instant.elapsed().as_nanos() as u64
194 }
195
196 #[inline]
201 fn get_or_create_state(
202 &self,
203 key: &str,
204 guard: &flurry::Guard<'_>,
205 now_nanos: u64,
206 ) -> Arc<AtomicTokenState> {
207 if let Some(state) = self.tokens.get(key, guard) {
210 return state.clone();
211 }
212
213 let key_string = key.to_string();
216 let new_state = Arc::new(AtomicTokenState::new(self.capacity, now_nanos));
217
218 match self.tokens.try_insert(key_string, new_state.clone(), guard) {
219 Ok(_) => new_state,
220 Err(current) => current.current.clone(),
221 }
222 }
223
224 fn cleanup_idle(&self, now_nanos: u64) {
225 if let Some(ttl) = self.idle_ttl {
226 let ttl_nanos = ttl.as_nanos() as u64;
227 let guard = self.tokens.guard();
228 let keys_to_remove: Vec<String> = self
229 .tokens
230 .iter(&guard)
231 .filter_map(|(key, state)| {
232 let last_access = state.last_access_nanos.load(Ordering::Relaxed);
233 let age = now_nanos.saturating_sub(last_access);
234 if age >= ttl_nanos {
235 Some(key.clone())
236 } else {
237 None
238 }
239 })
240 .collect();
241
242 for key in keys_to_remove {
243 self.tokens.remove(&key, &guard);
244 }
245 }
246 }
247}
248
249impl super::private::Sealed for ZeroCopyTokenBucket {}
250
251#[async_trait]
252impl Algorithm for ZeroCopyTokenBucket {
253 async fn check(&self, key: &str) -> Result<RateLimitDecision> {
254 let now = self.now_nanos();
255
256 if self.idle_ttl.is_some() && (now % 100) == 0 {
257 self.cleanup_idle(now);
258 }
259
260 let guard = self.tokens.guard();
262 let state = self.get_or_create_state(key, &guard, now);
263
264 let (permitted, remaining) =
265 state.try_consume(self.capacity, self.refill_rate_per_second, now, 1);
266
267 let retry_after = if !permitted {
268 let tokens_needed = 1u64.saturating_sub(remaining);
269 let seconds_to_wait = if self.refill_rate_per_second > 0 {
270 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
271 } else {
272 1.0
273 };
274 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
275 } else {
276 None
277 };
278
279 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
280 let tokens_to_refill = self.capacity.saturating_sub(remaining);
281 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
282 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
283 } else if remaining >= self.capacity {
284 Some(Duration::from_secs(0))
285 } else {
286 None
287 };
288
289 Ok(RateLimitDecision {
290 permitted,
291 retry_after,
292 remaining: Some(remaining),
293 limit: self.capacity,
294 reset,
295 })
296 }
297
298 async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
299 let now = self.now_nanos();
300
301 if self.idle_ttl.is_some() && (now % 100) == 0 {
302 self.cleanup_idle(now);
303 }
304
305 let guard = self.tokens.guard();
307 let state = self.get_or_create_state(key, &guard, now);
308
309 let (permitted, remaining) =
310 state.try_consume(self.capacity, self.refill_rate_per_second, now, cost);
311
312 let retry_after = if !permitted {
313 let tokens_needed = cost.saturating_sub(remaining);
314 let seconds_to_wait = if self.refill_rate_per_second > 0 {
315 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
316 } else {
317 1.0
318 };
319 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
320 } else {
321 None
322 };
323
324 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
325 let tokens_to_refill = self.capacity.saturating_sub(remaining);
326 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
327 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
328 } else if remaining >= self.capacity {
329 Some(Duration::from_secs(0))
330 } else {
331 None
332 };
333
334 Ok(RateLimitDecision {
335 permitted,
336 retry_after,
337 remaining: Some(remaining),
338 limit: self.capacity,
339 reset,
340 })
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[tokio::test]
349 async fn test_zerocopy_token_bucket_basic() {
350 let bucket = ZeroCopyTokenBucket::new(10, 100);
351
352 for _ in 0..10 {
353 let decision = bucket.check("test-key").await.unwrap();
354 assert!(decision.permitted);
355 }
356
357 let decision = bucket.check("test-key").await.unwrap();
358 assert!(!decision.permitted);
359 }
360
361 #[tokio::test(start_paused = true)]
362 async fn test_zerocopy_token_bucket_refill() {
363 let bucket = ZeroCopyTokenBucket::new(10, 100);
364
365 for _ in 0..10 {
366 bucket.check("test-key").await.unwrap();
367 }
368
369 let decision = bucket.check("test-key").await.unwrap();
370 assert!(!decision.permitted);
371
372 tokio::time::advance(Duration::from_millis(100)).await;
373
374 for _ in 0..10 {
375 let decision = bucket.check("test-key").await.unwrap();
376 assert!(decision.permitted);
377 }
378 }
379
380 #[tokio::test]
381 async fn test_zerocopy_no_allocation_on_second_access() {
382 let bucket = ZeroCopyTokenBucket::new(1000, 1000);
383
384 bucket.check("test-key").await.unwrap();
386
387 for _ in 0..100 {
390 bucket.check("test-key").await.unwrap();
391 }
392 }
393}