Skip to main content

tower_resilience_ratelimiter/
lib.rs

1//! Advanced rate limiting middleware for Tower services.
2//!
3//! This crate provides enhanced rate limiting inspired by Resilience4j's RateLimiter,
4//! with features beyond Tower's built-in rate limiting.
5//!
6//! # Features
7//!
8//! - **Permit-based rate limiting**: Control requests per time period
9//! - **Multiple window types**: Fixed, sliding log, and sliding counter algorithms
10//! - **Configurable timeout**: Wait up to a specified duration for permits
11//! - **Automatic refresh**: Permits automatically refresh after each period
12//! - **Event system**: Observability through rate limiter events
13//!
14//! # Window Types
15//!
16//! The rate limiter supports three different windowing strategies:
17//!
18//! - **Fixed** (default): Resets permits at fixed intervals. Simple and efficient
19//!   but can allow bursts at window boundaries.
20//!
21//! - **SlidingLog**: Stores timestamps of each request. Provides precise rate limiting
22//!   with no burst allowance, but uses O(n) memory where n = requests in window.
23//!
24//! - **SlidingCounter**: Uses weighted averaging between time buckets. Approximate
25//!   sliding window behavior with O(1) memory - ideal for high-throughput APIs.
26//!
27//! # Examples
28//!
29//! ## Basic Rate Limiting (Fixed Window)
30//!
31//! ```
32//! use tower_resilience_ratelimiter::RateLimiterLayer;
33//! use tower::ServiceBuilder;
34//! use std::time::Duration;
35//!
36//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
37//! // Allow 100 requests per second, wait up to 500ms for a permit
38//! let rate_limiter = RateLimiterLayer::builder()
39//!     .limit_for_period(100)
40//!     .refresh_period(Duration::from_secs(1))
41//!     .timeout_duration(Duration::from_millis(500))
42//!     .on_permit_acquired(|wait_duration| {
43//!         println!("Permit acquired after {:?}", wait_duration);
44//!     })
45//!     .on_permit_rejected(|timeout| {
46//!         println!("Rate limited! Timeout: {:?}", timeout);
47//!     })
48//!     .build();
49//!
50//! // Apply to a service
51//! let service = ServiceBuilder::new()
52//!     .layer(rate_limiter)
53//!     .service(tower::service_fn(|req: String| async move {
54//!         Ok::<_, std::io::Error>(format!("Response: {}", req))
55//!     }));
56//! # Ok(())
57//! # }
58//! ```
59//!
60//! ## Sliding Log Rate Limiting (Precise)
61//!
62//! Use sliding log for precise rate limiting with no burst allowance at window
63//! boundaries. This is ideal when you need to strictly enforce rate limits,
64//! such as when calling external APIs with strict quotas.
65//!
66//! ```
67//! use tower_resilience_ratelimiter::{RateLimiterLayer, WindowType};
68//! use tower::ServiceBuilder;
69//! use std::time::Duration;
70//!
71//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
72//! let rate_limiter = RateLimiterLayer::builder()
73//!     .limit_for_period(100)
74//!     .refresh_period(Duration::from_secs(1))
75//!     .window_type(WindowType::SlidingLog)
76//!     .timeout_duration(Duration::from_millis(500))
77//!     .build();
78//!
79//! let service = ServiceBuilder::new()
80//!     .layer(rate_limiter)
81//!     .service(tower::service_fn(|req: String| async move {
82//!         Ok::<_, std::io::Error>(format!("Response: {}", req))
83//!     }));
84//! # Ok(())
85//! # }
86//! ```
87//!
88//! ## Sliding Counter Rate Limiting (Efficient)
89//!
90//! Use sliding counter for high-throughput APIs where you want approximate
91//! sliding window behavior without the memory overhead of storing timestamps.
92//!
93//! ```
94//! use tower_resilience_ratelimiter::{RateLimiterLayer, WindowType};
95//! use tower::ServiceBuilder;
96//! use std::time::Duration;
97//!
98//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
99//! let rate_limiter = RateLimiterLayer::builder()
100//!     .limit_for_period(10000)  // High throughput
101//!     .refresh_period(Duration::from_secs(1))
102//!     .window_type(WindowType::SlidingCounter)
103//!     .timeout_duration(Duration::from_millis(100))
104//!     .build();
105//!
106//! let service = ServiceBuilder::new()
107//!     .layer(rate_limiter)
108//!     .service(tower::service_fn(|req: String| async move {
109//!         Ok::<_, std::io::Error>(format!("Response: {}", req))
110//!     }));
111//! # Ok(())
112//! # }
113//! ```
114//!
115//! ## Fallback When Rate Limited
116//!
117//! Handle rate limiting errors with appropriate fallback strategies:
118//!
119//! ### Return Informative Error
120//!
121//! ```
122//! use tower_resilience_ratelimiter::{RateLimiterLayer, RateLimiterError};
123//! use tower::{Service, ServiceBuilder, ServiceExt};
124//! use std::time::Duration;
125//!
126//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
127//! let rate_limiter = RateLimiterLayer::builder()
128//!     .limit_for_period(10)
129//!     .refresh_period(Duration::from_secs(1))
130//!     .timeout_duration(Duration::from_millis(100))
131//!     .build();
132//!
133//! let mut service = ServiceBuilder::new()
134//!     .layer(rate_limiter)
135//!     .service(tower::service_fn(|req: String| async move {
136//!         Ok::<String, std::io::Error>(format!("Processed: {}", req))
137//!     }));
138//!
139//! match service.ready().await?.call("request".to_string()).await {
140//!     Ok(response) => println!("Success: {}", response),
141//!     Err(e) => {
142//!         println!("Rate limited - please try again later");
143//!         // Could return 429 Too Many Requests in HTTP context
144//!     }
145//! }
146//! # Ok(())
147//! # }
148//! ```
149//!
150//! ### Queue for Later Processing
151//!
152//! ```
153//! use tower_resilience_ratelimiter::{RateLimiterLayer, RateLimiterError};
154//! use tower::{Service, ServiceBuilder, ServiceExt};
155//! use std::time::Duration;
156//! use std::sync::Arc;
157//! use tokio::sync::Mutex;
158//!
159//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
160//! let queue = Arc::new(Mutex::new(Vec::new()));
161//! let rate_limiter = RateLimiterLayer::builder()
162//!     .limit_for_period(10)
163//!     .refresh_period(Duration::from_secs(1))
164//!     .timeout_duration(Duration::from_millis(50))
165//!     .build();
166//!
167//! let mut service = ServiceBuilder::new()
168//!     .layer(rate_limiter)
169//!     .service(tower::service_fn(|req: String| async move {
170//!         Ok::<String, std::io::Error>(req)
171//!     }));
172//!
173//! let queue_clone = Arc::clone(&queue);
174//! let result: Result<String, std::io::Error> = match service.ready().await?.call("request".to_string()).await {
175//!     Ok(response) => Ok(response),
176//!     Err(_) => {
177//!         // Queue request for later processing
178//!         queue_clone.lock().await.push("request".to_string());
179//!         Ok("Queued for processing".to_string())
180//!     }
181//! };
182//! # Ok(())
183//! # }
184//! ```
185//!
186//! ### Shed Load Gracefully
187//!
188//! ```
189//! use tower_resilience_ratelimiter::{RateLimiterLayer, RateLimiterError};
190//! use tower::{Service, ServiceBuilder, ServiceExt};
191//! use std::time::Duration;
192//!
193//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
194//! let rate_limiter = RateLimiterLayer::builder()
195//!     .limit_for_period(100)
196//!     .refresh_period(Duration::from_secs(1))
197//!     .timeout_duration(Duration::from_millis(10)) // Short timeout = fast rejection
198//!     .build();
199//!
200//! let mut service = ServiceBuilder::new()
201//!     .layer(rate_limiter)
202//!     .service(tower::service_fn(|req: String| async move {
203//!         Ok::<String, std::io::Error>(req)
204//!     }));
205//!
206//! let result = service.ready().await?.call("request".to_string()).await
207//!     .unwrap_or_else(|_| {
208//!         // Shed load - return reduced functionality response
209//!         "Service at capacity - showing cached data".to_string()
210//!     });
211//! # Ok(())
212//! # }
213//! ```
214
215mod config;
216mod error;
217mod events;
218mod handle;
219mod layer;
220mod limiter;
221
222pub use config::{RateLimiterConfig, RateLimiterConfigBuilder, WindowType};
223pub use error::{RateLimiterError, RateLimiterServiceError};
224pub use events::RateLimiterEvent;
225pub use handle::RateLimiterHandle;
226pub use layer::RateLimiterLayer;
227
228use crate::limiter::SharedRateLimiter;
229use futures::future::BoxFuture;
230use futures::Future;
231use std::pin::Pin;
232use std::sync::Arc;
233use std::task::{Context, Poll};
234use std::time::Instant;
235use tower::Service;
236
237#[cfg(feature = "metrics")]
238use metrics::{counter, describe_counter, describe_histogram, histogram};
239
240#[cfg(feature = "tracing")]
241use tracing::{debug, warn};
242
243/// A Tower [`Service`] that applies rate limiting.
244///
245/// This service wraps an inner service and limits the rate at which
246/// requests can be processed according to the configured policy.
247///
248/// # Backpressure mode
249///
250/// By default, the rate limiter applies limits in `call()` and returns
251/// `RateLimiterServiceError::RateLimited` when permits are exhausted (rejection mode).
252///
253/// When [backpressure mode](RateLimiterConfigBuilder::backpressure) is enabled,
254/// limits are applied in `poll_ready()` instead: the service returns `Poll::Pending`
255/// when no permits are available, causing callers to wait naturally via
256/// `service.ready().await`. This integrates with Tower's load balancing and buffer
257/// layers. In this mode, `RateLimiterServiceError::RateLimited` is never returned
258/// and `timeout_duration` is ignored.
259pub struct RateLimiter<S> {
260    inner: S,
261    config: Arc<RateLimiterConfig>,
262    limiter: SharedRateLimiter,
263    /// Sleep future for backpressure mode wake-ups.
264    sleep: Option<Pin<Box<tokio::time::Sleep>>>,
265    /// Whether a permit has been acquired in `poll_ready` (backpressure mode only).
266    permit_acquired: bool,
267}
268
269impl<S> RateLimiter<S> {
270    /// Creates a new `RateLimiter` wrapping the given service.
271    pub fn new(inner: S, config: Arc<RateLimiterConfig>) -> Self {
272        #[cfg(feature = "metrics")]
273        {
274            describe_counter!(
275                "ratelimiter_calls_total",
276                "Total number of rate limiter calls (permitted or rejected)"
277            );
278            describe_histogram!(
279                "ratelimiter_wait_duration_seconds",
280                "Time spent waiting for a permit"
281            );
282        }
283
284        let limiter = SharedRateLimiter::new(
285            config.window_type,
286            config.limit_for_period,
287            config.refresh_period,
288            config.timeout_duration,
289        );
290
291        Self {
292            inner,
293            config,
294            limiter,
295            sleep: None,
296            permit_acquired: false,
297        }
298    }
299
300    /// Creates a new `RateLimiter` using pre-created shared limiter state.
301    pub(crate) fn from_shared(
302        inner: S,
303        config: Arc<RateLimiterConfig>,
304        limiter: SharedRateLimiter,
305    ) -> Self {
306        #[cfg(feature = "metrics")]
307        {
308            describe_counter!(
309                "ratelimiter_calls_total",
310                "Total number of rate limiter calls (permitted or rejected)"
311            );
312            describe_histogram!(
313                "ratelimiter_wait_duration_seconds",
314                "Time spent waiting for a permit"
315            );
316        }
317
318        Self {
319            inner,
320            config,
321            limiter,
322            sleep: None,
323            permit_acquired: false,
324        }
325    }
326}
327
328impl<S> Clone for RateLimiter<S>
329where
330    S: Clone,
331{
332    fn clone(&self) -> Self {
333        Self {
334            inner: self.inner.clone(),
335            config: Arc::clone(&self.config),
336            limiter: self.limiter.clone(),
337            sleep: None,
338            permit_acquired: false,
339        }
340    }
341}
342
343impl<S, Req> Service<Req> for RateLimiter<S>
344where
345    S: Service<Req> + Clone + Send + 'static,
346    S::Future: Send + 'static,
347    S::Error: Send + 'static,
348    Req: Send + 'static,
349{
350    type Response = S::Response;
351    type Error = RateLimiterServiceError<S::Error>;
352    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
353
354    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
355        // Check inner service readiness first
356        match self.inner.poll_ready(cx) {
357            Poll::Pending => return Poll::Pending,
358            Poll::Ready(Err(e)) => return Poll::Ready(Err(RateLimiterServiceError::Inner(e))),
359            Poll::Ready(Ok(())) => {}
360        }
361
362        if !self.config.backpressure {
363            return Poll::Ready(Ok(()));
364        }
365
366        // Backpressure mode: acquire permit in poll_ready
367        if self.permit_acquired {
368            return Poll::Ready(Ok(()));
369        }
370
371        // If we have a pending sleep, poll it first
372        if let Some(sleep) = self.sleep.as_mut() {
373            match sleep.as_mut().poll(cx) {
374                Poll::Pending => return Poll::Pending,
375                Poll::Ready(()) => {
376                    self.sleep = None;
377                    // Fall through to retry acquire
378                }
379            }
380        }
381
382        match self.limiter.try_acquire_now() {
383            Ok(()) => {
384                self.permit_acquired = true;
385                Poll::Ready(Ok(()))
386            }
387            Err(wait_duration) => {
388                let sleep = tokio::time::sleep(wait_duration);
389                let mut pinned = Box::pin(sleep);
390                // Register the waker so we get polled again when the sleep completes
391                let _ = pinned.as_mut().poll(cx);
392                self.sleep = Some(pinned);
393                Poll::Pending
394            }
395        }
396    }
397
398    fn call(&mut self, req: Req) -> Self::Future {
399        if self.permit_acquired {
400            // Backpressure mode: permit already acquired in poll_ready
401            self.permit_acquired = false;
402            let config = Arc::clone(&self.config);
403            let mut inner = self.inner.clone();
404
405            let event = RateLimiterEvent::PermitAcquired {
406                pattern_name: config.name.clone(),
407                timestamp: Instant::now(),
408                wait_duration: std::time::Duration::ZERO,
409            };
410            config.event_listeners.emit(&event);
411
412            #[cfg(feature = "metrics")]
413            {
414                counter!("ratelimiter_calls_total", "ratelimiter" => config.name.clone(), "result" => "permitted").increment(1);
415                histogram!("ratelimiter_wait_duration_seconds", "ratelimiter" => config.name.clone())
416                    .record(0.0);
417            }
418
419            #[cfg(feature = "tracing")]
420            debug!(ratelimiter = %config.name, "Permit acquired via backpressure");
421
422            return Box::pin(async move {
423                inner
424                    .call(req)
425                    .await
426                    .map_err(RateLimiterServiceError::Inner)
427            });
428        }
429
430        // Rejection mode: acquire permit in call
431        let limiter = self.limiter.clone();
432        let config = Arc::clone(&self.config);
433        let mut inner = self.inner.clone();
434
435        Box::pin(async move {
436            match limiter.acquire().await {
437                Ok(wait_duration) => {
438                    let event = RateLimiterEvent::PermitAcquired {
439                        pattern_name: config.name.clone(),
440                        timestamp: Instant::now(),
441                        wait_duration,
442                    };
443                    config.event_listeners.emit(&event);
444
445                    #[cfg(feature = "metrics")]
446                    {
447                        counter!("ratelimiter_calls_total", "ratelimiter" => config.name.clone(), "result" => "permitted").increment(1);
448                        histogram!("ratelimiter_wait_duration_seconds", "ratelimiter" => config.name.clone())
449                            .record(wait_duration.as_secs_f64());
450                    }
451
452                    #[cfg(feature = "tracing")]
453                    {
454                        if wait_duration.as_millis() > 0 {
455                            debug!(
456                                ratelimiter = %config.name,
457                                wait_ms = wait_duration.as_millis(),
458                                "Permit acquired after waiting"
459                            );
460                        } else {
461                            debug!(ratelimiter = %config.name, "Permit acquired immediately");
462                        }
463                    }
464
465                    inner
466                        .call(req)
467                        .await
468                        .map_err(RateLimiterServiceError::Inner)
469                }
470                Err(()) => {
471                    let event = RateLimiterEvent::PermitRejected {
472                        pattern_name: config.name.clone(),
473                        timestamp: Instant::now(),
474                        timeout_duration: config.timeout_duration,
475                    };
476                    config.event_listeners.emit(&event);
477
478                    #[cfg(feature = "metrics")]
479                    {
480                        counter!("ratelimiter_calls_total", "ratelimiter" => config.name.clone(), "result" => "rejected").increment(1);
481                    }
482
483                    #[cfg(feature = "tracing")]
484                    warn!(
485                        ratelimiter = %config.name,
486                        timeout_ms = config.timeout_duration.as_millis(),
487                        "Rate limit exceeded - permit rejected"
488                    );
489
490                    Err(RateLimiterServiceError::RateLimited)
491                }
492            }
493        })
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use std::sync::atomic::{AtomicUsize, Ordering};
501    use std::sync::Arc;
502    use std::time::Duration;
503    use tower::service_fn;
504    use tower::{Layer, ServiceExt};
505
506    #[tokio::test]
507    async fn test_allows_requests_within_limit() {
508        let call_count = Arc::new(AtomicUsize::new(0));
509        let cc = Arc::clone(&call_count);
510
511        let service = service_fn(move |req: String| {
512            let cc = Arc::clone(&cc);
513            async move {
514                cc.fetch_add(1, Ordering::SeqCst);
515                Ok::<_, std::io::Error>(format!("Response: {}", req))
516            }
517        });
518
519        let layer = RateLimiterLayer::builder()
520            .limit_for_period(10)
521            .refresh_period(Duration::from_secs(1))
522            .timeout_duration(Duration::from_millis(100))
523            .build();
524
525        let mut service = layer.layer(service);
526
527        // Should be able to make 10 requests
528        for _ in 0..10 {
529            let result = service
530                .ready()
531                .await
532                .unwrap()
533                .call("test".to_string())
534                .await;
535            assert!(result.is_ok());
536        }
537
538        assert_eq!(call_count.load(Ordering::SeqCst), 10);
539    }
540
541    #[tokio::test]
542    async fn test_rejects_requests_over_limit() {
543        let service = service_fn(|req: String| async move {
544            Ok::<_, std::io::Error>(format!("Response: {}", req))
545        });
546
547        let layer = RateLimiterLayer::builder()
548            .limit_for_period(2)
549            .refresh_period(Duration::from_secs(10))
550            .timeout_duration(Duration::from_millis(10))
551            .build();
552
553        let mut service = layer.layer(service);
554
555        // First 2 should succeed
556        assert!(service
557            .ready()
558            .await
559            .unwrap()
560            .call("1".to_string())
561            .await
562            .is_ok());
563        assert!(service
564            .ready()
565            .await
566            .unwrap()
567            .call("2".to_string())
568            .await
569            .is_ok());
570
571        // Third should be rate limited
572        let result = service.ready().await.unwrap().call("3".to_string()).await;
573        assert!(result.is_err());
574        assert!(matches!(
575            result.unwrap_err(),
576            RateLimiterServiceError::RateLimited
577        ));
578    }
579
580    #[tokio::test]
581    async fn test_permits_refresh_after_period() {
582        let call_count = Arc::new(AtomicUsize::new(0));
583        let cc = Arc::clone(&call_count);
584
585        let service = service_fn(move |_req: String| {
586            let cc = Arc::clone(&cc);
587            async move {
588                cc.fetch_add(1, Ordering::SeqCst);
589                Ok::<_, std::io::Error>("ok".to_string())
590            }
591        });
592
593        let layer = RateLimiterLayer::builder()
594            .limit_for_period(2)
595            .refresh_period(Duration::from_millis(100))
596            .timeout_duration(Duration::from_millis(200))
597            .build();
598
599        let mut service = layer.layer(service);
600
601        // Use up permits
602        assert!(service
603            .ready()
604            .await
605            .unwrap()
606            .call("1".to_string())
607            .await
608            .is_ok());
609        assert!(service
610            .ready()
611            .await
612            .unwrap()
613            .call("2".to_string())
614            .await
615            .is_ok());
616
617        // Wait for refresh
618        tokio::time::sleep(Duration::from_millis(150)).await;
619
620        // Should be able to make requests again
621        assert!(service
622            .ready()
623            .await
624            .unwrap()
625            .call("3".to_string())
626            .await
627            .is_ok());
628        assert_eq!(call_count.load(Ordering::SeqCst), 3);
629    }
630
631    #[tokio::test]
632    async fn test_event_listeners_called() {
633        let acquired_count = Arc::new(AtomicUsize::new(0));
634        let rejected_count = Arc::new(AtomicUsize::new(0));
635
636        let ac = Arc::clone(&acquired_count);
637        let rc = Arc::clone(&rejected_count);
638
639        let service =
640            service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
641
642        let layer = RateLimiterLayer::builder()
643            .limit_for_period(1)
644            .refresh_period(Duration::from_secs(10))
645            .timeout_duration(Duration::from_millis(10))
646            .on_permit_acquired(move |_| {
647                ac.fetch_add(1, Ordering::SeqCst);
648            })
649            .on_permit_rejected(move |_| {
650                rc.fetch_add(1, Ordering::SeqCst);
651            })
652            .build();
653
654        let mut service = layer.layer(service);
655
656        // First request should succeed
657        let _ = service.ready().await.unwrap().call("1".to_string()).await;
658        assert_eq!(acquired_count.load(Ordering::SeqCst), 1);
659
660        // Second should be rejected
661        let _ = service.ready().await.unwrap().call("2".to_string()).await;
662        assert_eq!(rejected_count.load(Ordering::SeqCst), 1);
663    }
664
665    #[tokio::test]
666    async fn test_waits_for_permit_within_timeout() {
667        let service =
668            service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
669
670        let layer = RateLimiterLayer::builder()
671            .limit_for_period(1)
672            .refresh_period(Duration::from_millis(50))
673            .timeout_duration(Duration::from_millis(100)) // Can wait through one refresh
674            .build();
675
676        let mut service = layer.layer(service);
677
678        // First request succeeds
679        assert!(service
680            .ready()
681            .await
682            .unwrap()
683            .call("1".to_string())
684            .await
685            .is_ok());
686
687        // Second request should wait for refresh and succeed
688        let start = std::time::Instant::now();
689        let result = service.ready().await.unwrap().call("2".to_string()).await;
690        let elapsed = start.elapsed();
691
692        assert!(result.is_ok());
693        assert!(elapsed >= Duration::from_millis(45)); // Should have waited
694    }
695
696    // ==================== Backpressure Mode Tests ====================
697
698    #[tokio::test]
699    async fn test_backpressure_allows_requests_within_limit() {
700        let call_count = Arc::new(AtomicUsize::new(0));
701        let cc = Arc::clone(&call_count);
702
703        let service = service_fn(move |req: String| {
704            let cc = Arc::clone(&cc);
705            async move {
706                cc.fetch_add(1, Ordering::SeqCst);
707                Ok::<_, std::io::Error>(format!("Response: {}", req))
708            }
709        });
710
711        let layer = RateLimiterLayer::builder()
712            .limit_for_period(10)
713            .refresh_period(Duration::from_secs(1))
714            .backpressure()
715            .build();
716
717        let mut service = layer.layer(service);
718
719        for _ in 0..10 {
720            let result = service
721                .ready()
722                .await
723                .unwrap()
724                .call("test".to_string())
725                .await;
726            assert!(result.is_ok());
727        }
728
729        assert_eq!(call_count.load(Ordering::SeqCst), 10);
730    }
731
732    #[tokio::test]
733    async fn test_backpressure_waits_instead_of_rejecting() {
734        let service =
735            service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
736
737        let layer = RateLimiterLayer::builder()
738            .limit_for_period(1)
739            .refresh_period(Duration::from_millis(50))
740            .backpressure()
741            .build();
742
743        let mut service = layer.layer(service);
744
745        // First request succeeds immediately
746        assert!(service
747            .ready()
748            .await
749            .unwrap()
750            .call("1".to_string())
751            .await
752            .is_ok());
753
754        // Second request should wait (not error) and eventually succeed
755        let start = std::time::Instant::now();
756        let result = service.ready().await.unwrap().call("2".to_string()).await;
757        let elapsed = start.elapsed();
758
759        assert!(result.is_ok());
760        assert!(elapsed >= Duration::from_millis(40));
761    }
762
763    #[tokio::test]
764    async fn test_backpressure_never_returns_rate_limited() {
765        let service =
766            service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
767
768        let layer = RateLimiterLayer::builder()
769            .limit_for_period(1)
770            .refresh_period(Duration::from_millis(50))
771            .backpressure()
772            .build();
773
774        let mut service = layer.layer(service);
775
776        // Make several requests; none should return RateLimited
777        for _ in 0..5 {
778            let result = service.ready().await.unwrap().call("x".to_string()).await;
779            assert!(result.is_ok());
780        }
781    }
782
783    #[tokio::test]
784    async fn test_backpressure_events_fire_permit_acquired() {
785        let acquired_count = Arc::new(AtomicUsize::new(0));
786        let rejected_count = Arc::new(AtomicUsize::new(0));
787
788        let ac = Arc::clone(&acquired_count);
789        let rc = Arc::clone(&rejected_count);
790
791        let service =
792            service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
793
794        let layer = RateLimiterLayer::builder()
795            .limit_for_period(1)
796            .refresh_period(Duration::from_millis(50))
797            .backpressure()
798            .on_permit_acquired(move |_| {
799                ac.fetch_add(1, Ordering::SeqCst);
800            })
801            .on_permit_rejected(move |_| {
802                rc.fetch_add(1, Ordering::SeqCst);
803            })
804            .build();
805
806        let mut service = layer.layer(service);
807
808        for _ in 0..3 {
809            let _ = service.ready().await.unwrap().call("x".to_string()).await;
810        }
811
812        assert_eq!(acquired_count.load(Ordering::SeqCst), 3);
813        assert_eq!(rejected_count.load(Ordering::SeqCst), 0);
814    }
815
816    #[tokio::test]
817    async fn test_backpressure_with_sliding_log() {
818        let service =
819            service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
820
821        let layer = RateLimiterLayer::builder()
822            .limit_for_period(2)
823            .refresh_period(Duration::from_millis(50))
824            .window_type(WindowType::SlidingLog)
825            .backpressure()
826            .build();
827
828        let mut service = layer.layer(service);
829
830        for _ in 0..4 {
831            let result = service.ready().await.unwrap().call("x".to_string()).await;
832            assert!(result.is_ok());
833        }
834    }
835
836    #[tokio::test]
837    async fn test_backpressure_with_sliding_counter() {
838        let service =
839            service_fn(|_req: String| async move { Ok::<_, std::io::Error>("ok".to_string()) });
840
841        let layer = RateLimiterLayer::builder()
842            .limit_for_period(2)
843            .refresh_period(Duration::from_millis(50))
844            .window_type(WindowType::SlidingCounter)
845            .backpressure()
846            .build();
847
848        let mut service = layer.layer(service);
849
850        for _ in 0..4 {
851            let result = service.ready().await.unwrap().call("x".to_string()).await;
852            assert!(result.is_ok());
853        }
854    }
855}