1use crate::context::AuthContext;
4use crate::env::allows_development_defaults;
5use crate::error::RustAuthError;
6use crate::options::{
7 validate_rate_limit_rule, MissingIpPolicy, RateLimitConsumeInput, RateLimitDecision,
8 RateLimitFuture, RateLimitRecord, RateLimitRule, RateLimitStorage, RateLimitStorageOption,
9 RateLimitStore,
10};
11use crate::utils::ip::{
12 create_rate_limit_key, create_rate_limit_key_with_suffix, is_valid_ip,
13 normalize_ip_with_options, NormalizeIpOptions,
14};
15use crate::utils::url::normalize_pathname;
16use hmac::{Hmac, Mac};
17use http::Request;
18use sha2::Sha256;
19use std::collections::HashMap;
20use std::net::IpAddr;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24pub type Body = Vec<u8>;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct RateLimitRejection {
28 pub retry_after: u64,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct RequestClientIp(pub IpAddr);
34
35#[derive(Default)]
36pub struct GovernorMemoryRateLimitStore {
37 records: Mutex<HashMap<String, MemoryRateLimitRecord>>,
38 cleanup_interval: Option<Duration>,
39 last_cleanup: Mutex<Option<Instant>>,
40}
41
42#[derive(Debug, Clone)]
43struct MemoryRateLimitRecord {
44 count: u64,
45 last_request: i64,
46 window_ms: i64,
47}
48
49impl GovernorMemoryRateLimitStore {
50 pub fn new() -> Self {
51 Self::with_cleanup_interval(Some(Duration::from_secs(60 * 60)))
52 }
53
54 pub fn with_cleanup_interval(cleanup_interval: Option<Duration>) -> Self {
55 Self {
56 records: Mutex::new(HashMap::new()),
57 cleanup_interval,
58 last_cleanup: Mutex::new(None),
59 }
60 }
61
62 fn cleanup_if_due(&self, now_ms: i64) -> Result<(), RustAuthError> {
63 let Some(interval) = self.cleanup_interval else {
64 return Ok(());
65 };
66
67 let mut last_cleanup =
68 self.last_cleanup
69 .lock()
70 .map_err(|_| RustAuthError::LockPoisoned {
71 context: "rate limit cleanup",
72 })?;
73 let now = Instant::now();
74 if last_cleanup
75 .as_ref()
76 .is_some_and(|last| last.elapsed() < interval)
77 {
78 return Ok(());
79 }
80 *last_cleanup = Some(now);
81 drop(last_cleanup);
82
83 self.records
84 .lock()
85 .map_err(|_| RustAuthError::LockPoisoned {
86 context: "rate limit store",
87 })?
88 .retain(|_, record| now_ms.saturating_sub(record.last_request) <= record.window_ms);
89 Ok(())
90 }
91}
92
93impl RateLimitStore for GovernorMemoryRateLimitStore {
94 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
95 Box::pin(async move {
96 let window_ms = validate_rate_limit_rule(&input.rule)?;
97 self.cleanup_if_due(input.now_ms)?;
98 let mut records = self
99 .records
100 .lock()
101 .map_err(|_| RustAuthError::LockPoisoned {
102 context: "rate limit store",
103 })?;
104 let decision = match records.get_mut(&input.key) {
105 Some(record)
106 if input.now_ms.saturating_sub(record.last_request) <= window_ms
107 && record.count >= input.rule.max =>
108 {
109 denied_decision(&input, record.last_request)
110 }
111 Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
112 record.count = record.count.saturating_add(1);
113 record.last_request = input.now_ms;
114 record.window_ms = window_ms;
115 allowed_decision(&input, record.count)
116 }
117 _ => {
118 records.insert(
119 input.key.clone(),
120 MemoryRateLimitRecord {
121 count: 1,
122 last_request: input.now_ms,
123 window_ms,
124 },
125 );
126 allowed_decision(&input, 1)
127 }
128 };
129 Ok(decision)
130 })
131 }
132}
133
134pub struct LegacyRateLimitStorageAdapter {
135 storage: Arc<dyn RateLimitStorage>,
136}
137
138impl LegacyRateLimitStorageAdapter {
139 pub fn new(storage: Arc<dyn RateLimitStorage>) -> Self {
140 Self { storage }
141 }
142}
143
144impl RateLimitStore for LegacyRateLimitStorageAdapter {
145 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
146 Box::pin(async move {
147 let window_ms = validate_rate_limit_rule(&input.rule)?;
148 let existing = self.storage.get(&input.key)?;
149 let decision = match existing {
150 Some(record)
151 if input.now_ms.saturating_sub(record.last_request) <= window_ms
152 && record.count >= input.rule.max =>
153 {
154 denied_decision(&input, record.last_request)
155 }
156 Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
157 let next_count = record.count.saturating_add(1);
158 self.storage.set(
159 &input.key,
160 RateLimitRecord {
161 key: input.key.clone(),
162 count: next_count,
163 last_request: input.now_ms,
164 },
165 input.rule.window.whole_seconds() as u64,
166 true,
167 )?;
168 allowed_decision(&input, next_count)
169 }
170 _ => {
171 self.storage.set(
172 &input.key,
173 RateLimitRecord {
174 key: input.key.clone(),
175 count: 1,
176 last_request: input.now_ms,
177 },
178 input.rule.window.whole_seconds() as u64,
179 existing.is_some(),
180 )?;
181 allowed_decision(&input, 1)
182 }
183 };
184 Ok(decision)
185 })
186 }
187}
188
189pub struct HybridRateLimitStore {
190 local: Arc<GovernorMemoryRateLimitStore>,
191 global: Arc<dyn RateLimitStore>,
192 local_multiplier: u64,
193}
194
195impl HybridRateLimitStore {
196 pub fn new(
197 local: Arc<GovernorMemoryRateLimitStore>,
198 global: Arc<dyn RateLimitStore>,
199 local_multiplier: u64,
200 ) -> Self {
201 Self {
202 local,
203 global,
204 local_multiplier: local_multiplier.max(1),
205 }
206 }
207}
208
209impl RateLimitStore for HybridRateLimitStore {
210 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
211 Box::pin(async move {
212 let local_input = RateLimitConsumeInput {
213 key: input.key.clone(),
214 rule: RateLimitRule {
215 window: input.rule.window,
216 max: input.rule.max.saturating_mul(self.local_multiplier).max(1),
217 },
218 now_ms: input.now_ms,
219 };
220 let local = self.local.consume(local_input).await?;
221 if !local.permitted {
222 return Ok(local);
223 }
224 self.global.consume(input).await
225 })
226 }
227}
228
229pub fn hash_rate_limit_scope(secret: &str, scope: &str) -> Result<String, RustAuthError> {
234 let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).map_err(|_| {
235 RustAuthError::InvalidConfig("secret is invalid for rate limit scope HMAC".to_owned())
236 })?;
237 mac.update(scope.as_bytes());
238 Ok(hex::encode(mac.finalize().into_bytes()))
239}
240
241pub async fn consume_scoped_rate_limit(
247 context: &AuthContext,
248 request: &Request<Body>,
249 path: &str,
250 scope: &str,
251 rule: RateLimitRule,
252) -> Result<Option<RateLimitRejection>, RustAuthError> {
253 if !context.rate_limit.enabled {
254 return Ok(None);
255 }
256 let scope_suffix = format!(
257 "challenge:{}",
258 hash_rate_limit_scope(&context.secret, scope)?
259 );
260 let key = match resolve_rate_limit_key_plan(context, request, path, Some(&scope_suffix))? {
261 RateLimitKeyPlan::Skip => return Ok(None),
262 RateLimitKeyPlan::Deny => {
263 return Ok(Some(RateLimitRejection {
264 retry_after: rule.window.whole_seconds() as u64,
265 }));
266 }
267 RateLimitKeyPlan::Consume { key } => key,
268 };
269 consume_rate_limit_bucket(context, key, rule).await
270}
271
272pub async fn consume_rate_limit(
273 context: &AuthContext,
274 request: &Request<Body>,
275) -> Result<Option<RateLimitRejection>, RustAuthError> {
276 if !context.rate_limit.enabled {
277 return Ok(None);
278 }
279 let config = match resolve_config(context, request)? {
280 RateLimitPlan::Skip => return Ok(None),
281 RateLimitPlan::Deny { retry_after } => {
282 return Ok(Some(RateLimitRejection { retry_after }));
283 }
284 RateLimitPlan::Consume(config) => config,
285 };
286 consume_rate_limit_bucket(context, config.key, config.rule).await
287}
288
289async fn consume_rate_limit_bucket(
290 context: &AuthContext,
291 key: String,
292 rule: RateLimitRule,
293) -> Result<Option<RateLimitRejection>, RustAuthError> {
294 let store = store(context)?;
295 let decision = store
296 .consume(RateLimitConsumeInput {
297 key,
298 rule,
299 now_ms: now_millis(),
300 })
301 .await?;
302 if decision.permitted {
303 return Ok(None);
304 }
305 Ok(Some(RateLimitRejection {
306 retry_after: decision.retry_after,
307 }))
308}
309
310pub fn on_request_rate_limit(
311 context: &AuthContext,
312 request: &Request<Body>,
313) -> Result<Option<RateLimitRejection>, RustAuthError> {
314 if !context.rate_limit.enabled {
315 return Ok(None);
316 }
317 match resolve_config(context, request)? {
318 RateLimitPlan::Skip => Ok(None),
319 RateLimitPlan::Deny { retry_after } => Ok(Some(RateLimitRejection { retry_after })),
320 RateLimitPlan::Consume(_) => Err(RustAuthError::Api(
321 "async rate limit storage requires AuthRouter::handle_async".to_owned(),
322 )),
323 }
324}
325
326pub fn on_response_rate_limit(
327 _context: &AuthContext,
328 _request: &Request<Body>,
329) -> Result<(), RustAuthError> {
330 Ok(())
331}
332
333#[derive(Debug)]
334struct ResolvedRateLimit {
335 key: String,
336 rule: RateLimitRule,
337}
338
339enum RateLimitPlan {
341 Skip,
343 Deny { retry_after: u64 },
345 Consume(ResolvedRateLimit),
347}
348
349enum RateLimitKeyPlan {
351 Skip,
352 Deny,
353 Consume { key: String },
354}
355
356const ANONYMOUS_IP_BUCKET: &str = "missing-ip";
360
361fn resolve_config(
362 context: &AuthContext,
363 request: &Request<Body>,
364) -> Result<RateLimitPlan, RustAuthError> {
365 let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
366 let Some(rule) = resolve_rule(context, request, &path)? else {
367 return Ok(RateLimitPlan::Skip);
368 };
369 match resolve_rate_limit_key_plan(context, request, &path, None)? {
370 RateLimitKeyPlan::Skip => Ok(RateLimitPlan::Skip),
371 RateLimitKeyPlan::Deny => Ok(RateLimitPlan::Deny {
372 retry_after: rule.window.whole_seconds() as u64,
373 }),
374 RateLimitKeyPlan::Consume { key } => {
375 Ok(RateLimitPlan::Consume(ResolvedRateLimit { key, rule }))
376 }
377 }
378}
379
380fn resolve_rate_limit_key_plan(
381 context: &AuthContext,
382 request: &Request<Body>,
383 path: &str,
384 key_suffix: Option<&str>,
385) -> Result<RateLimitKeyPlan, RustAuthError> {
386 if let Some(ip) = resolve_client_ip(context, request) {
387 let key = match key_suffix {
388 Some(suffix) => create_rate_limit_key_with_suffix(&ip, path, suffix),
389 None => create_rate_limit_key(&ip, path),
390 };
391 return Ok(RateLimitKeyPlan::Consume { key });
392 }
393 if context.options.advanced.ip_address.disable_ip_tracking {
397 return Ok(RateLimitKeyPlan::Skip);
398 }
399 match context.rate_limit.missing_ip_policy {
400 MissingIpPolicy::Allow => Ok(RateLimitKeyPlan::Skip),
401 MissingIpPolicy::SharedBucket => {
402 let key = match key_suffix {
403 Some(suffix) => {
404 create_rate_limit_key_with_suffix(ANONYMOUS_IP_BUCKET, path, suffix)
405 }
406 None => create_rate_limit_key(ANONYMOUS_IP_BUCKET, path),
407 };
408 Ok(RateLimitKeyPlan::Consume { key })
409 }
410 MissingIpPolicy::Deny => {
411 context.logger.warn(
412 "Rate limiting denied a request because no client IP could be resolved; inject RequestClientIp or set advanced.ip_address.headers",
413 &[path],
414 );
415 Ok(RateLimitKeyPlan::Deny)
416 }
417 }
418}
419
420fn resolve_rule(
421 context: &AuthContext,
422 request: &Request<Body>,
423 path: &str,
424) -> Result<Option<RateLimitRule>, RustAuthError> {
425 let mut rule = default_rule(context);
426 if let Some(special_rule) = default_special_rule(path) {
427 rule = special_rule;
428 }
429 for plugin_rule in &context.rate_limit.plugin_rules {
430 if path_matches(&plugin_rule.path, path) {
431 rule = plugin_rule.rule.clone();
432 break;
433 }
434 }
435 for custom_rule in &context.rate_limit.custom_rules {
436 if path_matches(&custom_rule.path, path) {
437 return Ok(custom_rule.rule.clone());
438 }
439 }
440 for dynamic_rule in &context.rate_limit.dynamic_rules {
441 if path_matches(&dynamic_rule.path, path) {
442 return dynamic_rule.provider.resolve(request, &rule);
443 }
444 }
445 Ok(Some(rule))
446}
447
448fn default_rule(context: &AuthContext) -> RateLimitRule {
449 RateLimitRule {
450 window: context.rate_limit.window,
451 max: context.rate_limit.max,
452 }
453}
454
455fn default_special_rule(path: &str) -> Option<RateLimitRule> {
456 if path.starts_with("/sign-in")
457 || path.starts_with("/sign-up")
458 || path.starts_with("/change-password")
459 || path.starts_with("/change-email")
460 {
461 return Some(RateLimitRule {
462 window: time::Duration::seconds(10),
463 max: 3,
464 });
465 }
466 if path == "/request-password-reset"
467 || path == "/send-verification-email"
468 || path.starts_with("/forget-password")
469 || path == "/email-otp/send-verification-otp"
470 || path == "/email-otp/request-password-reset"
471 {
472 return Some(RateLimitRule {
473 window: time::Duration::seconds(60),
474 max: 3,
475 });
476 }
477 None
478}
479
480pub fn resolve_client_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
489 if context.options.advanced.ip_address.disable_ip_tracking {
490 return None;
491 }
492
493 for header_name in &context.options.advanced.ip_address.headers {
494 if let Some(value) = request
495 .headers()
496 .get(header_name)
497 .and_then(|value| value.to_str().ok())
498 {
499 let Some(candidate) = value.split(',').next().map(str::trim) else {
500 continue;
501 };
502 if is_valid_ip(candidate) {
503 return Some(normalize_ip_with_options(
504 candidate,
505 NormalizeIpOptions {
506 ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
507 },
508 ));
509 }
510 }
511 }
512
513 if let Some(client_ip) = request.extensions().get::<RequestClientIp>() {
514 return Some(normalize_ip_with_options(
515 &client_ip.0.to_string(),
516 NormalizeIpOptions {
517 ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
518 },
519 ));
520 }
521
522 if allows_development_defaults(&context.options) {
523 return Some("127.0.0.1".to_owned());
524 }
525
526 None
527}
528
529fn store(context: &AuthContext) -> Result<Arc<dyn RateLimitStore>, RustAuthError> {
530 if let Some(store) = &context.rate_limit.custom_store {
531 if context.rate_limit.hybrid.enabled {
532 return Ok(Arc::new(HybridRateLimitStore::new(
533 Arc::clone(&context.rate_limit.memory_store),
534 Arc::clone(store),
535 context.rate_limit.hybrid.local_multiplier,
536 )));
537 }
538 return Ok(Arc::clone(store));
539 }
540 match context.rate_limit.storage {
541 RateLimitStorageOption::Memory => Ok(context.rate_limit.memory_store.clone()),
542 RateLimitStorageOption::Database => Err(RustAuthError::InvalidConfig(
543 "database rate limit storage requires a concrete RateLimitStore".to_owned(),
544 )),
545 RateLimitStorageOption::SecondaryStorage => Err(RustAuthError::InvalidConfig(
546 "secondary-storage rate limit storage requires a concrete RateLimitStore".to_owned(),
547 )),
548 }
549}
550
551fn allowed_decision(input: &RateLimitConsumeInput, count: u64) -> RateLimitDecision {
552 RateLimitDecision {
553 permitted: true,
554 retry_after: 0,
555 limit: input.rule.max,
556 remaining: input.rule.max.saturating_sub(count),
557 reset_after: input.rule.window.whole_seconds() as u64,
558 }
559}
560
561fn denied_decision(input: &RateLimitConsumeInput, last_request: i64) -> RateLimitDecision {
562 let window_ms = i64::try_from(input.rule.window.whole_milliseconds()).unwrap_or(i64::MAX);
563 let retry_after = last_request
564 .saturating_add(window_ms)
565 .saturating_sub(input.now_ms)
566 .max(0);
567 RateLimitDecision {
568 permitted: false,
569 retry_after: ceil_millis_to_seconds(retry_after),
570 limit: input.rule.max,
571 remaining: 0,
572 reset_after: ceil_millis_to_seconds(retry_after),
573 }
574}
575
576fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
577 if milliseconds <= 0 {
578 return 0;
579 }
580 ((milliseconds as u64).saturating_add(999)) / 1000
581}
582
583fn path_matches(pattern: &str, path: &str) -> bool {
584 if let Some((prefix, suffix)) = pattern.split_once('*') {
585 return path.starts_with(prefix) && path.ends_with(suffix);
586 }
587 pattern == path
588}
589
590fn now_millis() -> i64 {
591 time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000_000
592}