Skip to main content

rust_serv/throttle/
token_bucket.rs

1//! Token bucket implementation for rate limiting
2
3use std::time::{Duration, Instant};
4
5/// Token bucket for rate limiting
6#[derive(Debug)]
7pub struct TokenBucket {
8    /// Maximum tokens in bucket
9    capacity: u64,
10    /// Current tokens
11    tokens: f64,
12    /// Tokens added per second
13    refill_rate: f64,
14    /// Last refill time
15    last_refill: Instant,
16}
17
18impl TokenBucket {
19    /// Create a new token bucket
20    pub fn new(capacity: u64, refill_rate: u64) -> Self {
21        Self {
22            capacity,
23            tokens: capacity as f64,
24            refill_rate: refill_rate as f64,
25            last_refill: Instant::now(),
26        }
27    }
28
29    /// Get bucket capacity
30    pub fn capacity(&self) -> u64 {
31        self.capacity
32    }
33
34    /// Get refill rate (tokens per second)
35    pub fn refill_rate(&self) -> u64 {
36        self.refill_rate as u64
37    }
38
39    /// Get current token count
40    pub fn tokens(&mut self) -> u64 {
41        self.refill();
42        self.tokens as u64
43    }
44
45    /// Refill tokens based on elapsed time
46    fn refill(&mut self) {
47        let now = Instant::now();
48        let elapsed = now.duration_since(self.last_refill);
49        let tokens_to_add = elapsed.as_secs_f64() * self.refill_rate;
50        
51        self.tokens = (self.tokens + tokens_to_add).min(self.capacity as f64);
52        self.last_refill = now;
53    }
54
55    /// Try to consume tokens
56    /// Returns the number of tokens actually consumed
57    pub fn consume(&mut self, requested: u64) -> u64 {
58        self.refill();
59        
60        if self.tokens >= requested as f64 {
61            self.tokens -= requested as f64;
62            requested
63        } else {
64            let consumed = self.tokens as u64;
65            self.tokens = 0.0;
66            consumed
67        }
68    }
69
70    /// Try to consume exactly the requested amount
71    /// Returns true if successful, false if not enough tokens
72    pub fn try_consume(&mut self, amount: u64) -> bool {
73        self.refill();
74        
75        if self.tokens >= amount as f64 {
76            self.tokens -= amount as f64;
77            true
78        } else {
79            false
80        }
81    }
82
83    /// Wait until enough tokens are available (simulated)
84    /// Returns the duration to wait
85    pub fn wait_time(&mut self, amount: u64) -> Duration {
86        self.refill();
87        
88        if self.tokens >= amount as f64 {
89            return Duration::ZERO;
90        }
91        
92        let tokens_needed = amount as f64 - self.tokens;
93        let wait_secs = tokens_needed / self.refill_rate;
94        
95        Duration::from_secs_f64(wait_secs)
96    }
97
98    /// Reset bucket to full capacity
99    pub fn reset(&mut self) {
100        self.tokens = self.capacity as f64;
101        self.last_refill = Instant::now();
102    }
103
104    /// Update refill rate
105    pub fn set_refill_rate(&mut self, rate: u64) {
106        self.refill_rate = rate as f64;
107    }
108
109    /// Update capacity
110    pub fn set_capacity(&mut self, capacity: u64) {
111        self.capacity = capacity;
112        self.tokens = self.tokens.min(capacity as f64);
113    }
114
115    /// Check if bucket has enough tokens
116    pub fn has_tokens(&mut self, amount: u64) -> bool {
117        self.refill();
118        self.tokens >= amount as f64
119    }
120
121    /// Get fill percentage (0.0 to 1.0)
122    pub fn fill_level(&mut self) -> f64 {
123        self.refill();
124        self.tokens / self.capacity as f64
125    }
126}
127
128impl Clone for TokenBucket {
129    fn clone(&self) -> Self {
130        Self {
131            capacity: self.capacity,
132            tokens: self.tokens,
133            refill_rate: self.refill_rate,
134            last_refill: Instant::now(),
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_bucket_creation() {
145        let bucket = TokenBucket::new(1000, 100);
146        assert_eq!(bucket.capacity(), 1000);
147        assert_eq!(bucket.refill_rate(), 100);
148    }
149
150    #[test]
151    fn test_bucket_starts_full() {
152        let mut bucket = TokenBucket::new(1000, 100);
153        assert_eq!(bucket.tokens(), 1000);
154    }
155
156    #[test]
157    fn test_consume_success() {
158        let mut bucket = TokenBucket::new(1000, 100);
159        
160        let consumed = bucket.consume(500);
161        assert_eq!(consumed, 500);
162        assert_eq!(bucket.tokens(), 500);
163    }
164
165    #[test]
166    fn test_consume_partial() {
167        let mut bucket = TokenBucket::new(100, 10);
168        
169        // Try to consume more than available
170        let consumed = bucket.consume(150);
171        assert_eq!(consumed, 100);
172        assert_eq!(bucket.tokens(), 0);
173    }
174
175    #[test]
176    fn test_consume_empty_bucket() {
177        let mut bucket = TokenBucket::new(100, 10);
178        
179        bucket.consume(100);
180        assert_eq!(bucket.tokens(), 0);
181        
182        let consumed = bucket.consume(50);
183        assert_eq!(consumed, 0);
184    }
185
186    #[test]
187    fn test_try_consume_success() {
188        let mut bucket = TokenBucket::new(1000, 100);
189        
190        assert!(bucket.try_consume(500));
191        assert_eq!(bucket.tokens(), 500);
192    }
193
194    #[test]
195    fn test_try_consume_fail() {
196        let mut bucket = TokenBucket::new(100, 10);
197        
198        assert!(!bucket.try_consume(150));
199        // Should not have consumed anything
200        assert_eq!(bucket.tokens(), 100);
201    }
202
203    #[test]
204    fn test_refill() {
205        let mut bucket = TokenBucket::new(1000, 1000); // 1000 tokens/sec
206        
207        bucket.consume(500);
208        assert_eq!(bucket.tokens(), 500);
209        
210        // Wait and check refill
211        std::thread::sleep(Duration::from_millis(100));
212        
213        // Should have refilled approximately 100 tokens (±50% tolerance for timing)
214        let tokens = bucket.tokens();
215        assert!(tokens > 500 && tokens < 800, "Expected ~600 tokens, got {}", tokens);
216    }
217
218    #[test]
219    fn test_wait_time_zero() {
220        let mut bucket = TokenBucket::new(1000, 100);
221        
222        let wait = bucket.wait_time(500);
223        assert_eq!(wait, Duration::ZERO);
224    }
225
226    #[test]
227    fn test_wait_time_needed() {
228        let mut bucket = TokenBucket::new(100, 100); // 100 tokens/sec
229        
230        // Consume all tokens
231        bucket.consume(100);
232        
233        // Wait time for 100 tokens = 1 second
234        let wait = bucket.wait_time(100);
235        assert!(wait >= Duration::from_millis(900) && wait <= Duration::from_millis(1100));
236    }
237
238    #[test]
239    fn test_reset() {
240        let mut bucket = TokenBucket::new(1000, 100);
241        
242        bucket.consume(500);
243        assert_eq!(bucket.tokens(), 500);
244        
245        bucket.reset();
246        assert_eq!(bucket.tokens(), 1000);
247    }
248
249    #[test]
250    fn test_set_refill_rate() {
251        let mut bucket = TokenBucket::new(1000, 100);
252        bucket.set_refill_rate(200);
253        
254        assert_eq!(bucket.refill_rate(), 200);
255    }
256
257    #[test]
258    fn test_set_capacity() {
259        let mut bucket = TokenBucket::new(1000, 100);
260        bucket.consume(500);
261        
262        bucket.set_capacity(300);
263        
264        assert_eq!(bucket.capacity(), 300);
265        // Tokens should be capped at new capacity
266        assert_eq!(bucket.tokens(), 300);
267    }
268
269    #[test]
270    fn test_has_tokens() {
271        let mut bucket = TokenBucket::new(1000, 100);
272        
273        assert!(bucket.has_tokens(500));
274        assert!(bucket.has_tokens(1000));
275        assert!(!bucket.has_tokens(1500));
276    }
277
278    #[test]
279    fn test_fill_level() {
280        let mut bucket = TokenBucket::new(1000, 100);
281        
282        let level = bucket.fill_level();
283        assert!(level > 0.99 && level <= 1.0, "Expected ~1.0, got {}", level);
284        
285        bucket.consume(500);
286        let level = bucket.fill_level();
287        assert!(level > 0.49 && level < 0.51, "Expected ~0.5, got {}", level);
288        
289        bucket.consume(500);
290        let level = bucket.fill_level();
291        assert!(level >= 0.0 && level < 0.01, "Expected ~0.0, got {}", level);
292    }
293
294    #[test]
295    fn test_clone() {
296        let bucket = TokenBucket::new(1000, 100);
297        let cloned = bucket.clone();
298        
299        assert_eq!(cloned.capacity(), 1000);
300        assert_eq!(cloned.refill_rate(), 100);
301    }
302
303    #[test]
304    fn test_no_overflow() {
305        let mut bucket = TokenBucket::new(1000, 1000);
306        
307        // Wait for potential refill
308        std::thread::sleep(Duration::from_millis(50));
309        
310        // Tokens should not exceed capacity
311        let tokens = bucket.tokens();
312        assert!(tokens <= 1000);
313    }
314}