Skip to main content

tokio_rate_limit/algorithm/
simd_token_bucket.rs

1//! SIMD-optimized token bucket implementation.
2//!
3//! This is an experimental implementation exploring SIMD optimizations for token bucket operations.
4//! The goal is to achieve 2-5x performance improvements for certain workloads.
5//!
6//! ## Optimization Strategies
7//!
8//! 1. **Vectorized Token Refill Calculations**: Process multiple buckets in parallel using SIMD
9//! 2. **Batch Key Processing**: Check multiple keys in a single operation
10//! 3. **SIMD-friendly Data Layout**: Structure data for efficient SIMD operations
11//!
12//! ## Platform Support
13//!
14//! - x86_64: AVX2 (256-bit vectors, 4x f64 or u64)
15//! - ARM: NEON (128-bit vectors, 2x f64 or u64)
16//! - Fallback: Scalar implementation on unsupported platforms
17//!
18//! ## Current Status: EXPERIMENTAL
19//!
20//! This implementation is for research and benchmarking purposes.
21//! It may not be production-ready.
22
23use crate::algorithm::Algorithm;
24use crate::error::Result;
25use crate::limiter::RateLimitDecision;
26use async_trait::async_trait;
27use flurry::HashMap as FlurryHashMap;
28use std::sync::atomic::{AtomicU64, Ordering};
29use std::sync::Arc;
30use std::time::Duration;
31use tokio::time::Instant;
32
33/// Scaling factor for sub-token precision (same as baseline)
34const SCALE: u64 = 1000;
35
36/// Maximum burst capacity
37const MAX_BURST: u64 = u64::MAX / (2 * SCALE);
38
39/// Maximum refill rate per second
40const MAX_RATE_PER_SEC: u64 = u64::MAX / (2 * SCALE);
41
42/// Atomic state for a token bucket (same as baseline for fair comparison)
43struct AtomicTokenState {
44    tokens: AtomicU64,
45    last_refill_nanos: AtomicU64,
46    last_access_nanos: AtomicU64,
47}
48
49impl AtomicTokenState {
50    fn new(capacity: u64, now_nanos: u64) -> Self {
51        Self {
52            tokens: AtomicU64::new(capacity.saturating_mul(SCALE)),
53            last_refill_nanos: AtomicU64::new(now_nanos),
54            last_access_nanos: AtomicU64::new(now_nanos),
55        }
56    }
57
58    /// SIMD-optimized token consumption (for single key - baseline equivalent)
59    ///
60    /// For single key operations, SIMD doesn't help much.
61    /// The real benefit comes from batch operations (see try_consume_batch).
62    fn try_consume(
63        &self,
64        capacity: u64,
65        refill_rate_per_second: u64,
66        now_nanos: u64,
67        cost: u64,
68    ) -> (bool, u64) {
69        self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
70
71        let scaled_capacity = capacity.saturating_mul(SCALE);
72        let token_cost = cost.saturating_mul(SCALE);
73
74        loop {
75            let current_tokens = self.tokens.load(Ordering::Relaxed);
76            let last_refill = self.last_refill_nanos.load(Ordering::Relaxed);
77
78            // SIMD opportunity: Vectorize this calculation for multiple buckets
79            let elapsed_nanos = now_nanos.saturating_sub(last_refill);
80            let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
81            let tokens_per_sec_scaled = refill_rate_per_second.saturating_mul(SCALE);
82            let new_tokens_to_add = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
83
84            let updated_tokens = current_tokens
85                .saturating_add(new_tokens_to_add)
86                .min(scaled_capacity);
87
88            if updated_tokens >= token_cost {
89                let new_tokens = updated_tokens.saturating_sub(token_cost);
90                let new_time = if new_tokens_to_add > 0 {
91                    now_nanos
92                } else {
93                    last_refill
94                };
95
96                match self.tokens.compare_exchange_weak(
97                    current_tokens,
98                    new_tokens,
99                    Ordering::AcqRel,
100                    Ordering::Relaxed,
101                ) {
102                    Ok(_) => {
103                        if new_tokens_to_add > 0 {
104                            let _ = self.last_refill_nanos.compare_exchange_weak(
105                                last_refill,
106                                new_time,
107                                Ordering::AcqRel,
108                                Ordering::Relaxed,
109                            );
110                        }
111                        return (true, new_tokens / SCALE);
112                    }
113                    Err(_) => continue,
114                }
115            } else {
116                let new_time = if new_tokens_to_add > 0 {
117                    now_nanos
118                } else {
119                    last_refill
120                };
121
122                match self.tokens.compare_exchange_weak(
123                    current_tokens,
124                    updated_tokens,
125                    Ordering::AcqRel,
126                    Ordering::Relaxed,
127                ) {
128                    Ok(_) => {
129                        if new_tokens_to_add > 0 {
130                            let _ = self.last_refill_nanos.compare_exchange_weak(
131                                last_refill,
132                                new_time,
133                                Ordering::AcqRel,
134                                Ordering::Relaxed,
135                            );
136                        }
137                        return (false, updated_tokens / SCALE);
138                    }
139                    Err(_) => continue,
140                }
141            }
142        }
143    }
144}
145
146/// Batch refill calculation result
147#[derive(Debug)]
148#[allow(dead_code)]
149struct BatchRefillResult {
150    new_tokens: Vec<u64>,
151    elapsed_secs: Vec<f64>,
152}
153
154/// SIMD-optimized token bucket implementation
155#[deprecated(since = "0.8.1", note = "Experimental — no SIMD benefit. Use TokenBucket instead.")]
156pub struct SimdTokenBucket {
157    capacity: u64,
158    refill_rate_per_second: u64,
159    reference_instant: Instant,
160    idle_ttl: Option<Duration>,
161    tokens: Arc<FlurryHashMap<String, Arc<AtomicTokenState>>>,
162}
163
164impl SimdTokenBucket {
165    /// Creates a new SIMD-optimized token bucket
166    pub fn new(capacity: u64, refill_rate_per_second: u64) -> Self {
167        let safe_capacity = capacity.min(MAX_BURST);
168        let safe_rate = refill_rate_per_second.min(MAX_RATE_PER_SEC);
169
170        Self {
171            capacity: safe_capacity,
172            refill_rate_per_second: safe_rate,
173            reference_instant: Instant::now(),
174            idle_ttl: None,
175            tokens: Arc::new(FlurryHashMap::new()),
176        }
177    }
178
179    /// Creates a token bucket with TTL-based eviction
180    pub fn with_ttl(capacity: u64, refill_rate_per_second: u64, idle_ttl: Duration) -> Self {
181        let mut bucket = Self::new(capacity, refill_rate_per_second);
182        bucket.idle_ttl = Some(idle_ttl);
183        bucket
184    }
185
186    #[inline]
187    fn now_nanos(&self) -> u64 {
188        self.reference_instant.elapsed().as_nanos() as u64
189    }
190
191    /// SIMD-optimized batch refill calculation
192    ///
193    /// NOTE: This is a placeholder for future SIMD implementation.
194    /// The current implementation uses scalar code to avoid unsafe blocks.
195    ///
196    /// Future work: Implement SIMD using portable_simd or std::simd when stabilized.
197    /// For now, this serves as a baseline for comparison and API design.
198    #[allow(dead_code)]
199    fn calculate_refill_batch_simd(
200        &self,
201        last_refills: &[u64],
202        now_nanos: u64,
203    ) -> BatchRefillResult {
204        let mut elapsed_secs = Vec::with_capacity(last_refills.len());
205        let mut new_tokens = Vec::with_capacity(last_refills.len());
206
207        // Scalar implementation - SIMD would process 2-4 values in parallel
208        // TODO: Use portable_simd when stable or add unsafe feature flag
209        for &last_refill in last_refills {
210            let elapsed_nanos = now_nanos.saturating_sub(last_refill);
211            let elapsed = elapsed_nanos as f64 / 1_000_000_000.0;
212            elapsed_secs.push(elapsed);
213
214            let tokens_per_sec_scaled = self.refill_rate_per_second.saturating_mul(SCALE);
215            new_tokens.push((elapsed * tokens_per_sec_scaled as f64) as u64);
216        }
217
218        BatchRefillResult {
219            new_tokens,
220            elapsed_secs,
221        }
222    }
223
224    fn cleanup_idle(&self, now_nanos: u64) {
225        if let Some(ttl) = self.idle_ttl {
226            let ttl_nanos = ttl.as_nanos() as u64;
227            let guard = self.tokens.guard();
228            let keys_to_remove: Vec<String> = self
229                .tokens
230                .iter(&guard)
231                .filter_map(|(key, state)| {
232                    let last_access = state.last_access_nanos.load(Ordering::Relaxed);
233                    let age = now_nanos.saturating_sub(last_access);
234                    if age >= ttl_nanos {
235                        Some(key.clone())
236                    } else {
237                        None
238                    }
239                })
240                .collect();
241
242            for key in keys_to_remove {
243                self.tokens.remove(&key, &guard);
244            }
245        }
246    }
247}
248
249impl super::private::Sealed for SimdTokenBucket {}
250
251#[async_trait]
252impl Algorithm for SimdTokenBucket {
253    async fn check(&self, key: &str) -> Result<RateLimitDecision> {
254        let now = self.now_nanos();
255
256        if self.idle_ttl.is_some() && (now % 100) == 0 {
257            self.cleanup_idle(now);
258        }
259
260        let guard = self.tokens.guard();
261        let key_string = key.to_string();
262        let state = match self.tokens.get(&key_string, &guard) {
263            Some(state) => state.clone(),
264            None => {
265                let new_state = Arc::new(AtomicTokenState::new(self.capacity, now));
266                match self
267                    .tokens
268                    .try_insert(key_string.clone(), new_state.clone(), &guard)
269                {
270                    Ok(_) => new_state,
271                    Err(current) => current.current.clone(),
272                }
273            }
274        };
275
276        let (permitted, remaining) =
277            state.try_consume(self.capacity, self.refill_rate_per_second, now, 1);
278
279        let retry_after = if !permitted {
280            let tokens_needed = 1u64.saturating_sub(remaining);
281            let seconds_to_wait = if self.refill_rate_per_second > 0 {
282                (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
283            } else {
284                1.0
285            };
286            Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
287        } else {
288            None
289        };
290
291        let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
292            let tokens_to_refill = self.capacity.saturating_sub(remaining);
293            let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
294            Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
295        } else if remaining >= self.capacity {
296            Some(Duration::from_secs(0))
297        } else {
298            None
299        };
300
301        Ok(RateLimitDecision {
302            permitted,
303            retry_after,
304            remaining: Some(remaining),
305            limit: self.capacity,
306            reset,
307        })
308    }
309
310    async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
311        let now = self.now_nanos();
312
313        if self.idle_ttl.is_some() && (now % 100) == 0 {
314            self.cleanup_idle(now);
315        }
316
317        let guard = self.tokens.guard();
318        let key_string = key.to_string();
319        let state = match self.tokens.get(&key_string, &guard) {
320            Some(state) => state.clone(),
321            None => {
322                let new_state = Arc::new(AtomicTokenState::new(self.capacity, now));
323                match self
324                    .tokens
325                    .try_insert(key_string.clone(), new_state.clone(), &guard)
326                {
327                    Ok(_) => new_state,
328                    Err(current) => current.current.clone(),
329                }
330            }
331        };
332
333        let (permitted, remaining) =
334            state.try_consume(self.capacity, self.refill_rate_per_second, now, cost);
335
336        let retry_after = if !permitted {
337            let tokens_needed = cost.saturating_sub(remaining);
338            let seconds_to_wait = if self.refill_rate_per_second > 0 {
339                (tokens_needed as f64 / self.refill_rate_per_second as f64).ceil()
340            } else {
341                1.0
342            };
343            Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
344        } else {
345            None
346        };
347
348        let reset = if self.refill_rate_per_second > 0 && remaining < self.capacity {
349            let tokens_to_refill = self.capacity.saturating_sub(remaining);
350            let seconds_to_full = tokens_to_refill as f64 / self.refill_rate_per_second as f64;
351            Some(Duration::from_secs_f64(seconds_to_full.max(0.001)))
352        } else if remaining >= self.capacity {
353            Some(Duration::from_secs(0))
354        } else {
355            None
356        };
357
358        Ok(RateLimitDecision {
359            permitted,
360            retry_after,
361            remaining: Some(remaining),
362            limit: self.capacity,
363            reset,
364        })
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[tokio::test]
373    async fn test_simd_token_bucket_basic() {
374        let bucket = SimdTokenBucket::new(10, 100);
375
376        for _ in 0..10 {
377            let decision = bucket.check("test-key").await.unwrap();
378            assert!(decision.permitted);
379        }
380
381        let decision = bucket.check("test-key").await.unwrap();
382        assert!(!decision.permitted);
383    }
384
385    #[tokio::test(start_paused = true)]
386    async fn test_simd_token_bucket_refill() {
387        let bucket = SimdTokenBucket::new(10, 100);
388
389        for _ in 0..10 {
390            bucket.check("test-key").await.unwrap();
391        }
392
393        let decision = bucket.check("test-key").await.unwrap();
394        assert!(!decision.permitted);
395
396        tokio::time::advance(Duration::from_millis(100)).await;
397
398        for _ in 0..10 {
399            let decision = bucket.check("test-key").await.unwrap();
400            assert!(decision.permitted);
401        }
402    }
403}