1use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RateLimitConfig {
13 pub requests_per_second: u32,
15 pub requests_per_minute: u32,
17 pub requests_per_hour: u32,
19 pub burst_size: u32,
21 #[serde(default = "default_minute_capacity")]
23 pub minute_window_capacity: usize,
24 #[serde(default = "default_hour_capacity")]
26 pub hour_window_capacity: usize,
27}
28
29fn default_minute_capacity() -> usize {
30 1000
31}
32fn default_hour_capacity() -> usize {
33 10000
34}
35
36impl Default for RateLimitConfig {
37 fn default() -> Self {
38 Self {
39 requests_per_second: 10,
40 requests_per_minute: 100,
41 requests_per_hour: 1000,
42 burst_size: 20,
43 minute_window_capacity: 1000,
44 hour_window_capacity: 10000,
45 }
46 }
47}
48
49#[derive(Debug)]
51struct TokenBucket {
52 tokens: f64,
54 max_tokens: f64,
56 refill_rate: f64,
58 last_update: Instant,
60}
61
62impl TokenBucket {
63 fn new(max_tokens: f64, refill_rate: f64) -> Self {
64 Self {
65 tokens: max_tokens,
66 max_tokens,
67 refill_rate,
68 last_update: Instant::now(),
69 }
70 }
71
72 fn try_take(&mut self, tokens: f64) -> bool {
73 let elapsed = self.last_update.elapsed().as_secs_f64();
75 self.tokens = (self.tokens + elapsed * self.refill_rate).min(self.max_tokens);
76 self.last_update = Instant::now();
77
78 if self.tokens >= tokens {
80 self.tokens -= tokens;
81 true
82 } else {
83 false
84 }
85 }
86}
87
88pub struct RateLimiter {
90 config: RateLimitConfig,
92 buckets: DashMap<String, TokenBucket>,
94 counters: DashMap<String, SlidingWindowCounter>,
96}
97
98#[derive(Debug)]
100struct SlidingWindowCounter {
101 minute_requests: Vec<Instant>,
103 hour_requests: Vec<Instant>,
105 minute_capacity: usize,
107 hour_capacity: usize,
109}
110
111impl SlidingWindowCounter {
112 fn new() -> Self {
113 Self {
114 minute_requests: Vec::new(),
115 hour_requests: Vec::new(),
116 minute_capacity: 1000, hour_capacity: 10000, }
119 }
120
121 fn with_capacity(minute_capacity: usize, hour_capacity: usize) -> Self {
122 Self {
123 minute_requests: Vec::new(),
124 hour_requests: Vec::new(),
125 minute_capacity,
126 hour_capacity,
127 }
128 }
129
130 fn add_request(&mut self) {
131 let now = Instant::now();
132 self.minute_requests.push(now);
133 self.hour_requests.push(now);
134
135 self.minute_requests
137 .retain(|t| t.elapsed() < Duration::from_secs(60));
138 self.hour_requests
139 .retain(|t| t.elapsed() < Duration::from_secs(3600));
140
141 if self.minute_requests.len() > self.minute_capacity {
143 let excess = self.minute_requests.len() - self.minute_capacity;
144 self.minute_requests.drain(0..excess);
145 }
146 if self.hour_requests.len() > self.hour_capacity {
147 let excess = self.hour_requests.len() - self.hour_capacity;
148 self.hour_requests.drain(0..excess);
149 }
150 }
151
152 fn minute_count(&self) -> usize {
153 self.minute_requests.len()
154 }
155
156 fn hour_count(&self) -> usize {
157 self.hour_requests.len()
158 }
159}
160
161impl RateLimiter {
162 pub fn new() -> Self {
163 Self::with_config(RateLimitConfig::default())
164 }
165
166 pub fn with_config(config: RateLimitConfig) -> Self {
167 Self {
168 config: config.clone(),
169 buckets: DashMap::new(),
170 counters: DashMap::new(),
171 }
172 }
173
174 pub async fn check(&self, key: &str) -> anyhow::Result<bool> {
176 let bucket_result = {
178 let mut bucket = self.buckets.entry(key.to_string()).or_insert_with(|| {
179 TokenBucket::new(
180 self.config.burst_size as f64,
181 self.config.requests_per_second as f64,
182 )
183 });
184 bucket.try_take(1.0)
185 };
186
187 if !bucket_result {
188 return Ok(false);
189 }
190
191 let window_result = {
193 let minute_cap = self.config.minute_window_capacity;
194 let hour_cap = self.config.hour_window_capacity;
195 let mut counter = self
196 .counters
197 .entry(key.to_string())
198 .or_insert_with(|| SlidingWindowCounter::with_capacity(minute_cap, hour_cap));
199
200 let minute_exceeded =
201 counter.minute_count() >= self.config.requests_per_minute as usize;
202 let hour_exceeded = counter.hour_count() >= self.config.requests_per_hour as usize;
203
204 if minute_exceeded || hour_exceeded {
205 false
206 } else {
207 counter.add_request();
208 true
209 }
210 };
211
212 Ok(window_result)
213 }
214
215 pub fn reset(&self, key: &str) {
217 self.buckets.remove(key);
218 self.counters.remove(key);
219 }
220
221 pub fn get_status(&self, key: &str) -> RateLimitStatus {
223 let tokens_remaining = self
224 .buckets
225 .get(key)
226 .map(|b| b.tokens as u32)
227 .unwrap_or(self.config.burst_size);
228
229 let minute_remaining = self.config.requests_per_minute
230 - self
231 .counters
232 .get(key)
233 .map(|c| c.minute_count() as u32)
234 .unwrap_or(0);
235
236 let hour_remaining = self.config.requests_per_hour
237 - self
238 .counters
239 .get(key)
240 .map(|c| c.hour_count() as u32)
241 .unwrap_or(0);
242
243 RateLimitStatus {
244 tokens_remaining,
245 minute_remaining,
246 hour_remaining,
247 }
248 }
249
250 pub fn cleanup_expired(&self, max_age: Duration) {
252 let now = Instant::now();
253
254 self.buckets
256 .retain(|_, bucket| now.duration_since(bucket.last_update) < max_age);
257
258 self.counters.retain(|_, counter| {
260 !counter.minute_requests.is_empty() || !counter.hour_requests.is_empty()
261 });
262 }
263
264 pub fn active_keys(&self) -> usize {
266 self.buckets.len()
267 }
268}
269
270#[derive(Debug, Serialize, Deserialize)]
272pub struct RateLimitStatus {
273 pub tokens_remaining: u32,
274 pub minute_remaining: u32,
275 pub hour_remaining: u32,
276}
277
278impl Default for RateLimiter {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287 use std::sync::Arc;
288
289 #[tokio::test]
290 async fn test_basic_rate_limit() {
291 let limiter = RateLimiter::new();
292
293 for _ in 0..10 {
295 assert!(limiter.check("test_key").await.unwrap());
296 }
297 }
298
299 #[tokio::test]
300 async fn test_rate_limit_exceeded() {
301 let config = RateLimitConfig {
302 requests_per_second: 1,
303 requests_per_minute: 2,
304 requests_per_hour: 3,
305 burst_size: 2,
306 ..Default::default()
307 };
308 let limiter = RateLimiter::with_config(config);
309
310 assert!(limiter.check("test_key").await.unwrap());
312 assert!(limiter.check("test_key").await.unwrap());
313
314 assert!(!limiter.check("test_key").await.unwrap());
316 }
317
318 #[tokio::test]
319 async fn test_concurrent_requests() {
320 let config = RateLimitConfig {
321 requests_per_second: 100,
322 requests_per_minute: 1000,
323 requests_per_hour: 10000,
324 burst_size: 50,
325 ..Default::default()
326 };
327 let limiter = Arc::new(RateLimiter::with_config(config));
328
329 let mut tasks = vec![];
330
331 for _ in 0..100 {
332 let limiter_clone = Arc::clone(&limiter);
333 tasks.push(tokio::spawn(async move {
334 limiter_clone.check("concurrent_key").await.unwrap()
335 }));
336 }
337
338 let results: Vec<bool> = futures::future::join_all(tasks)
339 .await
340 .into_iter()
341 .map(|r| r.unwrap())
342 .collect();
343
344 let success_count = results.iter().filter(|&&r| r).count();
346 let fail_count = results.iter().filter(|&&r| !r).count();
347
348 assert!(success_count > 0, "At least some requests should succeed");
350 println!("Success: {}, Fail: {}", success_count, fail_count);
351 }
352
353 #[tokio::test]
354 async fn test_burst_handling() {
355 let config = RateLimitConfig {
356 requests_per_second: 5,
357 requests_per_minute: 100,
358 requests_per_hour: 1000,
359 burst_size: 10,
360 ..Default::default()
361 };
362 let limiter = RateLimiter::with_config(config);
363
364 let mut success_count = 0;
366 for _ in 0..20 {
367 if limiter.check("burst_key").await.unwrap() {
368 success_count += 1;
369 }
370 }
371
372 assert!(
374 success_count <= 11,
375 "Burst should be limited, but got {} successes",
376 success_count
377 );
378 assert!(
379 success_count >= 8,
380 "At least burst_size requests should succeed, but got {}",
381 success_count
382 );
383 }
384
385 #[tokio::test]
386 async fn test_token_refill_accuracy() {
387 let config = RateLimitConfig {
388 requests_per_second: 10,
389 requests_per_minute: 100,
390 requests_per_hour: 1000,
391 burst_size: 5,
392 ..Default::default()
393 };
394 let limiter = RateLimiter::with_config(config);
395
396 for _ in 0..5 {
398 assert!(limiter.check("refill_key").await.unwrap());
399 }
400
401 assert!(!limiter.check("refill_key").await.unwrap());
403
404 tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
406
407 assert!(
409 limiter.check("refill_key").await.unwrap(),
410 "Token should be refilled after waiting"
411 );
412 }
413
414 #[tokio::test]
415 async fn test_different_keys_isolated() {
416 let config = RateLimitConfig {
417 requests_per_second: 1,
418 requests_per_minute: 1,
419 requests_per_hour: 1,
420 burst_size: 1,
421 ..Default::default()
422 };
423 let limiter = RateLimiter::with_config(config);
424
425 assert!(limiter.check("key1").await.unwrap());
427 assert!(!limiter.check("key1").await.unwrap());
428
429 assert!(limiter.check("key2").await.unwrap());
431 assert!(!limiter.check("key2").await.unwrap());
432 }
433
434 #[test]
435 fn test_reset_functionality() {
436 let config = RateLimitConfig {
437 requests_per_second: 1,
438 requests_per_minute: 1,
439 requests_per_hour: 1,
440 burst_size: 1,
441 ..Default::default()
442 };
443 let limiter = RateLimiter::with_config(config);
444
445 let rt = tokio::runtime::Runtime::new().unwrap();
447 rt.block_on(async {
448 assert!(limiter.check("reset_key").await.unwrap());
449 assert!(!limiter.check("reset_key").await.unwrap());
450 });
451
452 limiter.reset("reset_key");
454
455 rt.block_on(async {
456 assert!(limiter.check("reset_key").await.unwrap());
458 });
459 }
460
461 #[test]
462 fn test_status_reporting() {
463 let config = RateLimitConfig {
464 requests_per_second: 10,
465 requests_per_minute: 100,
466 requests_per_hour: 1000,
467 burst_size: 20,
468 ..Default::default()
469 };
470 let limiter = RateLimiter::with_config(config);
471
472 let rt = tokio::runtime::Runtime::new().unwrap();
473 rt.block_on(async {
474 for _ in 0..5 {
476 limiter.check("status_key").await.unwrap();
477 }
478 });
479
480 let status = limiter.get_status("status_key");
481 assert!(status.tokens_remaining < 20, "Tokens should be consumed");
482 assert!(
483 status.minute_remaining < 100,
484 "Minute count should increase"
485 );
486 }
487
488 #[test]
489 fn test_cleanup_expired() {
490 let limiter = RateLimiter::new();
491
492 let rt = tokio::runtime::Runtime::new().unwrap();
494 rt.block_on(async {
495 limiter.check("key1").await.unwrap();
496 limiter.check("key2").await.unwrap();
497 });
498
499 assert!(limiter.active_keys() >= 2);
500
501 limiter.cleanup_expired(Duration::from_secs(0));
503
504 assert_eq!(limiter.active_keys(), 0);
506 }
507
508 #[test]
509 fn test_active_keys_count() {
510 let limiter = RateLimiter::new();
511
512 let rt = tokio::runtime::Runtime::new().unwrap();
513 rt.block_on(async {
514 limiter.check("key1").await.unwrap();
515 limiter.check("key2").await.unwrap();
516 limiter.check("key3").await.unwrap();
517 });
518
519 assert_eq!(limiter.active_keys(), 3);
520 }
521
522 #[test]
525 fn test_zero_rate_limit() {
526 let config = RateLimitConfig {
529 requests_per_second: 0,
530 requests_per_minute: 100,
531 requests_per_hour: 1000,
532 burst_size: 2,
533 ..Default::default()
534 };
535 let limiter = RateLimiter::with_config(config);
536
537 let rt = tokio::runtime::Runtime::new().unwrap();
538 rt.block_on(async {
539 let first = limiter.check("key").await.unwrap();
541 assert!(first, "First request with burst_size=2 should succeed");
542
543 let second = limiter.check("key").await.unwrap();
544 assert!(second, "Second request with burst_size=2 should succeed");
545
546 let third = limiter.check("key").await.unwrap();
548 assert!(!third, "Third request should be rate limited (no refill)");
549 });
550 }
551
552 #[test]
553 fn test_very_small_burst_size() {
554 let config = RateLimitConfig {
555 requests_per_second: 1,
556 requests_per_minute: 100,
557 requests_per_hour: 1000,
558 burst_size: 1,
559 ..Default::default()
560 };
561 let limiter = RateLimiter::with_config(config);
562
563 let rt = tokio::runtime::Runtime::new().unwrap();
564 rt.block_on(async {
565 assert!(limiter.check("key").await.unwrap());
566 assert!(!limiter.check("key").await.unwrap());
567 });
568 }
569
570 #[test]
571 fn test_large_burst_size() {
572 let config = RateLimitConfig {
573 requests_per_second: 1000,
574 requests_per_minute: 100000,
575 requests_per_hour: 1000000,
576 burst_size: 1000,
577 ..Default::default()
578 };
579 let limiter = RateLimiter::with_config(config);
580
581 let rt = tokio::runtime::Runtime::new().unwrap();
582 rt.block_on(async {
583 let mut success_count = 0;
584 for _ in 0..500 {
585 if limiter.check("key").await.unwrap() {
586 success_count += 1;
587 }
588 }
589 assert!(
590 success_count >= 400,
591 "Should allow most requests with large burst"
592 );
593 });
594 }
595
596 #[test]
597 fn test_empty_key() {
598 let limiter = RateLimiter::new();
599
600 let rt = tokio::runtime::Runtime::new().unwrap();
601 rt.block_on(async {
602 assert!(limiter.check("").await.unwrap());
604 });
605 }
606
607 #[test]
608 fn test_special_characters_in_key() {
609 let limiter = RateLimiter::new();
610
611 let rt = tokio::runtime::Runtime::new().unwrap();
612 rt.block_on(async {
613 let special_keys = vec![
615 "key:with:colons",
616 "key-with-dashes",
617 "key_with_underscores",
618 "key.with.dots",
619 "key/with/slashes",
620 ];
621 for key in special_keys {
622 assert!(
623 limiter.check(key).await.unwrap(),
624 "Key '{}' should work",
625 key
626 );
627 }
628 });
629 }
630
631 #[test]
632 fn test_unicode_key() {
633 let limiter = RateLimiter::new();
634
635 let rt = tokio::runtime::Runtime::new().unwrap();
636 rt.block_on(async {
637 assert!(limiter.check("用户_123").await.unwrap());
639 assert!(limiter.check("🔑_key").await.unwrap());
640 });
641 }
642
643 #[test]
644 fn test_very_long_key() {
645 let limiter = RateLimiter::new();
646 let long_key = "a".repeat(10000);
647
648 let rt = tokio::runtime::Runtime::new().unwrap();
649 rt.block_on(async {
650 assert!(limiter.check(&long_key).await.unwrap());
651 });
652 }
653
654 #[test]
655 fn test_reset_nonexistent_key() {
656 let limiter = RateLimiter::new();
657
658 limiter.reset("nonexistent_key");
660 assert_eq!(limiter.active_keys(), 0);
661 }
662
663 #[test]
664 fn test_status_nonexistent_key() {
665 let limiter = RateLimiter::new();
666 let config = RateLimitConfig::default();
667
668 let status = limiter.get_status("nonexistent");
669 assert_eq!(status.tokens_remaining, config.burst_size);
671 assert_eq!(status.minute_remaining, config.requests_per_minute);
672 assert_eq!(status.hour_remaining, config.requests_per_hour);
673 }
674
675 #[tokio::test]
676 async fn test_rapid_requests() {
677 let config = RateLimitConfig {
678 requests_per_second: 10,
679 requests_per_minute: 100,
680 requests_per_hour: 1000,
681 burst_size: 5,
682 ..Default::default()
683 };
684 let limiter = RateLimiter::with_config(config);
685
686 let mut success_count = 0;
688 for _ in 0..20 {
689 if limiter.check("rapid").await.unwrap() {
690 success_count += 1;
691 }
692 }
693
694 assert!(
696 success_count <= 7,
697 "Expected ~5 successful requests, got {}",
698 success_count
699 );
700 }
701
702 #[test]
703 fn test_cleanup_with_negative_duration() {
704 let limiter = RateLimiter::new();
705
706 let rt = tokio::runtime::Runtime::new().unwrap();
707 rt.block_on(async {
708 limiter.check("key").await.unwrap();
709 });
710
711 limiter.cleanup_expired(Duration::from_secs(u64::MAX));
714
715 assert!(limiter.active_keys() >= 1);
717 }
718
719 #[tokio::test]
720 async fn test_status_accuracy() {
721 let config = RateLimitConfig {
722 requests_per_second: 10,
723 requests_per_minute: 100,
724 requests_per_hour: 1000,
725 burst_size: 10,
726 ..Default::default()
727 };
728 let limiter = RateLimiter::with_config(config);
729
730 for _ in 0..3 {
732 limiter.check("status_test").await.unwrap();
733 }
734
735 let status = limiter.get_status("status_test");
736 assert!(status.tokens_remaining < 10);
738 assert!(status.tokens_remaining > 0);
740 }
741
742 #[test]
743 fn test_config_default_values() {
744 let config = RateLimitConfig::default();
745 assert_eq!(config.requests_per_second, 10);
746 assert_eq!(config.requests_per_minute, 100);
747 assert_eq!(config.requests_per_hour, 1000);
748 assert_eq!(config.burst_size, 20);
749 }
750
751 #[test]
752 fn test_config_serialization() {
753 let config = RateLimitConfig::default();
754 let json = serde_json::to_string(&config).unwrap();
755 let parsed: RateLimitConfig = serde_json::from_str(&json).unwrap();
756 assert_eq!(parsed.requests_per_second, config.requests_per_second);
757 }
758
759 #[tokio::test]
760 async fn test_token_refill_boundary() {
761 let config = RateLimitConfig {
762 requests_per_second: 100, requests_per_minute: 10000,
764 requests_per_hour: 100000,
765 burst_size: 10,
766 ..Default::default()
767 };
768 let limiter = RateLimiter::with_config(config);
769
770 for _ in 0..10 {
772 limiter.check("refill_boundary").await.unwrap();
773 }
774
775 assert!(!limiter.check("refill_boundary").await.unwrap());
777
778 tokio::time::sleep(tokio::time::Duration::from_millis(15)).await;
780
781 assert!(limiter.check("refill_boundary").await.unwrap());
783 }
784}