shaperail_runtime/auth/
rate_limit.rs1use std::sync::Arc;
2
3use shaperail_core::ShaperailError;
4
5#[derive(Debug, Clone)]
7pub struct RateLimitConfig {
8 pub max_requests: u64,
10 pub window_secs: u64,
12}
13
14impl Default for RateLimitConfig {
15 fn default() -> Self {
16 Self {
17 max_requests: 100,
18 window_secs: 60,
19 }
20 }
21}
22
23#[derive(Clone)]
28pub struct RateLimiter {
29 pool: Arc<deadpool_redis::Pool>,
30 config: RateLimitConfig,
31}
32
33impl RateLimiter {
34 pub fn new(pool: Arc<deadpool_redis::Pool>, config: RateLimitConfig) -> Self {
36 Self { pool, config }
37 }
38
39 pub async fn check(&self, key: &str) -> Result<u64, ShaperailError> {
44 let redis_key = format!("shaperail:ratelimit:{key}");
45 let now = chrono::Utc::now().timestamp_millis() as f64;
46 let window_start = now - (self.config.window_secs as f64 * 1000.0);
47
48 let mut conn = self
49 .pool
50 .get()
51 .await
52 .map_err(|e| ShaperailError::Internal(format!("Redis connection failed: {e}")))?;
53
54 let script = redis::Script::new(
60 r#"
61 redis.call('ZREMRANGEBYSCORE', KEYS[1], '-inf', ARGV[1])
62 redis.call('ZADD', KEYS[1], ARGV[2], ARGV[2] .. ':' .. math.random())
63 local count = redis.call('ZCARD', KEYS[1])
64 redis.call('EXPIRE', KEYS[1], ARGV[3])
65 return count
66 "#,
67 );
68
69 let count: u64 = script
70 .key(&redis_key)
71 .arg(window_start)
72 .arg(now)
73 .arg(self.config.window_secs as i64 + 1)
74 .invoke_async(&mut *conn)
75 .await
76 .map_err(|e| ShaperailError::Internal(format!("Redis rate limit error: {e}")))?;
77
78 if count > self.config.max_requests {
79 return Err(ShaperailError::RateLimited);
80 }
81
82 Ok(self.config.max_requests - count)
83 }
84
85 pub fn key_for(ip: &str, user_id: Option<&str>) -> String {
90 match user_id {
91 Some(uid) => format!("user:{uid}"),
92 None => format!("ip:{ip}"),
93 }
94 }
95
96 pub fn key_for_tenant(ip: &str, user_id: Option<&str>, tenant_id: Option<&str>) -> String {
100 let base = Self::key_for(ip, user_id);
101 match tenant_id {
102 Some(tid) => format!("t:{tid}:{base}"),
103 None => base,
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn default_config() {
114 let cfg = RateLimitConfig::default();
115 assert_eq!(cfg.max_requests, 100);
116 assert_eq!(cfg.window_secs, 60);
117 }
118
119 #[test]
120 fn key_for_ip() {
121 let key = RateLimiter::key_for("192.168.1.1", None);
122 assert_eq!(key, "ip:192.168.1.1");
123 }
124
125 #[test]
126 fn key_for_user() {
127 let key = RateLimiter::key_for("192.168.1.1", Some("user-123"));
128 assert_eq!(key, "user:user-123");
129 }
130
131 #[test]
132 fn key_for_tenant_scoped() {
133 let key = RateLimiter::key_for_tenant("192.168.1.1", Some("user-123"), Some("org-a"));
134 assert_eq!(key, "t:org-a:user:user-123");
135 }
136
137 #[test]
138 fn key_for_tenant_no_tenant() {
139 let key = RateLimiter::key_for_tenant("192.168.1.1", Some("user-123"), None);
140 assert_eq!(key, "user:user-123");
141 }
142
143 #[test]
144 fn tenant_keys_differ() {
145 let key_a = RateLimiter::key_for_tenant("192.168.1.1", Some("user-123"), Some("org-a"));
146 let key_b = RateLimiter::key_for_tenant("192.168.1.1", Some("user-123"), Some("org-b"));
147 assert_ne!(
148 key_a, key_b,
149 "Rate limit keys for different tenants must differ"
150 );
151 }
152}