Skip to main content

rustauth_fred/
store.rs

1use fred::clients::Client;
2use fred::interfaces::ClientLike;
3use fred::prelude::{Builder, Config};
4use fred::types::scripts::Script;
5use rustauth_core::error::RustAuthError;
6use rustauth_core::options::{
7    validate_rate_limit_rule, RateLimitConsumeInput, RateLimitDecision, RateLimitFuture,
8    RateLimitStore,
9};
10
11use crate::config::FredRateLimitOptions;
12use crate::error::fred_error;
13use crate::script::{parse_rate_limit_script_result, RATE_LIMIT_SCRIPT};
14use crate::url::normalize_fred_url;
15
16#[derive(Clone)]
17pub struct FredRateLimitStore {
18    client: Client,
19    options: FredRateLimitOptions,
20    script: Script,
21}
22
23impl FredRateLimitStore {
24    pub async fn connect(url: &str) -> Result<Self, RustAuthError> {
25        Self::connect_with_options(url, FredRateLimitOptions::default()).await
26    }
27
28    pub async fn connect_with_options(
29        url: &str,
30        options: FredRateLimitOptions,
31    ) -> Result<Self, RustAuthError> {
32        let client = connect_client(url).await?;
33        Ok(Self::new(client, options))
34    }
35
36    pub fn new(client: Client, options: FredRateLimitOptions) -> Self {
37        Self {
38            client,
39            options,
40            script: Script::from_lua(RATE_LIMIT_SCRIPT),
41        }
42    }
43
44    fn key(&self, key: &str) -> Result<String, RustAuthError> {
45        validate_rate_limit_key_prefix(&self.options.key_prefix)?;
46        Ok(format!("{}rate-limit:{key}", self.options.key_prefix))
47    }
48}
49
50fn validate_rate_limit_key_prefix(prefix: &str) -> Result<(), RustAuthError> {
51    if prefix.is_empty() {
52        return Err(RustAuthError::InvalidConfig(
53            "rate limit key prefix must not be empty".to_owned(),
54        ));
55    }
56    Ok(())
57}
58
59impl RateLimitStore for FredRateLimitStore {
60    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
61        Box::pin(async move {
62            let window_ms = validate_rate_limit_rule(&input.rule)?;
63            let redis_key = self.key(&input.key)?;
64            let result = self
65                .script
66                .evalsha_with_reload(
67                    &self.client,
68                    vec![redis_key],
69                    vec![
70                        input.now_ms.to_string(),
71                        window_ms.to_string(),
72                        input.rule.max.to_string(),
73                    ],
74                )
75                .await
76                .map_err(|error| fred_error("eval rate limit script", error))?;
77            let result = parse_rate_limit_script_result(result)?;
78            let retry_ms = result
79                .last_request
80                .saturating_add(window_ms)
81                .saturating_sub(input.now_ms)
82                .max(0);
83            Ok(RateLimitDecision {
84                permitted: result.permitted,
85                retry_after: if result.permitted {
86                    0
87                } else {
88                    ceil_millis_to_seconds(retry_ms)
89                },
90                limit: input.rule.max,
91                remaining: input.rule.max.saturating_sub(result.count),
92                reset_after: ceil_millis_to_seconds(retry_ms),
93            })
94        })
95    }
96}
97
98pub(crate) async fn connect_client(url: &str) -> Result<Client, RustAuthError> {
99    let url = normalize_fred_url(url);
100    let config = Config::from_url(url.as_ref()).map_err(|error| fred_error("parse url", error))?;
101    let client = Builder::from_config(config)
102        .build()
103        .map_err(|error| fred_error("build client", error))?;
104    client
105        .init()
106        .await
107        .map_err(|error| fred_error("connect", error))?;
108    Ok(client)
109}
110
111fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
112    if milliseconds <= 0 {
113        return 0;
114    }
115    ((milliseconds as u64).saturating_add(999)) / 1000
116}