1use dashmap::DashMap;
18use parking_lot::RwLock;
19use pingora_limits::rate::Rate;
20use std::sync::Arc;
21use std::time::Duration;
22use tracing::{debug, trace, warn};
23
24use sentinel_config::{RateLimitAction, RateLimitBackend, RateLimitKey};
25
26#[cfg(feature = "distributed-rate-limit")]
27use crate::distributed_rate_limit::{create_redis_rate_limiter, RedisRateLimiter};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum RateLimitOutcome {
32 Allowed,
34 Limited,
36}
37
38#[derive(Debug, Clone)]
40pub struct RateLimitConfig {
41 pub max_rps: u32,
43 pub burst: u32,
45 pub key: RateLimitKey,
47 pub action: RateLimitAction,
49 pub status_code: u16,
51 pub message: Option<String>,
53 pub backend: RateLimitBackend,
55}
56
57impl Default for RateLimitConfig {
58 fn default() -> Self {
59 Self {
60 max_rps: 100,
61 burst: 10,
62 key: RateLimitKey::ClientIp,
63 action: RateLimitAction::Reject,
64 status_code: 429,
65 message: None,
66 backend: RateLimitBackend::Local,
67 }
68 }
69}
70
71struct KeyRateLimiter {
75 rate: Rate,
77 max_requests: isize,
79}
80
81impl KeyRateLimiter {
82 fn new(max_rps: u32) -> Self {
83 Self {
84 rate: Rate::new(Duration::from_secs(1)),
85 max_requests: max_rps as isize,
86 }
87 }
88
89 fn check(&self) -> RateLimitOutcome {
91 let curr_count = self.rate.observe(&(), 1);
93
94 if curr_count > self.max_requests {
95 RateLimitOutcome::Limited
96 } else {
97 RateLimitOutcome::Allowed
98 }
99 }
100}
101
102pub enum RateLimitBackendType {
104 Local {
106 limiters: DashMap<String, Arc<KeyRateLimiter>>,
108 },
109 #[cfg(feature = "distributed-rate-limit")]
111 Distributed {
112 redis: Arc<RedisRateLimiter>,
114 local_fallback: DashMap<String, Arc<KeyRateLimiter>>,
116 },
117}
118
119pub struct RateLimiterPool {
121 backend: RateLimitBackendType,
123 config: RwLock<RateLimitConfig>,
125}
126
127impl RateLimiterPool {
128 pub fn new(config: RateLimitConfig) -> Self {
130 Self {
131 backend: RateLimitBackendType::Local {
132 limiters: DashMap::new(),
133 },
134 config: RwLock::new(config),
135 }
136 }
137
138 #[cfg(feature = "distributed-rate-limit")]
140 pub fn with_redis(config: RateLimitConfig, redis: Arc<RedisRateLimiter>) -> Self {
141 Self {
142 backend: RateLimitBackendType::Distributed {
143 redis,
144 local_fallback: DashMap::new(),
145 },
146 config: RwLock::new(config),
147 }
148 }
149
150 pub fn check(&self, key: &str) -> (RateLimitOutcome, isize) {
155 let config = self.config.read();
156 let max_rps = config.max_rps;
157 drop(config);
158
159 let limiters = match &self.backend {
160 RateLimitBackendType::Local { limiters } => limiters,
161 #[cfg(feature = "distributed-rate-limit")]
162 RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback,
163 };
164
165 let limiter = limiters
167 .entry(key.to_string())
168 .or_insert_with(|| Arc::new(KeyRateLimiter::new(max_rps)))
169 .clone();
170
171 let outcome = limiter.check();
172 let count = limiter.rate.observe(&(), 0); (outcome, count)
175 }
176
177 #[cfg(feature = "distributed-rate-limit")]
181 pub async fn check_async(&self, key: &str) -> (RateLimitOutcome, i64) {
182 match &self.backend {
183 RateLimitBackendType::Local { .. } => {
184 let (outcome, count) = self.check(key);
185 (outcome, count as i64)
186 }
187 RateLimitBackendType::Distributed {
188 redis,
189 local_fallback,
190 } => {
191 match redis.check(key).await {
193 Ok((outcome, count)) => (outcome, count),
194 Err(e) => {
195 warn!(
196 error = %e,
197 key = key,
198 "Redis rate limit check failed, falling back to local"
199 );
200 redis.mark_unhealthy();
201
202 if redis.fallback_enabled() {
204 let config = self.config.read();
205 let max_rps = config.max_rps;
206 drop(config);
207
208 let limiter = local_fallback
209 .entry(key.to_string())
210 .or_insert_with(|| Arc::new(KeyRateLimiter::new(max_rps)))
211 .clone();
212
213 let outcome = limiter.check();
214 let count = limiter.rate.observe(&(), 0);
215 (outcome, count as i64)
216 } else {
217 (RateLimitOutcome::Allowed, 0)
219 }
220 }
221 }
222 }
223 }
224 }
225
226 pub fn is_distributed(&self) -> bool {
228 match &self.backend {
229 RateLimitBackendType::Local { .. } => false,
230 #[cfg(feature = "distributed-rate-limit")]
231 RateLimitBackendType::Distributed { .. } => true,
232 }
233 }
234
235 pub fn extract_key(
237 &self,
238 client_ip: &str,
239 path: &str,
240 route_id: &str,
241 headers: Option<&impl HeaderAccessor>,
242 ) -> String {
243 let config = self.config.read();
244 match &config.key {
245 RateLimitKey::ClientIp => client_ip.to_string(),
246 RateLimitKey::Path => path.to_string(),
247 RateLimitKey::Route => route_id.to_string(),
248 RateLimitKey::ClientIpAndPath => format!("{}:{}", client_ip, path),
249 RateLimitKey::Header(header_name) => headers
250 .and_then(|h| h.get_header(header_name))
251 .unwrap_or_else(|| "unknown".to_string()),
252 }
253 }
254
255 pub fn action(&self) -> RateLimitAction {
257 self.config.read().action.clone()
258 }
259
260 pub fn status_code(&self) -> u16 {
262 self.config.read().status_code
263 }
264
265 pub fn message(&self) -> Option<String> {
267 self.config.read().message.clone()
268 }
269
270 pub fn update_config(&self, config: RateLimitConfig) {
272 *self.config.write() = config;
273 self.clear_local_limiters();
275 }
276
277 fn clear_local_limiters(&self) {
279 match &self.backend {
280 RateLimitBackendType::Local { limiters } => limiters.clear(),
281 #[cfg(feature = "distributed-rate-limit")]
282 RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback.clear(),
283 }
284 }
285
286 fn local_limiter_count(&self) -> usize {
288 match &self.backend {
289 RateLimitBackendType::Local { limiters } => limiters.len(),
290 #[cfg(feature = "distributed-rate-limit")]
291 RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback.len(),
292 }
293 }
294
295 pub fn cleanup(&self) {
297 let max_entries = 100_000; let limiters = match &self.backend {
303 RateLimitBackendType::Local { limiters } => limiters,
304 #[cfg(feature = "distributed-rate-limit")]
305 RateLimitBackendType::Distributed { local_fallback, .. } => local_fallback,
306 };
307
308 if limiters.len() > max_entries {
309 let to_remove: Vec<_> = limiters
311 .iter()
312 .take(max_entries / 2)
313 .map(|e| e.key().clone())
314 .collect();
315
316 for key in to_remove {
317 limiters.remove(&key);
318 }
319
320 debug!(
321 entries_before = max_entries,
322 entries_after = limiters.len(),
323 "Rate limiter pool cleanup completed"
324 );
325 }
326 }
327}
328
329pub trait HeaderAccessor {
331 fn get_header(&self, name: &str) -> Option<String>;
332}
333
334pub struct RateLimitManager {
336 route_limiters: DashMap<String, Arc<RateLimiterPool>>,
338 global_limiter: Option<Arc<RateLimiterPool>>,
340}
341
342impl RateLimitManager {
343 pub fn new() -> Self {
345 Self {
346 route_limiters: DashMap::new(),
347 global_limiter: None,
348 }
349 }
350
351 pub fn with_global_limit(max_rps: u32, burst: u32) -> Self {
353 let config = RateLimitConfig {
354 max_rps,
355 burst,
356 key: RateLimitKey::ClientIp,
357 action: RateLimitAction::Reject,
358 status_code: 429,
359 message: None,
360 backend: RateLimitBackend::Local,
361 };
362 Self {
363 route_limiters: DashMap::new(),
364 global_limiter: Some(Arc::new(RateLimiterPool::new(config))),
365 }
366 }
367
368 pub fn register_route(&self, route_id: &str, config: RateLimitConfig) {
370 trace!(
371 route_id = route_id,
372 max_rps = config.max_rps,
373 burst = config.burst,
374 key = ?config.key,
375 "Registering rate limiter for route"
376 );
377
378 self.route_limiters
379 .insert(route_id.to_string(), Arc::new(RateLimiterPool::new(config)));
380 }
381
382 pub fn check(
386 &self,
387 route_id: &str,
388 client_ip: &str,
389 path: &str,
390 headers: Option<&impl HeaderAccessor>,
391 ) -> RateLimitResult {
392 if let Some(ref global) = self.global_limiter {
394 let key = global.extract_key(client_ip, path, route_id, headers);
395 let (outcome, count) = global.check(&key);
396
397 if outcome == RateLimitOutcome::Limited {
398 warn!(
399 route_id = route_id,
400 client_ip = client_ip,
401 key = key,
402 count = count,
403 "Request rate limited by global limiter"
404 );
405 return RateLimitResult {
406 allowed: false,
407 action: global.action(),
408 status_code: global.status_code(),
409 message: global.message(),
410 limiter: "global".to_string(),
411 };
412 }
413 }
414
415 if let Some(pool) = self.route_limiters.get(route_id) {
417 let key = pool.extract_key(client_ip, path, route_id, headers);
418 let (outcome, count) = pool.check(&key);
419
420 if outcome == RateLimitOutcome::Limited {
421 warn!(
422 route_id = route_id,
423 client_ip = client_ip,
424 key = key,
425 count = count,
426 "Request rate limited by route limiter"
427 );
428 return RateLimitResult {
429 allowed: false,
430 action: pool.action(),
431 status_code: pool.status_code(),
432 message: pool.message(),
433 limiter: route_id.to_string(),
434 };
435 }
436
437 trace!(
438 route_id = route_id,
439 key = key,
440 count = count,
441 "Request allowed by rate limiter"
442 );
443 }
444
445 RateLimitResult {
446 allowed: true,
447 action: RateLimitAction::Reject,
448 status_code: 429,
449 message: None,
450 limiter: String::new(),
451 }
452 }
453
454 pub fn cleanup(&self) {
456 if let Some(ref global) = self.global_limiter {
457 global.cleanup();
458 }
459 for entry in self.route_limiters.iter() {
460 entry.value().cleanup();
461 }
462 }
463
464 pub fn route_count(&self) -> usize {
466 self.route_limiters.len()
467 }
468
469 #[inline]
474 pub fn is_enabled(&self) -> bool {
475 self.global_limiter.is_some() || !self.route_limiters.is_empty()
476 }
477
478 #[inline]
480 pub fn has_route_limiter(&self, route_id: &str) -> bool {
481 self.global_limiter.is_some() || self.route_limiters.contains_key(route_id)
482 }
483}
484
485impl Default for RateLimitManager {
486 fn default() -> Self {
487 Self::new()
488 }
489}
490
491#[derive(Debug, Clone)]
493pub struct RateLimitResult {
494 pub allowed: bool,
496 pub action: RateLimitAction,
498 pub status_code: u16,
500 pub message: Option<String>,
502 pub limiter: String,
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_rate_limiter_allows_under_limit() {
512 let config = RateLimitConfig {
513 max_rps: 10,
514 burst: 5,
515 key: RateLimitKey::ClientIp,
516 ..Default::default()
517 };
518 let pool = RateLimiterPool::new(config);
519
520 for _ in 0..10 {
522 let (outcome, _) = pool.check("127.0.0.1");
523 assert_eq!(outcome, RateLimitOutcome::Allowed);
524 }
525 }
526
527 #[test]
528 fn test_rate_limiter_blocks_over_limit() {
529 let config = RateLimitConfig {
530 max_rps: 5,
531 burst: 2,
532 key: RateLimitKey::ClientIp,
533 ..Default::default()
534 };
535 let pool = RateLimiterPool::new(config);
536
537 for _ in 0..5 {
539 let (outcome, _) = pool.check("127.0.0.1");
540 assert_eq!(outcome, RateLimitOutcome::Allowed);
541 }
542
543 let (outcome, _) = pool.check("127.0.0.1");
545 assert_eq!(outcome, RateLimitOutcome::Limited);
546 }
547
548 #[test]
549 fn test_rate_limiter_separate_keys() {
550 let config = RateLimitConfig {
551 max_rps: 2,
552 burst: 1,
553 key: RateLimitKey::ClientIp,
554 ..Default::default()
555 };
556 let pool = RateLimiterPool::new(config);
557
558 let (outcome1, _) = pool.check("192.168.1.1");
560 let (outcome2, _) = pool.check("192.168.1.2");
561 let (outcome3, _) = pool.check("192.168.1.1");
562 let (outcome4, _) = pool.check("192.168.1.2");
563
564 assert_eq!(outcome1, RateLimitOutcome::Allowed);
565 assert_eq!(outcome2, RateLimitOutcome::Allowed);
566 assert_eq!(outcome3, RateLimitOutcome::Allowed);
567 assert_eq!(outcome4, RateLimitOutcome::Allowed);
568
569 let (outcome5, _) = pool.check("192.168.1.1");
571 let (outcome6, _) = pool.check("192.168.1.2");
572
573 assert_eq!(outcome5, RateLimitOutcome::Limited);
574 assert_eq!(outcome6, RateLimitOutcome::Limited);
575 }
576
577 #[test]
578 fn test_rate_limit_manager() {
579 let manager = RateLimitManager::new();
580
581 manager.register_route(
582 "api",
583 RateLimitConfig {
584 max_rps: 5,
585 burst: 2,
586 key: RateLimitKey::ClientIp,
587 ..Default::default()
588 },
589 );
590
591 let result = manager.check("web", "127.0.0.1", "/", Option::<&NoHeaders>::None);
593 assert!(result.allowed);
594
595 for _ in 0..5 {
597 let result = manager.check("api", "127.0.0.1", "/api/test", Option::<&NoHeaders>::None);
598 assert!(result.allowed);
599 }
600
601 let result = manager.check("api", "127.0.0.1", "/api/test", Option::<&NoHeaders>::None);
602 assert!(!result.allowed);
603 assert_eq!(result.status_code, 429);
604 }
605
606 struct NoHeaders;
608 impl HeaderAccessor for NoHeaders {
609 fn get_header(&self, _name: &str) -> Option<String> {
610 None
611 }
612 }
613}