1use crate::error::{SecurityError, SecurityResult};
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::RwLock;
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct RateLimitConfig {
14 pub read_rps: u32,
16 pub write_rps: u32,
18 pub file_rps: u32,
20 pub burst_multiplier: u32,
22 pub per_ip: bool,
24 pub window_secs: u64,
26}
27
28impl Default for RateLimitConfig {
29 fn default() -> Self {
30 Self {
31 read_rps: 1000,
32 write_rps: 100,
33 file_rps: 10,
34 burst_multiplier: 2,
35 per_ip: true,
36 window_secs: 60,
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub enum OperationType {
44 Read,
46 Write,
48 File,
50}
51
52#[derive(Debug)]
54struct TokenBucket {
55 tokens: f64,
57 max_tokens: f64,
59 refill_rate: f64,
61 last_refill: Instant,
63}
64
65impl TokenBucket {
66 fn new(tokens_per_second: u32, burst_multiplier: u32) -> Self {
67 let max_tokens = (tokens_per_second * burst_multiplier) as f64;
68 Self {
69 tokens: max_tokens,
70 max_tokens,
71 refill_rate: tokens_per_second as f64,
72 last_refill: Instant::now(),
73 }
74 }
75
76 fn try_acquire(&mut self, tokens: f64) -> Result<(), Duration> {
77 self.refill();
78
79 if self.tokens >= tokens {
80 self.tokens -= tokens;
81 Ok(())
82 } else {
83 let needed = tokens - self.tokens;
85 let wait_secs = needed / self.refill_rate;
86 Err(Duration::from_secs_f64(wait_secs))
87 }
88 }
89
90 fn refill(&mut self) {
91 let now = Instant::now();
92 let elapsed = now.duration_since(self.last_refill);
93 let new_tokens = elapsed.as_secs_f64() * self.refill_rate;
94
95 self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
96 self.last_refill = now;
97 }
98
99 fn tokens_remaining(&mut self) -> u32 {
100 self.refill();
101 self.tokens as u32
102 }
103}
104
105struct RateLimiterState {
107 global_buckets: HashMap<OperationType, TokenBucket>,
109 ip_buckets: HashMap<String, HashMap<OperationType, TokenBucket>>,
111 config: RateLimitConfig,
113 last_cleanup: Instant,
115}
116
117impl RateLimiterState {
118 fn new(config: RateLimitConfig) -> Self {
119 let mut global_buckets = HashMap::new();
120
121 global_buckets.insert(
122 OperationType::Read,
123 TokenBucket::new(config.read_rps, config.burst_multiplier),
124 );
125 global_buckets.insert(
126 OperationType::Write,
127 TokenBucket::new(config.write_rps, config.burst_multiplier),
128 );
129 global_buckets.insert(
130 OperationType::File,
131 TokenBucket::new(config.file_rps, config.burst_multiplier),
132 );
133
134 Self {
135 global_buckets,
136 ip_buckets: HashMap::new(),
137 config,
138 last_cleanup: Instant::now(),
139 }
140 }
141
142 fn get_or_create_ip_bucket(&mut self, ip: &str, op: OperationType) -> &mut TokenBucket {
143 let config = &self.config;
144 let ip_map = self.ip_buckets.entry(ip.to_string()).or_default();
145
146 ip_map.entry(op).or_insert_with(|| {
147 let rps = match op {
148 OperationType::Read => config.read_rps,
149 OperationType::Write => config.write_rps,
150 OperationType::File => config.file_rps,
151 };
152 TokenBucket::new(rps, config.burst_multiplier)
153 })
154 }
155
156 fn cleanup_stale_entries(&mut self) {
157 let now = Instant::now();
158 let window = Duration::from_secs(self.config.window_secs * 2);
159
160 if now.duration_since(self.last_cleanup) > window {
161 self.ip_buckets.retain(|_, buckets| {
163 buckets
164 .values()
165 .any(|b| now.duration_since(b.last_refill) < window)
166 });
167 self.last_cleanup = now;
168 }
169 }
170}
171
172#[derive(Clone)]
174pub struct RateLimiter {
175 state: Arc<RwLock<RateLimiterState>>,
176 enabled: bool,
177}
178
179impl RateLimiter {
180 pub fn new(config: RateLimitConfig) -> Self {
182 Self {
183 state: Arc::new(RwLock::new(RateLimiterState::new(config))),
184 enabled: true,
185 }
186 }
187
188 pub fn disabled() -> Self {
190 Self {
191 state: Arc::new(RwLock::new(RateLimiterState::new(RateLimitConfig::default()))),
192 enabled: false,
193 }
194 }
195
196 pub async fn check(&self, op: OperationType, ip: Option<&str>) -> SecurityResult<()> {
206 if !self.enabled {
207 return Ok(());
208 }
209
210 let mut state = self.state.write().await;
211
212 state.cleanup_stale_entries();
214
215 if let Some(bucket) = state.global_buckets.get_mut(&op) {
217 if let Err(wait) = bucket.try_acquire(1.0) {
218 return Err(SecurityError::RateLimitExceeded {
219 retry_after_secs: wait.as_secs().max(1),
220 });
221 }
222 }
223
224 if state.config.per_ip {
226 if let Some(ip) = ip {
227 let bucket = state.get_or_create_ip_bucket(ip, op);
228 if let Err(wait) = bucket.try_acquire(1.0) {
229 return Err(SecurityError::RateLimitExceeded {
230 retry_after_secs: wait.as_secs().max(1),
231 });
232 }
233 }
234 }
235
236 Ok(())
237 }
238
239 pub async fn remaining(&self, op: OperationType, ip: Option<&str>) -> u32 {
241 if !self.enabled {
242 return u32::MAX;
243 }
244
245 let mut state = self.state.write().await;
246
247 let global_remaining = state
248 .global_buckets
249 .get_mut(&op)
250 .map(|b| b.tokens_remaining())
251 .unwrap_or(u32::MAX);
252
253 if let Some(ip) = ip {
254 if state.config.per_ip {
255 let ip_remaining = state.get_or_create_ip_bucket(ip, op).tokens_remaining();
256 return global_remaining.min(ip_remaining);
257 }
258 }
259
260 global_remaining
261 }
262
263 pub async fn limit(&self, op: OperationType) -> u32 {
265 let state = self.state.read().await;
266 match op {
267 OperationType::Read => state.config.read_rps,
268 OperationType::Write => state.config.write_rps,
269 OperationType::File => state.config.file_rps,
270 }
271 }
272}
273
274impl Default for RateLimiter {
275 fn default() -> Self {
276 Self::new(RateLimitConfig::default())
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[tokio::test]
285 async fn test_rate_limit_allows_within_limit() {
286 let config = RateLimitConfig {
287 read_rps: 10,
288 burst_multiplier: 1,
289 per_ip: false,
290 ..Default::default()
291 };
292 let limiter = RateLimiter::new(config);
293
294 for _ in 0..10 {
296 assert!(limiter.check(OperationType::Read, None).await.is_ok());
297 }
298 }
299
300 #[tokio::test]
301 async fn test_rate_limit_blocks_excess() {
302 let config = RateLimitConfig {
303 read_rps: 5,
304 burst_multiplier: 1,
305 per_ip: false,
306 ..Default::default()
307 };
308 let limiter = RateLimiter::new(config);
309
310 for _ in 0..5 {
312 let _ = limiter.check(OperationType::Read, None).await;
313 }
314
315 let result = limiter.check(OperationType::Read, None).await;
317 assert!(matches!(result, Err(SecurityError::RateLimitExceeded { .. })));
318 }
319
320 #[tokio::test]
321 async fn test_per_ip_limiting() {
322 let config = RateLimitConfig {
324 read_rps: 10,
325 burst_multiplier: 1,
326 per_ip: true,
327 ..Default::default()
328 };
329 let limiter = RateLimiter::new(config);
330
331 assert!(limiter
334 .check(OperationType::Read, Some("192.168.1.1"))
335 .await
336 .is_ok());
337 assert!(limiter
338 .check(OperationType::Read, Some("192.168.1.2"))
339 .await
340 .is_ok());
341
342 let remaining_ip1 = limiter.remaining(OperationType::Read, Some("192.168.1.1")).await;
344 let remaining_ip2 = limiter.remaining(OperationType::Read, Some("192.168.1.2")).await;
345
346 assert!(remaining_ip1 > 0);
349 assert!(remaining_ip2 > 0);
350 }
351
352 #[tokio::test]
353 async fn test_disabled_limiter() {
354 let limiter = RateLimiter::disabled();
355
356 for _ in 0..1000 {
358 assert!(limiter.check(OperationType::Read, None).await.is_ok());
359 }
360 }
361}