Skip to main content

shaperail_runtime/auth/
rate_limit.rs

1use std::sync::Arc;
2
3use shaperail_core::ShaperailError;
4
5/// Configuration for the rate limiter.
6#[derive(Debug, Clone)]
7pub struct RateLimitConfig {
8    /// Maximum requests per window.
9    pub max_requests: u64,
10    /// Window size in seconds.
11    pub window_secs: u64,
12}
13
14impl Default for RateLimitConfig {
15    fn default() -> Self {
16        Self {
17            max_requests: 100,
18            window_secs: 60,
19        }
20    }
21}
22
23/// Redis-backed sliding window rate limiter.
24///
25/// Uses a sorted set per key with timestamps as scores.
26/// Survives server restarts since all state is in Redis.
27#[derive(Clone)]
28pub struct RateLimiter {
29    pool: Arc<deadpool_redis::Pool>,
30    config: RateLimitConfig,
31}
32
33impl RateLimiter {
34    /// Creates a new rate limiter backed by the given Redis pool.
35    pub fn new(pool: Arc<deadpool_redis::Pool>, config: RateLimitConfig) -> Self {
36        Self { pool, config }
37    }
38
39    /// Checks if the given key (IP or token) is within rate limits.
40    ///
41    /// Returns `Ok(remaining)` with the number of remaining requests,
42    /// or `Err(ShaperailError::RateLimited)` if the limit is exceeded.
43    pub async fn check(&self, key: &str) -> Result<u64, ShaperailError> {
44        let redis_key = format!("shaperail:ratelimit:{key}");
45        let now = chrono::Utc::now().timestamp_millis() as f64;
46        let window_start = now - (self.config.window_secs as f64 * 1000.0);
47
48        let mut conn = self
49            .pool
50            .get()
51            .await
52            .map_err(|e| ShaperailError::Internal(format!("Redis connection failed: {e}")))?;
53
54        // Lua script for atomic sliding window:
55        // 1. Remove entries older than the window
56        // 2. Add current timestamp
57        // 3. Count entries in window
58        // 4. Set TTL on the key
59        let script = redis::Script::new(
60            r#"
61            redis.call('ZREMRANGEBYSCORE', KEYS[1], '-inf', ARGV[1])
62            redis.call('ZADD', KEYS[1], ARGV[2], ARGV[2] .. ':' .. math.random())
63            local count = redis.call('ZCARD', KEYS[1])
64            redis.call('EXPIRE', KEYS[1], ARGV[3])
65            return count
66            "#,
67        );
68
69        let count: u64 = script
70            .key(&redis_key)
71            .arg(window_start)
72            .arg(now)
73            .arg(self.config.window_secs as i64 + 1)
74            .invoke_async(&mut *conn)
75            .await
76            .map_err(|e| ShaperailError::Internal(format!("Redis rate limit error: {e}")))?;
77
78        if count > self.config.max_requests {
79            return Err(ShaperailError::RateLimited);
80        }
81
82        Ok(self.config.max_requests - count)
83    }
84
85    /// Builds the rate limit key from IP and optional token.
86    ///
87    /// If a token (user ID) is provided, rate limits per user.
88    /// Otherwise, rate limits per IP address.
89    pub fn key_for(ip: &str, user_id: Option<&str>) -> String {
90        match user_id {
91            Some(uid) => format!("user:{uid}"),
92            None => format!("ip:{ip}"),
93        }
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn default_config() {
103        let cfg = RateLimitConfig::default();
104        assert_eq!(cfg.max_requests, 100);
105        assert_eq!(cfg.window_secs, 60);
106    }
107
108    #[test]
109    fn key_for_ip() {
110        let key = RateLimiter::key_for("192.168.1.1", None);
111        assert_eq!(key, "ip:192.168.1.1");
112    }
113
114    #[test]
115    fn key_for_user() {
116        let key = RateLimiter::key_for("192.168.1.1", Some("user-123"));
117        assert_eq!(key, "user:user-123");
118    }
119}