warpdrive_proxy/middleware/
rate_limit.rs

1//! Rate limiting middleware using the governor crate
2//!
3//! Implements per-IP rate limiting using the GCRA (Generic Cell Rate Algorithm).
4//! Returns 429 Too Many Requests when limits are exceeded.
5//!
6//! # Trusted Sources
7//!
8//! Requests from trusted IP ranges (set by trusted_ranges middleware) bypass
9//! rate limiting entirely. The real client IP (normalized by trusted_ranges)
10//! is used for rate limiting instead of the socket IP.
11
12use async_trait::async_trait;
13use governor::{
14    Quota, RateLimiter,
15    clock::{Clock, DefaultClock},
16    state::keyed::DefaultKeyedStateStore,
17};
18use pingora::prelude::*;
19use std::net::IpAddr;
20use std::num::NonZeroU32;
21use std::sync::Arc;
22use tracing::{debug, warn};
23
24use super::{Middleware, MiddlewareContext};
25use crate::metrics::RATE_LIMIT_REJECTIONS;
26
27/// Rate limiting middleware
28///
29/// Uses the governor crate's GCRA implementation to enforce rate limits
30/// per client IP address. When a client exceeds their rate limit, a 429
31/// response is returned immediately without forwarding to the upstream.
32pub struct RateLimitMiddleware {
33    /// Governor rate limiter (keyed by IP address)
34    limiter: Arc<RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>>,
35    /// Whether rate limiting is enabled
36    enabled: bool,
37}
38
39impl RateLimitMiddleware {
40    /// Create a new rate limiting middleware
41    ///
42    /// # Arguments
43    ///
44    /// * `enabled` - Whether to enable rate limiting
45    /// * `requests_per_sec` - Maximum requests per second per IP
46    /// * `burst_size` - Maximum burst size (tokens available immediately)
47    ///
48    /// # Example
49    ///
50    /// ```no_run
51    /// use warpdrive::middleware::RateLimitMiddleware;
52    ///
53    /// // Allow 10 req/s with burst of 20
54    /// let middleware = RateLimitMiddleware::new(true, 10, 20);
55    /// ```
56    pub fn new(enabled: bool, requests_per_sec: u32, burst_size: u32) -> Self {
57        // Convert to NonZeroU32 for governor
58        let rps = NonZeroU32::new(requests_per_sec.max(1)).unwrap();
59        let burst = NonZeroU32::new(burst_size.max(1)).unwrap();
60
61        // Create quota: burst_size tokens, refill at requests_per_sec rate
62        let quota = Quota::per_second(rps).allow_burst(burst);
63
64        // Create keyed rate limiter (one bucket per IP)
65        let limiter = RateLimiter::keyed(quota);
66
67        debug!(
68            "Rate limiting initialized: enabled={}, rps={}, burst={}",
69            enabled, requests_per_sec, burst_size
70        );
71
72        Self {
73            limiter: Arc::new(limiter),
74            enabled,
75        }
76    }
77}
78
79#[async_trait]
80impl Middleware for RateLimitMiddleware {
81    async fn request_filter(
82        &self,
83        session: &mut Session,
84        ctx: &mut MiddlewareContext,
85    ) -> Result<()> {
86        if !self.enabled {
87            return Ok(()); // Pass through
88        }
89
90        // Skip rate limiting for trusted sources (proxies/CDNs)
91        if ctx.trusted_source {
92            debug!("Skipping rate limit for trusted source");
93            return Ok(());
94        }
95
96        // Use real client IP from context (normalized by trusted_ranges middleware)
97        let client_ip = ctx.real_client_ip;
98
99        // Check rate limit using governor
100        match self.limiter.check_key(&client_ip) {
101            Ok(_) => {
102                // Request allowed
103                debug!("Rate limit check passed for {}", client_ip);
104                Ok(()) // Continue to next middleware
105            }
106            Err(negative) => {
107                // Rate limit exceeded
108                let retry_after = negative.wait_time_from(DefaultClock::default().now());
109                warn!(
110                    "Rate limit exceeded for {}: retry after {:?}",
111                    client_ip, retry_after
112                );
113
114                // Record rate limit rejection metric
115                RATE_LIMIT_REJECTIONS
116                    .with_label_values(&[&client_ip.to_string()])
117                    .inc();
118
119                // Send 429 Too Many Requests using Pingora's respond_error
120                session.respond_error(429).await?;
121
122                // Return error to stop further processing
123                // Pingora interprets this as "response already sent"
124                Err(Error::explain(
125                    ErrorType::HTTPStatus(429),
126                    format!("Rate limit exceeded for {}", client_ip),
127                ))
128            }
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_rate_limit_middleware_creation() {
139        let middleware = RateLimitMiddleware::new(true, 10, 20);
140        assert!(middleware.enabled);
141    }
142
143    #[test]
144    fn test_rate_limit_middleware_disabled() {
145        let middleware = RateLimitMiddleware::new(false, 10, 20);
146        assert!(!middleware.enabled);
147    }
148
149    #[test]
150    fn test_quota_creation() {
151        // Ensure we handle edge cases
152        let _m1 = RateLimitMiddleware::new(true, 0, 0); // Should default to 1
153        let _m2 = RateLimitMiddleware::new(true, 1000, 5000); // High values
154    }
155}