warpdrive_proxy/middleware/rate_limit.rs
1//! Rate limiting middleware using the governor crate
2//!
3//! Implements per-IP rate limiting using the GCRA (Generic Cell Rate Algorithm).
4//! Returns 429 Too Many Requests when limits are exceeded.
5//!
6//! # Trusted Sources
7//!
8//! Requests from trusted IP ranges (set by trusted_ranges middleware) bypass
9//! rate limiting entirely. The real client IP (normalized by trusted_ranges)
10//! is used for rate limiting instead of the socket IP.
11
12use async_trait::async_trait;
13use governor::{
14 Quota, RateLimiter,
15 clock::{Clock, DefaultClock},
16 state::keyed::DefaultKeyedStateStore,
17};
18use pingora::prelude::*;
19use std::net::IpAddr;
20use std::num::NonZeroU32;
21use std::sync::Arc;
22use tracing::{debug, warn};
23
24use super::{Middleware, MiddlewareContext};
25use crate::metrics::RATE_LIMIT_REJECTIONS;
26
27/// Rate limiting middleware
28///
29/// Uses the governor crate's GCRA implementation to enforce rate limits
30/// per client IP address. When a client exceeds their rate limit, a 429
31/// response is returned immediately without forwarding to the upstream.
32pub struct RateLimitMiddleware {
33 /// Governor rate limiter (keyed by IP address)
34 limiter: Arc<RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>>,
35 /// Whether rate limiting is enabled
36 enabled: bool,
37}
38
39impl RateLimitMiddleware {
40 /// Create a new rate limiting middleware
41 ///
42 /// # Arguments
43 ///
44 /// * `enabled` - Whether to enable rate limiting
45 /// * `requests_per_sec` - Maximum requests per second per IP
46 /// * `burst_size` - Maximum burst size (tokens available immediately)
47 ///
48 /// # Example
49 ///
50 /// ```no_run
51 /// use warpdrive::middleware::RateLimitMiddleware;
52 ///
53 /// // Allow 10 req/s with burst of 20
54 /// let middleware = RateLimitMiddleware::new(true, 10, 20);
55 /// ```
56 pub fn new(enabled: bool, requests_per_sec: u32, burst_size: u32) -> Self {
57 // Convert to NonZeroU32 for governor
58 let rps = NonZeroU32::new(requests_per_sec.max(1)).unwrap();
59 let burst = NonZeroU32::new(burst_size.max(1)).unwrap();
60
61 // Create quota: burst_size tokens, refill at requests_per_sec rate
62 let quota = Quota::per_second(rps).allow_burst(burst);
63
64 // Create keyed rate limiter (one bucket per IP)
65 let limiter = RateLimiter::keyed(quota);
66
67 debug!(
68 "Rate limiting initialized: enabled={}, rps={}, burst={}",
69 enabled, requests_per_sec, burst_size
70 );
71
72 Self {
73 limiter: Arc::new(limiter),
74 enabled,
75 }
76 }
77}
78
79#[async_trait]
80impl Middleware for RateLimitMiddleware {
81 async fn request_filter(
82 &self,
83 session: &mut Session,
84 ctx: &mut MiddlewareContext,
85 ) -> Result<()> {
86 if !self.enabled {
87 return Ok(()); // Pass through
88 }
89
90 // Skip rate limiting for trusted sources (proxies/CDNs)
91 if ctx.trusted_source {
92 debug!("Skipping rate limit for trusted source");
93 return Ok(());
94 }
95
96 // Use real client IP from context (normalized by trusted_ranges middleware)
97 let client_ip = ctx.real_client_ip;
98
99 // Check rate limit using governor
100 match self.limiter.check_key(&client_ip) {
101 Ok(_) => {
102 // Request allowed
103 debug!("Rate limit check passed for {}", client_ip);
104 Ok(()) // Continue to next middleware
105 }
106 Err(negative) => {
107 // Rate limit exceeded
108 let retry_after = negative.wait_time_from(DefaultClock::default().now());
109 warn!(
110 "Rate limit exceeded for {}: retry after {:?}",
111 client_ip, retry_after
112 );
113
114 // Record rate limit rejection metric
115 RATE_LIMIT_REJECTIONS
116 .with_label_values(&[&client_ip.to_string()])
117 .inc();
118
119 // Send 429 Too Many Requests using Pingora's respond_error
120 session.respond_error(429).await?;
121
122 // Return error to stop further processing
123 // Pingora interprets this as "response already sent"
124 Err(Error::explain(
125 ErrorType::HTTPStatus(429),
126 format!("Rate limit exceeded for {}", client_ip),
127 ))
128 }
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_rate_limit_middleware_creation() {
139 let middleware = RateLimitMiddleware::new(true, 10, 20);
140 assert!(middleware.enabled);
141 }
142
143 #[test]
144 fn test_rate_limit_middleware_disabled() {
145 let middleware = RateLimitMiddleware::new(false, 10, 20);
146 assert!(!middleware.enabled);
147 }
148
149 #[test]
150 fn test_quota_creation() {
151 // Ensure we handle edge cases
152 let _m1 = RateLimitMiddleware::new(true, 0, 0); // Should default to 1
153 let _m2 = RateLimitMiddleware::new(true, 1000, 5000); // High values
154 }
155}