sentinel_proxy/
rate_limit.rs

1//! Rate limiting using pingora-limits
2//!
3//! This module provides efficient per-route, per-client rate limiting using
4//! Pingora's optimized rate limiting primitives. Supports both local (single-instance)
5//! and distributed (Redis-backed) rate limiting.
6//!
7//! # Local Rate Limiting
8//!
9//! Uses `pingora-limits::Rate` for efficient in-memory rate limiting.
10//! Suitable for single-instance deployments.
11//!
12//! # Distributed Rate Limiting
13//!
14//! Uses Redis sorted sets for sliding window rate limiting across multiple instances.
15//! Requires the `distributed-rate-limit` feature.
16
17use dashmap::DashMap;
18use parking_lot::RwLock;
19use pingora_limits::rate::Rate;
20use std::sync::Arc;
21use std::time::Duration;
22use tracing::{debug, trace, warn};
23
24use sentinel_config::{RateLimitAction, RateLimitBackend, RateLimitKey};
25
26#[cfg(feature = "distributed-rate-limit")]
27use crate::distributed_rate_limit::{create_redis_rate_limiter, RedisRateLimiter};
28
29/// Rate limiter outcome
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum RateLimitOutcome {
32    /// Request is allowed
33    Allowed,
34    /// Request is rate limited
35    Limited,
36}
37
38/// Rate limiter configuration
39#[derive(Debug, Clone)]
40pub struct RateLimitConfig {
41    /// Maximum requests per second
42    pub max_rps: u32,
43    /// Burst size
44    pub burst: u32,
45    /// Key type for bucketing
46    pub key: RateLimitKey,
47    /// Action when limited
48    pub action: RateLimitAction,
49    /// HTTP status code to return when limited
50    pub status_code: u16,
51    /// Custom message
52    pub message: Option<String>,
53    /// Backend for rate limiting (local or distributed)
54    pub backend: RateLimitBackend,
55}
56
57impl Default for RateLimitConfig {
58    fn default() -> Self {
59        Self {
60            max_rps: 100,
61            burst: 10,
62            key: RateLimitKey::ClientIp,
63            action: RateLimitAction::Reject,
64            status_code: 429,
65            message: None,
66            backend: RateLimitBackend::Local,
67        }
68    }
69}
70
71/// Per-key rate limiter using pingora-limits Rate
72///
73/// Uses a sliding window algorithm with 1-second granularity.
74struct KeyRateLimiter {
75    /// The rate limiter instance (tracks requests in current window)
76    rate: Rate,
77    /// Maximum requests per window
78    max_requests: isize,
79}
80
81impl KeyRateLimiter {
82    fn new(max_rps: u32) -> Self {
83        Self {
84            rate: Rate::new(Duration::from_secs(1)),
85            max_requests: max_rps as isize,
86        }
87    }
88
89    /// Check if a request should be allowed
90    fn check(&self) -> RateLimitOutcome {
91        // Rate::observe() returns the current count and whether it was a new window
92        let curr_count = self.rate.observe(&(), 1);
93
94        if curr_count > self.max_requests {
95            RateLimitOutcome::Limited
96        } else {
97            RateLimitOutcome::Allowed
98        }
99    }
100}
101
102/// Backend type for rate limiting
103pub enum RateLimitBackendType {
104    /// Local in-memory backend
105    Local {
106        /// Rate limiters by key (e.g., client IP -> limiter)
107        limiters: DashMap<String, Arc<KeyRateLimiter>>,
108    },
109    /// Distributed Redis backend
110    #[cfg(feature = "distributed-rate-limit")]
111    Distributed {
112        /// Redis rate limiter
113        redis: Arc<RedisRateLimiter>,
114        /// Local fallback
115        local_fallback: DashMap<String, Arc<KeyRateLimiter>>,
116    },
117}
118
119/// Thread-safe rate limiter pool managing multiple rate limiters by key
120pub struct RateLimiterPool {
121    /// Backend for rate limiting
122    backend: RateLimitBackendType,
123    /// Configuration
124    config: RwLock<RateLimitConfig>,
125}
126
127impl RateLimiterPool {
128    /// Create a new rate limiter pool with the given configuration (local backend)
129    pub fn new(config: RateLimitConfig) -> Self {
130        Self {
131            backend: RateLimitBackendType::Local {
132                limiters: DashMap::new(),
133            },
134            config: RwLock::new(config),
135        }
136    }
137
138    /// Create a new rate limiter pool with a distributed Redis backend
139    #[cfg(feature = "distributed-rate-limit")]
140    pub fn with_redis(config: RateLimitConfig, redis: Arc<RedisRateLimiter>) -> Self {
141        Self {
142            backend: RateLimitBackendType::Distributed {
143                redis,
144                local_fallback: DashMap::new(),
145            },
146            config: RwLock::new(config),
147        }
148    }
149
150    /// Check if a request should be rate limited (synchronous, local only)
151    ///
152    /// Returns the outcome and the current request count.
153    /// For distributed backends, this falls back to local limiting.
154    pub fn check(&self, key: &str) -> (RateLimitOutcome, isize) {
155        let config = self.config.read();
156        let max_rps = config.max_rps;
157        drop(config);
158
159        let limiters = match &self.backend {
160            RateLimitBackendType::Local { limiters } => limiters,
161            #[cfg(feature = "distributed-rate-limit")]
162            RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback,
163        };
164
165        // Get or create limiter for this key
166        let limiter = limiters
167            .entry(key.to_string())
168            .or_insert_with(|| Arc::new(KeyRateLimiter::new(max_rps)))
169            .clone();
170
171        let outcome = limiter.check();
172        let count = limiter.rate.observe(&(), 0); // Get current count without incrementing
173
174        (outcome, count)
175    }
176
177    /// Check if a request should be rate limited (async, supports distributed backends)
178    ///
179    /// Returns the outcome and the current request count.
180    #[cfg(feature = "distributed-rate-limit")]
181    pub async fn check_async(&self, key: &str) -> (RateLimitOutcome, i64) {
182        match &self.backend {
183            RateLimitBackendType::Local { .. } => {
184                let (outcome, count) = self.check(key);
185                (outcome, count as i64)
186            }
187            RateLimitBackendType::Distributed {
188                redis,
189                local_fallback,
190            } => {
191                // Try Redis first
192                match redis.check(key).await {
193                    Ok((outcome, count)) => (outcome, count),
194                    Err(e) => {
195                        warn!(
196                            error = %e,
197                            key = key,
198                            "Redis rate limit check failed, falling back to local"
199                        );
200                        redis.mark_unhealthy();
201
202                        // Fallback to local
203                        if redis.fallback_enabled() {
204                            let config = self.config.read();
205                            let max_rps = config.max_rps;
206                            drop(config);
207
208                            let limiter = local_fallback
209                                .entry(key.to_string())
210                                .or_insert_with(|| Arc::new(KeyRateLimiter::new(max_rps)))
211                                .clone();
212
213                            let outcome = limiter.check();
214                            let count = limiter.rate.observe(&(), 0);
215                            (outcome, count as i64)
216                        } else {
217                            // Fail open if no fallback
218                            (RateLimitOutcome::Allowed, 0)
219                        }
220                    }
221                }
222            }
223        }
224    }
225
226    /// Check if this pool uses a distributed backend
227    pub fn is_distributed(&self) -> bool {
228        match &self.backend {
229            RateLimitBackendType::Local { .. } => false,
230            #[cfg(feature = "distributed-rate-limit")]
231            RateLimitBackendType::Distributed { .. } => true,
232        }
233    }
234
235    /// Get the rate limit key from request context
236    pub fn extract_key(
237        &self,
238        client_ip: &str,
239        path: &str,
240        route_id: &str,
241        headers: Option<&impl HeaderAccessor>,
242    ) -> String {
243        let config = self.config.read();
244        match &config.key {
245            RateLimitKey::ClientIp => client_ip.to_string(),
246            RateLimitKey::Path => path.to_string(),
247            RateLimitKey::Route => route_id.to_string(),
248            RateLimitKey::ClientIpAndPath => format!("{}:{}", client_ip, path),
249            RateLimitKey::Header(header_name) => headers
250                .and_then(|h| h.get_header(header_name))
251                .unwrap_or_else(|| "unknown".to_string()),
252        }
253    }
254
255    /// Get the action to take when rate limited
256    pub fn action(&self) -> RateLimitAction {
257        self.config.read().action.clone()
258    }
259
260    /// Get the HTTP status code for rate limit responses
261    pub fn status_code(&self) -> u16 {
262        self.config.read().status_code
263    }
264
265    /// Get the custom message for rate limit responses
266    pub fn message(&self) -> Option<String> {
267        self.config.read().message.clone()
268    }
269
270    /// Update the configuration
271    pub fn update_config(&self, config: RateLimitConfig) {
272        *self.config.write() = config;
273        // Clear existing limiters so they get recreated with new config
274        self.clear_local_limiters();
275    }
276
277    /// Clear local limiters (for config updates)
278    fn clear_local_limiters(&self) {
279        match &self.backend {
280            RateLimitBackendType::Local { limiters } => limiters.clear(),
281            #[cfg(feature = "distributed-rate-limit")]
282            RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback.clear(),
283        }
284    }
285
286    /// Get the number of local limiter entries
287    fn local_limiter_count(&self) -> usize {
288        match &self.backend {
289            RateLimitBackendType::Local { limiters } => limiters.len(),
290            #[cfg(feature = "distributed-rate-limit")]
291            RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback.len(),
292        }
293    }
294
295    /// Clean up expired entries (call periodically)
296    pub fn cleanup(&self) {
297        // Remove entries that haven't been accessed recently
298        // In practice, Rate handles its own window cleanup, so this is mainly
299        // for memory management when many unique keys are seen
300        let max_entries = 100_000; // Prevent unbounded growth
301
302        let limiters = match &self.backend {
303            RateLimitBackendType::Local { limiters } => limiters,
304            #[cfg(feature = "distributed-rate-limit")]
305            RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback,
306        };
307
308        if limiters.len() > max_entries {
309            // Simple eviction: clear half
310            let to_remove: Vec<_> = limiters
311                .iter()
312                .take(max_entries / 2)
313                .map(|e| e.key().clone())
314                .collect();
315
316            for key in to_remove {
317                limiters.remove(&key);
318            }
319
320            debug!(
321                entries_before = max_entries,
322                entries_after = limiters.len(),
323                "Rate limiter pool cleanup completed"
324            );
325        }
326    }
327}
328
329/// Trait for accessing headers (allows abstracting over different header types)
330pub trait HeaderAccessor {
331    fn get_header(&self, name: &str) -> Option<String>;
332}
333
334/// Route-level rate limiter manager
335pub struct RateLimitManager {
336    /// Per-route rate limiter pools
337    route_limiters: DashMap<String, Arc<RateLimiterPool>>,
338    /// Global rate limiter (optional)
339    global_limiter: Option<Arc<RateLimiterPool>>,
340}
341
342impl RateLimitManager {
343    /// Create a new rate limit manager
344    pub fn new() -> Self {
345        Self {
346            route_limiters: DashMap::new(),
347            global_limiter: None,
348        }
349    }
350
351    /// Create a new rate limit manager with a global rate limit
352    pub fn with_global_limit(max_rps: u32, burst: u32) -> Self {
353        let config = RateLimitConfig {
354            max_rps,
355            burst,
356            key: RateLimitKey::ClientIp,
357            action: RateLimitAction::Reject,
358            status_code: 429,
359            message: None,
360            backend: RateLimitBackend::Local,
361        };
362        Self {
363            route_limiters: DashMap::new(),
364            global_limiter: Some(Arc::new(RateLimiterPool::new(config))),
365        }
366    }
367
368    /// Register a rate limiter for a route
369    pub fn register_route(&self, route_id: &str, config: RateLimitConfig) {
370        trace!(
371            route_id = route_id,
372            max_rps = config.max_rps,
373            burst = config.burst,
374            key = ?config.key,
375            "Registering rate limiter for route"
376        );
377
378        self.route_limiters
379            .insert(route_id.to_string(), Arc::new(RateLimiterPool::new(config)));
380    }
381
382    /// Check if a request should be rate limited
383    ///
384    /// Checks both global and route-specific limits.
385    pub fn check(
386        &self,
387        route_id: &str,
388        client_ip: &str,
389        path: &str,
390        headers: Option<&impl HeaderAccessor>,
391    ) -> RateLimitResult {
392        // Check global limit first
393        if let Some(ref global) = self.global_limiter {
394            let key = global.extract_key(client_ip, path, route_id, headers);
395            let (outcome, count) = global.check(&key);
396
397            if outcome == RateLimitOutcome::Limited {
398                warn!(
399                    route_id = route_id,
400                    client_ip = client_ip,
401                    key = key,
402                    count = count,
403                    "Request rate limited by global limiter"
404                );
405                return RateLimitResult {
406                    allowed: false,
407                    action: global.action(),
408                    status_code: global.status_code(),
409                    message: global.message(),
410                    limiter: "global".to_string(),
411                };
412            }
413        }
414
415        // Check route-specific limit
416        if let Some(pool) = self.route_limiters.get(route_id) {
417            let key = pool.extract_key(client_ip, path, route_id, headers);
418            let (outcome, count) = pool.check(&key);
419
420            if outcome == RateLimitOutcome::Limited {
421                warn!(
422                    route_id = route_id,
423                    client_ip = client_ip,
424                    key = key,
425                    count = count,
426                    "Request rate limited by route limiter"
427                );
428                return RateLimitResult {
429                    allowed: false,
430                    action: pool.action(),
431                    status_code: pool.status_code(),
432                    message: pool.message(),
433                    limiter: route_id.to_string(),
434                };
435            }
436
437            trace!(
438                route_id = route_id,
439                key = key,
440                count = count,
441                "Request allowed by rate limiter"
442            );
443        }
444
445        RateLimitResult {
446            allowed: true,
447            action: RateLimitAction::Reject,
448            status_code: 429,
449            message: None,
450            limiter: String::new(),
451        }
452    }
453
454    /// Perform periodic cleanup
455    pub fn cleanup(&self) {
456        if let Some(ref global) = self.global_limiter {
457            global.cleanup();
458        }
459        for entry in self.route_limiters.iter() {
460            entry.value().cleanup();
461        }
462    }
463
464    /// Get the number of registered route limiters
465    pub fn route_count(&self) -> usize {
466        self.route_limiters.len()
467    }
468
469    /// Check if any rate limiting is configured (fast path)
470    ///
471    /// Returns true if there's a global limiter or any route-specific limiters.
472    /// Use this to skip rate limit checks entirely when no limiting is configured.
473    #[inline]
474    pub fn is_enabled(&self) -> bool {
475        self.global_limiter.is_some() || !self.route_limiters.is_empty()
476    }
477
478    /// Check if a specific route has rate limiting configured (fast path)
479    #[inline]
480    pub fn has_route_limiter(&self, route_id: &str) -> bool {
481        self.global_limiter.is_some() || self.route_limiters.contains_key(route_id)
482    }
483}
484
485impl Default for RateLimitManager {
486    fn default() -> Self {
487        Self::new()
488    }
489}
490
491/// Result of a rate limit check
492#[derive(Debug, Clone)]
493pub struct RateLimitResult {
494    /// Whether the request is allowed
495    pub allowed: bool,
496    /// Action to take if limited
497    pub action: RateLimitAction,
498    /// HTTP status code for rejection
499    pub status_code: u16,
500    /// Custom message
501    pub message: Option<String>,
502    /// Which limiter triggered (for logging)
503    pub limiter: String,
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_rate_limiter_allows_under_limit() {
512        let config = RateLimitConfig {
513            max_rps: 10,
514            burst: 5,
515            key: RateLimitKey::ClientIp,
516            ..Default::default()
517        };
518        let pool = RateLimiterPool::new(config);
519
520        // Should allow first 10 requests
521        for _ in 0..10 {
522            let (outcome, _) = pool.check("127.0.0.1");
523            assert_eq!(outcome, RateLimitOutcome::Allowed);
524        }
525    }
526
527    #[test]
528    fn test_rate_limiter_blocks_over_limit() {
529        let config = RateLimitConfig {
530            max_rps: 5,
531            burst: 2,
532            key: RateLimitKey::ClientIp,
533            ..Default::default()
534        };
535        let pool = RateLimiterPool::new(config);
536
537        // Should allow first 5 requests
538        for _ in 0..5 {
539            let (outcome, _) = pool.check("127.0.0.1");
540            assert_eq!(outcome, RateLimitOutcome::Allowed);
541        }
542
543        // 6th request should be limited
544        let (outcome, _) = pool.check("127.0.0.1");
545        assert_eq!(outcome, RateLimitOutcome::Limited);
546    }
547
548    #[test]
549    fn test_rate_limiter_separate_keys() {
550        let config = RateLimitConfig {
551            max_rps: 2,
552            burst: 1,
553            key: RateLimitKey::ClientIp,
554            ..Default::default()
555        };
556        let pool = RateLimiterPool::new(config);
557
558        // Each IP gets its own bucket
559        let (outcome1, _) = pool.check("192.168.1.1");
560        let (outcome2, _) = pool.check("192.168.1.2");
561        let (outcome3, _) = pool.check("192.168.1.1");
562        let (outcome4, _) = pool.check("192.168.1.2");
563
564        assert_eq!(outcome1, RateLimitOutcome::Allowed);
565        assert_eq!(outcome2, RateLimitOutcome::Allowed);
566        assert_eq!(outcome3, RateLimitOutcome::Allowed);
567        assert_eq!(outcome4, RateLimitOutcome::Allowed);
568
569        // Both should hit limit now
570        let (outcome5, _) = pool.check("192.168.1.1");
571        let (outcome6, _) = pool.check("192.168.1.2");
572
573        assert_eq!(outcome5, RateLimitOutcome::Limited);
574        assert_eq!(outcome6, RateLimitOutcome::Limited);
575    }
576
577    #[test]
578    fn test_rate_limit_manager() {
579        let manager = RateLimitManager::new();
580
581        manager.register_route(
582            "api",
583            RateLimitConfig {
584                max_rps: 5,
585                burst: 2,
586                key: RateLimitKey::ClientIp,
587                ..Default::default()
588            },
589        );
590
591        // Route without limiter should always pass
592        let result = manager.check("web", "127.0.0.1", "/", Option::<&NoHeaders>::None);
593        assert!(result.allowed);
594
595        // Route with limiter should enforce limits
596        for _ in 0..5 {
597            let result = manager.check("api", "127.0.0.1", "/api/test", Option::<&NoHeaders>::None);
598            assert!(result.allowed);
599        }
600
601        let result = manager.check("api", "127.0.0.1", "/api/test", Option::<&NoHeaders>::None);
602        assert!(!result.allowed);
603        assert_eq!(result.status_code, 429);
604    }
605
606    // Helper type for tests that don't need header access
607    struct NoHeaders;
608    impl HeaderAccessor for NoHeaders {
609        fn get_header(&self, _name: &str) -> Option<String> {
610            None
611        }
612    }
613}