ruvector_security/
rate_limit.rs

1//! Rate limiting using token bucket algorithm
2//!
3//! Provides protection against API abuse and DoS attacks.
4
5use crate::error::{SecurityError, SecurityResult};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11/// Rate limit configuration
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct RateLimitConfig {
14    /// Requests per second for read operations
15    pub read_rps: u32,
16    /// Requests per second for write operations
17    pub write_rps: u32,
18    /// Requests per second for file operations
19    pub file_rps: u32,
20    /// Burst size multiplier (allows temporary bursts)
21    pub burst_multiplier: u32,
22    /// Enable per-IP rate limiting
23    pub per_ip: bool,
24    /// Window size in seconds for rate tracking
25    pub window_secs: u64,
26}
27
28impl Default for RateLimitConfig {
29    fn default() -> Self {
30        Self {
31            read_rps: 1000,
32            write_rps: 100,
33            file_rps: 10,
34            burst_multiplier: 2,
35            per_ip: true,
36            window_secs: 60,
37        }
38    }
39}
40
41/// Operation type for rate limiting
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub enum OperationType {
44    /// Read operations (search, get)
45    Read,
46    /// Write operations (insert, update, delete)
47    Write,
48    /// File operations (backup, restore)
49    File,
50}
51
52/// Token bucket for rate limiting
53#[derive(Debug)]
54struct TokenBucket {
55    /// Current tokens available
56    tokens: f64,
57    /// Maximum tokens (burst capacity)
58    max_tokens: f64,
59    /// Tokens added per second
60    refill_rate: f64,
61    /// Last refill time
62    last_refill: Instant,
63}
64
65impl TokenBucket {
66    fn new(tokens_per_second: u32, burst_multiplier: u32) -> Self {
67        let max_tokens = (tokens_per_second * burst_multiplier) as f64;
68        Self {
69            tokens: max_tokens,
70            max_tokens,
71            refill_rate: tokens_per_second as f64,
72            last_refill: Instant::now(),
73        }
74    }
75
76    fn try_acquire(&mut self, tokens: f64) -> Result<(), Duration> {
77        self.refill();
78
79        if self.tokens >= tokens {
80            self.tokens -= tokens;
81            Ok(())
82        } else {
83            // Calculate wait time
84            let needed = tokens - self.tokens;
85            let wait_secs = needed / self.refill_rate;
86            Err(Duration::from_secs_f64(wait_secs))
87        }
88    }
89
90    fn refill(&mut self) {
91        let now = Instant::now();
92        let elapsed = now.duration_since(self.last_refill);
93        let new_tokens = elapsed.as_secs_f64() * self.refill_rate;
94
95        self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
96        self.last_refill = now;
97    }
98
99    fn tokens_remaining(&mut self) -> u32 {
100        self.refill();
101        self.tokens as u32
102    }
103}
104
105/// Rate limiter state
106struct RateLimiterState {
107    /// Global buckets by operation type
108    global_buckets: HashMap<OperationType, TokenBucket>,
109    /// Per-IP buckets: IP -> (operation type -> bucket)
110    ip_buckets: HashMap<String, HashMap<OperationType, TokenBucket>>,
111    /// Configuration
112    config: RateLimitConfig,
113    /// Cleanup interval tracking
114    last_cleanup: Instant,
115}
116
117impl RateLimiterState {
118    fn new(config: RateLimitConfig) -> Self {
119        let mut global_buckets = HashMap::new();
120
121        global_buckets.insert(
122            OperationType::Read,
123            TokenBucket::new(config.read_rps, config.burst_multiplier),
124        );
125        global_buckets.insert(
126            OperationType::Write,
127            TokenBucket::new(config.write_rps, config.burst_multiplier),
128        );
129        global_buckets.insert(
130            OperationType::File,
131            TokenBucket::new(config.file_rps, config.burst_multiplier),
132        );
133
134        Self {
135            global_buckets,
136            ip_buckets: HashMap::new(),
137            config,
138            last_cleanup: Instant::now(),
139        }
140    }
141
142    fn get_or_create_ip_bucket(&mut self, ip: &str, op: OperationType) -> &mut TokenBucket {
143        let config = &self.config;
144        let ip_map = self.ip_buckets.entry(ip.to_string()).or_default();
145
146        ip_map.entry(op).or_insert_with(|| {
147            let rps = match op {
148                OperationType::Read => config.read_rps,
149                OperationType::Write => config.write_rps,
150                OperationType::File => config.file_rps,
151            };
152            TokenBucket::new(rps, config.burst_multiplier)
153        })
154    }
155
156    fn cleanup_stale_entries(&mut self) {
157        let now = Instant::now();
158        let window = Duration::from_secs(self.config.window_secs * 2);
159
160        if now.duration_since(self.last_cleanup) > window {
161            // Remove IP entries that haven't been used recently
162            self.ip_buckets.retain(|_, buckets| {
163                buckets
164                    .values()
165                    .any(|b| now.duration_since(b.last_refill) < window)
166            });
167            self.last_cleanup = now;
168        }
169    }
170}
171
172/// Thread-safe rate limiter
173#[derive(Clone)]
174pub struct RateLimiter {
175    state: Arc<RwLock<RateLimiterState>>,
176    enabled: bool,
177}
178
179impl RateLimiter {
180    /// Create a new rate limiter with configuration
181    pub fn new(config: RateLimitConfig) -> Self {
182        Self {
183            state: Arc::new(RwLock::new(RateLimiterState::new(config))),
184            enabled: true,
185        }
186    }
187
188    /// Create a disabled rate limiter (passes all requests)
189    pub fn disabled() -> Self {
190        Self {
191            state: Arc::new(RwLock::new(RateLimiterState::new(RateLimitConfig::default()))),
192            enabled: false,
193        }
194    }
195
196    /// Check rate limit for an operation
197    ///
198    /// # Arguments
199    /// * `op` - Operation type
200    /// * `ip` - Optional IP address for per-IP limiting
201    ///
202    /// # Returns
203    /// * `Ok(())` if request is allowed
204    /// * `Err(SecurityError::RateLimitExceeded)` if rate limited
205    pub async fn check(&self, op: OperationType, ip: Option<&str>) -> SecurityResult<()> {
206        if !self.enabled {
207            return Ok(());
208        }
209
210        let mut state = self.state.write().await;
211
212        // Cleanup stale entries periodically
213        state.cleanup_stale_entries();
214
215        // Check global limit first
216        if let Some(bucket) = state.global_buckets.get_mut(&op) {
217            if let Err(wait) = bucket.try_acquire(1.0) {
218                return Err(SecurityError::RateLimitExceeded {
219                    retry_after_secs: wait.as_secs().max(1),
220                });
221            }
222        }
223
224        // Check per-IP limit if enabled and IP provided
225        if state.config.per_ip {
226            if let Some(ip) = ip {
227                let bucket = state.get_or_create_ip_bucket(ip, op);
228                if let Err(wait) = bucket.try_acquire(1.0) {
229                    return Err(SecurityError::RateLimitExceeded {
230                        retry_after_secs: wait.as_secs().max(1),
231                    });
232                }
233            }
234        }
235
236        Ok(())
237    }
238
239    /// Get remaining tokens for rate limit headers
240    pub async fn remaining(&self, op: OperationType, ip: Option<&str>) -> u32 {
241        if !self.enabled {
242            return u32::MAX;
243        }
244
245        let mut state = self.state.write().await;
246
247        let global_remaining = state
248            .global_buckets
249            .get_mut(&op)
250            .map(|b| b.tokens_remaining())
251            .unwrap_or(u32::MAX);
252
253        if let Some(ip) = ip {
254            if state.config.per_ip {
255                let ip_remaining = state.get_or_create_ip_bucket(ip, op).tokens_remaining();
256                return global_remaining.min(ip_remaining);
257            }
258        }
259
260        global_remaining
261    }
262
263    /// Get rate limit for operation
264    pub async fn limit(&self, op: OperationType) -> u32 {
265        let state = self.state.read().await;
266        match op {
267            OperationType::Read => state.config.read_rps,
268            OperationType::Write => state.config.write_rps,
269            OperationType::File => state.config.file_rps,
270        }
271    }
272}
273
274impl Default for RateLimiter {
275    fn default() -> Self {
276        Self::new(RateLimitConfig::default())
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[tokio::test]
285    async fn test_rate_limit_allows_within_limit() {
286        let config = RateLimitConfig {
287            read_rps: 10,
288            burst_multiplier: 1,
289            per_ip: false,
290            ..Default::default()
291        };
292        let limiter = RateLimiter::new(config);
293
294        // Should allow 10 requests
295        for _ in 0..10 {
296            assert!(limiter.check(OperationType::Read, None).await.is_ok());
297        }
298    }
299
300    #[tokio::test]
301    async fn test_rate_limit_blocks_excess() {
302        let config = RateLimitConfig {
303            read_rps: 5,
304            burst_multiplier: 1,
305            per_ip: false,
306            ..Default::default()
307        };
308        let limiter = RateLimiter::new(config);
309
310        // Use up all tokens
311        for _ in 0..5 {
312            let _ = limiter.check(OperationType::Read, None).await;
313        }
314
315        // Next request should be rate limited
316        let result = limiter.check(OperationType::Read, None).await;
317        assert!(matches!(result, Err(SecurityError::RateLimitExceeded { .. })));
318    }
319
320    #[tokio::test]
321    async fn test_per_ip_limiting() {
322        // Test that per-IP limiting is enabled and creates separate buckets
323        let config = RateLimitConfig {
324            read_rps: 10,
325            burst_multiplier: 1,
326            per_ip: true,
327            ..Default::default()
328        };
329        let limiter = RateLimiter::new(config);
330
331        // Verify per_ip is enabled and creates separate bucket entries
332        // Both IPs should be able to use some tokens
333        assert!(limiter
334            .check(OperationType::Read, Some("192.168.1.1"))
335            .await
336            .is_ok());
337        assert!(limiter
338            .check(OperationType::Read, Some("192.168.1.2"))
339            .await
340            .is_ok());
341
342        // Verify remaining tokens are tracked independently
343        let remaining_ip1 = limiter.remaining(OperationType::Read, Some("192.168.1.1")).await;
344        let remaining_ip2 = limiter.remaining(OperationType::Read, Some("192.168.1.2")).await;
345
346        // Both should have used 1 token each from their per-IP bucket
347        // (plus global bucket consumption)
348        assert!(remaining_ip1 > 0);
349        assert!(remaining_ip2 > 0);
350    }
351
352    #[tokio::test]
353    async fn test_disabled_limiter() {
354        let limiter = RateLimiter::disabled();
355
356        // Should allow unlimited requests
357        for _ in 0..1000 {
358            assert!(limiter.check(OperationType::Read, None).await.is_ok());
359        }
360    }
361}