tokio_rate_limit/algorithm/
simd_token_bucket.rs1use crate::algorithm::Algorithm;
24use crate::error::Result;
25use crate::limiter::RateLimitDecision;
26use async_trait::async_trait;
27use flurry::HashMap as FlurryHashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::Arc;
30use std::time::Duration;
31use tokio::time::Instant;
32
33const SCALE: u64 = 1000;
35
36const MAX_BURST: u64 = u64::MAX / (2 * SCALE);
38
39const MAX_RATE_PER_SEC: u64 = u64::MAX / (2 * SCALE);
41
42struct AtomicTokenState {
44 tokens: AtomicU64,
45 last_refill_nanos: AtomicU64,
46 last_access_nanos: AtomicU64,
47}
48
49impl AtomicTokenState {
50 fn new(capacity: u64, now_nanos: u64) -> Self {
51 Self {
52 tokens: AtomicU64::new(capacity.saturating_mul(SCALE)),
53 last_refill_nanos: AtomicU64::new(now_nanos),
54 last_access_nanos: AtomicU64::new(now_nanos),
55 }
56 }
57
58 fn try_consume(
63 &self,
64 capacity: u64,
65 refill_rate_per_second: u64,
66 now_nanos: u64,
67 cost: u64,
68 ) -> (bool, u64) {
69 self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
70
71 let scaled_capacity = capacity.saturating_mul(SCALE);
72 let token_cost = cost.saturating_mul(SCALE);
73
74 loop {
75 let current_tokens = self.tokens.load(Ordering::Relaxed);
76 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
77
78 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
80 let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
81 let tokens_per_sec_scaled = refill_rate_per_second.saturating_mul(SCALE);
82 let new_tokens_to_add = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
83
84 let updated_tokens = current_tokens
85 .saturating_add(new_tokens_to_add)
86 .min(scaled_capacity);
87
88 if updated_tokens >= token_cost {
89 let new_tokens = updated_tokens.saturating_sub(token_cost);
90 let new_time = if new_tokens_to_add > 0 {
91 now_nanos
92 } else {
93 last_refill
94 };
95
96 match self.tokens.compare_exchange_weak(
97 current_tokens,
98 new_tokens,
99 Ordering::AcqRel,
100 Ordering::Relaxed,
101 ) {
102 Ok(_) => {
103 if new_tokens_to_add > 0 {
104 let _ = self.last_refill_nanos.compare_exchange_weak(
105 last_refill,
106 new_time,
107 Ordering::AcqRel,
108 Ordering::Relaxed,
109 );
110 }
111 return (true, new_tokens / SCALE);
112 }
113 Err(_) => continue,
114 }
115 } else {
116 let new_time = if new_tokens_to_add > 0 {
117 now_nanos
118 } else {
119 last_refill
120 };
121
122 match self.tokens.compare_exchange_weak(
123 current_tokens,
124 updated_tokens,
125 Ordering::AcqRel,
126 Ordering::Relaxed,
127 ) {
128 Ok(_) => {
129 if new_tokens_to_add > 0 {
130 let _ = self.last_refill_nanos.compare_exchange_weak(
131 last_refill,
132 new_time,
133 Ordering::AcqRel,
134 Ordering::Relaxed,
135 );
136 }
137 return (false, updated_tokens / SCALE);
138 }
139 Err(_) => continue,
140 }
141 }
142 }
143 }
144}
145
146#[derive(Debug)]
148#[allow(dead_code)]
149struct BatchRefillResult {
150 new_tokens: Vec<u64>,
151 elapsed_secs: Vec<f64>,
152}
153
154pub struct SimdTokenBucket {
156 capacity: u64,
157 refill_rate_per_second: u64,
158 reference_instant: Instant,
159 idle_ttl: Option<Duration>,
160 tokens: Arc<FlurryHashMap<String, Arc<AtomicTokenState>>>,
161}
162
163impl SimdTokenBucket {
164 pub fn new(capacity: u64, refill_rate_per_second: u64) -> Self {
166 let safe_capacity = capacity.min(MAX_BURST);
167 let safe_rate = refill_rate_per_second.min(MAX_RATE_PER_SEC);
168
169 Self {
170 capacity: safe_capacity,
171 refill_rate_per_second: safe_rate,
172 reference_instant: Instant::now(),
173 idle_ttl: None,
174 tokens: Arc::new(FlurryHashMap::new()),
175 }
176 }
177
178 pub fn with_ttl(capacity: u64, refill_rate_per_second: u64, idle_ttl: Duration) -> Self {
180 let mut bucket = Self::new(capacity, refill_rate_per_second);
181 bucket.idle_ttl = Some(idle_ttl);
182 bucket
183 }
184
185 #[inline]
186 fn now_nanos(&self) -> u64 {
187 self.reference_instant.elapsed().as_nanos() as u64
188 }
189
190 #[allow(dead_code)]
198 fn calculate_refill_batch_simd(
199 &self,
200 last_refills: &[u64],
201 now_nanos: u64,
202 ) -> BatchRefillResult {
203 let mut elapsed_secs = Vec::with_capacity(last_refills.len());
204 let mut new_tokens = Vec::with_capacity(last_refills.len());
205
206 for &last_refill in last_refills {
209 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
210 let elapsed = elapsed_nanos as f64 / 1_000_000_000.0;
211 elapsed_secs.push(elapsed);
212
213 let tokens_per_sec_scaled = self.refill_rate_per_second.saturating_mul(SCALE);
214 new_tokens.push((elapsed * tokens_per_sec_scaled as f64) as u64);
215 }
216
217 BatchRefillResult {
218 new_tokens,
219 elapsed_secs,
220 }
221 }
222
223 fn cleanup_idle(&self, now_nanos: u64) {
224 if let Some(ttl) = self.idle_ttl {
225 let ttl_nanos = ttl.as_nanos() as u64;
226 let guard = self.tokens.guard();
227 let keys_to_remove: Vec<String> = self
228 .tokens
229 .iter(&guard)
230 .filter_map(|(key, state)| {
231 let last_access = state.last_access_nanos.load(Ordering::Relaxed);
232 let age = now_nanos.saturating_sub(last_access);
233 if age >= ttl_nanos {
234 Some(key.clone())
235 } else {
236 None
237 }
238 })
239 .collect();
240
241 for key in keys_to_remove {
242 self.tokens.remove(&key, &guard);
243 }
244 }
245 }
246}
247
248impl super::private::Sealed for SimdTokenBucket {}
249
250#[async_trait]
251impl Algorithm for SimdTokenBucket {
252 async fn check(&self, key: &str) -> Result<RateLimitDecision> {
253 let now = self.now_nanos();
254
255 if self.idle_ttl.is_some() && (now % 100) == 0 {
256 self.cleanup_idle(now);
257 }
258
259 let guard = self.tokens.guard();
260 let key_string = key.to_string();
261 let state = match self.tokens.get(&key_string, &guard) {
262 Some(state) => state.clone(),
263 None => {
264 let new_state = Arc::new(AtomicTokenState::new(self.capacity, now));
265 match self
266 .tokens
267 .try_insert(key_string.clone(), new_state.clone(), &guard)
268 {
269 Ok(_) => new_state,
270 Err(current) => current.current.clone(),
271 }
272 }
273 };
274
275 let (permitted, remaining) =
276 state.try_consume(self.capacity, self.refill_rate_per_second, now, 1);
277
278 let retry_after = if !permitted {
279 let tokens_needed = 1u64.saturating_sub(remaining);
280 let seconds_to_wait = if self.refill_rate_per_second > 0 {
281 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
282 } else {
283 1.0
284 };
285 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
286 } else {
287 None
288 };
289
290 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
291 let tokens_to_refill = self.capacity.saturating_sub(remaining);
292 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
293 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
294 } else if remaining >= self.capacity {
295 Some(Duration::from_secs(0))
296 } else {
297 None
298 };
299
300 Ok(RateLimitDecision {
301 permitted,
302 retry_after,
303 remaining: Some(remaining),
304 limit: self.capacity,
305 reset,
306 })
307 }
308
309 async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
310 let now = self.now_nanos();
311
312 if self.idle_ttl.is_some() && (now % 100) == 0 {
313 self.cleanup_idle(now);
314 }
315
316 let guard = self.tokens.guard();
317 let key_string = key.to_string();
318 let state = match self.tokens.get(&key_string, &guard) {
319 Some(state) => state.clone(),
320 None => {
321 let new_state = Arc::new(AtomicTokenState::new(self.capacity, now));
322 match self
323 .tokens
324 .try_insert(key_string.clone(), new_state.clone(), &guard)
325 {
326 Ok(_) => new_state,
327 Err(current) => current.current.clone(),
328 }
329 }
330 };
331
332 let (permitted, remaining) =
333 state.try_consume(self.capacity, self.refill_rate_per_second, now, cost);
334
335 let retry_after = if !permitted {
336 let tokens_needed = cost.saturating_sub(remaining);
337 let seconds_to_wait = if self.refill_rate_per_second > 0 {
338 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
339 } else {
340 1.0
341 };
342 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
343 } else {
344 None
345 };
346
347 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
348 let tokens_to_refill = self.capacity.saturating_sub(remaining);
349 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
350 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
351 } else if remaining >= self.capacity {
352 Some(Duration::from_secs(0))
353 } else {
354 None
355 };
356
357 Ok(RateLimitDecision {
358 permitted,
359 retry_after,
360 remaining: Some(remaining),
361 limit: self.capacity,
362 reset,
363 })
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[tokio::test]
372 async fn test_simd_token_bucket_basic() {
373 let bucket = SimdTokenBucket::new(10, 100);
374
375 for _ in 0..10 {
376 let decision = bucket.check("test-key").await.unwrap();
377 assert!(decision.permitted);
378 }
379
380 let decision = bucket.check("test-key").await.unwrap();
381 assert!(!decision.permitted);
382 }
383
384 #[tokio::test(start_paused = true)]
385 async fn test_simd_token_bucket_refill() {
386 let bucket = SimdTokenBucket::new(10, 100);
387
388 for _ in 0..10 {
389 bucket.check("test-key").await.unwrap();
390 }
391
392 let decision = bucket.check("test-key").await.unwrap();
393 assert!(!decision.permitted);
394
395 tokio::time::advance(Duration::from_millis(100)).await;
396
397 for _ in 0..10 {
398 let decision = bucket.check("test-key").await.unwrap();
399 assert!(decision.permitted);
400 }
401 }
402}