1use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tracing::{debug, info, warn};
13
14pub struct AdaptiveRateLimiter {
16 manager: Arc<RwLock<RateLimitManager>>,
18 metrics: Arc<crate::metrics::MetricsRegistry>,
20 last_latency_ms: AtomicU64,
22 multiplier: AtomicU64,
24}
25
26impl AdaptiveRateLimiter {
27 pub fn new(
29 manager: Arc<RwLock<RateLimitManager>>,
30 metrics: Arc<crate::metrics::MetricsRegistry>,
31 ) -> Self {
32 Self {
33 manager,
34 metrics,
35 last_latency_ms: AtomicU64::new(0),
36 multiplier: AtomicU64::new(1000), }
38 }
39
40 pub fn adjust(&self) {
46 let avg_latency = self.metrics.avg_latency_ms();
47
48 let current_mult = self.multiplier.load(Ordering::Relaxed);
52 let mut next_mult = current_mult;
53
54 if avg_latency > 500.0 {
55 next_mult = (current_mult.saturating_mul(90) / 100).max(200);
58 warn!(latency = %avg_latency, multiplier = %(next_mult as f64 / 1000.0), "Adaptive RL: High latency detected, throttling fleet");
59 } else if avg_latency > 200.0 {
60 next_mult = (current_mult.saturating_mul(95) / 100).max(200);
62 debug!(latency = %avg_latency, multiplier = %(next_mult as f64 / 1000.0), "Adaptive RL: Latency rising, slowing down");
63 } else if avg_latency < 50.0 && current_mult < 1000 {
64 next_mult = (current_mult.saturating_add(50)).min(1000);
66 debug!(latency = %avg_latency, multiplier = %(next_mult as f64 / 1000.0), "Adaptive RL: Health recovered, restoring capacity");
67 }
68
69 if next_mult != current_mult {
70 self.multiplier.store(next_mult, Ordering::Relaxed);
71 self.apply_multiplier(next_mult as f64 / 1000.0);
72 }
73 }
74
75 fn apply_multiplier(&self, multiplier: f64) {
76 let manager = self.manager.read();
77
78 if let Some(global) = &manager.global_limiter {
80 global.set_multiplier(multiplier);
81 }
82
83 let sites = manager.site_limiters.read();
85 for limiter in sites.values() {
86 limiter.set_multiplier(multiplier);
87 }
88 }
89
90 pub fn current_multiplier(&self) -> f64 {
92 self.multiplier.load(Ordering::Relaxed) as f64 / 1000.0
93 }
94
95 pub fn start_background_task(self: Arc<Self>, interval: Duration) {
97 info!(?interval, "Starting adaptive rate limiting background task");
98 std::thread::spawn(move || loop {
99 self.adjust();
100 std::thread::sleep(interval);
101 });
102 }
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum RateLimitDecision {
108 Allow,
110 Limited,
112}
113
114#[derive(Debug, Clone)]
116pub struct RateLimitConfig {
117 pub rps: u32,
119 pub burst: u32,
121 pub enabled: bool,
123 pub window_secs: u64,
125}
126
127impl Default for RateLimitConfig {
128 fn default() -> Self {
129 Self {
130 rps: 1000,
131 burst: 2000,
132 enabled: true,
133 window_secs: 1,
134 }
135 }
136}
137
138impl RateLimitConfig {
139 pub fn new(rps: u32) -> Self {
141 Self {
142 rps,
143 burst: rps * 2,
144 enabled: true,
145 window_secs: 1,
146 }
147 }
148
149 pub fn with_burst(mut self, burst: u32) -> Self {
151 self.burst = burst;
152 self
153 }
154
155 pub fn disabled() -> Self {
157 Self {
158 enabled: false,
159 ..Default::default()
160 }
161 }
162}
163
164#[derive(Debug)]
166pub struct TokenBucket {
167 tokens: AtomicU64,
169 max_tokens: u64,
171 refill_rate: AtomicU64,
173 last_refill: AtomicU64,
175 start_time: Instant,
177 last_access: AtomicU64,
179}
180
181impl TokenBucket {
182 pub fn new(rps: u32, burst: u32) -> Self {
184 let max_tokens = burst as u64;
185 Self {
186 tokens: AtomicU64::new(max_tokens),
187 max_tokens,
188 refill_rate: AtomicU64::new(rps as u64),
189 last_refill: AtomicU64::new(0),
190 start_time: Instant::now(),
191 last_access: AtomicU64::new(0),
192 }
193 }
194
195 pub fn set_rate(&self, rps: u32) {
197 self.refill_rate.store(rps as u64, Ordering::Relaxed);
198 }
199
200 pub fn try_acquire(&self) -> bool {
204 let now_nanos = self.start_time.elapsed().as_nanos() as u64;
206 self.last_access.store(now_nanos, Ordering::Relaxed);
207
208 self.refill();
210
211 loop {
213 let current = self.tokens.load(Ordering::Acquire);
215 if current == 0 {
216 return false;
217 }
218
219 match self.tokens.compare_exchange_weak(
222 current,
223 current - 1,
224 Ordering::AcqRel,
225 Ordering::Acquire,
226 ) {
227 Ok(_) => return true,
228 Err(_) => {
229 core::hint::spin_loop();
232 continue;
233 }
234 }
235 }
236 }
237
238 fn refill(&self) {
244 let now_nanos = self.start_time.elapsed().as_nanos() as u64;
245
246 loop {
248 let last = self.last_refill.load(Ordering::Acquire);
250
251 if now_nanos <= last {
253 return;
254 }
255
256 let elapsed_nanos = now_nanos - last;
257 if elapsed_nanos < 1000 {
259 return;
260 }
261
262 let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0;
263 let refill_rate = self.refill_rate.load(Ordering::Relaxed);
264 let tokens_to_add = (elapsed_secs * refill_rate as f64) as u64;
265
266 if tokens_to_add == 0 {
267 return;
268 }
269
270 match self.last_refill.compare_exchange(
273 last,
274 now_nanos,
275 Ordering::AcqRel,
276 Ordering::Acquire,
277 ) {
278 Ok(_) => {
279 self.add_tokens(tokens_to_add);
282 return;
283 }
284 Err(actual) => {
285 if actual >= now_nanos {
288 return;
289 }
290 core::hint::spin_loop();
292 continue;
293 }
294 }
295 }
296 }
297
298 #[inline]
302 fn add_tokens(&self, tokens_to_add: u64) {
303 loop {
304 let current = self.tokens.load(Ordering::Acquire);
305 let new_tokens = (current.saturating_add(tokens_to_add)).min(self.max_tokens);
306
307 if new_tokens == current {
309 return;
310 }
311
312 match self.tokens.compare_exchange_weak(
313 current,
314 new_tokens,
315 Ordering::AcqRel,
316 Ordering::Acquire,
317 ) {
318 Ok(_) => return,
319 Err(_) => {
320 core::hint::spin_loop();
321 continue;
322 }
323 }
324 }
325 }
326
327 pub fn available_tokens(&self) -> u64 {
331 self.refill();
332 self.tokens.load(Ordering::Acquire)
333 }
334
335 pub fn last_access_nanos(&self) -> u64 {
337 self.last_access.load(Ordering::Relaxed)
338 }
339}
340
341#[derive(Debug)]
343pub struct KeyedRateLimiter {
344 buckets: RwLock<HashMap<String, Arc<TokenBucket>>>,
346 config: RateLimitConfig,
348 max_keys: usize,
350 multiplier: AtomicU64,
352}
353
354impl KeyedRateLimiter {
355 pub fn new(config: RateLimitConfig) -> Self {
357 Self {
358 buckets: RwLock::new(HashMap::new()),
359 config,
360 max_keys: 100_000, multiplier: AtomicU64::new(1000),
362 }
363 }
364
365 pub fn set_multiplier(&self, multiplier: f64) {
367 let m = (multiplier * 1000.0) as u64;
368 self.multiplier.store(m, Ordering::Relaxed);
369
370 let new_rps = (self.config.rps as f64 * multiplier) as u32;
372 let buckets = self.buckets.read();
373 for bucket in buckets.values() {
374 bucket.set_rate(new_rps);
375 }
376 }
377
378 pub fn with_max_keys(mut self, max_keys: usize) -> Self {
380 self.max_keys = max_keys;
381 self
382 }
383
384 pub fn check(&self, key: &str) -> RateLimitDecision {
386 if !self.config.enabled {
387 return RateLimitDecision::Allow;
388 }
389
390 {
392 let buckets = self.buckets.read();
393 if let Some(bucket) = buckets.get(key) {
394 return if bucket.try_acquire() {
395 RateLimitDecision::Allow
396 } else {
397 debug!("Rate limited key: {}", key);
398 RateLimitDecision::Limited
399 };
400 }
401 }
402
403 {
405 let mut buckets = self.buckets.write();
406
407 if buckets.len() >= self.max_keys {
411 warn!(
412 "Rate limiter at capacity ({}), evicting stale entries",
413 buckets.len()
414 );
415 let evict_count = self.max_keys / 10;
416 let mut entries: Vec<_> = buckets
417 .iter()
418 .map(|(k, v)| (k.clone(), v.last_access.load(Ordering::Relaxed)))
419 .collect();
420 entries.sort_unstable_by_key(|&(_, ts)| ts);
421 for (k, _) in entries.into_iter().take(evict_count) {
422 buckets.remove(&k);
423 }
424 }
425
426 let multiplier = self.multiplier.load(Ordering::Relaxed) as f64 / 1000.0;
427 let effective_rps = (self.config.rps as f64 * multiplier) as u32;
428 let bucket = Arc::new(TokenBucket::new(effective_rps, self.config.burst));
429 let allowed = bucket.try_acquire();
430 buckets.insert(key.to_string(), bucket);
431
432 if allowed {
433 RateLimitDecision::Allow
434 } else {
435 RateLimitDecision::Limited
436 }
437 }
438 }
439
440 pub fn key_count(&self) -> usize {
442 self.buckets.read().len()
443 }
444
445 pub fn clear(&self) {
447 self.buckets.write().clear();
448 }
449}
450
451#[derive(Debug)]
453pub struct RateLimitManager {
454 site_limiters: RwLock<HashMap<String, Arc<KeyedRateLimiter>>>,
456 global_limiter: Option<Arc<KeyedRateLimiter>>,
458 default_config: RateLimitConfig,
460}
461
462impl RateLimitManager {
463 pub fn new() -> Self {
465 Self {
466 site_limiters: RwLock::new(HashMap::new()),
467 global_limiter: None,
468 default_config: RateLimitConfig::default(),
469 }
470 }
471
472 pub fn with_global(config: RateLimitConfig) -> Self {
474 Self {
475 site_limiters: RwLock::new(HashMap::new()),
476 global_limiter: Some(Arc::new(KeyedRateLimiter::new(config.clone()))),
477 default_config: config,
478 }
479 }
480
481 pub fn set_default_config(&mut self, config: RateLimitConfig) {
483 self.default_config = config;
484 }
485
486 pub fn add_site(&self, hostname: &str, config: RateLimitConfig) {
488 let limiter = Arc::new(KeyedRateLimiter::new(config));
489 self.site_limiters
490 .write()
491 .insert(hostname.to_lowercase(), limiter);
492 }
493
494 pub fn remove_site(&self, hostname: &str) {
496 self.site_limiters.write().remove(&hostname.to_lowercase());
497 }
498
499 pub fn check(&self, hostname: &str, key: &str) -> RateLimitDecision {
505 if let Some(global) = &self.global_limiter {
507 if matches!(global.check(key), RateLimitDecision::Limited) {
508 return RateLimitDecision::Limited;
509 }
510 }
511
512 let normalized = hostname.to_lowercase();
514 let limiters = self.site_limiters.read();
515
516 if let Some(limiter) = limiters.get(&normalized) {
517 return limiter.check(key);
518 }
519
520 RateLimitDecision::Allow
522 }
523
524 pub fn is_allowed(&self, hostname: &str, key: &str) -> bool {
526 matches!(self.check(hostname, key), RateLimitDecision::Allow)
527 }
528
529 pub fn stats(&self) -> RateLimitStats {
531 let limiters = self.site_limiters.read();
532 let total_keys: usize = limiters.values().map(|l| l.key_count()).sum();
533 let global_keys = self
534 .global_limiter
535 .as_ref()
536 .map(|l| l.key_count())
537 .unwrap_or(0);
538
539 RateLimitStats {
540 site_count: limiters.len(),
541 total_tracked_keys: total_keys + global_keys,
542 global_enabled: self.global_limiter.is_some(),
543 }
544 }
545}
546
547impl Default for RateLimitManager {
548 fn default() -> Self {
549 Self::new()
550 }
551}
552
553#[derive(Debug, Clone, Serialize, Deserialize)]
555pub struct RateLimitStats {
556 pub site_count: usize,
558 pub total_tracked_keys: usize,
560 pub global_enabled: bool,
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use std::thread;
568 use std::time::Duration;
569
570 #[test]
571 fn test_token_bucket_basic() {
572 let bucket = TokenBucket::new(10, 10); for _ in 0..10 {
576 assert!(bucket.try_acquire());
577 }
578
579 assert!(!bucket.try_acquire());
581 }
582
583 #[test]
584 fn test_token_bucket_refill() {
585 let bucket = TokenBucket::new(1000, 10); for _ in 0..10 {
589 bucket.try_acquire();
590 }
591
592 thread::sleep(Duration::from_millis(20));
594
595 assert!(bucket.try_acquire());
597 }
598
599 #[test]
600 fn test_rate_limit_config() {
601 let config = RateLimitConfig::new(100).with_burst(200);
602 assert_eq!(config.rps, 100);
603 assert_eq!(config.burst, 200);
604 assert!(config.enabled);
605 }
606
607 #[test]
608 fn test_rate_limit_disabled() {
609 let config = RateLimitConfig::disabled();
610 let limiter = KeyedRateLimiter::new(config);
611
612 for _ in 0..1000 {
614 assert!(matches!(limiter.check("key"), RateLimitDecision::Allow));
615 }
616 }
617
618 #[test]
619 fn test_keyed_rate_limiter() {
620 let config = RateLimitConfig::new(5).with_burst(5);
621 let limiter = KeyedRateLimiter::new(config);
622
623 for _ in 0..5 {
625 assert!(matches!(limiter.check("key1"), RateLimitDecision::Allow));
626 assert!(matches!(limiter.check("key2"), RateLimitDecision::Allow));
627 }
628
629 assert!(matches!(limiter.check("key1"), RateLimitDecision::Limited));
631 assert!(matches!(limiter.check("key2"), RateLimitDecision::Limited));
632 }
633
634 #[test]
635 fn test_keyed_limiter_key_count() {
636 let config = RateLimitConfig::new(10);
637 let limiter = KeyedRateLimiter::new(config);
638
639 limiter.check("key1");
640 limiter.check("key2");
641 limiter.check("key3");
642
643 assert_eq!(limiter.key_count(), 3);
644 }
645
646 #[test]
647 fn test_rate_limit_manager() {
648 let manager = RateLimitManager::new();
649
650 manager.add_site("api.example.com", RateLimitConfig::new(2).with_burst(2));
652
653 assert!(manager.is_allowed("api.example.com", "client1"));
655 assert!(manager.is_allowed("api.example.com", "client1"));
656 assert!(!manager.is_allowed("api.example.com", "client1"));
657
658 assert!(manager.is_allowed("other.example.com", "client1"));
660 }
661
662 #[test]
663 fn test_global_rate_limit() {
664 let manager = RateLimitManager::with_global(RateLimitConfig::new(3).with_burst(3));
665
666 assert!(manager.is_allowed("any.com", "client1"));
668 assert!(manager.is_allowed("any.com", "client1"));
669 assert!(manager.is_allowed("any.com", "client1"));
670 assert!(!manager.is_allowed("any.com", "client1"));
671 }
672
673 #[test]
674 fn test_manager_case_insensitive() {
675 let manager = RateLimitManager::new();
676 manager.add_site("Example.COM", RateLimitConfig::new(1).with_burst(1));
677
678 assert!(manager.is_allowed("example.com", "client"));
679 assert!(!manager.is_allowed("EXAMPLE.COM", "client"));
680 }
681
682 #[test]
683 fn test_keyed_limiter_clear() {
684 let config = RateLimitConfig::new(10);
685 let limiter = KeyedRateLimiter::new(config);
686
687 limiter.check("key1");
688 limiter.check("key2");
689 assert_eq!(limiter.key_count(), 2);
690
691 limiter.clear();
692 assert_eq!(limiter.key_count(), 0);
693 }
694
695 #[test]
696 fn test_stats() {
697 let manager = RateLimitManager::with_global(RateLimitConfig::new(100));
698 manager.add_site("site1.com", RateLimitConfig::new(50));
699 manager.add_site("site2.com", RateLimitConfig::new(50));
700
701 manager.check("site1.com", "ip1");
703 manager.check("site2.com", "ip2");
704
705 let stats = manager.stats();
706 assert_eq!(stats.site_count, 2);
707 assert!(stats.global_enabled);
708 }
709
710 #[test]
711 fn test_available_tokens() {
712 let bucket = TokenBucket::new(100, 50);
713 assert_eq!(bucket.available_tokens(), 50); }
715
716 #[test]
722 fn test_concurrent_token_bucket_no_burst_bypass() {
723 use std::sync::atomic::AtomicUsize;
724
725 let bucket = Arc::new(TokenBucket::new(10, 100)); let successful_acquires = Arc::new(AtomicUsize::new(0));
727
728 let handles: Vec<_> = (0..10)
730 .map(|_| {
731 let bucket = Arc::clone(&bucket);
732 let counter = Arc::clone(&successful_acquires);
733
734 thread::spawn(move || {
735 for _ in 0..50 {
736 if bucket.try_acquire() {
737 counter.fetch_add(1, Ordering::Relaxed);
738 }
739 }
740 })
741 })
742 .collect();
743
744 for handle in handles {
746 handle.join().unwrap();
747 }
748
749 let total = successful_acquires.load(Ordering::Relaxed);
750
751 assert!(
755 total <= 100,
756 "Race condition detected! Got {} successful acquires, expected <= 100",
757 total
758 );
759
760 assert!(
762 total >= 95,
763 "Token bucket may have performance issue: only {} acquires, expected ~100",
764 total
765 );
766 }
767
768 #[test]
770 fn test_concurrent_refill_no_double_add() {
771 let bucket = Arc::new(TokenBucket::new(1000, 10)); for _ in 0..10 {
775 bucket.try_acquire();
776 }
777
778 thread::sleep(Duration::from_millis(50)); let tokens_before = bucket.available_tokens();
782
783 let handles: Vec<_> = (0..10)
785 .map(|_| {
786 let bucket = Arc::clone(&bucket);
787 thread::spawn(move || {
788 bucket.available_tokens()
790 })
791 })
792 .collect();
793
794 for handle in handles {
795 handle.join().unwrap();
796 }
797
798 let tokens_after = bucket.available_tokens();
799
800 assert!(
803 tokens_after <= tokens_before + 10,
804 "Possible double-add race: before={}, after={}",
805 tokens_before,
806 tokens_after
807 );
808 }
809}