1use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11
12use crate::errors::{LimitType, SentinelError, SentinelResult};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Limits {
17 pub max_header_size_bytes: usize,
19 pub max_header_count: usize,
20 pub max_header_name_bytes: usize,
21 pub max_header_value_bytes: usize,
22
23 pub max_body_size_bytes: usize,
25 pub max_body_buffer_bytes: usize,
26 pub max_body_inspection_bytes: usize,
27
28 pub max_decompression_ratio: f32,
30 pub max_decompressed_size_bytes: usize,
31
32 pub max_connections_per_client: usize,
34 pub max_connections_per_route: usize,
35 pub max_total_connections: usize,
36 pub max_idle_connections_per_upstream: usize,
37
38 pub max_in_flight_requests: usize,
40 pub max_in_flight_requests_per_worker: usize,
41 pub max_queued_requests: usize,
42
43 pub max_agent_queue_depth: usize,
45 pub max_agent_body_bytes: usize,
46 pub max_agent_response_bytes: usize,
47
48 pub max_requests_per_second_global: Option<u32>,
50 pub max_requests_per_second_per_client: Option<u32>,
51 pub max_requests_per_second_per_route: Option<u32>,
52
53 pub max_memory_bytes: Option<usize>,
55 pub max_memory_percent: Option<f32>,
56}
57
58impl Default for Limits {
59 fn default() -> Self {
60 Self {
61 max_header_size_bytes: 8192, max_header_count: 100, max_header_name_bytes: 256, max_header_value_bytes: 4096, max_body_size_bytes: 10 * 1024 * 1024,
69 max_body_buffer_bytes: 1024 * 1024,
70 max_body_inspection_bytes: 1024 * 1024,
71
72 max_decompression_ratio: 100.0,
74 max_decompressed_size_bytes: 100 * 1024 * 1024, max_connections_per_client: 100,
78 max_connections_per_route: 1000,
79 max_total_connections: 10000,
80 max_idle_connections_per_upstream: 100,
81
82 max_in_flight_requests: 10000,
84 max_in_flight_requests_per_worker: 1000,
85 max_queued_requests: 1000,
86
87 max_agent_queue_depth: 100,
89 max_agent_body_bytes: 1024 * 1024, max_agent_response_bytes: 10 * 1024, max_requests_per_second_global: None,
94 max_requests_per_second_per_client: None,
95 max_requests_per_second_per_route: None,
96
97 max_memory_bytes: None,
99 max_memory_percent: None,
100 }
101 }
102}
103
104impl Limits {
105 pub fn for_testing() -> Self {
107 Self {
108 max_header_size_bytes: 16384,
109 max_header_count: 200,
110 max_body_size_bytes: 100 * 1024 * 1024, max_in_flight_requests: 100000,
112 ..Default::default()
113 }
114 }
115
116 pub fn for_production() -> Self {
118 Self {
119 max_header_size_bytes: 4096,
120 max_header_count: 50,
121 max_body_size_bytes: 1024 * 1024, max_in_flight_requests: 5000,
123 max_requests_per_second_global: Some(10000),
124 max_requests_per_second_per_client: Some(100),
125 max_memory_percent: Some(80.0),
126 ..Default::default()
127 }
128 }
129
130 pub fn validate(&self) -> SentinelResult<()> {
132 if self.max_header_size_bytes == 0 {
133 return Err(SentinelError::Config {
134 message: "max_header_size_bytes must be greater than 0".to_string(),
135 source: None,
136 });
137 }
138
139 if self.max_header_count == 0 {
140 return Err(SentinelError::Config {
141 message: "max_header_count must be greater than 0".to_string(),
142 source: None,
143 });
144 }
145
146 if self.max_body_buffer_bytes > self.max_body_size_bytes {
147 return Err(SentinelError::Config {
148 message: "max_body_buffer_bytes cannot exceed max_body_size_bytes".to_string(),
149 source: None,
150 });
151 }
152
153 if self.max_decompression_ratio <= 0.0 {
154 return Err(SentinelError::Config {
155 message: "max_decompression_ratio must be positive".to_string(),
156 source: None,
157 });
158 }
159
160 if let Some(pct) = self.max_memory_percent {
161 if pct <= 0.0 || pct > 100.0 {
162 return Err(SentinelError::Config {
163 message: "max_memory_percent must be between 0 and 100".to_string(),
164 source: None,
165 });
166 }
167 }
168
169 Ok(())
170 }
171
172 pub fn check_header_size(&self, size: usize) -> SentinelResult<()> {
174 if size > self.max_header_size_bytes {
175 return Err(SentinelError::limit_exceeded(
176 LimitType::HeaderSize,
177 size,
178 self.max_header_size_bytes,
179 ));
180 }
181 Ok(())
182 }
183
184 pub fn check_header_count(&self, count: usize) -> SentinelResult<()> {
186 if count > self.max_header_count {
187 return Err(SentinelError::limit_exceeded(
188 LimitType::HeaderCount,
189 count,
190 self.max_header_count,
191 ));
192 }
193 Ok(())
194 }
195
196 pub fn check_body_size(&self, size: usize) -> SentinelResult<()> {
198 if size > self.max_body_size_bytes {
199 return Err(SentinelError::limit_exceeded(
200 LimitType::BodySize,
201 size,
202 self.max_body_size_bytes,
203 ));
204 }
205 Ok(())
206 }
207}
208
209#[derive(Debug)]
211pub struct RateLimiter {
212 capacity: u32,
213 tokens: Arc<RwLock<f64>>,
214 refill_rate: f64,
215 last_refill: Arc<RwLock<Instant>>,
216}
217
218impl RateLimiter {
219 pub fn new(capacity: u32, refill_per_second: u32) -> Self {
221 Self {
222 capacity,
223 tokens: Arc::new(RwLock::new(capacity as f64)),
224 refill_rate: refill_per_second as f64,
225 last_refill: Arc::new(RwLock::new(Instant::now())),
226 }
227 }
228
229 pub fn try_acquire(&self, tokens: u32) -> bool {
231 self.refill();
232
233 let mut available_tokens = self.tokens.write();
234 if *available_tokens >= tokens as f64 {
235 *available_tokens -= tokens as f64;
236 true
237 } else {
238 false
239 }
240 }
241
242 pub fn check(&self, tokens: u32) -> bool {
244 self.refill();
245 let available_tokens = self.tokens.read();
246 *available_tokens >= tokens as f64
247 }
248
249 pub fn available(&self) -> u32 {
251 self.refill();
252 let tokens = self.tokens.read();
253 *tokens as u32
254 }
255
256 fn refill(&self) {
258 let now = Instant::now();
259 let mut last_refill = self.last_refill.write();
260 let elapsed = now.duration_since(*last_refill).as_secs_f64();
261
262 if elapsed > 0.0 {
263 let mut tokens = self.tokens.write();
264 let tokens_to_add = elapsed * self.refill_rate;
265 *tokens = (*tokens + tokens_to_add).min(self.capacity as f64);
266 *last_refill = now;
267 }
268 }
269
270 pub fn reset(&self) {
272 let mut tokens = self.tokens.write();
273 *tokens = self.capacity as f64;
274 let mut last_refill = self.last_refill.write();
275 *last_refill = Instant::now();
276 }
277}
278
279pub struct MultiRateLimiter {
281 global: Option<RateLimiter>,
282 per_client: Arc<RwLock<HashMap<String, RateLimiter>>>,
283 per_route: Arc<RwLock<HashMap<String, RateLimiter>>>,
284 client_limit: Option<(u32, u32)>, route_limit: Option<(u32, u32)>, }
287
288impl MultiRateLimiter {
289 pub fn new(limits: &Limits) -> Self {
291 let global = limits
292 .max_requests_per_second_global
293 .map(|rps| RateLimiter::new(rps * 10, rps)); let client_limit = limits
296 .max_requests_per_second_per_client
297 .map(|rps| (rps * 10, rps));
298
299 let route_limit = limits
300 .max_requests_per_second_per_route
301 .map(|rps| (rps * 10, rps));
302
303 Self {
304 global,
305 per_client: Arc::new(RwLock::new(HashMap::new())),
306 per_route: Arc::new(RwLock::new(HashMap::new())),
307 client_limit,
308 route_limit,
309 }
310 }
311
312 pub fn check_request(&self, client_id: &str, route: &str) -> SentinelResult<()> {
314 if let Some(ref limiter) = self.global {
316 if !limiter.try_acquire(1) {
317 return Err(SentinelError::RateLimit {
318 message: "Global rate limit exceeded".to_string(),
319 limit: limiter.capacity,
320 window_seconds: 10,
321 retry_after_seconds: Some(1),
322 });
323 }
324 }
325
326 if let Some((capacity, refill)) = self.client_limit {
328 let mut limiters = self.per_client.write();
329 let limiter = limiters
330 .entry(client_id.to_string())
331 .or_insert_with(|| RateLimiter::new(capacity, refill));
332
333 if !limiter.try_acquire(1) {
334 return Err(SentinelError::RateLimit {
335 message: format!("Rate limit exceeded for client {}", client_id),
336 limit: capacity,
337 window_seconds: 10,
338 retry_after_seconds: Some(1),
339 });
340 }
341 }
342
343 if let Some((capacity, refill)) = self.route_limit {
345 let mut limiters = self.per_route.write();
346 let limiter = limiters
347 .entry(route.to_string())
348 .or_insert_with(|| RateLimiter::new(capacity, refill));
349
350 if !limiter.try_acquire(1) {
351 return Err(SentinelError::RateLimit {
352 message: format!("Rate limit exceeded for route {}", route),
353 limit: capacity,
354 window_seconds: 10,
355 retry_after_seconds: Some(1),
356 });
357 }
358 }
359
360 Ok(())
361 }
362
363 pub fn cleanup(&self, _max_age: Duration) {
365 }
368}
369
370pub struct ConnectionLimiter {
372 per_client: Arc<RwLock<HashMap<String, usize>>>,
373 per_route: Arc<RwLock<HashMap<String, usize>>>,
374 total: Arc<RwLock<usize>>,
375 limits: Limits,
376}
377
378impl ConnectionLimiter {
379 pub fn new(limits: Limits) -> Self {
380 Self {
381 per_client: Arc::new(RwLock::new(HashMap::new())),
382 per_route: Arc::new(RwLock::new(HashMap::new())),
383 total: Arc::new(RwLock::new(0)),
384 limits,
385 }
386 }
387
388 pub fn try_acquire(&self, client_id: &str, route: &str) -> SentinelResult<ConnectionGuard<'_>> {
390 {
392 let mut total = self.total.write();
393 if *total >= self.limits.max_total_connections {
394 return Err(SentinelError::limit_exceeded(
395 LimitType::ConnectionCount,
396 *total,
397 self.limits.max_total_connections,
398 ));
399 }
400 *total += 1;
401 }
402
403 {
405 let mut per_client = self.per_client.write();
406 let client_count = per_client.entry(client_id.to_string()).or_insert(0);
407 if *client_count >= self.limits.max_connections_per_client {
408 *self.total.write() -= 1;
410 return Err(SentinelError::limit_exceeded(
411 LimitType::ConnectionCount,
412 *client_count,
413 self.limits.max_connections_per_client,
414 ));
415 }
416 *client_count += 1;
417 }
418
419 {
421 let mut per_route = self.per_route.write();
422 let route_count = per_route.entry(route.to_string()).or_insert(0);
423 if *route_count >= self.limits.max_connections_per_route {
424 *self.total.write() -= 1;
426 *self.per_client.write().get_mut(client_id).unwrap() -= 1;
427 return Err(SentinelError::limit_exceeded(
428 LimitType::ConnectionCount,
429 *route_count,
430 self.limits.max_connections_per_route,
431 ));
432 }
433 *route_count += 1;
434 }
435
436 Ok(ConnectionGuard {
437 limiter: self,
438 client_id: client_id.to_string(),
439 route: route.to_string(),
440 })
441 }
442
443 fn release(&self, client_id: &str, route: &str) {
445 *self.total.write() -= 1;
446
447 if let Some(count) = self.per_client.write().get_mut(client_id) {
448 *count = count.saturating_sub(1);
449 }
450
451 if let Some(count) = self.per_route.write().get_mut(route) {
452 *count = count.saturating_sub(1);
453 }
454 }
455
456 pub fn stats(&self) -> ConnectionStats {
458 ConnectionStats {
459 total: *self.total.read(),
460 per_client_count: self.per_client.read().len(),
461 per_route_count: self.per_route.read().len(),
462 }
463 }
464}
465
466pub struct ConnectionGuard<'a> {
468 limiter: &'a ConnectionLimiter,
469 client_id: String,
470 route: String,
471}
472
473impl<'a> Drop for ConnectionGuard<'a> {
474 fn drop(&mut self) {
475 self.limiter.release(&self.client_id, &self.route);
476 }
477}
478
479#[derive(Debug, Clone, Serialize)]
481pub struct ConnectionStats {
482 pub total: usize,
483 pub per_client_count: usize,
484 pub per_route_count: usize,
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use std::thread;
491 use std::time::Duration;
492
493 #[test]
494 fn test_limits_validation() {
495 let mut limits = Limits::default();
496 assert!(limits.validate().is_ok());
497
498 limits.max_header_size_bytes = 0;
499 assert!(limits.validate().is_err());
500
501 limits = Limits::default();
502 limits.max_body_buffer_bytes = limits.max_body_size_bytes + 1;
503 assert!(limits.validate().is_err());
504 }
505
506 #[test]
507 fn test_rate_limiter() {
508 let limiter = RateLimiter::new(10, 10);
509
510 for _ in 0..10 {
512 assert!(limiter.try_acquire(1));
513 }
514
515 assert!(!limiter.try_acquire(1));
517
518 thread::sleep(Duration::from_millis(200));
520
521 assert!(limiter.try_acquire(1));
523 assert!(limiter.available() > 0);
524 }
525
526 #[test]
527 fn test_connection_limiter() {
528 let limits = Limits {
529 max_total_connections: 100,
530 max_connections_per_client: 10,
531 max_connections_per_route: 50,
532 ..Default::default()
533 };
534
535 let limiter = ConnectionLimiter::new(limits);
536
537 let _guard1 = limiter.try_acquire("client1", "route1").unwrap();
539 let _guard2 = limiter.try_acquire("client1", "route1").unwrap();
540
541 let stats = limiter.stats();
542 assert_eq!(stats.total, 2);
543
544 }
546}