Skip to main content

rust_serv/throttle/
limiter.rs

1//! Throttle limiter
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7
8use super::config::ThrottleConfig;
9use super::token_bucket::TokenBucket;
10
11/// Result of a throttle check
12#[derive(Debug, Clone, PartialEq)]
13pub enum ThrottleResult {
14    /// Allowed, with remaining bandwidth
15    Allowed { remaining: u64 },
16    /// Throttled, need to wait
17    Throttled { wait_ms: u64 },
18    /// No limit configured
19    Unlimited,
20}
21
22impl ThrottleResult {
23    /// Check if the request is allowed
24    pub fn is_allowed(&self) -> bool {
25        matches!(self, ThrottleResult::Allowed { .. } | ThrottleResult::Unlimited)
26    }
27}
28
29/// Throttle limiter manages bandwidth limits
30#[derive(Debug)]
31pub struct ThrottleLimiter {
32    config: ThrottleConfig,
33    global_bucket: Arc<RwLock<TokenBucket>>,
34    ip_buckets: Arc<RwLock<HashMap<String, TokenBucket>>>,
35}
36
37impl ThrottleLimiter {
38    /// Create a new throttle limiter
39    pub fn new(config: ThrottleConfig) -> Self {
40        let global_bucket = if config.has_global_limit() {
41            TokenBucket::new(
42                config.bucket_capacity,
43                config.global_limit,
44            )
45        } else {
46            TokenBucket::new(0, 0)
47        };
48        
49        Self {
50            config,
51            global_bucket: Arc::new(RwLock::new(global_bucket)),
52            ip_buckets: Arc::new(RwLock::new(HashMap::new())),
53        }
54    }
55
56    /// Get configuration
57    pub fn config(&self) -> &ThrottleConfig {
58        &self.config
59    }
60
61    /// Check if a request should be throttled
62    pub async fn check(&self, ip: &str, bytes: u64) -> ThrottleResult {
63        if !self.config.is_active() {
64            return ThrottleResult::Unlimited;
65        }
66        
67        // Check global limit first
68        if self.config.has_global_limit() {
69            let mut bucket = self.global_bucket.write().await;
70            if !bucket.try_consume(bytes) {
71                let wait = bucket.wait_time(bytes);
72                return ThrottleResult::Throttled {
73                    wait_ms: wait.as_millis() as u64,
74                };
75            }
76        }
77        
78        // Check per-IP limit
79        if self.config.has_per_ip_limit() {
80            let mut buckets = self.ip_buckets.write().await;
81            let bucket = buckets.entry(ip.to_string()).or_insert_with(|| {
82                TokenBucket::new(
83                    self.config.bucket_capacity,
84                    self.config.per_ip_limit,
85                )
86            });
87            
88            if !bucket.try_consume(bytes) {
89                let wait = bucket.wait_time(bytes);
90                return ThrottleResult::Throttled {
91                    wait_ms: wait.as_millis() as u64,
92                };
93            }
94        }
95        
96        ThrottleResult::Allowed { remaining: 0 }
97    }
98
99    /// Consume bandwidth tokens
100    /// Returns the number of bytes actually consumed
101    pub async fn consume(&self, ip: &str, bytes: u64) -> u64 {
102        if !self.config.is_active() {
103            return bytes;
104        }
105        
106        let mut total_consumed = bytes;
107        
108        // Consume from global bucket
109        if self.config.has_global_limit() {
110            let mut bucket = self.global_bucket.write().await;
111            total_consumed = total_consumed.min(bucket.consume(bytes));
112        }
113        
114        // Consume from per-IP bucket
115        if self.config.has_per_ip_limit() && total_consumed > 0 {
116            let mut buckets = self.ip_buckets.write().await;
117            let bucket = buckets.entry(ip.to_string()).or_insert_with(|| {
118                TokenBucket::new(
119                    self.config.bucket_capacity,
120                    self.config.per_ip_limit,
121                )
122            });
123            total_consumed = total_consumed.min(bucket.consume(bytes));
124        }
125        
126        total_consumed
127    }
128
129    /// Get wait time for a request
130    pub async fn wait_time(&self, ip: &str, bytes: u64) -> Duration {
131        if !self.config.is_active() {
132            return Duration::ZERO;
133        }
134        
135        let mut max_wait = Duration::ZERO;
136        
137        if self.config.has_global_limit() {
138            let mut bucket = self.global_bucket.write().await;
139            max_wait = max_wait.max(bucket.wait_time(bytes));
140        }
141        
142        if self.config.has_per_ip_limit() {
143            let buckets = self.ip_buckets.write().await;
144            if let Some(bucket) = buckets.get(ip) {
145                let mut bucket = bucket.clone();
146                max_wait = max_wait.max(bucket.wait_time(bytes));
147            }
148        }
149        
150        max_wait
151    }
152
153    /// Reset all buckets
154    pub async fn reset(&self) {
155        if self.config.has_global_limit() {
156            let mut bucket = self.global_bucket.write().await;
157            bucket.reset();
158        }
159        
160        let mut buckets = self.ip_buckets.write().await;
161        for bucket in buckets.values_mut() {
162            bucket.reset();
163        }
164    }
165
166    /// Clear per-IP buckets
167    pub async fn clear_ip_buckets(&self) {
168        let mut buckets = self.ip_buckets.write().await;
169        buckets.clear();
170    }
171
172    /// Get number of tracked IPs
173    pub async fn tracked_ip_count(&self) -> usize {
174        let buckets = self.ip_buckets.read().await;
175        buckets.len()
176    }
177
178    /// Remove an IP from tracking
179    pub async fn remove_ip(&self, ip: &str) -> bool {
180        let mut buckets = self.ip_buckets.write().await;
181        buckets.remove(ip).is_some()
182    }
183
184    /// Update configuration (creates new buckets)
185    pub fn update_config(&mut self, config: ThrottleConfig) {
186        self.config = config;
187        
188        // Recreate global bucket
189        let global_bucket = if self.config.has_global_limit() {
190            TokenBucket::new(
191                self.config.bucket_capacity,
192                self.config.global_limit,
193            )
194        } else {
195            TokenBucket::new(0, 0)
196        };
197        
198        self.global_bucket = Arc::new(RwLock::new(global_bucket));
199    }
200
201    /// Get global bucket tokens
202    pub async fn global_tokens(&self) -> u64 {
203        if !self.config.has_global_limit() {
204            return u64::MAX;
205        }
206        let mut bucket = self.global_bucket.write().await;
207        bucket.tokens()
208    }
209
210    /// Get per-IP bucket tokens
211    pub async fn ip_tokens(&self, ip: &str) -> u64 {
212        if !self.config.has_per_ip_limit() {
213            return u64::MAX;
214        }
215        let buckets = self.ip_buckets.read().await;
216        if let Some(bucket) = buckets.get(ip) {
217            let mut bucket = bucket.clone();
218            bucket.tokens()
219        } else {
220            self.config.bucket_capacity
221        }
222    }
223}
224
225impl Clone for ThrottleLimiter {
226    fn clone(&self) -> Self {
227        Self {
228            config: self.config.clone(),
229            global_bucket: Arc::clone(&self.global_bucket),
230            ip_buckets: Arc::clone(&self.ip_buckets),
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_limiter_creation() {
241        let config = ThrottleConfig::new();
242        let limiter = ThrottleLimiter::new(config);
243        
244        assert!(!limiter.config().is_active());
245    }
246
247    #[tokio::test]
248    async fn test_check_unlimited() {
249        let config = ThrottleConfig::new(); // Not enabled
250        let limiter = ThrottleLimiter::new(config);
251        
252        let result = limiter.check("127.0.0.1", 1000).await;
253        assert_eq!(result, ThrottleResult::Unlimited);
254        assert!(result.is_allowed());
255    }
256
257    #[tokio::test]
258    async fn test_check_global_limit() {
259        let config = ThrottleConfig::new()
260            .enable()
261            .with_global_limit(1000)
262            .with_bucket_capacity(1000);
263        
264        let limiter = ThrottleLimiter::new(config);
265        
266        // First request should be allowed
267        let result = limiter.check("127.0.0.1", 500).await;
268        assert!(result.is_allowed());
269        
270        // Second request should be allowed
271        let result = limiter.check("127.0.0.1", 500).await;
272        assert!(result.is_allowed());
273        
274        // Third request should be throttled
275        let result = limiter.check("127.0.0.1", 500).await;
276        assert!(!result.is_allowed());
277    }
278
279    #[tokio::test]
280    async fn test_check_per_ip_limit() {
281        let config = ThrottleConfig::new()
282            .enable()
283            .with_per_ip_limit(500)
284            .with_bucket_capacity(500);
285        
286        let limiter = ThrottleLimiter::new(config);
287        
288        // IP1 should be limited
289        let result = limiter.check("127.0.0.1", 300).await;
290        assert!(result.is_allowed());
291        
292        let result = limiter.check("127.0.0.1", 300).await;
293        assert!(!result.is_allowed());
294        
295        // IP2 should still have tokens
296        let result = limiter.check("127.0.0.2", 300).await;
297        assert!(result.is_allowed());
298    }
299
300    #[tokio::test]
301    async fn test_consume_unlimited() {
302        let config = ThrottleConfig::new();
303        let limiter = ThrottleLimiter::new(config);
304        
305        let consumed = limiter.consume("127.0.0.1", 1000).await;
306        assert_eq!(consumed, 1000);
307    }
308
309    #[tokio::test]
310    async fn test_consume_partial() {
311        let config = ThrottleConfig::new()
312            .enable()
313            .with_global_limit(500)
314            .with_bucket_capacity(500);
315        
316        let limiter = ThrottleLimiter::new(config);
317        
318        let consumed = limiter.consume("127.0.0.1", 300).await;
319        assert_eq!(consumed, 300);
320        
321        let consumed = limiter.consume("127.0.0.1", 300).await;
322        assert_eq!(consumed, 200); // Only 200 left
323    }
324
325    #[tokio::test]
326    async fn test_wait_time_zero() {
327        let config = ThrottleConfig::new()
328            .enable()
329            .with_global_limit(1000)
330            .with_bucket_capacity(1000);
331        
332        let limiter = ThrottleLimiter::new(config);
333        
334        let wait = limiter.wait_time("127.0.0.1", 500).await;
335        assert_eq!(wait, Duration::ZERO);
336    }
337
338    #[tokio::test]
339    async fn test_reset() {
340        let config = ThrottleConfig::new()
341            .enable()
342            .with_global_limit(1000)
343            .with_bucket_capacity(1000);
344        
345        let limiter = ThrottleLimiter::new(config);
346        
347        limiter.check("127.0.0.1", 1000).await;
348        
349        limiter.reset().await;
350        
351        let tokens = limiter.global_tokens().await;
352        assert_eq!(tokens, 1000);
353    }
354
355    #[tokio::test]
356    async fn test_clear_ip_buckets() {
357        let config = ThrottleConfig::new()
358            .enable()
359            .with_per_ip_limit(1000)
360            .with_bucket_capacity(1000);
361        
362        let limiter = ThrottleLimiter::new(config);
363        
364        limiter.check("127.0.0.1", 500).await;
365        limiter.check("127.0.0.2", 500).await;
366        
367        assert_eq!(limiter.tracked_ip_count().await, 2);
368        
369        limiter.clear_ip_buckets().await;
370        
371        assert_eq!(limiter.tracked_ip_count().await, 0);
372    }
373
374    #[tokio::test]
375    async fn test_remove_ip() {
376        let config = ThrottleConfig::new()
377            .enable()
378            .with_per_ip_limit(1000)
379            .with_bucket_capacity(1000);
380        
381        let limiter = ThrottleLimiter::new(config);
382        
383        limiter.check("127.0.0.1", 500).await;
384        limiter.check("127.0.0.2", 500).await;
385        
386        assert!(limiter.remove_ip("127.0.0.1").await);
387        assert_eq!(limiter.tracked_ip_count().await, 1);
388        
389        assert!(!limiter.remove_ip("127.0.0.1").await); // Already removed
390    }
391
392    #[tokio::test]
393    async fn test_update_config() {
394        let config = ThrottleConfig::new()
395            .enable()
396            .with_global_limit(1000)
397            .with_bucket_capacity(1000);
398        
399        let mut limiter = ThrottleLimiter::new(config);
400        
401        limiter.check("127.0.0.1", 500).await;
402        
403        let new_config = ThrottleConfig::new()
404            .enable()
405            .with_global_limit(2000)
406            .with_bucket_capacity(2000);
407        
408        limiter.update_config(new_config);
409        
410        // New bucket should be full
411        let tokens = limiter.global_tokens().await;
412        assert_eq!(tokens, 2000);
413    }
414
415    #[tokio::test]
416    async fn test_global_tokens() {
417        let config = ThrottleConfig::new()
418            .enable()
419            .with_global_limit(1000)
420            .with_bucket_capacity(1000);
421        
422        let limiter = ThrottleLimiter::new(config);
423        
424        assert_eq!(limiter.global_tokens().await, 1000);
425        
426        limiter.consume("127.0.0.1", 300).await;
427        
428        assert_eq!(limiter.global_tokens().await, 700);
429    }
430
431    #[tokio::test]
432    async fn test_ip_tokens() {
433        let config = ThrottleConfig::new()
434            .enable()
435            .with_per_ip_limit(1000)
436            .with_bucket_capacity(1000);
437        
438        let limiter = ThrottleLimiter::new(config);
439        
440        // New IP should have full bucket
441        assert_eq!(limiter.ip_tokens("127.0.0.1").await, 1000);
442        
443        limiter.consume("127.0.0.1", 300).await;
444        
445        assert_eq!(limiter.ip_tokens("127.0.0.1").await, 700);
446    }
447
448    #[tokio::test]
449    async fn test_clone() {
450        let config = ThrottleConfig::new()
451            .enable()
452            .with_global_limit(1000)
453            .with_bucket_capacity(1000);
454        
455        let limiter = ThrottleLimiter::new(config);
456        let cloned = limiter.clone();
457        
458        // Both should share the same buckets
459        limiter.consume("127.0.0.1", 500).await;
460        assert_eq!(cloned.global_tokens().await, 500);
461    }
462
463    #[test]
464    fn test_throttle_result_is_allowed() {
465        assert!(ThrottleResult::Allowed { remaining: 100 }.is_allowed());
466        assert!(ThrottleResult::Unlimited.is_allowed());
467        assert!(!ThrottleResult::Throttled { wait_ms: 100 }.is_allowed());
468    }
469}