warp_rate_limit/
lib.rs

1#![forbid(unsafe_code)]
2//! This crate provides RFC 6585 compliant in-memory rate limiting with 
3//! configurable windows and limits as lightweight middleware for 
4//! Warp web applications.
5//! 
6//! It provides a Filter you add to your routes that exposes rate-limiting
7//! information to your handlers, and a Rejection Type for error recovery.
8//! 
9//! It does not yet provide persistence, nor is the HashMap that stores IPs
10//! bounded. Both of these may be changed in a future version. 
11//! 
12//! # Quickstart
13//! 
14//! 1. Include the crate:
15//! 
16//! `cargo add warp-rate-limit`
17//! 
18//! 2. Define one or more rate limit configurations. Following are some 
19//! examples of available builder methods. The variable names are arbitrary: 
20//! 
21//! ```rust,no_run,ignore
22//! // Limit: 60 requests per 60 Earth seconds
23//! let public_routes_rate_limit = RateLimitConfig::default();
24//! 
25//! // Limit: 100 requests per 60 Earth seconds
26//! let parter_routes_rate_limit = RateLimitConfig::max_per_minute(100);
27//! 
28//! // Limit: 10 requests per 20 Earth seconds
29//! let static_route_limit = RateLimitConfig::max_per_window(10,20);
30//! ```
31//! 
32//! 3. Use rate limiting information in request handler. If you don't want 
33//! to use rate-limiting information related to the IP address associated 
34//! with this request, you can skip this part. 
35//! 
36//! ```rust,no_run,ignore
37//! // Example route handler
38//! async fn hande_request(rate_limit_info: RateLimitInfo) -> Result<impl Reply, Rejection> {
39//!     // Create a base response
40//!     let mut response = warp::reply::with_status(
41//!         "Hello world", 
42//!         StatusCode::OK
43//!     ).into_response();
44//! 
45//!     // Optionally add rate limit headers to your response.
46//!     if let Err(e) = add_rate_limit_headers(response.headers_mut(), &rate_limit_info) {
47//!         match e {
48//!             RateLimitError::HeaderError(e) => {
49//!                 eprintln!("Failed to set rate limit headers due to invalid value: {}", e);
50//!             }
51//!             RateLimitError::Other(e) => {
52//!                 eprintln!("Unexpected error setting rate limit headers: {}", e);
53//!             }
54//!         }
55//!     } 
56//! 
57//!     // You could also replace the above `if let Err(e)` block with:
58//!     // let _ = add_rate_limit_headers(response.headers_mut(), &rate_limit_info);
59//! 
60//!     Ok(response)
61//! }
62//! ```
63//! 
64//! 4. Handle rate limit errors in your rejection handler: 
65//! 
66//! ```rust,no_run,ignore
67//! // Example rejection handler
68//! async fn handle_rejection(rejection: Rejection) -> Result<impl Reply, Infallible> {
69//!     // Somewhere in your rejection handling:
70//!     if let Some(rate_limit_rejection) = rejection.find::<RateLimitRejection>() {
71//!         // We have a rate limit rejection -- so let's get some info about it: 
72//!         let info = get_rate_limit_info(rate_limit_rejection);
73//! 
74//!         // Let's use that info to create a response:
75//!         let message = format!(
76//!             "Rate limit exceeded. Try again after {}.", 
77//!             info.retry_after
78//!         );
79//! 
80//!         // Let's build that response:
81//!         let mut response = warp::reply::with_status(
82//!             message,
83//!             StatusCode::TOO_MANY_REQUESTS
84//!         ).into_response();
85//! 
86//!         // Then, let's add the rate-limiting headers to that response:
87//!         if let Err(e) = add_rate_limit_headers(response.headers_mut(), &info) {
88//!             // Whether or not you use the specific RateLimitError in 
89//!             // your handler, consider handling errors explicitly here. 
90//!             // Again, though, you're free to `if let _ = add_rate_limit_headers(...` 
91//!             // if you don't care about these errors.
92//!             match e {
93//!                 RateLimitError::HeaderError(e) => {
94//!                     eprintln!("Failed to set rate limit headers due to invalid value: {}", e);
95//!                 }
96//!                 RateLimitError::Other(e) => {
97//!                     eprintln!("Unexpected error setting rate limit headers: {}", e);
98//!                 }
99//!             }
100//!         }
101//! 
102//!         Ok(response)    
103//!     } else {
104//!         // Handle other types of rejections, e.g.
105//!         Ok(warp::reply::with_status(
106//!             "Internal Server Error",
107//!             StatusCode::INTERNAL_SERVER_ERROR,
108//!         ).into_response())
109//!     }
110//! } 
111//! ```
112
113use chrono::{DateTime, Duration as ChronoDuration, Utc};
114use serde::{Deserialize, Serialize};
115
116use std::collections::HashMap;
117use std::sync::Arc;
118use std::time::{Duration, Instant};
119use tokio::sync::RwLock;
120use warp::{
121    http::header::{self, HeaderMap, HeaderValue},
122    reject, Filter, Rejection
123};
124
125pub use chrono;
126pub use serde;
127
128/// Configuration for the rate limiter
129#[derive(Clone, Debug, PartialEq)]
130pub struct RateLimitConfig {
131    /// Maximum number of requests allowed within the window
132    pub max_requests: u32,
133    /// Time window for rate limiting
134    pub window: Duration,
135    /// Format for Retry-After header (RFC 7231 Date or Seconds)
136    pub retry_after_format: RetryAfterFormat,
137}
138
139/// Format options for the Retry-After header
140#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
141pub enum RetryAfterFormat {
142    /// HTTP-date format (RFC 7231)
143    #[default]
144    HttpDate,
145    /// Number of seconds
146    Seconds,
147}
148
149/// Information about the current rate limit status
150#[derive(Clone, Debug, Serialize, Deserialize)]
151pub struct RateLimitInfo {
152    /// Time until the rate limit resets
153    pub retry_after: String,
154    /// Maximum requests allowed in the window
155    pub limit: u32,
156    /// Remaining requests in the current window
157    pub remaining: u32,
158    /// Unix timestamp when the rate limit resets
159    pub reset_timestamp: i64,
160    /// Format used for retry-after header
161    pub retry_after_format: RetryAfterFormat,
162}
163
164/// Custom rejection type for rate limiting
165#[derive(Debug)]
166pub struct RateLimitRejection {
167    /// Duration until the client can retry
168    pub retry_after: Duration,
169    /// Maximum requests allowed in the window
170    pub limit: u32,
171    /// Unix timestamp when the rate limit resets
172    pub reset_time: DateTime<Utc>,
173    /// Format to use for Retry-After header
174    pub retry_after_format: RetryAfterFormat,
175}
176
177impl warp::reject::Reject for RateLimitRejection {}
178
179/// Sensible (opinionated) defaults
180impl Default for RateLimitConfig {
181    fn default() -> Self {
182        Self {
183            max_requests: 60, // 60 req/min baseline
184            window: Duration::from_secs(60),
185            retry_after_format: RetryAfterFormat::HttpDate,
186        }
187    }
188}
189
190/// Factory methods for quickly building a rate limiter
191impl RateLimitConfig {
192    /// Build a `RateLimitConfig` with sensible defaults for requests per minute
193    pub fn max_per_minute(max: u32) -> Self {
194        Self {
195            max_requests: max,
196            window: Duration::from_secs(60),
197            ..Default::default()
198        }
199    }
200
201    /// Build a `RateLimitConfig` with custom window size in seconds
202    pub fn max_per_window(max_requests: u32, window_seconds: u64) -> Self {
203        Self {
204            max_requests,
205            window: Duration::from_secs(window_seconds),
206            ..Default::default()
207        }
208    }
209}
210
211/// Errors that can occur during rate limiting logic
212#[derive(Debug)]
213pub enum RateLimitError {
214    /// Failed to set rate limit headers
215    HeaderError(warp::http::header::InvalidHeaderValue),
216    /// Other unexpected errors
217    Other(Box<dyn std::error::Error + Send + Sync>),
218}
219
220impl std::fmt::Display for RateLimitError {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            RateLimitError::HeaderError(e) => write!(f, "Failed to set rate limit header: {}", e),
224            RateLimitError::Other(e) => write!(f, "Rate limit error: {}", e),
225        }
226    }
227}
228
229impl std::error::Error for RateLimitError {
230    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
231        match self {
232            RateLimitError::HeaderError(e) => Some(e),
233            RateLimitError::Other(e) => Some(&**e),
234        }
235    }
236}
237
238#[derive(Clone)]
239struct RateLimiter {
240    state: Arc<RwLock<HashMap<String, (Instant, u32)>>>,
241    config: RateLimitConfig,
242}
243
244impl RateLimiter {
245    fn new(config: RateLimitConfig) -> Self {
246        Self {
247            state: Arc::new(RwLock::new(HashMap::new())),
248            config,
249        }
250    }
251
252    async fn check_rate_limit(&self, key: &str) -> Result<RateLimitInfo, Rejection> {
253        let mut state = self.state.write().await;
254        let now = Instant::now();
255        let current = state.get(key).copied();
256
257        match current {
258            Some((last_request, count)) => {
259                if now.duration_since(last_request) > self.config.window {
260                    // Window has passed, reset counter
261                    state.insert(key.to_string(), (now, 1));
262                    Ok(self.create_info(self.config.max_requests - 1, now))
263                } else if count >= self.config.max_requests {
264                    // Rate limit exceeded
265                    let retry_after = self.config.window - now.duration_since(last_request);
266                    let reset_time = Utc::now() + ChronoDuration::from_std(retry_after).unwrap();
267
268                    Err(reject::custom(RateLimitRejection {
269                        retry_after,
270                        limit: self.config.max_requests,
271                        reset_time,
272                        retry_after_format: self.config.retry_after_format.clone(),
273                    }))
274                } else {
275                    // Increment counter
276                    state.insert(key.to_string(), (last_request, count + 1));
277                    Ok(self.create_info(
278                        self.config.max_requests - (count + 1),
279                        last_request,
280                    ))
281                }
282            }
283            None => {
284                // First request
285                state.insert(key.to_string(), (now, 1));
286                Ok(self.create_info(self.config.max_requests - 1, now))
287            }
288        }
289    }
290
291    fn create_info(&self, remaining: u32, start: Instant) -> RateLimitInfo {
292        let reset_time = start + self.config.window;
293        let retry_after = match self.config.retry_after_format {
294            RetryAfterFormat::HttpDate => {
295                (Utc::now() + ChronoDuration::from_std(self.config.window).unwrap()).to_rfc2822()
296            }
297            RetryAfterFormat::Seconds => self.config.window.as_secs().to_string(),
298        };
299
300        RateLimitInfo {
301            retry_after,
302            limit: self.config.max_requests,
303            remaining,
304            reset_timestamp: (Utc::now() + ChronoDuration::from_std(reset_time.duration_since(start)).unwrap()).timestamp(),
305            retry_after_format: self.config.retry_after_format.clone(),
306        }
307    }
308}
309
310/// Creates a rate limiting filter with the given configuration
311pub fn with_rate_limit(
312    config: RateLimitConfig,
313) -> impl Filter<Extract = (RateLimitInfo,), Error = Rejection> + Clone {
314    let rate_limiter = RateLimiter::new(config);
315
316    warp::filters::addr::remote()
317        .map(move |addr: Option<std::net::SocketAddr>| {
318            (
319                rate_limiter.clone(),
320                addr.map(|a| a.ip().to_string())
321                    .unwrap_or_else(|| "unknown".to_string()),
322            )
323        })
324        .and_then(|(rate_limiter, ip): (RateLimiter, String)| async move {
325            rate_limiter.check_rate_limit(&ip).await
326        })
327}
328
329/// Adds rate limit headers to a response
330pub fn add_rate_limit_headers(
331    headers: &mut HeaderMap,
332    info: &RateLimitInfo,
333) -> Result<(), RateLimitError> {
334    headers.insert(header::RETRY_AFTER, 
335        HeaderValue::from_str(&info.retry_after).map_err(RateLimitError::HeaderError)?);
336    headers.insert(
337        "X-RateLimit-Limit",
338        HeaderValue::from_str(&info.limit.to_string()).map_err(RateLimitError::HeaderError)?,
339    );
340    headers.insert(
341        "X-RateLimit-Remaining",
342        HeaderValue::from_str(&info.remaining.to_string()).map_err(RateLimitError::HeaderError)?,
343    );
344    headers.insert(
345        "X-RateLimit-Reset",
346        HeaderValue::from_str(&info.reset_timestamp.to_string()).map_err(RateLimitError::HeaderError)?,
347    );
348    Ok(())
349}
350
351/// Gets rate limit information from a rejection
352pub fn get_rate_limit_info(rejection: &RateLimitRejection) -> RateLimitInfo {
353    let retry_after = match rejection.retry_after_format {
354        RetryAfterFormat::HttpDate => rejection.reset_time.to_rfc2822(),
355        RetryAfterFormat::Seconds => rejection.retry_after.as_secs().to_string(),
356    };
357
358    RateLimitInfo {
359        retry_after,
360        limit: rejection.limit,
361        remaining: 0,
362        reset_timestamp: rejection.reset_time.timestamp(),
363        retry_after_format: rejection.retry_after_format.clone(),
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use tokio::task::JoinSet;
371    use warp::Reply;
372    use warp::{
373        test::request,
374        http::StatusCode,
375        Filter,
376    };
377    use std::convert::Infallible;
378
379    // Helper function to create a test rate limiter with rejection handling
380    async fn create_test_route(
381        config: RateLimitConfig,
382    ) -> impl Filter<Extract = impl Reply, Error = Infallible> + Clone {
383        with_rate_limit(config)
384            .map(|info: RateLimitInfo| info.remaining.to_string())
385            .recover(|rejection: Rejection| async move {
386                if let Some(rate_limit) = rejection.find::<RateLimitRejection>() {
387                    let info = get_rate_limit_info(rate_limit);
388                    let mut resp = warp::reply::with_status(
389                        "Rate limit exceeded",
390                        StatusCode::TOO_MANY_REQUESTS,
391                    ).into_response();
392                    add_rate_limit_headers(resp.headers_mut(), &info).unwrap();
393                    Ok(resp)
394                } else {
395                    Ok(warp::reply::with_status(
396                        "Internal error", 
397                        StatusCode::INTERNAL_SERVER_ERROR,
398                    ).into_response())
399                }
400            })
401    }
402
403    #[test]
404    fn test_config_builders() {
405        // Test max_per_minute builder
406        let per_minute = RateLimitConfig::max_per_minute(60);
407        assert_eq!(per_minute.window, Duration::from_secs(60));
408        assert_eq!(per_minute.max_requests, 60);
409        assert_eq!(per_minute.retry_after_format, RetryAfterFormat::HttpDate);
410
411        // Test max_per_window builder
412        let custom = RateLimitConfig::max_per_window(30, 120);
413        assert_eq!(custom.window, Duration::from_secs(120));
414        assert_eq!(custom.max_requests, 30);
415        assert_eq!(custom.retry_after_format, RetryAfterFormat::HttpDate);
416
417        // Test default config
418        let default = RateLimitConfig::default();
419        assert_eq!(default.window, Duration::from_secs(60));
420        assert_eq!(default.max_requests, 60);
421        assert_eq!(default.retry_after_format, RetryAfterFormat::HttpDate);
422    }
423
424    #[tokio::test]
425    async fn test_comprehensive_rate_limit_rejection() {
426        let config = RateLimitConfig {
427            max_requests: 1,
428            window: Duration::from_secs(5),
429            retry_after_format: RetryAfterFormat::Seconds,
430        };
431
432        let route = create_test_route(config.clone()).await;
433
434        // First request succeeds
435        let resp1 = request()
436            .remote_addr("127.0.0.1:1234".parse().unwrap())
437            .reply(&route)
438            .await;
439        assert_eq!(resp1.status(), 200);
440        assert_eq!(resp1.body(), "0"); // Last remaining request
441
442        // Second request gets rejected with proper headers
443        let resp2 = request()
444            .remote_addr("127.0.0.1:1234".parse().unwrap())
445            .reply(&route)
446            .await;
447        
448        assert_eq!(resp2.status(), 429);
449        
450        // Verify rate limit headers exist and have correct format
451        let headers = resp2.headers();
452        assert!(headers.contains_key(header::RETRY_AFTER));
453        assert!(headers.contains_key("X-RateLimit-Limit"));
454        assert!(headers.contains_key("X-RateLimit-Remaining"));
455        assert!(headers.contains_key("X-RateLimit-Reset"));
456        
457        // Verify header values
458        assert_eq!(headers.get("X-RateLimit-Limit").unwrap(), "1");
459        assert_eq!(headers.get("X-RateLimit-Remaining").unwrap(), "0");
460        
461        // Verify Retry-After is a number of seconds
462        let retry_after = headers.get(header::RETRY_AFTER).unwrap().to_str().unwrap();
463        assert!(retry_after.parse::<u64>().is_ok());
464    }
465
466    #[tokio::test]
467    async fn test_retry_after_formats() {
468        // Test HttpDate format
469        let http_date_config = RateLimitConfig {
470            max_requests: 1,
471            window: Duration::from_secs(15),
472            retry_after_format: RetryAfterFormat::HttpDate,
473        };
474
475        let http_date_route = create_test_route(http_date_config).await;
476
477        // Trigger rate limit with HttpDate format
478        let _ = request()
479            .remote_addr("127.0.0.1:1234".parse().unwrap())
480            .reply(&http_date_route)
481            .await;
482        
483        let resp_http = request()
484            .remote_addr("127.0.0.1:1234".parse().unwrap())
485            .reply(&http_date_route)
486            .await;
487
488        // Verify HttpDate format
489        let retry_after_http = resp_http.headers().get(header::RETRY_AFTER).unwrap().to_str().unwrap();
490        assert!(!retry_after_http.is_empty()); // RFC2822 date contains GMT
491        
492        // Test Seconds format
493        let seconds_config = RateLimitConfig {
494            max_requests: 1,
495            window: Duration::from_secs(5),
496            retry_after_format: RetryAfterFormat::Seconds,
497        };
498
499        let seconds_route = create_test_route(seconds_config).await;
500
501        // Trigger rate limit with Seconds format
502        let _ = request()
503            .remote_addr("127.0.0.2:1234".parse().unwrap())
504            .reply(&seconds_route)
505            .await;
506        
507        let resp_sec = request()
508            .remote_addr("127.0.0.2:1234".parse().unwrap())
509            .reply(&seconds_route)
510            .await;
511
512        // Verify Seconds format
513        let retry_after_sec = resp_sec.headers().get(header::RETRY_AFTER).unwrap().to_str().unwrap();
514        assert!(retry_after_sec.parse::<u64>().is_ok());
515        assert!(retry_after_sec.parse::<u64>().unwrap() <= 5);
516    }
517
518    #[test]
519    fn test_rate_limit_info_extraction() {
520        let now = Utc::now();
521        let rejection = RateLimitRejection {
522            retry_after: Duration::from_secs(60),
523            limit: 100,
524            reset_time: now,
525            retry_after_format: RetryAfterFormat::Seconds,
526        };
527
528        let info = get_rate_limit_info(&rejection);
529
530        assert_eq!(info.limit, 100);
531        assert_eq!(info.remaining, 0);
532        assert_eq!(info.reset_timestamp, now.timestamp());
533        assert_eq!(info.retry_after, "60");
534        
535        // Test with HttpDate format
536        let rejection_http = RateLimitRejection {
537            retry_after: Duration::from_secs(60),
538            limit: 100,
539            reset_time: now,
540            retry_after_format: RetryAfterFormat::HttpDate,
541        };
542
543        let info_http = get_rate_limit_info(&rejection_http);
544        assert!(!info_http.retry_after.is_empty()); // RFC2822 date format
545    }
546
547    #[tokio::test]
548    async fn test_concurrent_requests() {
549        let config = RateLimitConfig {
550            max_requests: 5,
551            window: Duration::from_secs(1),
552            retry_after_format: RetryAfterFormat::Seconds,
553        };
554
555        let route = create_test_route(config.clone()).await;
556        let mut set = JoinSet::new();
557
558        // Launch 10 concurrent requests
559        for _ in 0..10 {
560            let route = route.clone();
561            set.spawn(async move {
562                request()
563                    .remote_addr("127.0.0.1:1234".parse().unwrap())
564                    .reply(&route)
565                    .await
566            });
567        }
568
569        let mut success_count = 0;
570        let mut rate_limited_count = 0;
571
572        while let Some(Ok(resp)) = set.join_next().await {
573            match resp.status() {
574                StatusCode::OK => success_count += 1,
575                StatusCode::TOO_MANY_REQUESTS => rate_limited_count += 1,
576                _ => panic!("Unexpected response status"),
577            }
578        }
579
580        assert_eq!(success_count, 5, "Expected exactly 5 successful requests");
581        assert_eq!(rate_limited_count, 5, "Expected exactly 5 rate-limited requests");
582    }
583
584    #[test]
585    fn test_invalid_header_value_handling() {
586        let mut headers = HeaderMap::new();
587        let invalid_info = RateLimitInfo {
588            retry_after: "invalid\u{0000}characters".to_string(),
589            limit: 100,
590            remaining: 50,
591            reset_timestamp: 1234567890,
592            retry_after_format: RetryAfterFormat::Seconds,
593        };
594        
595        let result = add_rate_limit_headers(&mut headers, &invalid_info);
596        assert!(matches!(result, Err(RateLimitError::HeaderError(_))));
597    }
598}