skp_ratelimit/algorithm/
token_bucket.rs

1//! Token Bucket rate limiting algorithm.
2
3use std::time::Duration;
4
5use crate::algorithm::{current_timestamp_ms, timestamp_to_instant, Algorithm};
6use crate::decision::{Decision, DecisionMetadata, RateLimitInfo};
7use crate::error::Result;
8use crate::quota::Quota;
9use crate::storage::{Storage, StorageEntry};
10
11/// Token Bucket rate limiting algorithm.
12///
13/// Allows controlled bursts while enforcing an average rate limit.
14/// Tokens are refilled at a constant rate up to maximum capacity.
15#[derive(Debug, Clone, Default)]
16pub struct TokenBucket;
17
18impl TokenBucket {
19    /// Create a new Token Bucket algorithm instance.
20    pub fn new() -> Self {
21        Self
22    }
23
24    /// Calculate token refill based on elapsed time.
25    fn calculate_refill(&self, elapsed_ms: u64, refill_rate: f64) -> f64 {
26        let elapsed_secs = elapsed_ms as f64 / 1000.0;
27        elapsed_secs * refill_rate
28    }
29
30    /// Build rate limit info from current state.
31    fn build_info(&self, tokens: f64, quota: &Quota, now: u64) -> RateLimitInfo {
32        let max_tokens = quota.effective_burst();
33        let remaining = tokens.floor() as u64;
34        let refill_rate = quota.effective_refill_rate();
35
36        let time_to_next_token = if tokens < 1.0 {
37            ((1.0 - tokens) / refill_rate * 1000.0) as u64
38        } else {
39            0
40        };
41
42        let tokens_needed = max_tokens as f64 - tokens;
43        let time_to_full = if tokens_needed > 0.0 {
44            (tokens_needed / refill_rate * 1000.0) as u64
45        } else {
46            0
47        };
48
49        let reset_at = timestamp_to_instant(now + time_to_full);
50        let window_start = timestamp_to_instant(now);
51
52        let mut info = RateLimitInfo::new(max_tokens, remaining, reset_at, window_start)
53            .with_algorithm("token_bucket")
54            .with_metadata(DecisionMetadata::new().with_tokens_available(tokens));
55
56        if remaining == 0 && time_to_next_token > 0 {
57            info = info.with_retry_after(Duration::from_millis(time_to_next_token));
58        }
59
60        info
61    }
62}
63
64impl Algorithm for TokenBucket {
65    fn name(&self) -> &'static str {
66        "token_bucket"
67    }
68
69    async fn check_and_record<S: Storage>(
70        &self,
71        storage: &S,
72        key: &str,
73        quota: &Quota,
74    ) -> Result<Decision> {
75        let now = current_timestamp_ms();
76        let max_tokens = quota.effective_burst() as f64;
77        let refill_rate = quota.effective_refill_rate();
78
79        let ttl_ms = ((max_tokens / refill_rate) * 1000.0 * 2.0) as u64;
80        let ttl = Duration::from_millis(ttl_ms.max(1000));
81
82        let decision = storage
83            .execute_atomic(key, ttl, |entry| {
84                let (mut tokens, last_update) = match entry {
85                    Some(e) => (e.tokens.unwrap_or(max_tokens), e.last_update),
86                    None => (max_tokens, now),
87                };
88
89                if now > last_update {
90                    let elapsed = now - last_update;
91                    let refill = self.calculate_refill(elapsed, refill_rate);
92                    tokens = (tokens + refill).min(max_tokens);
93                }
94
95                if tokens >= 1.0 {
96                    tokens -= 1.0;
97                    let new_entry = StorageEntry::with_tokens(tokens, now);
98                    let info = self.build_info(tokens, quota, now);
99                    (new_entry, Decision::allowed(info))
100                } else {
101                    let new_entry = StorageEntry::with_tokens(tokens, now);
102                    let info = self.build_info(tokens, quota, now);
103                    (new_entry, Decision::denied(info))
104                }
105            })
106            .await?;
107
108        Ok(decision)
109    }
110
111    async fn check<S: Storage>(
112        &self,
113        storage: &S,
114        key: &str,
115        quota: &Quota,
116    ) -> Result<Decision> {
117        let now = current_timestamp_ms();
118        let max_tokens = quota.effective_burst() as f64;
119        let refill_rate = quota.effective_refill_rate();
120
121        let entry = storage.get(key).await?;
122
123        let (mut tokens, last_update) = match entry {
124            Some(e) => (e.tokens.unwrap_or(max_tokens), e.last_update),
125            None => (max_tokens, now),
126        };
127
128        if now > last_update {
129            let elapsed = now - last_update;
130            let refill = self.calculate_refill(elapsed, refill_rate);
131            tokens = (tokens + refill).min(max_tokens);
132        }
133
134        let info = self.build_info(tokens, quota, now);
135
136        Ok(if tokens >= 1.0 {
137            Decision::allowed(info)
138        } else {
139            Decision::denied(info)
140        })
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use crate::storage::MemoryStorage;
148
149    #[tokio::test]
150    async fn test_token_bucket_basic() {
151        let algorithm = TokenBucket::new();
152        let storage = MemoryStorage::new();
153        let quota = Quota::per_minute(5).with_burst(5);
154
155        for i in 1..=5 {
156            let decision = algorithm.check_and_record(&storage, "user:1", &quota).await.unwrap();
157            assert!(decision.is_allowed(), "Request {} should be allowed", i);
158        }
159
160        let decision = algorithm.check_and_record(&storage, "user:1", &quota).await.unwrap();
161        assert!(decision.is_denied());
162    }
163
164    #[tokio::test]
165    async fn test_token_bucket_burst() {
166        let algorithm = TokenBucket::new();
167        let storage = MemoryStorage::new();
168        let quota = Quota::per_second(1).with_burst(10);
169
170        for i in 1..=10 {
171            let decision = algorithm.check_and_record(&storage, "user:1", &quota).await.unwrap();
172            assert!(decision.is_allowed(), "Burst request {} should be allowed", i);
173        }
174
175        let decision = algorithm.check_and_record(&storage, "user:1", &quota).await.unwrap();
176        assert!(decision.is_denied());
177    }
178
179    #[tokio::test]
180    async fn test_token_bucket_refill() {
181        let algorithm = TokenBucket::new();
182        let storage = MemoryStorage::new();
183        let quota = Quota::per_second(10).with_burst(1);
184
185        algorithm.check_and_record(&storage, "user:1", &quota).await.unwrap();
186
187        let decision = algorithm.check_and_record(&storage, "user:1", &quota).await.unwrap();
188        assert!(decision.is_denied());
189
190        tokio::time::sleep(Duration::from_millis(150)).await;
191
192        let decision = algorithm.check_and_record(&storage, "user:1", &quota).await.unwrap();
193        assert!(decision.is_allowed());
194    }
195}