rustauth_core/db/sql/
rate_limit.rs1use super::*;
2use crate::options::validate_rate_limit_rule;
3
4pub fn rate_limit_count_from_i64(count: i64) -> Result<u64, RustAuthError> {
7 u64::try_from(count).map_err(|_| {
8 RustAuthError::Adapter("negative rate limit count persisted in database".to_owned())
9 })
10}
11
12pub fn rate_limit_count_to_i64(count: u64) -> Result<i64, RustAuthError> {
14 i64::try_from(count)
15 .map_err(|_| RustAuthError::Adapter("rate limit count exceeds SQL BIGINT range".to_owned()))
16}
17
18pub fn rate_limit_consume_statements(
20 dialect: SqlDialect,
21 table: &str,
22 key: &str,
23 count: &str,
24 last_request: &str,
25) -> Result<SqlRateLimitPlan, RustAuthError> {
26 let table = dialect.quote_identifier(table)?;
27 let key = dialect.quote_identifier(key)?;
28 let count = dialect.quote_identifier(count)?;
29 let last_request = dialect.quote_identifier(last_request)?;
30 let insert_keyword = match dialect {
31 SqlDialect::Postgres | SqlDialect::MySql => "INSERT",
32 SqlDialect::Sqlite => "INSERT OR IGNORE",
33 };
34 let conflict_suffix = match dialect {
35 SqlDialect::Postgres => format!(" ON CONFLICT ({key}) DO NOTHING"),
36 SqlDialect::MySql => String::new(),
37 SqlDialect::Sqlite => String::new(),
38 };
39 let insert_prefix = match dialect {
40 SqlDialect::MySql => "INSERT IGNORE",
41 SqlDialect::Postgres | SqlDialect::Sqlite => insert_keyword,
42 };
43 let lock_suffix = match dialect {
44 SqlDialect::Postgres | SqlDialect::MySql => " FOR UPDATE",
45 SqlDialect::Sqlite => "",
46 };
47
48 Ok(SqlRateLimitPlan {
49 insert_ignore: SqlStatement::new(format!(
50 "{insert_prefix} INTO {table} ({key}, {count}, {last_request}) VALUES ({}, 0, {}){conflict_suffix}",
51 dialect.placeholder(1),
52 dialect.placeholder(2)
53 )),
54 select: SqlStatement::new(format!(
55 "SELECT {count} AS count, {last_request} AS last_request FROM {table} WHERE {key} = {}{lock_suffix}",
56 dialect.placeholder(1)
57 )),
58 update: SqlStatement::new(format!(
59 "UPDATE {table} SET {count} = {}, {last_request} = {} WHERE {key} = {}",
60 dialect.placeholder(1),
61 dialect.placeholder(2),
62 dialect.placeholder(3)
63 )),
64 })
65}
66
67pub fn consume_sql_rate_limit_record(
72 input: RateLimitConsumeInput,
73 existing: Option<RateLimitRecord>,
74) -> Result<(RateLimitDecision, RateLimitRecord, bool), RustAuthError> {
75 let window_ms = validate_rate_limit_rule(&input.rule)?;
76 Ok(match existing {
77 Some(record)
78 if input.now_ms.saturating_sub(record.last_request) <= window_ms
79 && record.count >= input.rule.max =>
80 {
81 let retry_ms = record
82 .last_request
83 .saturating_add(window_ms)
84 .saturating_sub(input.now_ms)
85 .max(0);
86 (
87 RateLimitDecision {
88 permitted: false,
89 retry_after: ceil_millis_to_seconds(retry_ms),
90 limit: input.rule.max,
91 remaining: 0,
92 reset_after: ceil_millis_to_seconds(retry_ms),
93 },
94 record,
95 true,
96 )
97 }
98 Some(mut record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
99 record.key = input.key;
100 record.count = record.count.saturating_add(1);
101 record.last_request = input.now_ms;
102 let remaining = input.rule.max.saturating_sub(record.count);
103 (
104 RateLimitDecision {
105 permitted: true,
106 retry_after: 0,
107 limit: input.rule.max,
108 remaining,
109 reset_after: input.rule.window.whole_seconds() as u64,
110 },
111 record,
112 true,
113 )
114 }
115 Some(mut record) => {
116 record.key = input.key;
117 record.count = 1;
118 record.last_request = input.now_ms;
119 (
120 RateLimitDecision {
121 permitted: true,
122 retry_after: 0,
123 limit: input.rule.max,
124 remaining: input.rule.max.saturating_sub(1),
125 reset_after: input.rule.window.whole_seconds() as u64,
126 },
127 record,
128 true,
129 )
130 }
131 None => {
132 let record = RateLimitRecord {
133 key: input.key,
134 count: 1,
135 last_request: input.now_ms,
136 };
137 (
138 RateLimitDecision {
139 permitted: true,
140 retry_after: 0,
141 limit: input.rule.max,
142 remaining: input.rule.max.saturating_sub(1),
143 reset_after: input.rule.window.whole_seconds() as u64,
144 },
145 record,
146 false,
147 )
148 }
149 })
150}
151
152fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
153 if milliseconds <= 0 {
154 return 0;
155 }
156 ((milliseconds as u64).saturating_add(999)) / 1000
157}