rust_serv/throttle/
token_bucket.rs1use std::time::{Duration, Instant};
4
5#[derive(Debug)]
7pub struct TokenBucket {
8 capacity: u64,
10 tokens: f64,
12 refill_rate: f64,
14 last_refill: Instant,
16}
17
18impl TokenBucket {
19 pub fn new(capacity: u64, refill_rate: u64) -> Self {
21 Self {
22 capacity,
23 tokens: capacity as f64,
24 refill_rate: refill_rate as f64,
25 last_refill: Instant::now(),
26 }
27 }
28
29 pub fn capacity(&self) -> u64 {
31 self.capacity
32 }
33
34 pub fn refill_rate(&self) -> u64 {
36 self.refill_rate as u64
37 }
38
39 pub fn tokens(&mut self) -> u64 {
41 self.refill();
42 self.tokens as u64
43 }
44
45 fn refill(&mut self) {
47 let now = Instant::now();
48 let elapsed = now.duration_since(self.last_refill);
49 let tokens_to_add = elapsed.as_secs_f64() * self.refill_rate;
50
51 self.tokens = (self.tokens + tokens_to_add).min(self.capacity as f64);
52 self.last_refill = now;
53 }
54
55 pub fn consume(&mut self, requested: u64) -> u64 {
58 self.refill();
59
60 if self.tokens >= requested as f64 {
61 self.tokens -= requested as f64;
62 requested
63 } else {
64 let consumed = self.tokens as u64;
65 self.tokens = 0.0;
66 consumed
67 }
68 }
69
70 pub fn try_consume(&mut self, amount: u64) -> bool {
73 self.refill();
74
75 if self.tokens >= amount as f64 {
76 self.tokens -= amount as f64;
77 true
78 } else {
79 false
80 }
81 }
82
83 pub fn wait_time(&mut self, amount: u64) -> Duration {
86 self.refill();
87
88 if self.tokens >= amount as f64 {
89 return Duration::ZERO;
90 }
91
92 let tokens_needed = amount as f64 - self.tokens;
93 let wait_secs = tokens_needed / self.refill_rate;
94
95 Duration::from_secs_f64(wait_secs)
96 }
97
98 pub fn reset(&mut self) {
100 self.tokens = self.capacity as f64;
101 self.last_refill = Instant::now();
102 }
103
104 pub fn set_refill_rate(&mut self, rate: u64) {
106 self.refill_rate = rate as f64;
107 }
108
109 pub fn set_capacity(&mut self, capacity: u64) {
111 self.capacity = capacity;
112 self.tokens = self.tokens.min(capacity as f64);
113 }
114
115 pub fn has_tokens(&mut self, amount: u64) -> bool {
117 self.refill();
118 self.tokens >= amount as f64
119 }
120
121 pub fn fill_level(&mut self) -> f64 {
123 self.refill();
124 self.tokens / self.capacity as f64
125 }
126}
127
128impl Clone for TokenBucket {
129 fn clone(&self) -> Self {
130 Self {
131 capacity: self.capacity,
132 tokens: self.tokens,
133 refill_rate: self.refill_rate,
134 last_refill: Instant::now(),
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 #[test]
144 fn test_bucket_creation() {
145 let bucket = TokenBucket::new(1000, 100);
146 assert_eq!(bucket.capacity(), 1000);
147 assert_eq!(bucket.refill_rate(), 100);
148 }
149
150 #[test]
151 fn test_bucket_starts_full() {
152 let mut bucket = TokenBucket::new(1000, 100);
153 assert_eq!(bucket.tokens(), 1000);
154 }
155
156 #[test]
157 fn test_consume_success() {
158 let mut bucket = TokenBucket::new(1000, 100);
159
160 let consumed = bucket.consume(500);
161 assert_eq!(consumed, 500);
162 assert_eq!(bucket.tokens(), 500);
163 }
164
165 #[test]
166 fn test_consume_partial() {
167 let mut bucket = TokenBucket::new(100, 10);
168
169 let consumed = bucket.consume(150);
171 assert_eq!(consumed, 100);
172 assert_eq!(bucket.tokens(), 0);
173 }
174
175 #[test]
176 fn test_consume_empty_bucket() {
177 let mut bucket = TokenBucket::new(100, 10);
178
179 bucket.consume(100);
180 assert_eq!(bucket.tokens(), 0);
181
182 let consumed = bucket.consume(50);
183 assert_eq!(consumed, 0);
184 }
185
186 #[test]
187 fn test_try_consume_success() {
188 let mut bucket = TokenBucket::new(1000, 100);
189
190 assert!(bucket.try_consume(500));
191 assert_eq!(bucket.tokens(), 500);
192 }
193
194 #[test]
195 fn test_try_consume_fail() {
196 let mut bucket = TokenBucket::new(100, 10);
197
198 assert!(!bucket.try_consume(150));
199 assert_eq!(bucket.tokens(), 100);
201 }
202
203 #[test]
204 fn test_refill() {
205 let mut bucket = TokenBucket::new(1000, 1000); bucket.consume(500);
208 assert_eq!(bucket.tokens(), 500);
209
210 std::thread::sleep(Duration::from_millis(100));
212
213 let tokens = bucket.tokens();
215 assert!(tokens > 500 && tokens < 800, "Expected ~600 tokens, got {}", tokens);
216 }
217
218 #[test]
219 fn test_wait_time_zero() {
220 let mut bucket = TokenBucket::new(1000, 100);
221
222 let wait = bucket.wait_time(500);
223 assert_eq!(wait, Duration::ZERO);
224 }
225
226 #[test]
227 fn test_wait_time_needed() {
228 let mut bucket = TokenBucket::new(100, 100); bucket.consume(100);
232
233 let wait = bucket.wait_time(100);
235 assert!(wait >= Duration::from_millis(900) && wait <= Duration::from_millis(1100));
236 }
237
238 #[test]
239 fn test_reset() {
240 let mut bucket = TokenBucket::new(1000, 100);
241
242 bucket.consume(500);
243 assert_eq!(bucket.tokens(), 500);
244
245 bucket.reset();
246 assert_eq!(bucket.tokens(), 1000);
247 }
248
249 #[test]
250 fn test_set_refill_rate() {
251 let mut bucket = TokenBucket::new(1000, 100);
252 bucket.set_refill_rate(200);
253
254 assert_eq!(bucket.refill_rate(), 200);
255 }
256
257 #[test]
258 fn test_set_capacity() {
259 let mut bucket = TokenBucket::new(1000, 100);
260 bucket.consume(500);
261
262 bucket.set_capacity(300);
263
264 assert_eq!(bucket.capacity(), 300);
265 assert_eq!(bucket.tokens(), 300);
267 }
268
269 #[test]
270 fn test_has_tokens() {
271 let mut bucket = TokenBucket::new(1000, 100);
272
273 assert!(bucket.has_tokens(500));
274 assert!(bucket.has_tokens(1000));
275 assert!(!bucket.has_tokens(1500));
276 }
277
278 #[test]
279 fn test_fill_level() {
280 let mut bucket = TokenBucket::new(1000, 100);
281
282 let level = bucket.fill_level();
283 assert!(level > 0.99 && level <= 1.0, "Expected ~1.0, got {}", level);
284
285 bucket.consume(500);
286 let level = bucket.fill_level();
287 assert!(level > 0.49 && level < 0.51, "Expected ~0.5, got {}", level);
288
289 bucket.consume(500);
290 let level = bucket.fill_level();
291 assert!(level >= 0.0 && level < 0.01, "Expected ~0.0, got {}", level);
292 }
293
294 #[test]
295 fn test_clone() {
296 let bucket = TokenBucket::new(1000, 100);
297 let cloned = bucket.clone();
298
299 assert_eq!(cloned.capacity(), 1000);
300 assert_eq!(cloned.refill_rate(), 100);
301 }
302
303 #[test]
304 fn test_no_overflow() {
305 let mut bucket = TokenBucket::new(1000, 1000);
306
307 std::thread::sleep(Duration::from_millis(50));
309
310 let tokens = bucket.tokens();
312 assert!(tokens <= 1000);
313 }
314}