Skip to main content

rustauth_redis/
rate_limit.rs

1use redis::aio::ConnectionManager;
2use redis::Script;
3use rustauth_core::error::RustAuthError;
4use rustauth_core::options::{
5    validate_rate_limit_rule, RateLimitConsumeInput, RateLimitDecision, RateLimitFuture,
6    RateLimitStore,
7};
8
9use crate::url::validate_rate_limit_key_prefix;
10
11pub(crate) const RATE_LIMIT_SCRIPT: &str = r#"
12local key = KEYS[1]
13local now = tonumber(ARGV[1])
14local window = tonumber(ARGV[2])
15local max = tonumber(ARGV[3])
16
17local data = redis.call("HMGET", key, "count", "last_request")
18local count = tonumber(data[1])
19local last_request = tonumber(data[2])
20
21if count == nil or last_request == nil or (now - last_request) > window then
22  redis.call("HSET", key, "count", 1, "last_request", now)
23  redis.call("PEXPIRE", key, window)
24  return {1, 1, now}
25end
26
27if count >= max then
28  redis.call("PEXPIRE", key, window)
29  return {0, count, last_request}
30end
31
32count = count + 1
33redis.call("HSET", key, "count", count, "last_request", now)
34redis.call("PEXPIRE", key, window)
35return {1, count, now}
36"#;
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub struct RedisRateLimitOptions {
40    pub key_prefix: String,
41}
42
43impl Default for RedisRateLimitOptions {
44    fn default() -> Self {
45        Self {
46            key_prefix: "rustauth:".to_owned(),
47        }
48    }
49}
50
51#[derive(Clone)]
52pub struct RedisRateLimitStore {
53    manager: ConnectionManager,
54    options: RedisRateLimitOptions,
55}
56
57impl RedisRateLimitStore {
58    pub async fn connect(redis_url: &str) -> Result<Self, RustAuthError> {
59        Self::connect_with_options(redis_url, RedisRateLimitOptions::default()).await
60    }
61
62    pub async fn connect_with_options(
63        redis_url: &str,
64        options: RedisRateLimitOptions,
65    ) -> Result<Self, RustAuthError> {
66        let manager = crate::connect_manager(redis_url).await?;
67        Ok(Self::new(manager, options))
68    }
69
70    pub fn new(manager: ConnectionManager, options: RedisRateLimitOptions) -> Self {
71        Self { manager, options }
72    }
73
74    fn key(&self, key: &str) -> Result<String, RustAuthError> {
75        validate_rate_limit_key_prefix(&self.options.key_prefix)?;
76        Ok(format!("{}rate-limit:{key}", self.options.key_prefix))
77    }
78}
79
80impl RateLimitStore for RedisRateLimitStore {
81    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
82        Box::pin(async move {
83            let window_ms = validate_rate_limit_rule(&input.rule)?;
84            let redis_key = self.key(&input.key)?;
85            let mut manager = self.manager.clone();
86            let result: (i64, i64, i64) = Script::new(RATE_LIMIT_SCRIPT)
87                .key(redis_key)
88                .arg(input.now_ms)
89                .arg(window_ms)
90                .arg(input.rule.max as i64)
91                .invoke_async(&mut manager)
92                .await
93                .map_err(|error| RustAuthError::Adapter(error.to_string()))?;
94            let permitted = match result.0 {
95                0 => false,
96                1 => true,
97                _ => {
98                    return Err(RustAuthError::Adapter(
99                        "invalid redis rate limit script result: `permitted` was not 0 or 1"
100                            .to_owned(),
101                    ));
102                }
103            };
104            if result.1 < 0 {
105                return Err(RustAuthError::Adapter(
106                    "invalid redis rate limit script result: `count` was negative".to_owned(),
107                ));
108            }
109            let count = result.1 as u64;
110            let last_request = result.2;
111            let retry_ms = last_request
112                .saturating_add(window_ms)
113                .saturating_sub(input.now_ms)
114                .max(0);
115            Ok(RateLimitDecision {
116                permitted,
117                retry_after: if permitted {
118                    0
119                } else {
120                    ceil_millis_to_seconds(retry_ms)
121                },
122                limit: input.rule.max,
123                remaining: input.rule.max.saturating_sub(count),
124                reset_after: ceil_millis_to_seconds(retry_ms),
125            })
126        })
127    }
128}
129
130fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
131    if milliseconds <= 0 {
132        return 0;
133    }
134    ((milliseconds as u64).saturating_add(999)) / 1000
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn rate_limit_key_prefix_must_not_be_empty() {
143        assert!(matches!(
144            validate_rate_limit_key_prefix(""),
145            Err(RustAuthError::InvalidConfig(message))
146                if message == "rate limit key prefix must not be empty"
147        ));
148    }
149}