tower_resilience_ratelimiter/
lib.rs1mod config;
216mod error;
217mod events;
218mod handle;
219mod layer;
220mod limiter;
221
222pub use config::{RateLimiterConfig, RateLimiterConfigBuilder, WindowType};
223pub use error::{RateLimiterError, RateLimiterServiceError};
224pub use events::RateLimiterEvent;
225pub use handle::RateLimiterHandle;
226pub use layer::RateLimiterLayer;
227
228use crate::limiter::SharedRateLimiter;
229use futures::future::BoxFuture;
230use futures::Future;
231use std::pin::Pin;
232use std::sync::Arc;
233use std::task::{Context, Poll};
234use std::time::Instant;
235use tower::Service;
236
237#[cfg(feature = "metrics")]
238use metrics::{counter, describe_counter, describe_histogram, histogram};
239
240#[cfg(feature = "tracing")]
241use tracing::{debug, warn};
242
243pub struct RateLimiter<S> {
260 inner: S,
261 config: Arc<RateLimiterConfig>,
262 limiter: SharedRateLimiter,
263 sleep: Option<Pin<Box<tokio::time::Sleep>>>,
265 permit_acquired: bool,
267}
268
269impl<S> RateLimiter<S> {
270 pub fn new(inner: S, config: Arc<RateLimiterConfig>) -> Self {
272 #[cfg(feature = "metrics")]
273 {
274 describe_counter!(
275 "ratelimiter_calls_total",
276 "Total number of rate limiter calls (permitted or rejected)"
277 );
278 describe_histogram!(
279 "ratelimiter_wait_duration_seconds",
280 "Time spent waiting for a permit"
281 );
282 }
283
284 let limiter = SharedRateLimiter::new(
285 config.window_type,
286 config.limit_for_period,
287 config.refresh_period,
288 config.timeout_duration,
289 );
290
291 Self {
292 inner,
293 config,
294 limiter,
295 sleep: None,
296 permit_acquired: false,
297 }
298 }
299
300 pub(crate) fn from_shared(
302 inner: S,
303 config: Arc<RateLimiterConfig>,
304 limiter: SharedRateLimiter,
305 ) -> Self {
306 #[cfg(feature = "metrics")]
307 {
308 describe_counter!(
309 "ratelimiter_calls_total",
310 "Total number of rate limiter calls (permitted or rejected)"
311 );
312 describe_histogram!(
313 "ratelimiter_wait_duration_seconds",
314 "Time spent waiting for a permit"
315 );
316 }
317
318 Self {
319 inner,
320 config,
321 limiter,
322 sleep: None,
323 permit_acquired: false,
324 }
325 }
326}
327
328impl<S> Clone for RateLimiter<S>
329where
330 S: Clone,
331{
332 fn clone(&self) -> Self {
333 Self {
334 inner: self.inner.clone(),
335 config: Arc::clone(&self.config),
336 limiter: self.limiter.clone(),
337 sleep: None,
338 permit_acquired: false,
339 }
340 }
341}
342
343impl<S, Req> Service<Req> for RateLimiter<S>
344where
345 S: Service<Req> + Clone + Send + 'static,
346 S::Future: Send + 'static,
347 S::Error: Send + 'static,
348 Req: Send + 'static,
349{
350 type Response = S::Response;
351 type Error = RateLimiterServiceError<S::Error>;
352 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
353
354 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355 match self.inner.poll_ready(cx) {
357 Poll::Pending => return Poll::Pending,
358 Poll::Ready(Err(e)) => return Poll::Ready(Err(RateLimiterServiceError::Inner(e))),
359 Poll::Ready(Ok(())) => {}
360 }
361
362 if !self.config.backpressure {
363 return Poll::Ready(Ok(()));
364 }
365
366 if self.permit_acquired {
368 return Poll::Ready(Ok(()));
369 }
370
371 if let Some(sleep) = self.sleep.as_mut() {
373 match sleep.as_mut().poll(cx) {
374 Poll::Pending => return Poll::Pending,
375 Poll::Ready(()) => {
376 self.sleep = None;
377 }
379 }
380 }
381
382 match self.limiter.try_acquire_now() {
383 Ok(()) => {
384 self.permit_acquired = true;
385 Poll::Ready(Ok(()))
386 }
387 Err(wait_duration) => {
388 let sleep = tokio::time::sleep(wait_duration);
389 let mut pinned = Box::pin(sleep);
390 let _ = pinned.as_mut().poll(cx);
392 self.sleep = Some(pinned);
393 Poll::Pending
394 }
395 }
396 }
397
398 fn call(&mut self, req: Req) -> Self::Future {
399 if self.permit_acquired {
400 self.permit_acquired = false;
402 let config = Arc::clone(&self.config);
403 let mut inner = self.inner.clone();
404
405 let event = RateLimiterEvent::PermitAcquired {
406 pattern_name: config.name.clone(),
407 timestamp: Instant::now(),
408 wait_duration: std::time::Duration::ZERO,
409 };
410 config.event_listeners.emit(&event);
411
412 #[cfg(feature = "metrics")]
413 {
414 counter!("ratelimiter_calls_total", "ratelimiter" => config.name.clone(), "result" => "permitted").increment(1);
415 histogram!("ratelimiter_wait_duration_seconds", "ratelimiter" => config.name.clone())
416 .record(0.0);
417 }
418
419 #[cfg(feature = "tracing")]
420 debug!(ratelimiter = %config.name, "Permit acquired via backpressure");
421
422 return Box::pin(async move {
423 inner
424 .call(req)
425 .await
426 .map_err(RateLimiterServiceError::Inner)
427 });
428 }
429
430 let limiter = self.limiter.clone();
432 let config = Arc::clone(&self.config);
433 let mut inner = self.inner.clone();
434
435 Box::pin(async move {
436 match limiter.acquire().await {
437 Ok(wait_duration) => {
438 let event = RateLimiterEvent::PermitAcquired {
439 pattern_name: config.name.clone(),
440 timestamp: Instant::now(),
441 wait_duration,
442 };
443 config.event_listeners.emit(&event);
444
445 #[cfg(feature = "metrics")]
446 {
447 counter!("ratelimiter_calls_total", "ratelimiter" => config.name.clone(), "result" => "permitted").increment(1);
448 histogram!("ratelimiter_wait_duration_seconds", "ratelimiter" => config.name.clone())
449 .record(wait_duration.as_secs_f64());
450 }
451
452 #[cfg(feature = "tracing")]
453 {
454 if wait_duration.as_millis() > 0 {
455 debug!(
456 ratelimiter = %config.name,
457 wait_ms = wait_duration.as_millis(),
458 "Permit acquired after waiting"
459 );
460 } else {
461 debug!(ratelimiter = %config.name, "Permit acquired immediately");
462 }
463 }
464
465 inner
466 .call(req)
467 .await
468 .map_err(RateLimiterServiceError::Inner)
469 }
470 Err(()) => {
471 let event = RateLimiterEvent::PermitRejected {
472 pattern_name: config.name.clone(),
473 timestamp: Instant::now(),
474 timeout_duration: config.timeout_duration,
475 };
476 config.event_listeners.emit(&event);
477
478 #[cfg(feature = "metrics")]
479 {
480 counter!("ratelimiter_calls_total", "ratelimiter" => config.name.clone(), "result" => "rejected").increment(1);
481 }
482
483 #[cfg(feature = "tracing")]
484 warn!(
485 ratelimiter = %config.name,
486 timeout_ms = config.timeout_duration.as_millis(),
487 "Rate limit exceeded - permit rejected"
488 );
489
490 Err(RateLimiterServiceError::RateLimited)
491 }
492 }
493 })
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use std::sync::atomic::{AtomicUsize, Ordering};
501 use std::sync::Arc;
502 use std::time::Duration;
503 use tower::service_fn;
504 use tower::{Layer, ServiceExt};
505
506 #[tokio::test]
507 async fn test_allows_requests_within_limit() {
508 let call_count = Arc::new(AtomicUsize::new(0));
509 let cc = Arc::clone(&call_count);
510
511 let service = service_fn(move |req: String| {
512 let cc = Arc::clone(&cc);
513 async move {
514 cc.fetch_add(1, Ordering::SeqCst);
515 Ok::<_, std::io::Error>(format!("Response: {}", req))
516 }
517 });
518
519 let layer = RateLimiterLayer::builder()
520 .limit_for_period(10)
521 .refresh_period(Duration::from_secs(1))
522 .timeout_duration(Duration::from_millis(100))
523 .build();
524
525 let mut service = layer.layer(service);
526
527 for _ in 0..10 {
529 let result = service
530 .ready()
531 .await
532 .unwrap()
533 .call("test".to_string())
534 .await;
535 assert!(result.is_ok());
536 }
537
538 assert_eq!(call_count.load(Ordering::SeqCst), 10);
539 }
540
541 #[tokio::test]
542 async fn test_rejects_requests_over_limit() {
543 let service = service_fn(|req: String| async move {
544 Ok::<_, std::io::Error>(format!("Response: {}", req))
545 });
546
547 let layer = RateLimiterLayer::builder()
548 .limit_for_period(2)
549 .refresh_period(Duration::from_secs(10))
550 .timeout_duration(Duration::from_millis(10))
551 .build();
552
553 let mut service = layer.layer(service);
554
555 assert!(service
557 .ready()
558 .await
559 .unwrap()
560 .call("1".to_string())
561 .await
562 .is_ok());
563 assert!(service
564 .ready()
565 .await
566 .unwrap()
567 .call("2".to_string())
568 .await
569 .is_ok());
570
571 let result = service.ready().await.unwrap().call("3".to_string()).await;
573 assert!(result.is_err());
574 assert!(matches!(
575 result.unwrap_err(),
576 RateLimiterServiceError::RateLimited
577 ));
578 }
579
580 #[tokio::test]
581 async fn test_permits_refresh_after_period() {
582 let call_count = Arc::new(AtomicUsize::new(0));
583 let cc = Arc::clone(&call_count);
584
585 let service = service_fn(move |_req: String| {
586 let cc = Arc::clone(&cc);
587 async move {
588 cc.fetch_add(1, Ordering::SeqCst);
589 Ok::<_, std::io::Error>("ok".to_string())
590 }
591 });
592
593 let layer = RateLimiterLayer::builder()
594 .limit_for_period(2)
595 .refresh_period(Duration::from_millis(100))
596 .timeout_duration(Duration::from_millis(200))
597 .build();
598
599 let mut service = layer.layer(service);
600
601 assert!(service
603 .ready()
604 .await
605 .unwrap()
606 .call("1".to_string())
607 .await
608 .is_ok());
609 assert!(service
610 .ready()
611 .await
612 .unwrap()
613 .call("2".to_string())
614 .await
615 .is_ok());
616
617 tokio::time::sleep(Duration::from_millis(150)).await;
619
620 assert!(service
622 .ready()
623 .await
624 .unwrap()
625 .call("3".to_string())
626 .await
627 .is_ok());
628 assert_eq!(call_count.load(Ordering::SeqCst), 3);
629 }
630
631 #[tokio::test]
632 async fn test_event_listeners_called() {
633 let acquired_count = Arc::new(AtomicUsize::new(0));
634 let rejected_count = Arc::new(AtomicUsize::new(0));
635
636 let ac = Arc::clone(&acquired_count);
637 let rc = Arc::clone(&rejected_count);
638
639 let service =
640 service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
641
642 let layer = RateLimiterLayer::builder()
643 .limit_for_period(1)
644 .refresh_period(Duration::from_secs(10))
645 .timeout_duration(Duration::from_millis(10))
646 .on_permit_acquired(move |_| {
647 ac.fetch_add(1, Ordering::SeqCst);
648 })
649 .on_permit_rejected(move |_| {
650 rc.fetch_add(1, Ordering::SeqCst);
651 })
652 .build();
653
654 let mut service = layer.layer(service);
655
656 let _ = service.ready().await.unwrap().call("1".to_string()).await;
658 assert_eq!(acquired_count.load(Ordering::SeqCst), 1);
659
660 let _ = service.ready().await.unwrap().call("2".to_string()).await;
662 assert_eq!(rejected_count.load(Ordering::SeqCst), 1);
663 }
664
665 #[tokio::test]
666 async fn test_waits_for_permit_within_timeout() {
667 let service =
668 service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
669
670 let layer = RateLimiterLayer::builder()
671 .limit_for_period(1)
672 .refresh_period(Duration::from_millis(50))
673 .timeout_duration(Duration::from_millis(100)) .build();
675
676 let mut service = layer.layer(service);
677
678 assert!(service
680 .ready()
681 .await
682 .unwrap()
683 .call("1".to_string())
684 .await
685 .is_ok());
686
687 let start = std::time::Instant::now();
689 let result = service.ready().await.unwrap().call("2".to_string()).await;
690 let elapsed = start.elapsed();
691
692 assert!(result.is_ok());
693 assert!(elapsed >= Duration::from_millis(45)); }
695
696 #[tokio::test]
699 async fn test_backpressure_allows_requests_within_limit() {
700 let call_count = Arc::new(AtomicUsize::new(0));
701 let cc = Arc::clone(&call_count);
702
703 let service = service_fn(move |req: String| {
704 let cc = Arc::clone(&cc);
705 async move {
706 cc.fetch_add(1, Ordering::SeqCst);
707 Ok::<_, std::io::Error>(format!("Response: {}", req))
708 }
709 });
710
711 let layer = RateLimiterLayer::builder()
712 .limit_for_period(10)
713 .refresh_period(Duration::from_secs(1))
714 .backpressure()
715 .build();
716
717 let mut service = layer.layer(service);
718
719 for _ in 0..10 {
720 let result = service
721 .ready()
722 .await
723 .unwrap()
724 .call("test".to_string())
725 .await;
726 assert!(result.is_ok());
727 }
728
729 assert_eq!(call_count.load(Ordering::SeqCst), 10);
730 }
731
732 #[tokio::test]
733 async fn test_backpressure_waits_instead_of_rejecting() {
734 let service =
735 service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
736
737 let layer = RateLimiterLayer::builder()
738 .limit_for_period(1)
739 .refresh_period(Duration::from_millis(50))
740 .backpressure()
741 .build();
742
743 let mut service = layer.layer(service);
744
745 assert!(service
747 .ready()
748 .await
749 .unwrap()
750 .call("1".to_string())
751 .await
752 .is_ok());
753
754 let start = std::time::Instant::now();
756 let result = service.ready().await.unwrap().call("2".to_string()).await;
757 let elapsed = start.elapsed();
758
759 assert!(result.is_ok());
760 assert!(elapsed >= Duration::from_millis(40));
761 }
762
763 #[tokio::test]
764 async fn test_backpressure_never_returns_rate_limited() {
765 let service =
766 service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
767
768 let layer = RateLimiterLayer::builder()
769 .limit_for_period(1)
770 .refresh_period(Duration::from_millis(50))
771 .backpressure()
772 .build();
773
774 let mut service = layer.layer(service);
775
776 for _ in 0..5 {
778 let result = service.ready().await.unwrap().call("x".to_string()).await;
779 assert!(result.is_ok());
780 }
781 }
782
783 #[tokio::test]
784 async fn test_backpressure_events_fire_permit_acquired() {
785 let acquired_count = Arc::new(AtomicUsize::new(0));
786 let rejected_count = Arc::new(AtomicUsize::new(0));
787
788 let ac = Arc::clone(&acquired_count);
789 let rc = Arc::clone(&rejected_count);
790
791 let service =
792 service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
793
794 let layer = RateLimiterLayer::builder()
795 .limit_for_period(1)
796 .refresh_period(Duration::from_millis(50))
797 .backpressure()
798 .on_permit_acquired(move |_| {
799 ac.fetch_add(1, Ordering::SeqCst);
800 })
801 .on_permit_rejected(move |_| {
802 rc.fetch_add(1, Ordering::SeqCst);
803 })
804 .build();
805
806 let mut service = layer.layer(service);
807
808 for _ in 0..3 {
809 let _ = service.ready().await.unwrap().call("x".to_string()).await;
810 }
811
812 assert_eq!(acquired_count.load(Ordering::SeqCst), 3);
813 assert_eq!(rejected_count.load(Ordering::SeqCst), 0);
814 }
815
816 #[tokio::test]
817 async fn test_backpressure_with_sliding_log() {
818 let service =
819 service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
820
821 let layer = RateLimiterLayer::builder()
822 .limit_for_period(2)
823 .refresh_period(Duration::from_millis(50))
824 .window_type(WindowType::SlidingLog)
825 .backpressure()
826 .build();
827
828 let mut service = layer.layer(service);
829
830 for _ in 0..4 {
831 let result = service.ready().await.unwrap().call("x".to_string()).await;
832 assert!(result.is_ok());
833 }
834 }
835
836 #[tokio::test]
837 async fn test_backpressure_with_sliding_counter() {
838 let service =
839 service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
840
841 let layer = RateLimiterLayer::builder()
842 .limit_for_period(2)
843 .refresh_period(Duration::from_millis(50))
844 .window_type(WindowType::SlidingCounter)
845 .backpressure()
846 .build();
847
848 let mut service = layer.layer(service);
849
850 for _ in 0..4 {
851 let result = service.ready().await.unwrap().call("x".to_string()).await;
852 assert!(result.is_ok());
853 }
854 }
855}