tokio_rate_limit/algorithm/
leaky_bucket.rs

1//! Leaky bucket rate limiting algorithm implementation.
2
3use crate::algorithm::Algorithm;
4use crate::error::Result;
5use crate::limiter::RateLimitDecision;
6use async_trait::async_trait;
7use flurry::HashMap as FlurryHashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::time::Instant;
12
13/// Scaling factor for sub-token precision.
14/// Token counts are multiplied by this value to allow fractional tokens.
15const SCALE: u64 = 1000;
16
17/// Maximum capacity to prevent overflow in scaled arithmetic.
18const MAX_CAPACITY: u64 = u64::MAX / (2 * SCALE);
19
20/// Maximum leak rate per second to prevent overflow.
21const MAX_LEAK_RATE: u64 = u64::MAX / (2 * SCALE);
22
23/// Atomic state for a leaky bucket.
24///
25/// Unlike token bucket which allows bursts, leaky bucket enforces a steady rate
26/// by "leaking" tokens at a constant rate. Requests add tokens to the bucket,
27/// and if the bucket would overflow, the request is denied.
28struct AtomicLeakyState {
29    /// Current number of tokens in the bucket (scaled by 1000 for sub-token precision)
30    /// Lower is better - represents pending requests that haven't "leaked" yet
31    tokens: AtomicU64,
32
33    /// Last leak timestamp in nanoseconds since the tokio runtime started
34    last_leak_nanos: AtomicU64,
35
36    /// Last access timestamp in nanoseconds, used for TTL-based eviction
37    last_access_nanos: AtomicU64,
38}
39
40impl AtomicLeakyState {
41    /// Creates a new leaky bucket state starting empty.
42    fn new(now_nanos: u64) -> Self {
43        Self {
44            tokens: AtomicU64::new(0), // Start empty
45            last_leak_nanos: AtomicU64::new(now_nanos),
46            last_access_nanos: AtomicU64::new(now_nanos),
47        }
48    }
49
50    /// Attempts to add tokens to the bucket (i.e., make a request).
51    ///
52    /// This method performs automatic leaking based on elapsed time and uses
53    /// lock-free compare-and-swap loops for token updates.
54    ///
55    /// # Arguments
56    ///
57    /// * `capacity` - Maximum bucket capacity (water level)
58    /// * `leak_rate_per_second` - Rate at which tokens leak out
59    /// * `now_nanos` - Current time in nanoseconds
60    /// * `cost` - Number of tokens to add (request cost)
61    ///
62    /// Returns `(permitted, remaining_capacity)`
63    fn try_add(
64        &self,
65        capacity: u64,
66        leak_rate_per_second: u64,
67        now_nanos: u64,
68        cost: u64,
69    ) -> (bool, u64) {
70        // Update last access time (for TTL tracking)
71        self.last_access_nanos.store(now_nanos, Ordering::Relaxed);
72
73        // Scale capacity and cost for precision
74        let scaled_capacity = capacity.saturating_mul(SCALE);
75        let token_cost = cost.saturating_mul(SCALE);
76
77        loop {
78            // Load current state
79            let current_tokens = self.tokens.load(Ordering::Relaxed);
80            let last_leak = self.last_leak_nanos.load(Ordering::Relaxed);
81
82            // Calculate elapsed time and tokens to leak out
83            let elapsed_nanos = now_nanos.saturating_sub(last_leak);
84            let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
85
86            // Calculate tokens to leak (remove from bucket)
87            let tokens_per_sec_scaled = leak_rate_per_second.saturating_mul(SCALE);
88            let tokens_to_leak = (elapsed_secs * tokens_per_sec_scaled as f64) as u64;
89
90            // Calculate updated token count after leaking (can't go below 0)
91            let leaked_tokens = current_tokens.saturating_sub(tokens_to_leak);
92
93            // Try to add the requested cost
94            let new_tokens = leaked_tokens.saturating_add(token_cost);
95
96            if new_tokens <= scaled_capacity {
97                // Request fits in bucket, allow it
98                let new_time = if tokens_to_leak > 0 {
99                    now_nanos
100                } else {
101                    last_leak
102                };
103
104                // Try to update tokens atomically
105                match self.tokens.compare_exchange_weak(
106                    current_tokens,
107                    new_tokens,
108                    Ordering::AcqRel,
109                    Ordering::Relaxed,
110                ) {
111                    Ok(_) => {
112                        // Successfully updated tokens, now update time if needed
113                        if tokens_to_leak > 0 {
114                            let _ = self.last_leak_nanos.compare_exchange_weak(
115                                last_leak,
116                                new_time,
117                                Ordering::AcqRel,
118                                Ordering::Relaxed,
119                            );
120                        }
121                        // Return remaining capacity
122                        let remaining_capacity = (scaled_capacity - new_tokens) / SCALE;
123                        return (true, remaining_capacity);
124                    }
125                    Err(_) => {
126                        // Another thread modified it, retry
127                        continue;
128                    }
129                }
130            } else {
131                // Request would overflow bucket, deny it
132                // But still update state to reflect leaking
133                let new_time = if tokens_to_leak > 0 {
134                    now_nanos
135                } else {
136                    last_leak
137                };
138
139                match self.tokens.compare_exchange_weak(
140                    current_tokens,
141                    leaked_tokens,
142                    Ordering::AcqRel,
143                    Ordering::Relaxed,
144                ) {
145                    Ok(_) => {
146                        if tokens_to_leak > 0 {
147                            let _ = self.last_leak_nanos.compare_exchange_weak(
148                                last_leak,
149                                new_time,
150                                Ordering::AcqRel,
151                                Ordering::Relaxed,
152                            );
153                        }
154                        // Return current capacity (how much room is left)
155                        let remaining_capacity = (scaled_capacity - leaked_tokens) / SCALE;
156                        return (false, remaining_capacity);
157                    }
158                    Err(_) => {
159                        // Another thread modified it, retry
160                        continue;
161                    }
162                }
163            }
164        }
165    }
166}
167
168/// Leaky bucket rate limiting algorithm.
169///
170/// The leaky bucket algorithm enforces a steady rate by "leaking" tokens at a constant rate.
171/// Unlike token bucket which allows bursts, leaky bucket smooths traffic by maintaining
172/// a consistent flow rate.
173///
174/// # Algorithm Details
175///
176/// - **Capacity**: Maximum bucket size (water level)
177/// - **Leak Rate**: Tokens removed per second (steady outflow)
178/// - **Request Handling**: Each request adds tokens; if bucket overflows, request is denied
179/// - **Traffic Smoothing**: Enforces steady rate without bursts
180///
181/// # Comparison with Token Bucket
182///
183/// | Feature | Token Bucket | Leaky Bucket |
184/// |---------|--------------|--------------|
185/// | **Bursts** | Allowed up to capacity | Not allowed |
186/// | **Rate Enforcement** | Average rate over time | Strict steady rate |
187/// | **Traffic Pattern** | Bursty | Smooth |
188/// | **Use Case** | Public APIs, user requests | Backend protection, QPS limits |
189///
190/// # Use Cases
191///
192/// - **Backend Protection**: Prevent overwhelming downstream services with consistent load
193/// - **Strict QPS Enforcement**: When you need exactly N requests/sec, no more, no less
194/// - **Traffic Smoothing**: Convert bursty traffic into steady stream
195/// - **Fair Queuing**: Ensure no client can monopolize resources with bursts
196///
197/// # Performance
198///
199/// - Uses same lock-free architecture as TokenBucket
200/// - Expected: Similar performance to TokenBucket (15M+ ops/sec single-threaded)
201/// - Minimal overhead compared to token bucket
202///
203/// # Examples
204///
205/// ```
206/// use tokio_rate_limit::algorithm::LeakyBucket;
207/// use tokio_rate_limit::RateLimiter;
208///
209/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
210/// // 100 requests/sec steady rate, capacity of 50
211/// let algorithm = LeakyBucket::new(50, 100);
212/// let limiter = RateLimiter::from_algorithm(algorithm);
213///
214/// // Requests that would cause bursts are denied
215/// let decision = limiter.check("client-123").await?;
216/// # Ok(())
217/// # }
218/// ```
219pub struct LeakyBucket {
220    /// Maximum tokens the bucket can hold
221    capacity: u64,
222
223    /// Number of tokens leaked per second (steady rate)
224    leak_rate_per_second: u64,
225
226    /// Reference instant for time measurements.
227    reference_instant: Instant,
228
229    /// Time-to-live for idle keys.
230    idle_ttl: Option<Duration>,
231
232    /// Per-key leaky bucket state.
233    buckets: Arc<FlurryHashMap<String, Arc<AtomicLeakyState>>>,
234}
235
236impl LeakyBucket {
237    /// Creates a new leaky bucket with the specified capacity and leak rate.
238    ///
239    /// # Arguments
240    ///
241    /// * `capacity` - Maximum bucket size, clamped to MAX_CAPACITY if exceeded
242    /// * `leak_rate_per_second` - Tokens leaked per second, clamped to MAX_LEAK_RATE if exceeded
243    ///
244    /// # Examples
245    ///
246    /// ```
247    /// use tokio_rate_limit::algorithm::LeakyBucket;
248    ///
249    /// // 100 requests/sec with capacity of 50
250    /// let bucket = LeakyBucket::new(50, 100);
251    /// ```
252    pub fn new(capacity: u64, leak_rate_per_second: u64) -> Self {
253        let safe_capacity = capacity.min(MAX_CAPACITY);
254        let safe_rate = leak_rate_per_second.min(MAX_LEAK_RATE);
255
256        Self {
257            capacity: safe_capacity,
258            leak_rate_per_second: safe_rate,
259            reference_instant: Instant::now(),
260            idle_ttl: None,
261            buckets: Arc::new(FlurryHashMap::new()),
262        }
263    }
264
265    /// Creates a new leaky bucket with TTL-based eviction.
266    ///
267    /// # Arguments
268    ///
269    /// * `capacity` - Maximum bucket size
270    /// * `leak_rate_per_second` - Tokens leaked per second
271    /// * `idle_ttl` - Duration after which idle keys are evicted
272    ///
273    /// # Examples
274    ///
275    /// ```
276    /// use tokio_rate_limit::algorithm::LeakyBucket;
277    /// use std::time::Duration;
278    ///
279    /// // Evict keys idle for more than 1 hour
280    /// let bucket = LeakyBucket::with_ttl(50, 100, Duration::from_secs(3600));
281    /// ```
282    pub fn with_ttl(capacity: u64, leak_rate_per_second: u64, idle_ttl: Duration) -> Self {
283        let mut bucket = Self::new(capacity, leak_rate_per_second);
284        bucket.idle_ttl = Some(idle_ttl);
285        bucket
286    }
287
288    /// Get current time in nanoseconds since the reference instant.
289    #[inline]
290    fn now_nanos(&self) -> u64 {
291        self.reference_instant.elapsed().as_nanos() as u64
292    }
293
294    /// Cleanup idle keys based on TTL configuration.
295    fn cleanup_idle(&self, now_nanos: u64) {
296        if let Some(ttl) = self.idle_ttl {
297            let ttl_nanos = ttl.as_nanos() as u64;
298
299            let guard = self.buckets.guard();
300            let keys_to_remove: Vec<String> = self
301                .buckets
302                .iter(&guard)
303                .filter_map(|(key, state)| {
304                    let last_access = state.last_access_nanos.load(Ordering::Relaxed);
305                    let age = now_nanos.saturating_sub(last_access);
306                    if age >= ttl_nanos {
307                        Some(key.clone())
308                    } else {
309                        None
310                    }
311                })
312                .collect();
313
314            for key in keys_to_remove {
315                self.buckets.remove(&key, &guard);
316            }
317        }
318    }
319}
320
321// Implement the sealed trait marker
322impl super::private::Sealed for LeakyBucket {}
323
324#[async_trait]
325impl Algorithm for LeakyBucket {
326    async fn check(&self, key: &str) -> Result<RateLimitDecision> {
327        let now = self.now_nanos();
328
329        // Probabilistic cleanup (1% of the time)
330        if self.idle_ttl.is_some() && (now % 100) == 0 {
331            self.cleanup_idle(now);
332        }
333
334        // Get or create bucket state for this key
335        let guard = self.buckets.guard();
336        let key_string = key.to_string();
337        let state = match self.buckets.get(&key_string, &guard) {
338            Some(state) => state.clone(),
339            None => {
340                let new_state = Arc::new(AtomicLeakyState::new(now));
341                match self
342                    .buckets
343                    .try_insert(key_string.clone(), new_state.clone(), &guard)
344                {
345                    Ok(_) => new_state,
346                    Err(current) => current.current.clone(),
347                }
348            }
349        };
350
351        // Try to add a token (cost of 1)
352        let (permitted, remaining_capacity) =
353            state.try_add(self.capacity, self.leak_rate_per_second, now, 1);
354
355        // Calculate retry_after if rate limited
356        let retry_after = if !permitted {
357            // We need to wait for enough tokens to leak out
358            // At leak_rate_per_second, time to leak 1 token is 1/leak_rate_per_second
359            let seconds_to_wait = if self.leak_rate_per_second > 0 {
360                1.0 / self.leak_rate_per_second as f64
361            } else {
362                1.0
363            };
364            Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
365        } else {
366            None
367        };
368
369        // Calculate reset time (time until bucket is empty)
370        // Current bucket level is (capacity - remaining_capacity)
371        // Time to empty = current_level / leak_rate_per_second
372        let current_level = self.capacity.saturating_sub(remaining_capacity);
373        let reset = if self.leak_rate_per_second > 0 && current_level > 0 {
374            let seconds_to_empty = current_level as f64 / self.leak_rate_per_second as f64;
375            Some(Duration::from_secs_f64(seconds_to_empty.max(0.001)))
376        } else {
377            Some(Duration::from_secs(0))
378        };
379
380        Ok(RateLimitDecision {
381            permitted,
382            retry_after,
383            remaining: Some(remaining_capacity),
384            limit: self.capacity,
385            reset,
386        })
387    }
388
389    async fn check_with_cost(&self, key: &str, cost: u64) -> Result<RateLimitDecision> {
390        let now = self.now_nanos();
391
392        // Probabilistic cleanup (1% of the time)
393        if self.idle_ttl.is_some() && (now % 100) == 0 {
394            self.cleanup_idle(now);
395        }
396
397        // Get or create bucket state for this key
398        let guard = self.buckets.guard();
399        let key_string = key.to_string();
400        let state = match self.buckets.get(&key_string, &guard) {
401            Some(state) => state.clone(),
402            None => {
403                let new_state = Arc::new(AtomicLeakyState::new(now));
404                match self
405                    .buckets
406                    .try_insert(key_string.clone(), new_state.clone(), &guard)
407                {
408                    Ok(_) => new_state,
409                    Err(current) => current.current.clone(),
410                }
411            }
412        };
413
414        // Try to add tokens with specified cost
415        let (permitted, remaining_capacity) =
416            state.try_add(self.capacity, self.leak_rate_per_second, now, cost);
417
418        // Calculate retry_after if rate limited
419        let retry_after = if !permitted {
420            // Time to leak enough tokens for this cost
421            let seconds_to_wait = if self.leak_rate_per_second > 0 {
422                cost as f64 / self.leak_rate_per_second as f64
423            } else {
424                1.0
425            };
426            Some(Duration::from_secs_f64(seconds_to_wait.max(0.001)))
427        } else {
428            None
429        };
430
431        // Calculate reset time
432        let current_level = self.capacity.saturating_sub(remaining_capacity);
433        let reset = if self.leak_rate_per_second > 0 && current_level > 0 {
434            let seconds_to_empty = current_level as f64 / self.leak_rate_per_second as f64;
435            Some(Duration::from_secs_f64(seconds_to_empty.max(0.001)))
436        } else {
437            Some(Duration::from_secs(0))
438        };
439
440        Ok(RateLimitDecision {
441            permitted,
442            retry_after,
443            remaining: Some(remaining_capacity),
444            limit: self.capacity,
445            reset,
446        })
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[tokio::test]
455    async fn test_leaky_bucket_basic() {
456        let bucket = LeakyBucket::new(10, 100);
457
458        // First request should succeed (bucket starts empty)
459        let decision = bucket.check("test-key").await.unwrap();
460        assert!(decision.permitted, "First request should be permitted");
461
462        // Multiple rapid requests should eventually be rate limited
463        // since leaky bucket doesn't allow bursts
464        let mut permitted_count = 1;
465        for _ in 0..20 {
466            let decision = bucket.check("test-key").await.unwrap();
467            if decision.permitted {
468                permitted_count += 1;
469            }
470        }
471
472        // Should have rate limited some requests (capacity is 10, we sent 21 total)
473        assert!(
474            permitted_count <= 11,
475            "Should have rate limited some requests, but allowed {}",
476            permitted_count
477        );
478    }
479
480    #[tokio::test(start_paused = true)]
481    async fn test_leaky_bucket_leak_rate() {
482        let bucket = LeakyBucket::new(5, 10); // 5 capacity, 10 per second leak rate
483
484        // Fill the bucket to capacity
485        for i in 0..5 {
486            let decision = bucket.check("test-key").await.unwrap();
487            assert!(decision.permitted, "Request {} should be permitted", i + 1);
488        }
489
490        // Should be rate limited now (bucket full)
491        let decision = bucket.check("test-key").await.unwrap();
492        assert!(!decision.permitted, "Should be rate limited when full");
493
494        // Wait 100ms (should leak 1 token at 10/sec)
495        tokio::time::advance(Duration::from_millis(100)).await;
496
497        // Should work again
498        let decision = bucket.check("test-key").await.unwrap();
499        assert!(decision.permitted, "Request should be permitted after leak");
500    }
501
502    #[tokio::test]
503    async fn test_leaky_bucket_multiple_keys() {
504        let bucket = LeakyBucket::new(2, 10);
505
506        // Key 1: fill bucket
507        bucket.check("key1").await.unwrap();
508        bucket.check("key1").await.unwrap();
509        let decision = bucket.check("key1").await.unwrap();
510        assert!(!decision.permitted, "key1 should be rate limited");
511
512        // Key 2: should still work (separate bucket)
513        let decision = bucket.check("key2").await.unwrap();
514        assert!(decision.permitted, "key2 should be permitted");
515    }
516
517    #[tokio::test]
518    async fn test_leaky_bucket_cost() {
519        let bucket = LeakyBucket::new(10, 10);
520
521        // Request with cost 5 should work
522        let decision = bucket.check_with_cost("test-key", 5).await.unwrap();
523        assert!(decision.permitted, "Cost 5 request should be permitted");
524
525        // Request with cost 6 should fail (5 + 6 > 10)
526        let decision = bucket.check_with_cost("test-key", 6).await.unwrap();
527        assert!(!decision.permitted, "Cost 6 request should be denied");
528
529        // Request with cost 5 should still work (still at 5)
530        let decision = bucket.check_with_cost("test-key", 5).await.unwrap();
531        assert!(decision.permitted, "Cost 5 request should still work");
532    }
533
534    #[tokio::test(start_paused = true)]
535    async fn test_leaky_bucket_ttl() {
536        let bucket = LeakyBucket::with_ttl(10, 100, Duration::from_secs(1));
537
538        // Access key1
539        bucket.check("key1").await.unwrap();
540        assert_eq!(bucket.buckets.len(), 1);
541
542        // Advance time past TTL
543        tokio::time::advance(Duration::from_secs(2)).await;
544
545        // Access key2 multiple times to trigger cleanup
546        for _ in 0..200 {
547            bucket.check("key2").await.unwrap();
548        }
549
550        // key1 should eventually be evicted
551        let count = bucket.buckets.len();
552        assert!(
553            (1..=2).contains(&count),
554            "Expected 1-2 keys after TTL, got {}",
555            count
556        );
557    }
558}