tokio_rate_limit/algorithm/
simd_token_bucket.rs1use crate::algorithm::Algorithm;
24use crate::error::Result;
25use crate::limiter::RateLimitDecision;
26use async_trait::async_trait;
27use flurry::HashMap as FlurryHashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::Arc;
30use std::time::Duration;
31use tokio::time::Instant;
32
33const SCALE: u64 = 1000;
35
36const MAX_BURST: u64 = u64::MAX / (2 * SCALE);
38
39const MAX_RATE_PER_SEC: u64 = u64::MAX / (2 * SCALE);
41
42struct AtomicTokenState {
44 tokens: AtomicU64,
45 last_refill_nanos: AtomicU64,
46 last_access_nanos: AtomicU64,
47}
48
49impl AtomicTokenState {
50 fn new(capacity: u64, now_nanos: u64) -> Self {
51 Self {
52 tokens: AtomicU64::new(capacity.saturating_mul(SCALE)),
53 last_refill_nanos: AtomicU64::new(now_nanos),
54 last_access_nanos: AtomicU64::new(now_nanos),
55 }
56 }
57
58 fn try_consume(
63 &self,
64 capacity: u64,
65 refill_rate_per_second: u64,
66 now_nanos: u64,
67 cost: u64,
68 ) -> (bool, u64) {
69 self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
70
71 let scaled_capacity = capacity.saturating_mul(SCALE);
72 let token_cost = cost.saturating_mul(SCALE);
73
74 loop {
75 let current_tokens = self.tokens.load(Ordering::Relaxed);
76 let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
77
78 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
80 let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
81 let tokens_per_sec_scaled = refill_rate_per_second.saturating_mul(SCALE);
82 let new_tokens_to_add = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
83
84 let updated_tokens = current_tokens
85 .saturating_add(new_tokens_to_add)
86 .min(scaled_capacity);
87
88 if updated_tokens >= token_cost {
89 let new_tokens = updated_tokens.saturating_sub(token_cost);
90 let new_time = if new_tokens_to_add > 0 {
91 now_nanos
92 } else {
93 last_refill
94 };
95
96 match self.tokens.compare_exchange_weak(
97 current_tokens,
98 new_tokens,
99 Ordering::AcqRel,
100 Ordering::Relaxed,
101 ) {
102 Ok(_) => {
103 if new_tokens_to_add > 0 {
104 let _ = self.last_refill_nanos.compare_exchange_weak(
105 last_refill,
106 new_time,
107 Ordering::AcqRel,
108 Ordering::Relaxed,
109 );
110 }
111 return (true, new_tokens / SCALE);
112 }
113 Err(_) => continue,
114 }
115 } else {
116 let new_time = if new_tokens_to_add > 0 {
117 now_nanos
118 } else {
119 last_refill
120 };
121
122 match self.tokens.compare_exchange_weak(
123 current_tokens,
124 updated_tokens,
125 Ordering::AcqRel,
126 Ordering::Relaxed,
127 ) {
128 Ok(_) => {
129 if new_tokens_to_add > 0 {
130 let _ = self.last_refill_nanos.compare_exchange_weak(
131 last_refill,
132 new_time,
133 Ordering::AcqRel,
134 Ordering::Relaxed,
135 );
136 }
137 return (false, updated_tokens / SCALE);
138 }
139 Err(_) => continue,
140 }
141 }
142 }
143 }
144}
145
146#[derive(Debug)]
148#[allow(dead_code)]
149struct BatchRefillResult {
150 new_tokens: Vec<u64>,
151 elapsed_secs: Vec<f64>,
152}
153
154#[deprecated(since = "0.8.1", note = "Experimental — no SIMD benefit. Use TokenBucket instead.")]
156pub struct SimdTokenBucket {
157 capacity: u64,
158 refill_rate_per_second: u64,
159 reference_instant: Instant,
160 idle_ttl: Option<Duration>,
161 tokens: Arc<FlurryHashMap<String, Arc<AtomicTokenState>>>,
162}
163
164impl SimdTokenBucket {
165 pub fn new(capacity: u64, refill_rate_per_second: u64) -> Self {
167 let safe_capacity = capacity.min(MAX_BURST);
168 let safe_rate = refill_rate_per_second.min(MAX_RATE_PER_SEC);
169
170 Self {
171 capacity: safe_capacity,
172 refill_rate_per_second: safe_rate,
173 reference_instant: Instant::now(),
174 idle_ttl: None,
175 tokens: Arc::new(FlurryHashMap::new()),
176 }
177 }
178
179 pub fn with_ttl(capacity: u64, refill_rate_per_second: u64, idle_ttl: Duration) -> Self {
181 let mut bucket = Self::new(capacity, refill_rate_per_second);
182 bucket.idle_ttl = Some(idle_ttl);
183 bucket
184 }
185
186 #[inline]
187 fn now_nanos(&self) -> u64 {
188 self.reference_instant.elapsed().as_nanos() as u64
189 }
190
191 #[allow(dead_code)]
199 fn calculate_refill_batch_simd(
200 &self,
201 last_refills: &[u64],
202 now_nanos: u64,
203 ) -> BatchRefillResult {
204 let mut elapsed_secs = Vec::with_capacity(last_refills.len());
205 let mut new_tokens = Vec::with_capacity(last_refills.len());
206
207 for &last_refill in last_refills {
210 let elapsed_nanos = now_nanos.saturating_sub(last_refill);
211 let elapsed = elapsed_nanos as f64 / 1_000_000_000.0;
212 elapsed_secs.push(elapsed);
213
214 let tokens_per_sec_scaled = self.refill_rate_per_second.saturating_mul(SCALE);
215 new_tokens.push((elapsed * tokens_per_sec_scaled as f64) as u64);
216 }
217
218 BatchRefillResult {
219 new_tokens,
220 elapsed_secs,
221 }
222 }
223
224 fn cleanup_idle(&self, now_nanos: u64) {
225 if let Some(ttl) = self.idle_ttl {
226 let ttl_nanos = ttl.as_nanos() as u64;
227 let guard = self.tokens.guard();
228 let keys_to_remove: Vec<String> = self
229 .tokens
230 .iter(&guard)
231 .filter_map(|(key, state)| {
232 let last_access = state.last_access_nanos.load(Ordering::Relaxed);
233 let age = now_nanos.saturating_sub(last_access);
234 if age >= ttl_nanos {
235 Some(key.clone())
236 } else {
237 None
238 }
239 })
240 .collect();
241
242 for key in keys_to_remove {
243 self.tokens.remove(&key, &guard);
244 }
245 }
246 }
247}
248
249impl super::private::Sealed for SimdTokenBucket {}
250
251#[async_trait]
252impl Algorithm for SimdTokenBucket {
253 async fn check(&self, key: &str) -> Result<RateLimitDecision> {
254 let now = self.now_nanos();
255
256 if self.idle_ttl.is_some() && (now % 100) == 0 {
257 self.cleanup_idle(now);
258 }
259
260 let guard = self.tokens.guard();
261 let key_string = key.to_string();
262 let state = match self.tokens.get(&key_string, &guard) {
263 Some(state) => state.clone(),
264 None => {
265 let new_state = Arc::new(AtomicTokenState::new(self.capacity, now));
266 match self
267 .tokens
268 .try_insert(key_string.clone(), new_state.clone(), &guard)
269 {
270 Ok(_) => new_state,
271 Err(current) => current.current.clone(),
272 }
273 }
274 };
275
276 let (permitted, remaining) =
277 state.try_consume(self.capacity, self.refill_rate_per_second, now, 1);
278
279 let retry_after = if !permitted {
280 let tokens_needed = 1u64.saturating_sub(remaining);
281 let seconds_to_wait = if self.refill_rate_per_second > 0 {
282 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
283 } else {
284 1.0
285 };
286 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
287 } else {
288 None
289 };
290
291 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
292 let tokens_to_refill = self.capacity.saturating_sub(remaining);
293 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
294 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
295 } else if remaining >= self.capacity {
296 Some(Duration::from_secs(0))
297 } else {
298 None
299 };
300
301 Ok(RateLimitDecision {
302 permitted,
303 retry_after,
304 remaining: Some(remaining),
305 limit: self.capacity,
306 reset,
307 })
308 }
309
310 async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
311 let now = self.now_nanos();
312
313 if self.idle_ttl.is_some() && (now % 100) == 0 {
314 self.cleanup_idle(now);
315 }
316
317 let guard = self.tokens.guard();
318 let key_string = key.to_string();
319 let state = match self.tokens.get(&key_string, &guard) {
320 Some(state) => state.clone(),
321 None => {
322 let new_state = Arc::new(AtomicTokenState::new(self.capacity, now));
323 match self
324 .tokens
325 .try_insert(key_string.clone(), new_state.clone(), &guard)
326 {
327 Ok(_) => new_state,
328 Err(current) => current.current.clone(),
329 }
330 }
331 };
332
333 let (permitted, remaining) =
334 state.try_consume(self.capacity, self.refill_rate_per_second, now, cost);
335
336 let retry_after = if !permitted {
337 let tokens_needed = cost.saturating_sub(remaining);
338 let seconds_to_wait = if self.refill_rate_per_second > 0 {
339 (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
340 } else {
341 1.0
342 };
343 Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
344 } else {
345 None
346 };
347
348 let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
349 let tokens_to_refill = self.capacity.saturating_sub(remaining);
350 let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
351 Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
352 } else if remaining >= self.capacity {
353 Some(Duration::from_secs(0))
354 } else {
355 None
356 };
357
358 Ok(RateLimitDecision {
359 permitted,
360 retry_after,
361 remaining: Some(remaining),
362 limit: self.capacity,
363 reset,
364 })
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[tokio::test]
373 async fn test_simd_token_bucket_basic() {
374 let bucket = SimdTokenBucket::new(10, 100);
375
376 for _ in 0..10 {
377 let decision = bucket.check("test-key").await.unwrap();
378 assert!(decision.permitted);
379 }
380
381 let decision = bucket.check("test-key").await.unwrap();
382 assert!(!decision.permitted);
383 }
384
385 #[tokio::test(start_paused = true)]
386 async fn test_simd_token_bucket_refill() {
387 let bucket = SimdTokenBucket::new(10, 100);
388
389 for _ in 0..10 {
390 bucket.check("test-key").await.unwrap();
391 }
392
393 let decision = bucket.check("test-key").await.unwrap();
394 assert!(!decision.permitted);
395
396 tokio::time::advance(Duration::from_millis(100)).await;
397
398 for _ in 0..10 {
399 let decision = bucket.check("test-key").await.unwrap();
400 assert!(decision.permitted);
401 }
402 }
403}