sentinel_common/
limits.rs

1//! Limits and rate limiting for Sentinel proxy
2//!
3//! This module implements bounded limits for all resources to ensure predictable
4//! behavior and prevent resource exhaustion - core to "sleepable ops".
5
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tracing::{debug, trace, warn};
12
13use crate::errors::{LimitType, SentinelError, SentinelResult};
14
15/// System-wide limits configuration
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct Limits {
18    // Header limits
19    pub max_header_size_bytes: usize,
20    pub max_header_count: usize,
21    pub max_header_name_bytes: usize,
22    pub max_header_value_bytes: usize,
23
24    // Body limits
25    pub max_body_size_bytes: usize,
26    pub max_body_buffer_bytes: usize,
27    pub max_body_inspection_bytes: usize,
28
29    // Decompression limits
30    pub max_decompression_ratio: f32,
31    pub max_decompressed_size_bytes: usize,
32
33    // Connection limits
34    pub max_connections_per_client: usize,
35    pub max_connections_per_route: usize,
36    pub max_total_connections: usize,
37    pub max_idle_connections_per_upstream: usize,
38
39    // Request limits
40    pub max_in_flight_requests: usize,
41    pub max_in_flight_requests_per_worker: usize,
42    pub max_queued_requests: usize,
43
44    // Agent limits
45    pub max_agent_queue_depth: usize,
46    pub max_agent_body_bytes: usize,
47    pub max_agent_response_bytes: usize,
48
49    // Rate limits
50    pub max_requests_per_second_global: Option<u32>,
51    pub max_requests_per_second_per_client: Option<u32>,
52    pub max_requests_per_second_per_route: Option<u32>,
53
54    // Memory limits
55    pub max_memory_bytes: Option<usize>,
56    pub max_memory_percent: Option<f32>,
57}
58
59impl Default for Limits {
60    fn default() -> Self {
61        Self {
62            // Conservative header limits
63            max_header_size_bytes: 8192,  // 8KB total headers
64            max_header_count: 100,        // Max 100 headers
65            max_header_name_bytes: 256,   // 256 bytes per header name
66            max_header_value_bytes: 4096, // 4KB per header value
67
68            // Body limits - 10MB default, 1MB buffer
69            max_body_size_bytes: 10 * 1024 * 1024,
70            max_body_buffer_bytes: 1024 * 1024,
71            max_body_inspection_bytes: 1024 * 1024,
72
73            // Decompression protection
74            max_decompression_ratio: 100.0,
75            max_decompressed_size_bytes: 100 * 1024 * 1024, // 100MB
76
77            // Connection limits
78            max_connections_per_client: 100,
79            max_connections_per_route: 1000,
80            max_total_connections: 10000,
81            max_idle_connections_per_upstream: 100,
82
83            // Request concurrency
84            max_in_flight_requests: 10000,
85            max_in_flight_requests_per_worker: 1000,
86            max_queued_requests: 1000,
87
88            // Agent communication
89            max_agent_queue_depth: 100,
90            max_agent_body_bytes: 1024 * 1024,   // 1MB to agents
91            max_agent_response_bytes: 10 * 1024, // 10KB from agents
92
93            // Rate limits (optional by default)
94            max_requests_per_second_global: None,
95            max_requests_per_second_per_client: None,
96            max_requests_per_second_per_route: None,
97
98            // Memory limits (optional by default)
99            max_memory_bytes: None,
100            max_memory_percent: None,
101        }
102    }
103}
104
105impl Limits {
106    /// Create limits suitable for testing (more permissive)
107    pub fn for_testing() -> Self {
108        Self {
109            max_header_size_bytes: 16384,
110            max_header_count: 200,
111            max_body_size_bytes: 100 * 1024 * 1024, // 100MB
112            max_in_flight_requests: 100000,
113            ..Default::default()
114        }
115    }
116
117    /// Create limits suitable for production (more restrictive)
118    pub fn for_production() -> Self {
119        Self {
120            max_header_size_bytes: 4096,
121            max_header_count: 50,
122            max_body_size_bytes: 1024 * 1024, // 1MB
123            max_in_flight_requests: 5000,
124            max_requests_per_second_global: Some(10000),
125            max_requests_per_second_per_client: Some(100),
126            max_memory_percent: Some(80.0),
127            ..Default::default()
128        }
129    }
130
131    /// Validate the limits configuration
132    pub fn validate(&self) -> SentinelResult<()> {
133        if self.max_header_size_bytes == 0 {
134            return Err(SentinelError::Config {
135                message: "max_header_size_bytes must be greater than 0".to_string(),
136                source: None,
137            });
138        }
139
140        if self.max_header_count == 0 {
141            return Err(SentinelError::Config {
142                message: "max_header_count must be greater than 0".to_string(),
143                source: None,
144            });
145        }
146
147        if self.max_body_buffer_bytes > self.max_body_size_bytes {
148            return Err(SentinelError::Config {
149                message: "max_body_buffer_bytes cannot exceed max_body_size_bytes".to_string(),
150                source: None,
151            });
152        }
153
154        if self.max_decompression_ratio <= 0.0 {
155            return Err(SentinelError::Config {
156                message: "max_decompression_ratio must be positive".to_string(),
157                source: None,
158            });
159        }
160
161        if let Some(pct) = self.max_memory_percent {
162            if pct <= 0.0 || pct > 100.0 {
163                return Err(SentinelError::Config {
164                    message: "max_memory_percent must be between 0 and 100".to_string(),
165                    source: None,
166                });
167            }
168        }
169
170        Ok(())
171    }
172
173    /// Check if a header size exceeds limits
174    pub fn check_header_size(&self, size: usize) -> SentinelResult<()> {
175        if size > self.max_header_size_bytes {
176            return Err(SentinelError::limit_exceeded(
177                LimitType::HeaderSize,
178                size,
179                self.max_header_size_bytes,
180            ));
181        }
182        Ok(())
183    }
184
185    /// Check if header count exceeds limits
186    pub fn check_header_count(&self, count: usize) -> SentinelResult<()> {
187        if count > self.max_header_count {
188            return Err(SentinelError::limit_exceeded(
189                LimitType::HeaderCount,
190                count,
191                self.max_header_count,
192            ));
193        }
194        Ok(())
195    }
196
197    /// Check if body size exceeds limits
198    pub fn check_body_size(&self, size: usize) -> SentinelResult<()> {
199        if size > self.max_body_size_bytes {
200            return Err(SentinelError::limit_exceeded(
201                LimitType::BodySize,
202                size,
203                self.max_body_size_bytes,
204            ));
205        }
206        Ok(())
207    }
208}
209
210/// Token bucket rate limiter implementation
211#[derive(Debug)]
212pub struct RateLimiter {
213    capacity: u32,
214    tokens: Arc<RwLock<f64>>,
215    refill_rate: f64,
216    last_refill: Arc<RwLock<Instant>>,
217}
218
219impl RateLimiter {
220    /// Create a new rate limiter with specified capacity and refill rate
221    pub fn new(capacity: u32, refill_per_second: u32) -> Self {
222        trace!(
223            capacity = capacity,
224            refill_per_second = refill_per_second,
225            "Creating rate limiter"
226        );
227        Self {
228            capacity,
229            tokens: Arc::new(RwLock::new(capacity as f64)),
230            refill_rate: refill_per_second as f64,
231            last_refill: Arc::new(RwLock::new(Instant::now())),
232        }
233    }
234
235    /// Try to acquire tokens, returns true if successful
236    pub fn try_acquire(&self, tokens: u32) -> bool {
237        self.refill();
238
239        let mut available_tokens = self.tokens.write();
240        if *available_tokens >= tokens as f64 {
241            *available_tokens -= tokens as f64;
242            trace!(
243                tokens_requested = tokens,
244                tokens_remaining = *available_tokens as u32,
245                "Rate limiter: tokens acquired"
246            );
247            true
248        } else {
249            trace!(
250                tokens_requested = tokens,
251                tokens_available = *available_tokens as u32,
252                "Rate limiter: insufficient tokens"
253            );
254            false
255        }
256    }
257
258    /// Check if tokens are available without consuming
259    pub fn check(&self, tokens: u32) -> bool {
260        self.refill();
261        let available_tokens = self.tokens.read();
262        *available_tokens >= tokens as f64
263    }
264
265    /// Get current available tokens
266    pub fn available(&self) -> u32 {
267        self.refill();
268        let tokens = self.tokens.read();
269        *tokens as u32
270    }
271
272    /// Refill tokens based on elapsed time
273    fn refill(&self) {
274        let now = Instant::now();
275        let mut last_refill = self.last_refill.write();
276        let elapsed = now.duration_since(*last_refill).as_secs_f64();
277
278        if elapsed > 0.0 {
279            let mut tokens = self.tokens.write();
280            let tokens_to_add = elapsed * self.refill_rate;
281            *tokens = (*tokens + tokens_to_add).min(self.capacity as f64);
282            *last_refill = now;
283        }
284    }
285
286    /// Reset the rate limiter to full capacity
287    pub fn reset(&self) {
288        let mut tokens = self.tokens.write();
289        *tokens = self.capacity as f64;
290        let mut last_refill = self.last_refill.write();
291        *last_refill = Instant::now();
292    }
293}
294
295/// Multi-level rate limiter for different scopes
296pub struct MultiRateLimiter {
297    global: Option<RateLimiter>,
298    per_client: Arc<RwLock<HashMap<String, RateLimiter>>>,
299    per_route: Arc<RwLock<HashMap<String, RateLimiter>>>,
300    client_limit: Option<(u32, u32)>, // (capacity, refill_per_second)
301    route_limit: Option<(u32, u32)>,  // (capacity, refill_per_second)
302}
303
304impl MultiRateLimiter {
305    /// Create a new multi-level rate limiter
306    pub fn new(limits: &Limits) -> Self {
307        let global = limits
308            .max_requests_per_second_global
309            .map(|rps| RateLimiter::new(rps * 10, rps)); // 10 second burst
310
311        let client_limit = limits
312            .max_requests_per_second_per_client
313            .map(|rps| (rps * 10, rps));
314
315        let route_limit = limits
316            .max_requests_per_second_per_route
317            .map(|rps| (rps * 10, rps));
318
319        Self {
320            global,
321            per_client: Arc::new(RwLock::new(HashMap::new())),
322            per_route: Arc::new(RwLock::new(HashMap::new())),
323            client_limit,
324            route_limit,
325        }
326    }
327
328    /// Check if request is allowed for client and route
329    pub fn check_request(&self, client_id: &str, route: &str) -> SentinelResult<()> {
330        trace!(
331            client_id = %client_id,
332            route = %route,
333            "Checking rate limits"
334        );
335
336        // Check global rate limit
337        if let Some(ref limiter) = self.global {
338            if !limiter.try_acquire(1) {
339                warn!(
340                    client_id = %client_id,
341                    route = %route,
342                    "Global rate limit exceeded"
343                );
344                return Err(SentinelError::RateLimit {
345                    message: "Global rate limit exceeded".to_string(),
346                    limit: limiter.capacity,
347                    window_seconds: 10,
348                    retry_after_seconds: Some(1),
349                });
350            }
351        }
352
353        // Check per-client rate limit
354        if let Some((capacity, refill)) = self.client_limit {
355            let mut limiters = self.per_client.write();
356            let limiter = limiters
357                .entry(client_id.to_string())
358                .or_insert_with(|| RateLimiter::new(capacity, refill));
359
360            if !limiter.try_acquire(1) {
361                warn!(
362                    client_id = %client_id,
363                    route = %route,
364                    "Per-client rate limit exceeded"
365                );
366                return Err(SentinelError::RateLimit {
367                    message: format!("Rate limit exceeded for client {}", client_id),
368                    limit: capacity,
369                    window_seconds: 10,
370                    retry_after_seconds: Some(1),
371                });
372            }
373        }
374
375        // Check per-route rate limit
376        if let Some((capacity, refill)) = self.route_limit {
377            let mut limiters = self.per_route.write();
378            let limiter = limiters
379                .entry(route.to_string())
380                .or_insert_with(|| RateLimiter::new(capacity, refill));
381
382            if !limiter.try_acquire(1) {
383                warn!(
384                    client_id = %client_id,
385                    route = %route,
386                    "Per-route rate limit exceeded"
387                );
388                return Err(SentinelError::RateLimit {
389                    message: format!("Rate limit exceeded for route {}", route),
390                    limit: capacity,
391                    window_seconds: 10,
392                    retry_after_seconds: Some(1),
393                });
394            }
395        }
396
397        trace!(
398            client_id = %client_id,
399            route = %route,
400            "Rate limits check passed"
401        );
402        Ok(())
403    }
404
405    /// Clean up old rate limiters that haven't been used recently
406    pub fn cleanup(&self, _max_age: Duration) {
407        // TODO: Implement cleanup of unused rate limiters
408        // This would track last access time and remove old entries
409    }
410}
411
412/// Connection limiter for managing concurrent connections
413pub struct ConnectionLimiter {
414    per_client: Arc<RwLock<HashMap<String, usize>>>,
415    per_route: Arc<RwLock<HashMap<String, usize>>>,
416    total: Arc<RwLock<usize>>,
417    limits: Limits,
418}
419
420impl ConnectionLimiter {
421    pub fn new(limits: Limits) -> Self {
422        debug!(
423            max_total = limits.max_total_connections,
424            max_per_client = limits.max_connections_per_client,
425            max_per_route = limits.max_connections_per_route,
426            "Creating connection limiter"
427        );
428        Self {
429            per_client: Arc::new(RwLock::new(HashMap::new())),
430            per_route: Arc::new(RwLock::new(HashMap::new())),
431            total: Arc::new(RwLock::new(0)),
432            limits,
433        }
434    }
435
436    /// Try to acquire a connection slot
437    pub fn try_acquire(&self, client_id: &str, route: &str) -> SentinelResult<ConnectionGuard<'_>> {
438        trace!(
439            client_id = %client_id,
440            route = %route,
441            "Attempting to acquire connection slot"
442        );
443
444        // Check total connections
445        {
446            let mut total = self.total.write();
447            if *total >= self.limits.max_total_connections {
448                warn!(
449                    current = *total,
450                    max = self.limits.max_total_connections,
451                    "Total connection limit exceeded"
452                );
453                return Err(SentinelError::limit_exceeded(
454                    LimitType::ConnectionCount,
455                    *total,
456                    self.limits.max_total_connections,
457                ));
458            }
459            *total += 1;
460        }
461
462        // Check per-client connections
463        {
464            let mut per_client = self.per_client.write();
465            let client_count = per_client.entry(client_id.to_string()).or_insert(0);
466            if *client_count >= self.limits.max_connections_per_client {
467                // Rollback total count
468                *self.total.write() -= 1;
469                warn!(
470                    client_id = %client_id,
471                    current = *client_count,
472                    max = self.limits.max_connections_per_client,
473                    "Per-client connection limit exceeded"
474                );
475                return Err(SentinelError::limit_exceeded(
476                    LimitType::ConnectionCount,
477                    *client_count,
478                    self.limits.max_connections_per_client,
479                ));
480            }
481            *client_count += 1;
482        }
483
484        // Check per-route connections
485        {
486            let mut per_route = self.per_route.write();
487            let route_count = per_route.entry(route.to_string()).or_insert(0);
488            if *route_count >= self.limits.max_connections_per_route {
489                // Rollback counts
490                *self.total.write() -= 1;
491                *self.per_client.write().get_mut(client_id).unwrap() -= 1;
492                warn!(
493                    route = %route,
494                    current = *route_count,
495                    max = self.limits.max_connections_per_route,
496                    "Per-route connection limit exceeded"
497                );
498                return Err(SentinelError::limit_exceeded(
499                    LimitType::ConnectionCount,
500                    *route_count,
501                    self.limits.max_connections_per_route,
502                ));
503            }
504            *route_count += 1;
505        }
506
507        trace!(
508            client_id = %client_id,
509            route = %route,
510            "Connection slot acquired"
511        );
512
513        Ok(ConnectionGuard {
514            limiter: self,
515            client_id: client_id.to_string(),
516            route: route.to_string(),
517        })
518    }
519
520    /// Release a connection slot
521    fn release(&self, client_id: &str, route: &str) {
522        trace!(
523            client_id = %client_id,
524            route = %route,
525            "Releasing connection slot"
526        );
527
528        *self.total.write() -= 1;
529
530        if let Some(count) = self.per_client.write().get_mut(client_id) {
531            *count = count.saturating_sub(1);
532        }
533
534        if let Some(count) = self.per_route.write().get_mut(route) {
535            *count = count.saturating_sub(1);
536        }
537    }
538
539    /// Get current connection statistics
540    pub fn stats(&self) -> ConnectionStats {
541        ConnectionStats {
542            total: *self.total.read(),
543            per_client_count: self.per_client.read().len(),
544            per_route_count: self.per_route.read().len(),
545        }
546    }
547}
548
549/// RAII guard for connection slots
550pub struct ConnectionGuard<'a> {
551    limiter: &'a ConnectionLimiter,
552    client_id: String,
553    route: String,
554}
555
556impl<'a> Drop for ConnectionGuard<'a> {
557    fn drop(&mut self) {
558        self.limiter.release(&self.client_id, &self.route);
559    }
560}
561
562/// Connection statistics
563#[derive(Debug, Clone, Serialize)]
564pub struct ConnectionStats {
565    pub total: usize,
566    pub per_client_count: usize,
567    pub per_route_count: usize,
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573    use std::thread;
574    use std::time::Duration;
575
576    #[test]
577    fn test_limits_validation() {
578        let mut limits = Limits::default();
579        assert!(limits.validate().is_ok());
580
581        limits.max_header_size_bytes = 0;
582        assert!(limits.validate().is_err());
583
584        limits = Limits::default();
585        limits.max_body_buffer_bytes = limits.max_body_size_bytes + 1;
586        assert!(limits.validate().is_err());
587    }
588
589    #[test]
590    fn test_rate_limiter() {
591        let limiter = RateLimiter::new(10, 10);
592
593        // Should allow initial burst
594        for _ in 0..10 {
595            assert!(limiter.try_acquire(1));
596        }
597
598        // Should be exhausted
599        assert!(!limiter.try_acquire(1));
600
601        // Wait for refill
602        thread::sleep(Duration::from_millis(200));
603
604        // Should have some tokens refilled (approximately 2)
605        assert!(limiter.try_acquire(1));
606        assert!(limiter.available() > 0);
607    }
608
609    #[test]
610    fn test_connection_limiter() {
611        let limits = Limits {
612            max_total_connections: 100,
613            max_connections_per_client: 10,
614            max_connections_per_route: 50,
615            ..Default::default()
616        };
617
618        let limiter = ConnectionLimiter::new(limits);
619
620        // Acquire connections
621        let _guard1 = limiter.try_acquire("client1", "route1").unwrap();
622        let _guard2 = limiter.try_acquire("client1", "route1").unwrap();
623
624        let stats = limiter.stats();
625        assert_eq!(stats.total, 2);
626
627        // Guards will release on drop
628    }
629}