rustauth_redis/
rate_limit.rs1use redis::aio::ConnectionManager;
2use redis::Script;
3use rustauth_core::error::RustAuthError;
4use rustauth_core::options::{
5 validate_rate_limit_rule, RateLimitConsumeInput, RateLimitDecision, RateLimitFuture,
6 RateLimitStore,
7};
8
9use crate::url::validate_rate_limit_key_prefix;
10
11pub(crate) const RATE_LIMIT_SCRIPT: &str = r#"
12local key = KEYS[1]
13local now = tonumber(ARGV[1])
14local window = tonumber(ARGV[2])
15local max = tonumber(ARGV[3])
16
17local data = redis.call("HMGET", key, "count", "last_request")
18local count = tonumber(data[1])
19local last_request = tonumber(data[2])
20
21if count == nil or last_request == nil or (now - last_request) > window then
22 redis.call("HSET", key, "count", 1, "last_request", now)
23 redis.call("PEXPIRE", key, window)
24 return {1, 1, now}
25end
26
27if count >= max then
28 redis.call("PEXPIRE", key, window)
29 return {0, count, last_request}
30end
31
32count = count + 1
33redis.call("HSET", key, "count", count, "last_request", now)
34redis.call("PEXPIRE", key, window)
35return {1, count, now}
36"#;
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct RedisRateLimitOptions {
40 pub key_prefix: String,
41}
42
43impl Default for RedisRateLimitOptions {
44 fn default() -> Self {
45 Self {
46 key_prefix: "rustauth:".to_owned(),
47 }
48 }
49}
50
51#[derive(Clone)]
52pub struct RedisRateLimitStore {
53 manager: ConnectionManager,
54 options: RedisRateLimitOptions,
55}
56
57impl RedisRateLimitStore {
58 pub async fn connect(redis_url: &str) -> Result<Self, RustAuthError> {
59 Self::connect_with_options(redis_url, RedisRateLimitOptions::default()).await
60 }
61
62 pub async fn connect_with_options(
63 redis_url: &str,
64 options: RedisRateLimitOptions,
65 ) -> Result<Self, RustAuthError> {
66 let manager = crate::connect_manager(redis_url).await?;
67 Ok(Self::new(manager, options))
68 }
69
70 pub fn new(manager: ConnectionManager, options: RedisRateLimitOptions) -> Self {
71 Self { manager, options }
72 }
73
74 fn key(&self, key: &str) -> Result<String, RustAuthError> {
75 validate_rate_limit_key_prefix(&self.options.key_prefix)?;
76 Ok(format!("{}rate-limit:{key}", self.options.key_prefix))
77 }
78}
79
80impl RateLimitStore for RedisRateLimitStore {
81 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
82 Box::pin(async move {
83 let window_ms = validate_rate_limit_rule(&input.rule)?;
84 let redis_key = self.key(&input.key)?;
85 let mut manager = self.manager.clone();
86 let result: (i64, i64, i64) = Script::new(RATE_LIMIT_SCRIPT)
87 .key(redis_key)
88 .arg(input.now_ms)
89 .arg(window_ms)
90 .arg(input.rule.max as i64)
91 .invoke_async(&mut manager)
92 .await
93 .map_err(|error| RustAuthError::Adapter(error.to_string()))?;
94 let permitted = match result.0 {
95 0 => false,
96 1 => true,
97 _ => {
98 return Err(RustAuthError::Adapter(
99 "invalid redis rate limit script result: `permitted` was not 0 or 1"
100 .to_owned(),
101 ));
102 }
103 };
104 if result.1 < 0 {
105 return Err(RustAuthError::Adapter(
106 "invalid redis rate limit script result: `count` was negative".to_owned(),
107 ));
108 }
109 let count = result.1 as u64;
110 let last_request = result.2;
111 let retry_ms = last_request
112 .saturating_add(window_ms)
113 .saturating_sub(input.now_ms)
114 .max(0);
115 Ok(RateLimitDecision {
116 permitted,
117 retry_after: if permitted {
118 0
119 } else {
120 ceil_millis_to_seconds(retry_ms)
121 },
122 limit: input.rule.max,
123 remaining: input.rule.max.saturating_sub(count),
124 reset_after: ceil_millis_to_seconds(retry_ms),
125 })
126 })
127 }
128}
129
130fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
131 if milliseconds <= 0 {
132 return 0;
133 }
134 ((milliseconds as u64).saturating_add(999)) / 1000
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn rate_limit_key_prefix_must_not_be_empty() {
143 assert!(matches!(
144 validate_rate_limit_key_prefix(""),
145 Err(RustAuthError::InvalidConfig(message))
146 if message == "rate limit key prefix must not be empty"
147 ));
148 }
149}