Skip to main content

rustauth_core/
rate_limit.rs

1//! Router-level rate limiting.
2
3use crate::context::AuthContext;
4use crate::env::allows_development_defaults;
5use crate::error::RustAuthError;
6use crate::options::{
7    validate_rate_limit_rule, MissingIpPolicy, RateLimitConsumeInput, RateLimitDecision,
8    RateLimitFuture, RateLimitRecord, RateLimitRule, RateLimitStorage, RateLimitStorageOption,
9    RateLimitStore,
10};
11use crate::utils::ip::{
12    create_rate_limit_key, create_rate_limit_key_with_suffix, is_valid_ip,
13    normalize_ip_with_options, NormalizeIpOptions,
14};
15use crate::utils::url::normalize_pathname;
16use hmac::{Hmac, Mac};
17use http::Request;
18use sha2::Sha256;
19use std::collections::HashMap;
20use std::net::IpAddr;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24pub type Body = Vec<u8>;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct RateLimitRejection {
28    pub retry_after: u64,
29}
30
31/// Framework-neutral client IP resolved by an HTTP adapter.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33pub struct RequestClientIp(pub IpAddr);
34
35#[derive(Default)]
36pub struct GovernorMemoryRateLimitStore {
37    records: Mutex<HashMap<String, MemoryRateLimitRecord>>,
38    cleanup_interval: Option<Duration>,
39    last_cleanup: Mutex<Option<Instant>>,
40}
41
42#[derive(Debug, Clone)]
43struct MemoryRateLimitRecord {
44    count: u64,
45    last_request: i64,
46    window_ms: i64,
47}
48
49impl GovernorMemoryRateLimitStore {
50    pub fn new() -> Self {
51        Self::with_cleanup_interval(Some(Duration::from_secs(60 * 60)))
52    }
53
54    pub fn with_cleanup_interval(cleanup_interval: Option<Duration>) -> Self {
55        Self {
56            records: Mutex::new(HashMap::new()),
57            cleanup_interval,
58            last_cleanup: Mutex::new(None),
59        }
60    }
61
62    fn cleanup_if_due(&self, now_ms: i64) -> Result<(), RustAuthError> {
63        let Some(interval) = self.cleanup_interval else {
64            return Ok(());
65        };
66
67        let mut last_cleanup =
68            self.last_cleanup
69                .lock()
70                .map_err(|_| RustAuthError::LockPoisoned {
71                    context: "rate limit cleanup",
72                })?;
73        let now = Instant::now();
74        if last_cleanup
75            .as_ref()
76            .is_some_and(|last| last.elapsed() < interval)
77        {
78            return Ok(());
79        }
80        *last_cleanup = Some(now);
81        drop(last_cleanup);
82
83        self.records
84            .lock()
85            .map_err(|_| RustAuthError::LockPoisoned {
86                context: "rate limit store",
87            })?
88            .retain(|_, record| now_ms.saturating_sub(record.last_request) <= record.window_ms);
89        Ok(())
90    }
91}
92
93impl RateLimitStore for GovernorMemoryRateLimitStore {
94    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
95        Box::pin(async move {
96            let window_ms = validate_rate_limit_rule(&input.rule)?;
97            self.cleanup_if_due(input.now_ms)?;
98            let mut records = self
99                .records
100                .lock()
101                .map_err(|_| RustAuthError::LockPoisoned {
102                    context: "rate limit store",
103                })?;
104            let decision = match records.get_mut(&input.key) {
105                Some(record)
106                    if input.now_ms.saturating_sub(record.last_request) <= window_ms
107                        && record.count >= input.rule.max =>
108                {
109                    denied_decision(&input, record.last_request)
110                }
111                Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
112                    record.count = record.count.saturating_add(1);
113                    record.last_request = input.now_ms;
114                    record.window_ms = window_ms;
115                    allowed_decision(&input, record.count)
116                }
117                _ => {
118                    records.insert(
119                        input.key.clone(),
120                        MemoryRateLimitRecord {
121                            count: 1,
122                            last_request: input.now_ms,
123                            window_ms,
124                        },
125                    );
126                    allowed_decision(&input, 1)
127                }
128            };
129            Ok(decision)
130        })
131    }
132}
133
134pub struct LegacyRateLimitStorageAdapter {
135    storage: Arc<dyn RateLimitStorage>,
136}
137
138impl LegacyRateLimitStorageAdapter {
139    pub fn new(storage: Arc<dyn RateLimitStorage>) -> Self {
140        Self { storage }
141    }
142}
143
144impl RateLimitStore for LegacyRateLimitStorageAdapter {
145    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
146        Box::pin(async move {
147            let window_ms = validate_rate_limit_rule(&input.rule)?;
148            let existing = self.storage.get(&input.key)?;
149            let decision = match existing {
150                Some(record)
151                    if input.now_ms.saturating_sub(record.last_request) <= window_ms
152                        && record.count >= input.rule.max =>
153                {
154                    denied_decision(&input, record.last_request)
155                }
156                Some(record) if input.now_ms.saturating_sub(record.last_request) <= window_ms => {
157                    let next_count = record.count.saturating_add(1);
158                    self.storage.set(
159                        &input.key,
160                        RateLimitRecord {
161                            key: input.key.clone(),
162                            count: next_count,
163                            last_request: input.now_ms,
164                        },
165                        input.rule.window.whole_seconds() as u64,
166                        true,
167                    )?;
168                    allowed_decision(&input, next_count)
169                }
170                _ => {
171                    self.storage.set(
172                        &input.key,
173                        RateLimitRecord {
174                            key: input.key.clone(),
175                            count: 1,
176                            last_request: input.now_ms,
177                        },
178                        input.rule.window.whole_seconds() as u64,
179                        existing.is_some(),
180                    )?;
181                    allowed_decision(&input, 1)
182                }
183            };
184            Ok(decision)
185        })
186    }
187}
188
189pub struct HybridRateLimitStore {
190    local: Arc<GovernorMemoryRateLimitStore>,
191    global: Arc<dyn RateLimitStore>,
192    local_multiplier: u64,
193}
194
195impl HybridRateLimitStore {
196    pub fn new(
197        local: Arc<GovernorMemoryRateLimitStore>,
198        global: Arc<dyn RateLimitStore>,
199        local_multiplier: u64,
200    ) -> Self {
201        Self {
202            local,
203            global,
204            local_multiplier: local_multiplier.max(1),
205        }
206    }
207}
208
209impl RateLimitStore for HybridRateLimitStore {
210    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
211        Box::pin(async move {
212            let local_input = RateLimitConsumeInput {
213                key: input.key.clone(),
214                rule: RateLimitRule {
215                    window: input.rule.window,
216                    max: input.rule.max.saturating_mul(self.local_multiplier).max(1),
217                },
218                now_ms: input.now_ms,
219            };
220            let local = self.local.consume(local_input).await?;
221            if !local.permitted {
222                return Ok(local);
223            }
224            self.global.consume(input).await
225        })
226    }
227}
228
229/// Derive a stable, non-reversible rate-limit scope identifier.
230///
231/// Uses `HMAC-SHA256(secret, scope)` hex-encoded so storage keys never contain
232/// raw challenge tokens or other client-controlled secrets.
233pub fn hash_rate_limit_scope(secret: &str, scope: &str) -> Result<String, RustAuthError> {
234    let mut mac = Hmac::<Sha256>::new_from_slice(secret.as_bytes()).map_err(|_| {
235        RustAuthError::InvalidConfig("secret is invalid for rate limit scope HMAC".to_owned())
236    })?;
237    mac.update(scope.as_bytes());
238    Ok(hex::encode(mac.finalize().into_bytes()))
239}
240
241/// Consume a rate-limit bucket keyed by client IP, path, and an opaque scope.
242///
243/// Scope values are digested with [`hash_rate_limit_scope`] before being stored.
244/// Returns `None` when rate limiting is disabled, the request is permitted, or no
245/// client IP can be resolved under the configured [`MissingIpPolicy::Allow`] policy.
246pub async fn consume_scoped_rate_limit(
247    context: &AuthContext,
248    request: &Request<Body>,
249    path: &str,
250    scope: &str,
251    rule: RateLimitRule,
252) -> Result<Option<RateLimitRejection>, RustAuthError> {
253    if !context.rate_limit.enabled {
254        return Ok(None);
255    }
256    let scope_suffix = format!(
257        "challenge:{}",
258        hash_rate_limit_scope(&context.secret, scope)?
259    );
260    let key = match resolve_rate_limit_key_plan(context, request, path, Some(&scope_suffix))? {
261        RateLimitKeyPlan::Skip => return Ok(None),
262        RateLimitKeyPlan::Deny => {
263            return Ok(Some(RateLimitRejection {
264                retry_after: rule.window.whole_seconds() as u64,
265            }));
266        }
267        RateLimitKeyPlan::Consume { key } => key,
268    };
269    consume_rate_limit_bucket(context, key, rule).await
270}
271
272pub async fn consume_rate_limit(
273    context: &AuthContext,
274    request: &Request<Body>,
275) -> Result<Option<RateLimitRejection>, RustAuthError> {
276    if !context.rate_limit.enabled {
277        return Ok(None);
278    }
279    let config = match resolve_config(context, request)? {
280        RateLimitPlan::Skip => return Ok(None),
281        RateLimitPlan::Deny { retry_after } => {
282            return Ok(Some(RateLimitRejection { retry_after }));
283        }
284        RateLimitPlan::Consume(config) => config,
285    };
286    consume_rate_limit_bucket(context, config.key, config.rule).await
287}
288
289async fn consume_rate_limit_bucket(
290    context: &AuthContext,
291    key: String,
292    rule: RateLimitRule,
293) -> Result<Option<RateLimitRejection>, RustAuthError> {
294    let store = store(context)?;
295    let decision = store
296        .consume(RateLimitConsumeInput {
297            key,
298            rule,
299            now_ms: now_millis(),
300        })
301        .await?;
302    if decision.permitted {
303        return Ok(None);
304    }
305    Ok(Some(RateLimitRejection {
306        retry_after: decision.retry_after,
307    }))
308}
309
310pub fn on_request_rate_limit(
311    context: &AuthContext,
312    request: &Request<Body>,
313) -> Result<Option<RateLimitRejection>, RustAuthError> {
314    if !context.rate_limit.enabled {
315        return Ok(None);
316    }
317    match resolve_config(context, request)? {
318        RateLimitPlan::Skip => Ok(None),
319        RateLimitPlan::Deny { retry_after } => Ok(Some(RateLimitRejection { retry_after })),
320        RateLimitPlan::Consume(_) => Err(RustAuthError::Api(
321            "async rate limit storage requires AuthRouter::handle_async".to_owned(),
322        )),
323    }
324}
325
326pub fn on_response_rate_limit(
327    _context: &AuthContext,
328    _request: &Request<Body>,
329) -> Result<(), RustAuthError> {
330    Ok(())
331}
332
333#[derive(Debug)]
334struct ResolvedRateLimit {
335    key: String,
336    rule: RateLimitRule,
337}
338
339/// Outcome of resolving how a request should be rate limited.
340enum RateLimitPlan {
341    /// No applicable rule, or IP tracking is intentionally disabled.
342    Skip,
343    /// Rate limiting is enabled with a rule but no client IP could be resolved.
344    Deny { retry_after: u64 },
345    /// Consume the resolved rule against the resolved bucket key.
346    Consume(ResolvedRateLimit),
347}
348
349/// Outcome of resolving a rate-limit bucket key.
350enum RateLimitKeyPlan {
351    Skip,
352    Deny,
353    Consume { key: String },
354}
355
356/// Shared bucket key segment used when no client IP can be resolved and the
357/// configured policy is [`MissingIpPolicy::SharedBucket`]. It is not a valid IP,
358/// so it never collides with a real per-IP bucket.
359const ANONYMOUS_IP_BUCKET: &str = "missing-ip";
360
361fn resolve_config(
362    context: &AuthContext,
363    request: &Request<Body>,
364) -> Result<RateLimitPlan, RustAuthError> {
365    let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
366    let Some(rule) = resolve_rule(context, request, &path)? else {
367        return Ok(RateLimitPlan::Skip);
368    };
369    match resolve_rate_limit_key_plan(context, request, &path, None)? {
370        RateLimitKeyPlan::Skip => Ok(RateLimitPlan::Skip),
371        RateLimitKeyPlan::Deny => Ok(RateLimitPlan::Deny {
372            retry_after: rule.window.whole_seconds() as u64,
373        }),
374        RateLimitKeyPlan::Consume { key } => {
375            Ok(RateLimitPlan::Consume(ResolvedRateLimit { key, rule }))
376        }
377    }
378}
379
380fn resolve_rate_limit_key_plan(
381    context: &AuthContext,
382    request: &Request<Body>,
383    path: &str,
384    key_suffix: Option<&str>,
385) -> Result<RateLimitKeyPlan, RustAuthError> {
386    if let Some(ip) = resolve_client_ip(context, request) {
387        let key = match key_suffix {
388            Some(suffix) => create_rate_limit_key_with_suffix(&ip, path, suffix),
389            None => create_rate_limit_key(&ip, path),
390        };
391        return Ok(RateLimitKeyPlan::Consume { key });
392    }
393    // No client IP could be resolved. When IP tracking is intentionally
394    // disabled, per-IP limiting cannot apply, so skip. Otherwise apply the
395    // configured fail-closed policy instead of silently bypassing the limit.
396    if context.options.advanced.ip_address.disable_ip_tracking {
397        return Ok(RateLimitKeyPlan::Skip);
398    }
399    match context.rate_limit.missing_ip_policy {
400        MissingIpPolicy::Allow => Ok(RateLimitKeyPlan::Skip),
401        MissingIpPolicy::SharedBucket => {
402            let key = match key_suffix {
403                Some(suffix) => {
404                    create_rate_limit_key_with_suffix(ANONYMOUS_IP_BUCKET, path, suffix)
405                }
406                None => create_rate_limit_key(ANONYMOUS_IP_BUCKET, path),
407            };
408            Ok(RateLimitKeyPlan::Consume { key })
409        }
410        MissingIpPolicy::Deny => {
411            context.logger.warn(
412                "Rate limiting denied a request because no client IP could be resolved; inject RequestClientIp or set advanced.ip_address.headers",
413                &[path],
414            );
415            Ok(RateLimitKeyPlan::Deny)
416        }
417    }
418}
419
420fn resolve_rule(
421    context: &AuthContext,
422    request: &Request<Body>,
423    path: &str,
424) -> Result<Option<RateLimitRule>, RustAuthError> {
425    let mut rule = default_rule(context);
426    if let Some(special_rule) = default_special_rule(path) {
427        rule = special_rule;
428    }
429    for plugin_rule in &context.rate_limit.plugin_rules {
430        if path_matches(&plugin_rule.path, path) {
431            rule = plugin_rule.rule.clone();
432            break;
433        }
434    }
435    for custom_rule in &context.rate_limit.custom_rules {
436        if path_matches(&custom_rule.path, path) {
437            return Ok(custom_rule.rule.clone());
438        }
439    }
440    for dynamic_rule in &context.rate_limit.dynamic_rules {
441        if path_matches(&dynamic_rule.path, path) {
442            return dynamic_rule.provider.resolve(request, &rule);
443        }
444    }
445    Ok(Some(rule))
446}
447
448fn default_rule(context: &AuthContext) -> RateLimitRule {
449    RateLimitRule {
450        window: context.rate_limit.window,
451        max: context.rate_limit.max,
452    }
453}
454
455fn default_special_rule(path: &str) -> Option<RateLimitRule> {
456    if path.starts_with("/sign-in")
457        || path.starts_with("/sign-up")
458        || path.starts_with("/change-password")
459        || path.starts_with("/change-email")
460    {
461        return Some(RateLimitRule {
462            window: time::Duration::seconds(10),
463            max: 3,
464        });
465    }
466    if path == "/request-password-reset"
467        || path == "/send-verification-email"
468        || path.starts_with("/forget-password")
469        || path == "/email-otp/send-verification-otp"
470        || path == "/email-otp/request-password-reset"
471    {
472        return Some(RateLimitRule {
473            window: time::Duration::seconds(60),
474            max: 3,
475        });
476    }
477    None
478}
479
480/// Resolve the trusted client IP for a request using `advanced.ip_address`
481/// configuration. Shared by rate limiting and request metadata so the two
482/// never disagree about the same request. Returns `None` when no trusted IP
483/// can be resolved instead of trusting raw forwarding headers.
484///
485/// Exposed so plugin crates that create sessions outside the core auth flows
486/// (e.g. passkey login) persist the same validated client IP rather than
487/// trusting raw forwarding headers.
488pub fn resolve_client_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
489    if context.options.advanced.ip_address.disable_ip_tracking {
490        return None;
491    }
492
493    for header_name in &context.options.advanced.ip_address.headers {
494        if let Some(value) = request
495            .headers()
496            .get(header_name)
497            .and_then(|value| value.to_str().ok())
498        {
499            let Some(candidate) = value.split(',').next().map(str::trim) else {
500                continue;
501            };
502            if is_valid_ip(candidate) {
503                return Some(normalize_ip_with_options(
504                    candidate,
505                    NormalizeIpOptions {
506                        ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
507                    },
508                ));
509            }
510        }
511    }
512
513    if let Some(client_ip) = request.extensions().get::<RequestClientIp>() {
514        return Some(normalize_ip_with_options(
515            &client_ip.0.to_string(),
516            NormalizeIpOptions {
517                ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
518            },
519        ));
520    }
521
522    if allows_development_defaults(&context.options) {
523        return Some("127.0.0.1".to_owned());
524    }
525
526    None
527}
528
529fn store(context: &AuthContext) -> Result<Arc<dyn RateLimitStore>, RustAuthError> {
530    if let Some(store) = &context.rate_limit.custom_store {
531        if context.rate_limit.hybrid.enabled {
532            return Ok(Arc::new(HybridRateLimitStore::new(
533                Arc::clone(&context.rate_limit.memory_store),
534                Arc::clone(store),
535                context.rate_limit.hybrid.local_multiplier,
536            )));
537        }
538        return Ok(Arc::clone(store));
539    }
540    match context.rate_limit.storage {
541        RateLimitStorageOption::Memory => Ok(context.rate_limit.memory_store.clone()),
542        RateLimitStorageOption::Database => Err(RustAuthError::InvalidConfig(
543            "database rate limit storage requires a concrete RateLimitStore".to_owned(),
544        )),
545        RateLimitStorageOption::SecondaryStorage => Err(RustAuthError::InvalidConfig(
546            "secondary-storage rate limit storage requires a concrete RateLimitStore".to_owned(),
547        )),
548    }
549}
550
551fn allowed_decision(input: &RateLimitConsumeInput, count: u64) -> RateLimitDecision {
552    RateLimitDecision {
553        permitted: true,
554        retry_after: 0,
555        limit: input.rule.max,
556        remaining: input.rule.max.saturating_sub(count),
557        reset_after: input.rule.window.whole_seconds() as u64,
558    }
559}
560
561fn denied_decision(input: &RateLimitConsumeInput, last_request: i64) -> RateLimitDecision {
562    let window_ms = i64::try_from(input.rule.window.whole_milliseconds()).unwrap_or(i64::MAX);
563    let retry_after = last_request
564        .saturating_add(window_ms)
565        .saturating_sub(input.now_ms)
566        .max(0);
567    RateLimitDecision {
568        permitted: false,
569        retry_after: ceil_millis_to_seconds(retry_after),
570        limit: input.rule.max,
571        remaining: 0,
572        reset_after: ceil_millis_to_seconds(retry_after),
573    }
574}
575
576fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
577    if milliseconds <= 0 {
578        return 0;
579    }
580    ((milliseconds as u64).saturating_add(999)) / 1000
581}
582
583fn path_matches(pattern: &str, path: &str) -> bool {
584    if let Some((prefix, suffix)) = pattern.split_once('*') {
585        return path.starts_with(prefix) && path.ends_with(suffix);
586    }
587    pattern == path
588}
589
590fn now_millis() -> i64 {
591    time::OffsetDateTime::now_utc().unix_timestamp_nanos() as i64 / 1_000_000
592}