Skip to main content

shaperail_runtime/auth/
rate_limit.rs

1use std::sync::Arc;
2
3use shaperail_core::ShaperailError;
4
5/// Configuration for the rate limiter.
6#[derive(Debug, Clone)]
7pub struct RateLimitConfig {
8    /// Maximum requests per window.
9    pub max_requests: u64,
10    /// Window size in seconds.
11    pub window_secs: u64,
12}
13
14impl Default for RateLimitConfig {
15    fn default() -> Self {
16        Self {
17            max_requests: 100,
18            window_secs: 60,
19        }
20    }
21}
22
23/// Redis-backed sliding window rate limiter.
24///
25/// Uses a sorted set per key with timestamps as scores.
26/// Survives server restarts since all state is in Redis.
27#[derive(Clone)]
28pub struct RateLimiter {
29    pool: Arc<deadpool_redis::Pool>,
30    config: RateLimitConfig,
31}
32
33impl RateLimiter {
34    /// Creates a new rate limiter backed by the given Redis pool.
35    pub fn new(pool: Arc<deadpool_redis::Pool>, config: RateLimitConfig) -> Self {
36        Self { pool, config }
37    }
38
39    /// Checks if the given key (IP or token) is within rate limits.
40    ///
41    /// Returns `Ok(remaining)` with the number of remaining requests,
42    /// or `Err(ShaperailError::RateLimited)` if the limit is exceeded.
43    pub async fn check(&self, key: &str) -> Result<u64, ShaperailError> {
44        let redis_key = format!("shaperail:ratelimit:{key}");
45        let now = chrono::Utc::now().timestamp_millis() as f64;
46        let window_start = now - (self.config.window_secs as f64 * 1000.0);
47
48        let mut conn = self
49            .pool
50            .get()
51            .await
52            .map_err(|e| ShaperailError::Internal(format!("Redis connection failed: {e}")))?;
53
54        // Lua script for atomic sliding window:
55        // 1. Remove entries older than the window
56        // 2. Add current timestamp
57        // 3. Count entries in window
58        // 4. Set TTL on the key
59        let script = redis::Script::new(
60            r#"
61            redis.call('ZREMRANGEBYSCORE', KEYS[1], '-inf', ARGV[1])
62            redis.call('ZADD', KEYS[1], ARGV[2], ARGV[2] .. ':' .. math.random())
63            local count = redis.call('ZCARD', KEYS[1])
64            redis.call('EXPIRE', KEYS[1], ARGV[3])
65            return count
66            "#,
67        );
68
69        let count: u64 = script
70            .key(&redis_key)
71            .arg(window_start)
72            .arg(now)
73            .arg(self.config.window_secs as i64 + 1)
74            .invoke_async(&mut *conn)
75            .await
76            .map_err(|e| ShaperailError::Internal(format!("Redis rate limit error: {e}")))?;
77
78        if count > self.config.max_requests {
79            return Err(ShaperailError::RateLimited);
80        }
81
82        Ok(self.config.max_requests - count)
83    }
84
85    /// Builds the rate limit key from IP and optional token.
86    ///
87    /// If a token (user ID) is provided, rate limits per user.
88    /// Otherwise, rate limits per IP address.
89    pub fn key_for(ip: &str, user_id: Option<&str>) -> String {
90        match user_id {
91            Some(uid) => format!("user:{uid}"),
92            None => format!("ip:{ip}"),
93        }
94    }
95
96    /// Builds a tenant-scoped rate limit key (M18).
97    ///
98    /// Scopes the key by tenant_id so each tenant has independent rate limits.
99    pub fn key_for_tenant(ip: &str, user_id: Option<&str>, tenant_id: Option<&str>) -> String {
100        let base = Self::key_for(ip, user_id);
101        match tenant_id {
102            Some(tid) => format!("t:{tid}:{base}"),
103            None => base,
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn default_config() {
114        let cfg = RateLimitConfig::default();
115        assert_eq!(cfg.max_requests, 100);
116        assert_eq!(cfg.window_secs, 60);
117    }
118
119    #[test]
120    fn key_for_ip() {
121        let key = RateLimiter::key_for("192.168.1.1", None);
122        assert_eq!(key, "ip:192.168.1.1");
123    }
124
125    #[test]
126    fn key_for_user() {
127        let key = RateLimiter::key_for("192.168.1.1", Some("user-123"));
128        assert_eq!(key, "user:user-123");
129    }
130
131    #[test]
132    fn key_for_tenant_scoped() {
133        let key = RateLimiter::key_for_tenant("192.168.1.1", Some("user-123"), Some("org-a"));
134        assert_eq!(key, "t:org-a:user:user-123");
135    }
136
137    #[test]
138    fn key_for_tenant_no_tenant() {
139        let key = RateLimiter::key_for_tenant("192.168.1.1", Some("user-123"), None);
140        assert_eq!(key, "user:user-123");
141    }
142
143    #[test]
144    fn tenant_keys_differ() {
145        let key_a = RateLimiter::key_for_tenant("192.168.1.1", Some("user-123"), Some("org-a"));
146        let key_b = RateLimiter::key_for_tenant("192.168.1.1", Some("user-123"), Some("org-b"));
147        assert_ne!(
148            key_a, key_b,
149            "Rate limit keys for different tenants must differ"
150        );
151    }
152}