shaperail_runtime/auth/
rate_limit.rs1use std::sync::Arc;
2
3use shaperail_core::ShaperailError;
4
5#[derive(Debug, Clone)]
7pub struct RateLimitConfig {
8 pub max_requests: u64,
10 pub window_secs: u64,
12}
13
14impl Default for RateLimitConfig {
15 fn default() -> Self {
16 Self {
17 max_requests: 100,
18 window_secs: 60,
19 }
20 }
21}
22
23#[derive(Clone)]
28pub struct RateLimiter {
29 pool: Arc<deadpool_redis::Pool>,
30 config: RateLimitConfig,
31}
32
33impl RateLimiter {
34 pub fn new(pool: Arc<deadpool_redis::Pool>, config: RateLimitConfig) -> Self {
36 Self { pool, config }
37 }
38
39 pub async fn check(&self, key: &str) -> Result<u64, ShaperailError> {
44 let redis_key = format!("shaperail:ratelimit:{key}");
45 let now = chrono::Utc::now().timestamp_millis() as f64;
46 let window_start = now - (self.config.window_secs as f64 * 1000.0);
47
48 let mut conn = self
49 .pool
50 .get()
51 .await
52 .map_err(|e| ShaperailError::Internal(format!("Redis connection failed: {e}")))?;
53
54 let script = redis::Script::new(
60 r#"
61 redis.call('ZREMRANGEBYSCORE', KEYS[1], '-inf', ARGV[1])
62 redis.call('ZADD', KEYS[1], ARGV[2], ARGV[2] .. ':' .. math.random())
63 local count = redis.call('ZCARD', KEYS[1])
64 redis.call('EXPIRE', KEYS[1], ARGV[3])
65 return count
66 "#,
67 );
68
69 let count: u64 = script
70 .key(&redis_key)
71 .arg(window_start)
72 .arg(now)
73 .arg(self.config.window_secs as i64 + 1)
74 .invoke_async(&mut *conn)
75 .await
76 .map_err(|e| ShaperailError::Internal(format!("Redis rate limit error: {e}")))?;
77
78 if count > self.config.max_requests {
79 return Err(ShaperailError::RateLimited);
80 }
81
82 Ok(self.config.max_requests - count)
83 }
84
85 pub fn key_for(ip: &str, user_id: Option<&str>) -> String {
90 match user_id {
91 Some(uid) => format!("user:{uid}"),
92 None => format!("ip:{ip}"),
93 }
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn default_config() {
103 let cfg = RateLimitConfig::default();
104 assert_eq!(cfg.max_requests, 100);
105 assert_eq!(cfg.window_secs, 60);
106 }
107
108 #[test]
109 fn key_for_ip() {
110 let key = RateLimiter::key_for("192.168.1.1", None);
111 assert_eq!(key, "ip:192.168.1.1");
112 }
113
114 #[test]
115 fn key_for_user() {
116 let key = RateLimiter::key_for("192.168.1.1", Some("user-123"));
117 assert_eq!(key, "user:user-123");
118 }
119}