skp_ratelimit/algorithm/
token_bucket.rs1use std::time::Duration;
4
5use crate::algorithm::{current_timestamp_ms, timestamp_to_instant, Algorithm};
6use crate::decision::{Decision, DecisionMetadata, RateLimitInfo};
7use crate::error::Result;
8use crate::quota::Quota;
9use crate::storage::{Storage, StorageEntry};
10
11#[derive(Debug, Clone, Default)]
16pub struct TokenBucket;
17
18impl TokenBucket {
19 pub fn new() -> Self {
21 Self
22 }
23
24 fn calculate_refill(&self, elapsed_ms: u64, refill_rate: f64) -> f64 {
26 let elapsed_secs = elapsed_ms as f64 / 1000.0;
27 elapsed_secs * refill_rate
28 }
29
30 fn build_info(&self, tokens: f64, quota: &Quota, now: u64) -> RateLimitInfo {
32 let max_tokens = quota.effective_burst();
33 let remaining = tokens.floor() as u64;
34 let refill_rate = quota.effective_refill_rate();
35
36 let time_to_next_token = if tokens < 1.0 {
37 ((1.0 - tokens) / refill_rate * 1000.0) as u64
38 } else {
39 0
40 };
41
42 let tokens_needed = max_tokens as f64 - tokens;
43 let time_to_full = if tokens_needed > 0.0 {
44 (tokens_needed / refill_rate * 1000.0) as u64
45 } else {
46 0
47 };
48
49 let reset_at = timestamp_to_instant(now + time_to_full);
50 let window_start = timestamp_to_instant(now);
51
52 let mut info = RateLimitInfo::new(max_tokens, remaining, reset_at, window_start)
53 .with_algorithm("token_bucket")
54 .with_metadata(DecisionMetadata::new().with_tokens_available(tokens));
55
56 if remaining == 0 && time_to_next_token > 0 {
57 info = info.with_retry_after(Duration::from_millis(time_to_next_token));
58 }
59
60 info
61 }
62}
63
64impl Algorithm for TokenBucket {
65 fn name(&self) -> &'static str {
66 "token_bucket"
67 }
68
69 async fn check_and_record<S: Storage>(
70 &self,
71 storage: &S,
72 key: &str,
73 quota: &Quota,
74 ) -> Result<Decision> {
75 let now = current_timestamp_ms();
76 let max_tokens = quota.effective_burst() as f64;
77 let refill_rate = quota.effective_refill_rate();
78
79 let ttl_ms = ((max_tokens / refill_rate) * 1000.0 * 2.0) as u64;
80 let ttl = Duration::from_millis(ttl_ms.max(1000));
81
82 let decision = storage
83 .execute_atomic(key, ttl, |entry| {
84 let (mut tokens, last_update) = match entry {
85 Some(e) => (e.tokens.unwrap_or(max_tokens), e.last_update),
86 None => (max_tokens, now),
87 };
88
89 if now > last_update {
90 let elapsed = now - last_update;
91 let refill = self.calculate_refill(elapsed, refill_rate);
92 tokens = (tokens + refill).min(max_tokens);
93 }
94
95 if tokens >= 1.0 {
96 tokens -= 1.0;
97 let new_entry = StorageEntry::with_tokens(tokens, now);
98 let info = self.build_info(tokens, quota, now);
99 (new_entry, Decision::allowed(info))
100 } else {
101 let new_entry = StorageEntry::with_tokens(tokens, now);
102 let info = self.build_info(tokens, quota, now);
103 (new_entry, Decision::denied(info))
104 }
105 })
106 .await?;
107
108 Ok(decision)
109 }
110
111 async fn check<S: Storage>(
112 &self,
113 storage: &S,
114 key: &str,
115 quota: &Quota,
116 ) -> Result<Decision> {
117 let now = current_timestamp_ms();
118 let max_tokens = quota.effective_burst() as f64;
119 let refill_rate = quota.effective_refill_rate();
120
121 let entry = storage.get(key).await?;
122
123 let (mut tokens, last_update) = match entry {
124 Some(e) => (e.tokens.unwrap_or(max_tokens), e.last_update),
125 None => (max_tokens, now),
126 };
127
128 if now > last_update {
129 let elapsed = now - last_update;
130 let refill = self.calculate_refill(elapsed, refill_rate);
131 tokens = (tokens + refill).min(max_tokens);
132 }
133
134 let info = self.build_info(tokens, quota, now);
135
136 Ok(if tokens >= 1.0 {
137 Decision::allowed(info)
138 } else {
139 Decision::denied(info)
140 })
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147 use crate::storage::MemoryStorage;
148
149 #[tokio::test]
150 async fn test_token_bucket_basic() {
151 let algorithm = TokenBucket::new();
152 let storage = MemoryStorage::new();
153 let quota = Quota::per_minute(5).with_burst(5);
154
155 for i in 1..=5 {
156 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
157 assert!(decision.is_allowed(), "Request {} should be allowed", i);
158 }
159
160 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
161 assert!(decision.is_denied());
162 }
163
164 #[tokio::test]
165 async fn test_token_bucket_burst() {
166 let algorithm = TokenBucket::new();
167 let storage = MemoryStorage::new();
168 let quota = Quota::per_second(1).with_burst(10);
169
170 for i in 1..=10 {
171 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
172 assert!(decision.is_allowed(), "Burst request {} should be allowed", i);
173 }
174
175 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
176 assert!(decision.is_denied());
177 }
178
179 #[tokio::test]
180 async fn test_token_bucket_refill() {
181 let algorithm = TokenBucket::new();
182 let storage = MemoryStorage::new();
183 let quota = Quota::per_second(10).with_burst(1);
184
185 algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
186
187 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
188 assert!(decision.is_denied());
189
190 tokio::time::sleep(Duration::from_millis(150)).await;
191
192 let decision = algorithm.check_and_record(&storage, "user:1", "a).await.unwrap();
193 assert!(decision.is_allowed());
194 }
195}