1use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tracing::{debug, trace, warn};
12
13use crate::errors::{LimitType, SentinelError, SentinelResult};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Limits {
18 pub max_header_size_bytes: usize,
20 pub max_header_count: usize,
21 pub max_header_name_bytes: usize,
22 pub max_header_value_bytes: usize,
23
24 pub max_body_size_bytes: usize,
26 pub max_body_buffer_bytes: usize,
27 pub max_body_inspection_bytes: usize,
28
29 pub max_decompression_ratio: f32,
31 pub max_decompressed_size_bytes: usize,
32
33 pub max_connections_per_client: usize,
35 pub max_connections_per_route: usize,
36 pub max_total_connections: usize,
37 pub max_idle_connections_per_upstream: usize,
38
39 pub max_in_flight_requests: usize,
41 pub max_in_flight_requests_per_worker: usize,
42 pub max_queued_requests: usize,
43
44 pub max_agent_queue_depth: usize,
46 pub max_agent_body_bytes: usize,
47 pub max_agent_response_bytes: usize,
48
49 pub max_requests_per_second_global: Option<u32>,
51 pub max_requests_per_second_per_client: Option<u32>,
52 pub max_requests_per_second_per_route: Option<u32>,
53
54 pub max_memory_bytes: Option<usize>,
56 pub max_memory_percent: Option<f32>,
57}
58
59impl Default for Limits {
60 fn default() -> Self {
61 Self {
62 max_header_size_bytes: 8192, max_header_count: 100, max_header_name_bytes: 256, max_header_value_bytes: 4096, max_body_size_bytes: 10 * 1024 * 1024,
70 max_body_buffer_bytes: 1024 * 1024,
71 max_body_inspection_bytes: 1024 * 1024,
72
73 max_decompression_ratio: 100.0,
75 max_decompressed_size_bytes: 100 * 1024 * 1024, max_connections_per_client: 100,
79 max_connections_per_route: 1000,
80 max_total_connections: 10000,
81 max_idle_connections_per_upstream: 100,
82
83 max_in_flight_requests: 10000,
85 max_in_flight_requests_per_worker: 1000,
86 max_queued_requests: 1000,
87
88 max_agent_queue_depth: 100,
90 max_agent_body_bytes: 1024 * 1024, max_agent_response_bytes: 10 * 1024, max_requests_per_second_global: None,
95 max_requests_per_second_per_client: None,
96 max_requests_per_second_per_route: None,
97
98 max_memory_bytes: None,
100 max_memory_percent: None,
101 }
102 }
103}
104
105impl Limits {
106 pub fn for_testing() -> Self {
108 Self {
109 max_header_size_bytes: 16384,
110 max_header_count: 200,
111 max_body_size_bytes: 100 * 1024 * 1024, max_in_flight_requests: 100000,
113 ..Default::default()
114 }
115 }
116
117 pub fn for_production() -> Self {
119 Self {
120 max_header_size_bytes: 4096,
121 max_header_count: 50,
122 max_body_size_bytes: 1024 * 1024, max_in_flight_requests: 5000,
124 max_requests_per_second_global: Some(10000),
125 max_requests_per_second_per_client: Some(100),
126 max_memory_percent: Some(80.0),
127 ..Default::default()
128 }
129 }
130
131 pub fn validate(&self) -> SentinelResult<()> {
133 if self.max_header_size_bytes == 0 {
134 return Err(SentinelError::Config {
135 message: "max_header_size_bytes must be greater than 0".to_string(),
136 source: None,
137 });
138 }
139
140 if self.max_header_count == 0 {
141 return Err(SentinelError::Config {
142 message: "max_header_count must be greater than 0".to_string(),
143 source: None,
144 });
145 }
146
147 if self.max_body_buffer_bytes > self.max_body_size_bytes {
148 return Err(SentinelError::Config {
149 message: "max_body_buffer_bytes cannot exceed max_body_size_bytes".to_string(),
150 source: None,
151 });
152 }
153
154 if self.max_decompression_ratio <= 0.0 {
155 return Err(SentinelError::Config {
156 message: "max_decompression_ratio must be positive".to_string(),
157 source: None,
158 });
159 }
160
161 if let Some(pct) = self.max_memory_percent {
162 if pct <= 0.0 || pct > 100.0 {
163 return Err(SentinelError::Config {
164 message: "max_memory_percent must be between 0 and 100".to_string(),
165 source: None,
166 });
167 }
168 }
169
170 Ok(())
171 }
172
173 pub fn check_header_size(&self, size: usize) -> SentinelResult<()> {
175 if size > self.max_header_size_bytes {
176 return Err(SentinelError::limit_exceeded(
177 LimitType::HeaderSize,
178 size,
179 self.max_header_size_bytes,
180 ));
181 }
182 Ok(())
183 }
184
185 pub fn check_header_count(&self, count: usize) -> SentinelResult<()> {
187 if count > self.max_header_count {
188 return Err(SentinelError::limit_exceeded(
189 LimitType::HeaderCount,
190 count,
191 self.max_header_count,
192 ));
193 }
194 Ok(())
195 }
196
197 pub fn check_body_size(&self, size: usize) -> SentinelResult<()> {
199 if size > self.max_body_size_bytes {
200 return Err(SentinelError::limit_exceeded(
201 LimitType::BodySize,
202 size,
203 self.max_body_size_bytes,
204 ));
205 }
206 Ok(())
207 }
208}
209
210#[derive(Debug)]
212pub struct RateLimiter {
213 capacity: u32,
214 tokens: Arc<RwLock<f64>>,
215 refill_rate: f64,
216 last_refill: Arc<RwLock<Instant>>,
217}
218
219impl RateLimiter {
220 pub fn new(capacity: u32, refill_per_second: u32) -> Self {
222 trace!(
223 capacity = capacity,
224 refill_per_second = refill_per_second,
225 "Creating rate limiter"
226 );
227 Self {
228 capacity,
229 tokens: Arc::new(RwLock::new(capacity as f64)),
230 refill_rate: refill_per_second as f64,
231 last_refill: Arc::new(RwLock::new(Instant::now())),
232 }
233 }
234
235 pub fn try_acquire(&self, tokens: u32) -> bool {
237 self.refill();
238
239 let mut available_tokens = self.tokens.write();
240 if *available_tokens >= tokens as f64 {
241 *available_tokens -= tokens as f64;
242 trace!(
243 tokens_requested = tokens,
244 tokens_remaining = *available_tokens as u32,
245 "Rate limiter: tokens acquired"
246 );
247 true
248 } else {
249 trace!(
250 tokens_requested = tokens,
251 tokens_available = *available_tokens as u32,
252 "Rate limiter: insufficient tokens"
253 );
254 false
255 }
256 }
257
258 pub fn check(&self, tokens: u32) -> bool {
260 self.refill();
261 let available_tokens = self.tokens.read();
262 *available_tokens >= tokens as f64
263 }
264
265 pub fn available(&self) -> u32 {
267 self.refill();
268 let tokens = self.tokens.read();
269 *tokens as u32
270 }
271
272 fn refill(&self) {
274 let now = Instant::now();
275 let mut last_refill = self.last_refill.write();
276 let elapsed = now.duration_since(*last_refill).as_secs_f64();
277
278 if elapsed > 0.0 {
279 let mut tokens = self.tokens.write();
280 let tokens_to_add = elapsed * self.refill_rate;
281 *tokens = (*tokens + tokens_to_add).min(self.capacity as f64);
282 *last_refill = now;
283 }
284 }
285
286 pub fn reset(&self) {
288 let mut tokens = self.tokens.write();
289 *tokens = self.capacity as f64;
290 let mut last_refill = self.last_refill.write();
291 *last_refill = Instant::now();
292 }
293}
294
295pub struct MultiRateLimiter {
297 global: Option<RateLimiter>,
298 per_client: Arc<RwLock<HashMap<String, RateLimiter>>>,
299 per_route: Arc<RwLock<HashMap<String, RateLimiter>>>,
300 client_limit: Option<(u32, u32)>, route_limit: Option<(u32, u32)>, }
303
304impl MultiRateLimiter {
305 pub fn new(limits: &Limits) -> Self {
307 let global = limits
308 .max_requests_per_second_global
309 .map(|rps| RateLimiter::new(rps * 10, rps)); let client_limit = limits
312 .max_requests_per_second_per_client
313 .map(|rps| (rps * 10, rps));
314
315 let route_limit = limits
316 .max_requests_per_second_per_route
317 .map(|rps| (rps * 10, rps));
318
319 Self {
320 global,
321 per_client: Arc::new(RwLock::new(HashMap::new())),
322 per_route: Arc::new(RwLock::new(HashMap::new())),
323 client_limit,
324 route_limit,
325 }
326 }
327
328 pub fn check_request(&self, client_id: &str, route: &str) -> SentinelResult<()> {
330 trace!(
331 client_id = %client_id,
332 route = %route,
333 "Checking rate limits"
334 );
335
336 if let Some(ref limiter) = self.global {
338 if !limiter.try_acquire(1) {
339 warn!(
340 client_id = %client_id,
341 route = %route,
342 "Global rate limit exceeded"
343 );
344 return Err(SentinelError::RateLimit {
345 message: "Global rate limit exceeded".to_string(),
346 limit: limiter.capacity,
347 window_seconds: 10,
348 retry_after_seconds: Some(1),
349 });
350 }
351 }
352
353 if let Some((capacity, refill)) = self.client_limit {
355 let mut limiters = self.per_client.write();
356 let limiter = limiters
357 .entry(client_id.to_string())
358 .or_insert_with(|| RateLimiter::new(capacity, refill));
359
360 if !limiter.try_acquire(1) {
361 warn!(
362 client_id = %client_id,
363 route = %route,
364 "Per-client rate limit exceeded"
365 );
366 return Err(SentinelError::RateLimit {
367 message: format!("Rate limit exceeded for client {}", client_id),
368 limit: capacity,
369 window_seconds: 10,
370 retry_after_seconds: Some(1),
371 });
372 }
373 }
374
375 if let Some((capacity, refill)) = self.route_limit {
377 let mut limiters = self.per_route.write();
378 let limiter = limiters
379 .entry(route.to_string())
380 .or_insert_with(|| RateLimiter::new(capacity, refill));
381
382 if !limiter.try_acquire(1) {
383 warn!(
384 client_id = %client_id,
385 route = %route,
386 "Per-route rate limit exceeded"
387 );
388 return Err(SentinelError::RateLimit {
389 message: format!("Rate limit exceeded for route {}", route),
390 limit: capacity,
391 window_seconds: 10,
392 retry_after_seconds: Some(1),
393 });
394 }
395 }
396
397 trace!(
398 client_id = %client_id,
399 route = %route,
400 "Rate limits check passed"
401 );
402 Ok(())
403 }
404
405 pub fn cleanup(&self, _max_age: Duration) {
407 }
410}
411
412pub struct ConnectionLimiter {
414 per_client: Arc<RwLock<HashMap<String, usize>>>,
415 per_route: Arc<RwLock<HashMap<String, usize>>>,
416 total: Arc<RwLock<usize>>,
417 limits: Limits,
418}
419
420impl ConnectionLimiter {
421 pub fn new(limits: Limits) -> Self {
422 debug!(
423 max_total = limits.max_total_connections,
424 max_per_client = limits.max_connections_per_client,
425 max_per_route = limits.max_connections_per_route,
426 "Creating connection limiter"
427 );
428 Self {
429 per_client: Arc::new(RwLock::new(HashMap::new())),
430 per_route: Arc::new(RwLock::new(HashMap::new())),
431 total: Arc::new(RwLock::new(0)),
432 limits,
433 }
434 }
435
436 pub fn try_acquire(&self, client_id: &str, route: &str) -> SentinelResult<ConnectionGuard<'_>> {
438 trace!(
439 client_id = %client_id,
440 route = %route,
441 "Attempting to acquire connection slot"
442 );
443
444 {
446 let mut total = self.total.write();
447 if *total >= self.limits.max_total_connections {
448 warn!(
449 current = *total,
450 max = self.limits.max_total_connections,
451 "Total connection limit exceeded"
452 );
453 return Err(SentinelError::limit_exceeded(
454 LimitType::ConnectionCount,
455 *total,
456 self.limits.max_total_connections,
457 ));
458 }
459 *total += 1;
460 }
461
462 {
464 let mut per_client = self.per_client.write();
465 let client_count = per_client.entry(client_id.to_string()).or_insert(0);
466 if *client_count >= self.limits.max_connections_per_client {
467 *self.total.write() -= 1;
469 warn!(
470 client_id = %client_id,
471 current = *client_count,
472 max = self.limits.max_connections_per_client,
473 "Per-client connection limit exceeded"
474 );
475 return Err(SentinelError::limit_exceeded(
476 LimitType::ConnectionCount,
477 *client_count,
478 self.limits.max_connections_per_client,
479 ));
480 }
481 *client_count += 1;
482 }
483
484 {
486 let mut per_route = self.per_route.write();
487 let route_count = per_route.entry(route.to_string()).or_insert(0);
488 if *route_count >= self.limits.max_connections_per_route {
489 *self.total.write() -= 1;
491 *self.per_client.write().get_mut(client_id).unwrap() -= 1;
492 warn!(
493 route = %route,
494 current = *route_count,
495 max = self.limits.max_connections_per_route,
496 "Per-route connection limit exceeded"
497 );
498 return Err(SentinelError::limit_exceeded(
499 LimitType::ConnectionCount,
500 *route_count,
501 self.limits.max_connections_per_route,
502 ));
503 }
504 *route_count += 1;
505 }
506
507 trace!(
508 client_id = %client_id,
509 route = %route,
510 "Connection slot acquired"
511 );
512
513 Ok(ConnectionGuard {
514 limiter: self,
515 client_id: client_id.to_string(),
516 route: route.to_string(),
517 })
518 }
519
520 fn release(&self, client_id: &str, route: &str) {
522 trace!(
523 client_id = %client_id,
524 route = %route,
525 "Releasing connection slot"
526 );
527
528 *self.total.write() -= 1;
529
530 if let Some(count) = self.per_client.write().get_mut(client_id) {
531 *count = count.saturating_sub(1);
532 }
533
534 if let Some(count) = self.per_route.write().get_mut(route) {
535 *count = count.saturating_sub(1);
536 }
537 }
538
539 pub fn stats(&self) -> ConnectionStats {
541 ConnectionStats {
542 total: *self.total.read(),
543 per_client_count: self.per_client.read().len(),
544 per_route_count: self.per_route.read().len(),
545 }
546 }
547}
548
549pub struct ConnectionGuard<'a> {
551 limiter: &'a ConnectionLimiter,
552 client_id: String,
553 route: String,
554}
555
556impl<'a> Drop for ConnectionGuard<'a> {
557 fn drop(&mut self) {
558 self.limiter.release(&self.client_id, &self.route);
559 }
560}
561
562#[derive(Debug, Clone, Serialize)]
564pub struct ConnectionStats {
565 pub total: usize,
566 pub per_client_count: usize,
567 pub per_route_count: usize,
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use std::thread;
574 use std::time::Duration;
575
576 #[test]
577 fn test_limits_validation() {
578 let mut limits = Limits::default();
579 assert!(limits.validate().is_ok());
580
581 limits.max_header_size_bytes = 0;
582 assert!(limits.validate().is_err());
583
584 limits = Limits::default();
585 limits.max_body_buffer_bytes = limits.max_body_size_bytes + 1;
586 assert!(limits.validate().is_err());
587 }
588
589 #[test]
590 fn test_rate_limiter() {
591 let limiter = RateLimiter::new(10, 10);
592
593 for _ in 0..10 {
595 assert!(limiter.try_acquire(1));
596 }
597
598 assert!(!limiter.try_acquire(1));
600
601 thread::sleep(Duration::from_millis(200));
603
604 assert!(limiter.try_acquire(1));
606 assert!(limiter.available() > 0);
607 }
608
609 #[test]
610 fn test_connection_limiter() {
611 let limits = Limits {
612 max_total_connections: 100,
613 max_connections_per_client: 10,
614 max_connections_per_route: 50,
615 ..Default::default()
616 };
617
618 let limiter = ConnectionLimiter::new(limits);
619
620 let _guard1 = limiter.try_acquire("client1", "route1").unwrap();
622 let _guard2 = limiter.try_acquire("client1", "route1").unwrap();
623
624 let stats = limiter.stats();
625 assert_eq!(stats.total, 2);
626
627 }
629}