tokio_rate_limit/algorithm/
simd_token_bucket.rs

1//! SIMD-optimized token bucket implementation.
2//!
3//! This is an experimental implementation exploring SIMD optimizations for token bucket operations.
4//! The goal is to achieve 2-5x performance improvements for certain workloads.
5//!
6//! ## Optimization Strategies
7//!
8//! 1. **Vectorized Token Refill Calculations**: Process multiple buckets in parallel using SIMD
9//! 2. **Batch Key Processing**: Check multiple keys in a single operation
10//! 3. **SIMD-friendly Data Layout**: Structure data for efficient SIMD operations
11//!
12//! ## Platform Support
13//!
14//! - x86_64: AVX2 (256-bit vectors, 4x f64 or u64)
15//! - ARM: NEON (128-bit vectors, 2x f64 or u64)
16//! - Fallback: Scalar implementation on unsupported platforms
17//!
18//! ## Current Status: EXPERIMENTAL
19//!
20//! This implementation is for research and benchmarking purposes.
21//! It may not be production-ready.
22
23use crate::algorithm::Algorithm;
24use crate::error::Result;
25use crate::limiter::RateLimitDecision;
26use async_trait::async_trait;
27use flurry::HashMap as FlurryHashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::Arc;
30use std::time::Duration;
31use tokio::time::Instant;
32
33/// Scaling factor for sub-token precision (same as baseline)
34const SCALE: u64 = 1000;
35
36/// Maximum burst capacity
37const MAX_BURST: u64 = u64::MAX / (2 * SCALE);
38
39/// Maximum refill rate per second
40const MAX_RATE_PER_SEC: u64 = u64::MAX / (2 * SCALE);
41
42/// Atomic state for a token bucket (same as baseline for fair comparison)
43struct AtomicTokenState {
44    tokens: AtomicU64,
45    last_refill_nanos: AtomicU64,
46    last_access_nanos: AtomicU64,
47}
48
49impl AtomicTokenState {
50    fn new(capacity: u64, now_nanos: u64) -> Self {
51        Self {
52            tokens: AtomicU64::new(capacity.saturating_mul(SCALE)),
53            last_refill_nanos: AtomicU64::new(now_nanos),
54            last_access_nanos: AtomicU64::new(now_nanos),
55        }
56    }
57
58    /// SIMD-optimized token consumption (for single key - baseline equivalent)
59    ///
60    /// For single key operations, SIMD doesn't help much.
61    /// The real benefit comes from batch operations (see try_consume_batch).
62    fn try_consume(
63        &self,
64        capacity: u64,
65        refill_rate_per_second: u64,
66        now_nanos: u64,
67        cost: u64,
68    ) -> (bool, u64) {
69        self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
70
71        let scaled_capacity = capacity.saturating_mul(SCALE);
72        let token_cost = cost.saturating_mul(SCALE);
73
74        loop {
75            let current_tokens = self.tokens.load(Ordering::Relaxed);
76            let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
77
78            // SIMD opportunity: Vectorize this calculation for multiple buckets
79            let elapsed_nanos = now_nanos.saturating_sub(last_refill);
80            let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
81            let tokens_per_sec_scaled = refill_rate_per_second.saturating_mul(SCALE);
82            let new_tokens_to_add = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
83
84            let updated_tokens = current_tokens
85                .saturating_add(new_tokens_to_add)
86                .min(scaled_capacity);
87
88            if updated_tokens >= token_cost {
89                let new_tokens = updated_tokens.saturating_sub(token_cost);
90                let new_time = if new_tokens_to_add > 0 {
91                    now_nanos
92                } else {
93                    last_refill
94                };
95
96                match self.tokens.compare_exchange_weak(
97                    current_tokens,
98                    new_tokens,
99                    Ordering::AcqRel,
100                    Ordering::Relaxed,
101                ) {
102                    Ok(_) => {
103                        if new_tokens_to_add > 0 {
104                            let _ = self.last_refill_nanos.compare_exchange_weak(
105                                last_refill,
106                                new_time,
107                                Ordering::AcqRel,
108                                Ordering::Relaxed,
109                            );
110                        }
111                        return (true, new_tokens / SCALE);
112                    }
113                    Err(_) => continue,
114                }
115            } else {
116                let new_time = if new_tokens_to_add > 0 {
117                    now_nanos
118                } else {
119                    last_refill
120                };
121
122                match self.tokens.compare_exchange_weak(
123                    current_tokens,
124                    updated_tokens,
125                    Ordering::AcqRel,
126                    Ordering::Relaxed,
127                ) {
128                    Ok(_) => {
129                        if new_tokens_to_add > 0 {
130                            let _ = self.last_refill_nanos.compare_exchange_weak(
131                                last_refill,
132                                new_time,
133                                Ordering::AcqRel,
134                                Ordering::Relaxed,
135                            );
136                        }
137                        return (false, updated_tokens / SCALE);
138                    }
139                    Err(_) => continue,
140                }
141            }
142        }
143    }
144}
145
146/// Batch refill calculation result
147#[derive(Debug)]
148#[allow(dead_code)]
149struct BatchRefillResult {
150    new_tokens: Vec<u64>,
151    elapsed_secs: Vec<f64>,
152}
153
154/// SIMD-optimized token bucket implementation
155pub struct SimdTokenBucket {
156    capacity: u64,
157    refill_rate_per_second: u64,
158    reference_instant: Instant,
159    idle_ttl: Option<Duration>,
160    tokens: Arc<FlurryHashMap<String, Arc<AtomicTokenState>>>,
161}
162
163impl SimdTokenBucket {
164    /// Creates a new SIMD-optimized token bucket
165    pub fn new(capacity: u64, refill_rate_per_second: u64) -> Self {
166        let safe_capacity = capacity.min(MAX_BURST);
167        let safe_rate = refill_rate_per_second.min(MAX_RATE_PER_SEC);
168
169        Self {
170            capacity: safe_capacity,
171            refill_rate_per_second: safe_rate,
172            reference_instant: Instant::now(),
173            idle_ttl: None,
174            tokens: Arc::new(FlurryHashMap::new()),
175        }
176    }
177
178    /// Creates a token bucket with TTL-based eviction
179    pub fn with_ttl(capacity: u64, refill_rate_per_second: u64, idle_ttl: Duration) -> Self {
180        let mut bucket = Self::new(capacity, refill_rate_per_second);
181        bucket.idle_ttl = Some(idle_ttl);
182        bucket
183    }
184
185    #[inline]
186    fn now_nanos(&self) -> u64 {
187        self.reference_instant.elapsed().as_nanos() as u64
188    }
189
190    /// SIMD-optimized batch refill calculation
191    ///
192    /// NOTE: This is a placeholder for future SIMD implementation.
193    /// The current implementation uses scalar code to avoid unsafe blocks.
194    ///
195    /// Future work: Implement SIMD using portable_simd or std::simd when stabilized.
196    /// For now, this serves as a baseline for comparison and API design.
197    #[allow(dead_code)]
198    fn calculate_refill_batch_simd(
199        &self,
200        last_refills: &[u64],
201        now_nanos: u64,
202    ) -> BatchRefillResult {
203        let mut elapsed_secs = Vec::with_capacity(last_refills.len());
204        let mut new_tokens = Vec::with_capacity(last_refills.len());
205
206        // Scalar implementation - SIMD would process 2-4 values in parallel
207        // TODO: Use portable_simd when stable or add unsafe feature flag
208        for &last_refill in last_refills {
209            let elapsed_nanos = now_nanos.saturating_sub(last_refill);
210            let elapsed = elapsed_nanos as f64 / 1_000_000_000.0;
211            elapsed_secs.push(elapsed);
212
213            let tokens_per_sec_scaled = self.refill_rate_per_second.saturating_mul(SCALE);
214            new_tokens.push((elapsed * tokens_per_sec_scaled as f64) as u64);
215        }
216
217        BatchRefillResult {
218            new_tokens,
219            elapsed_secs,
220        }
221    }
222
223    fn cleanup_idle(&self, now_nanos: u64) {
224        if let Some(ttl) = self.idle_ttl {
225            let ttl_nanos = ttl.as_nanos() as u64;
226            let guard = self.tokens.guard();
227            let keys_to_remove: Vec<String> = self
228                .tokens
229                .iter(&guard)
230                .filter_map(|(key, state)| {
231                    let last_access = state.last_access_nanos.load(Ordering::Relaxed);
232                    let age = now_nanos.saturating_sub(last_access);
233                    if age >= ttl_nanos {
234                        Some(key.clone())
235                    } else {
236                        None
237                    }
238                })
239                .collect();
240
241            for key in keys_to_remove {
242                self.tokens.remove(&key, &guard);
243            }
244        }
245    }
246}
247
248impl super::private::Sealed for SimdTokenBucket {}
249
250#[async_trait]
251impl Algorithm for SimdTokenBucket {
252    async fn check(&self, key: &str) -> Result<RateLimitDecision> {
253        let now = self.now_nanos();
254
255        if self.idle_ttl.is_some() && (now % 100) == 0 {
256            self.cleanup_idle(now);
257        }
258
259        let guard = self.tokens.guard();
260        let key_string = key.to_string();
261        let state = match self.tokens.get(&key_string, &guard) {
262            Some(state) => state.clone(),
263            None => {
264                let new_state = Arc::new(AtomicTokenState::new(self.capacity, now));
265                match self
266                    .tokens
267                    .try_insert(key_string.clone(), new_state.clone(), &guard)
268                {
269                    Ok(_) => new_state,
270                    Err(current) => current.current.clone(),
271                }
272            }
273        };
274
275        let (permitted, remaining) =
276            state.try_consume(self.capacity, self.refill_rate_per_second, now, 1);
277
278        let retry_after = if !permitted {
279            let tokens_needed = 1u64.saturating_sub(remaining);
280            let seconds_to_wait = if self.refill_rate_per_second > 0 {
281                (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
282            } else {
283                1.0
284            };
285            Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
286        } else {
287            None
288        };
289
290        let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
291            let tokens_to_refill = self.capacity.saturating_sub(remaining);
292            let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
293            Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
294        } else if remaining >= self.capacity {
295            Some(Duration::from_secs(0))
296        } else {
297            None
298        };
299
300        Ok(RateLimitDecision {
301            permitted,
302            retry_after,
303            remaining: Some(remaining),
304            limit: self.capacity,
305            reset,
306        })
307    }
308
309    async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
310        let now = self.now_nanos();
311
312        if self.idle_ttl.is_some() && (now % 100) == 0 {
313            self.cleanup_idle(now);
314        }
315
316        let guard = self.tokens.guard();
317        let key_string = key.to_string();
318        let state = match self.tokens.get(&key_string, &guard) {
319            Some(state) => state.clone(),
320            None => {
321                let new_state = Arc::new(AtomicTokenState::new(self.capacity, now));
322                match self
323                    .tokens
324                    .try_insert(key_string.clone(), new_state.clone(), &guard)
325                {
326                    Ok(_) => new_state,
327                    Err(current) => current.current.clone(),
328                }
329            }
330        };
331
332        let (permitted, remaining) =
333            state.try_consume(self.capacity, self.refill_rate_per_second, now, cost);
334
335        let retry_after = if !permitted {
336            let tokens_needed = cost.saturating_sub(remaining);
337            let seconds_to_wait = if self.refill_rate_per_second > 0 {
338                (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
339            } else {
340                1.0
341            };
342            Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
343        } else {
344            None
345        };
346
347        let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
348            let tokens_to_refill = self.capacity.saturating_sub(remaining);
349            let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
350            Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
351        } else if remaining >= self.capacity {
352            Some(Duration::from_secs(0))
353        } else {
354            None
355        };
356
357        Ok(RateLimitDecision {
358            permitted,
359            retry_after,
360            remaining: Some(remaining),
361            limit: self.capacity,
362            reset,
363        })
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[tokio::test]
372    async fn test_simd_token_bucket_basic() {
373        let bucket = SimdTokenBucket::new(10, 100);
374
375        for _ in 0..10 {
376            let decision = bucket.check("test-key").await.unwrap();
377            assert!(decision.permitted);
378        }
379
380        let decision = bucket.check("test-key").await.unwrap();
381        assert!(!decision.permitted);
382    }
383
384    #[tokio::test(start_paused = true)]
385    async fn test_simd_token_bucket_refill() {
386        let bucket = SimdTokenBucket::new(10, 100);
387
388        for _ in 0..10 {
389            bucket.check("test-key").await.unwrap();
390        }
391
392        let decision = bucket.check("test-key").await.unwrap();
393        assert!(!decision.permitted);
394
395        tokio::time::advance(Duration::from_millis(100)).await;
396
397        for _ in 0..10 {
398            let decision = bucket.check("test-key").await.unwrap();
399            assert!(decision.permitted);
400        }
401    }
402}