Skip to main content

rustauth_core/db/sql/
rate_limit.rs

1use super::*;
2use crate::options::validate_rate_limit_rule;
3
4/// Decodes a persisted rate-limit count, rejecting corrupt negative values
5/// instead of wrapping them into a huge `u64`.
6pub fn rate_limit_count_from_i64(count: i64) -> Result<u64, RustAuthError> {
7    u64::try_from(count).map_err(|_| {
8        RustAuthError::Adapter("negative rate limit count persisted in database".to_owned())
9    })
10}
11
12/// Encodes an in-memory rate-limit count for SQL `BIGINT` storage.
13pub fn rate_limit_count_to_i64(count: u64) -> Result<i64, RustAuthError> {
14    i64::try_from(count)
15        .map_err(|_| RustAuthError::Adapter("rate limit count exceeds SQL BIGINT range".to_owned()))
16}
17
18/// Builds the dialect-specific statement trio used by SQL-backed rate-limit stores.
19pub fn rate_limit_consume_statements(
20    dialect: SqlDialect,
21    table: &str,
22    key: &str,
23    count: &str,
24    last_request: &str,
25) -> Result<SqlRateLimitPlan, RustAuthError> {
26    let table = dialect.quote_identifier(table)?;
27    let key = dialect.quote_identifier(key)?;
28    let count = dialect.quote_identifier(count)?;
29    let last_request = dialect.quote_identifier(last_request)?;
30    let insert_keyword = match dialect {
31        SqlDialect::Postgres | SqlDialect::MySql => "INSERT",
32        SqlDialect::Sqlite => "INSERT OR IGNORE",
33    };
34    let conflict_suffix = match dialect {
35        SqlDialect::Postgres => format!(" ON CONFLICT ({key}) DO NOTHING"),
36        SqlDialect::MySql => String::new(),
37        SqlDialect::Sqlite => String::new(),
38    };
39    let insert_prefix = match dialect {
40        SqlDialect::MySql => "INSERT IGNORE",
41        SqlDialect::Postgres | SqlDialect::Sqlite => insert_keyword,
42    };
43    let lock_suffix = match dialect {
44        SqlDialect::Postgres | SqlDialect::MySql => " FOR UPDATE",
45        SqlDialect::Sqlite => "",
46    };
47
48    Ok(SqlRateLimitPlan {
49        insert_ignore: SqlStatement::new(format!(
50            "{insert_prefix} INTO {table} ({key}, {count}, {last_request}) VALUES ({}, 0, {}){conflict_suffix}",
51            dialect.placeholder(1),
52            dialect.placeholder(2)
53        )),
54        select: SqlStatement::new(format!(
55            "SELECT {count} AS count, {last_request} AS last_request FROM {table} WHERE {key} = {}{lock_suffix}",
56            dialect.placeholder(1)
57        )),
58        update: SqlStatement::new(format!(
59            "UPDATE {table} SET {count} = {}, {last_request} = {} WHERE {key} = {}",
60            dialect.placeholder(1),
61            dialect.placeholder(2),
62            dialect.placeholder(3)
63        )),
64    })
65}
66
67/// Applies RustAuth rate-limit semantics to a locked database record.
68///
69/// SQL adapters share this decision logic after they insert/select the row
70/// inside their own transaction or locking primitive.
71pub fn consume_sql_rate_limit_record(
72    input: RateLimitConsumeInput,
73    existing: Option<RateLimitRecord>,
74) -> Result<(RateLimitDecision, RateLimitRecord, bool), RustAuthError> {
75    let window_ms = validate_rate_limit_rule(&input.rule)?;
76    Ok(match existing {
77        Some(record)
78            if input.now_ms.saturating_sub(record.last_request) <= window_ms
79                && record.count >= input.rule.max =>
80        {
81            let retry_ms = record
82                .last_request
83                .saturating_add(window_ms)
84                .saturating_sub(input.now_ms)
85                .max(0);
86            (
87                RateLimitDecision {
88                    permitted: false,
89                    retry_after: ceil_millis_to_seconds(retry_ms),
90                    limit: input.rule.max,
91                    remaining: 0,
92                    reset_after: ceil_millis_to_seconds(retry_ms),
93                },
94                record,
95                true,
96            )
97        }
98        Some(mut record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
99            record.key = input.key;
100            record.count = record.count.saturating_add(1);
101            record.last_request = input.now_ms;
102            let remaining = input.rule.max.saturating_sub(record.count);
103            (
104                RateLimitDecision {
105                    permitted: true,
106                    retry_after: 0,
107                    limit: input.rule.max,
108                    remaining,
109                    reset_after: input.rule.window.whole_seconds() as u64,
110                },
111                record,
112                true,
113            )
114        }
115        Some(mut record) => {
116            record.key = input.key;
117            record.count = 1;
118            record.last_request = input.now_ms;
119            (
120                RateLimitDecision {
121                    permitted: true,
122                    retry_after: 0,
123                    limit: input.rule.max,
124                    remaining: input.rule.max.saturating_sub(1),
125                    reset_after: input.rule.window.whole_seconds() as u64,
126                },
127                record,
128                true,
129            )
130        }
131        None => {
132            let record = RateLimitRecord {
133                key: input.key,
134                count: 1,
135                last_request: input.now_ms,
136            };
137            (
138                RateLimitDecision {
139                    permitted: true,
140                    retry_after: 0,
141                    limit: input.rule.max,
142                    remaining: input.rule.max.saturating_sub(1),
143                    reset_after: input.rule.window.whole_seconds() as u64,
144                },
145                record,
146                false,
147            )
148        }
149    })
150}
151
152fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
153    if milliseconds <= 0 {
154        return 0;
155    }
156    ((milliseconds as u64).saturating_add(999)) / 1000
157}