Skip to main content

rustauth_core/options/
rate_limit.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::Duration;
6use time::Duration as TimeDuration;
7
8use http::Request;
9
10use super::model_schema::ModelSchemaOptions;
11use crate::error::RustAuthError;
12
13/// Rate limiting defaults.
14#[derive(Clone)]
15pub struct RateLimitOptions {
16    pub schema: ModelSchemaOptions,
17    pub enabled: Option<bool>,
18    pub window: TimeDuration,
19    pub max: u64,
20    pub storage: RateLimitStorageOption,
21    pub custom_rules: Vec<RateLimitPathRule>,
22    pub dynamic_rules: Vec<DynamicRateLimitPathRule>,
23    pub custom_store: Option<Arc<dyn RateLimitStore>>,
24    pub custom_storage: Option<Arc<dyn RateLimitStorage>>,
25    pub hybrid: HybridRateLimitOptions,
26    pub memory_cleanup_interval: Option<Duration>,
27    pub missing_ip_policy: MissingIpPolicy,
28}
29
30impl Default for RateLimitOptions {
31    fn default() -> Self {
32        Self {
33            schema: ModelSchemaOptions::default(),
34            enabled: None,
35            window: TimeDuration::seconds(10),
36            max: 100,
37            storage: RateLimitStorageOption::Memory,
38            custom_rules: Vec::new(),
39            dynamic_rules: Vec::new(),
40            custom_store: None,
41            custom_storage: None,
42            hybrid: HybridRateLimitOptions::default(),
43            memory_cleanup_interval: Some(Duration::from_secs(60 * 60)),
44            missing_ip_policy: MissingIpPolicy::default(),
45        }
46    }
47}
48
49impl RateLimitOptions {
50    pub fn new() -> Self {
51        Self::default()
52    }
53
54    pub fn builder() -> Self {
55        Self::new()
56    }
57
58    #[must_use]
59    pub fn schema(mut self, schema: ModelSchemaOptions) -> Self {
60        self.schema = schema;
61        self
62    }
63
64    pub fn memory() -> Self {
65        Self {
66            storage: RateLimitStorageOption::Memory,
67            ..Self::default()
68        }
69    }
70
71    pub fn database<S>(store: S) -> Self
72    where
73        S: RateLimitStore,
74    {
75        Self::database_arc(Arc::new(store))
76    }
77
78    pub fn database_arc(store: Arc<dyn RateLimitStore>) -> Self {
79        Self {
80            storage: RateLimitStorageOption::Database,
81            custom_store: Some(store),
82            ..Self::default()
83        }
84    }
85
86    pub fn secondary_storage<S>(store: S) -> Self
87    where
88        S: RateLimitStore,
89    {
90        Self::secondary_storage_arc(Arc::new(store))
91    }
92
93    pub fn secondary_storage_arc(store: Arc<dyn RateLimitStore>) -> Self {
94        Self {
95            storage: RateLimitStorageOption::SecondaryStorage,
96            custom_store: Some(store),
97            ..Self::default()
98        }
99    }
100
101    #[must_use]
102    pub fn enabled(mut self, enabled: bool) -> Self {
103        self.enabled = Some(enabled);
104        self
105    }
106
107    #[must_use]
108    pub fn window(mut self, window: TimeDuration) -> Self {
109        self.window = window;
110        self
111    }
112
113    #[must_use]
114    pub fn max(mut self, max: u64) -> Self {
115        self.max = max;
116        self
117    }
118
119    #[must_use]
120    pub fn storage(mut self, storage: RateLimitStorageOption) -> Self {
121        self.storage = storage;
122        self
123    }
124
125    #[must_use]
126    pub fn custom_store<S>(mut self, store: S) -> Self
127    where
128        S: RateLimitStore,
129    {
130        self.custom_store = Some(Arc::new(store));
131        self
132    }
133
134    #[must_use]
135    pub fn custom_store_arc(mut self, store: Arc<dyn RateLimitStore>) -> Self {
136        self.custom_store = Some(store);
137        self
138    }
139
140    #[must_use]
141    pub fn custom_storage(mut self, storage: Arc<dyn RateLimitStorage>) -> Self {
142        self.custom_storage = Some(storage);
143        self
144    }
145
146    #[must_use]
147    pub fn custom_rule(mut self, path: impl Into<String>, rule: RateLimitRule) -> Self {
148        self.custom_rules.push(RateLimitPathRule {
149            path: path.into(),
150            rule: Some(rule),
151        });
152        self
153    }
154
155    #[must_use]
156    pub fn disabled_path(mut self, path: impl Into<String>) -> Self {
157        self.custom_rules.push(RateLimitPathRule {
158            path: path.into(),
159            rule: None,
160        });
161        self
162    }
163
164    #[must_use]
165    pub fn dynamic_rule<P>(mut self, path: impl Into<String>, provider: P) -> Self
166    where
167        P: RateLimitRuleProvider,
168    {
169        self.dynamic_rules
170            .push(DynamicRateLimitPathRule::new(path, provider));
171        self
172    }
173
174    #[must_use]
175    pub fn hybrid(mut self, hybrid: HybridRateLimitOptions) -> Self {
176        self.hybrid = hybrid;
177        self
178    }
179
180    #[must_use]
181    pub fn memory_cleanup_interval(mut self, interval: Option<Duration>) -> Self {
182        self.memory_cleanup_interval = interval;
183        self
184    }
185
186    #[must_use]
187    pub fn missing_ip_policy(mut self, policy: MissingIpPolicy) -> Self {
188        self.missing_ip_policy = policy;
189        self
190    }
191}
192
193impl fmt::Debug for RateLimitOptions {
194    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
195        formatter
196            .debug_struct("RateLimitOptions")
197            .field("enabled", &self.enabled)
198            .field("window", &self.window)
199            .field("max", &self.max)
200            .field("storage", &self.storage)
201            .field("custom_rules", &self.custom_rules)
202            .field("dynamic_rules", &self.dynamic_rules)
203            .field(
204                "custom_store",
205                &self.custom_store.as_ref().map(|_| "<custom-store>"),
206            )
207            .field(
208                "custom_storage",
209                &self.custom_storage.as_ref().map(|_| "<custom-storage>"),
210            )
211            .field("hybrid", &self.hybrid)
212            .field("memory_cleanup_interval", &self.memory_cleanup_interval)
213            .field("missing_ip_policy", &self.missing_ip_policy)
214            .finish()
215    }
216}
217
218/// A single rate-limit bucket rule.
219///
220/// `window` is the sliding window length in **seconds** (not milliseconds).
221#[derive(Debug, Clone, PartialEq, Eq)]
222pub struct RateLimitRule {
223    /// Sliding window length in seconds.
224    pub window: TimeDuration,
225    pub max: u64,
226}
227
228impl RateLimitRule {
229    pub fn new(window: TimeDuration, max: u64) -> Self {
230        Self { window, max }
231    }
232}
233
234/// Rejects invalid rate-limit rules before any store consumes a record.
235pub fn validate_rate_limit_rule(rule: &RateLimitRule) -> Result<i64, RustAuthError> {
236    if rule.window.is_zero() {
237        return Err(RustAuthError::InvalidConfig(
238            "rate limit window must be greater than zero".to_owned(),
239        ));
240    }
241    if rule.max == 0 {
242        return Err(RustAuthError::InvalidConfig(
243            "rate limit max must be greater than zero".to_owned(),
244        ));
245    }
246    let milliseconds = rule.window.whole_milliseconds();
247    if milliseconds <= 0 {
248        return Err(RustAuthError::InvalidConfig(
249            "rate limit window must be greater than zero".to_owned(),
250        ));
251    }
252    let window_ms = i64::try_from(milliseconds)
253        .map_err(|_| RustAuthError::InvalidConfig("rate limit window is too large".to_owned()))?;
254    i64::try_from(rule.max)
255        .map_err(|_| RustAuthError::InvalidConfig("rate limit max must fit in i64".to_owned()))?;
256    Ok(window_ms)
257}
258
259#[derive(Debug, Clone, PartialEq, Eq)]
260pub struct HybridRateLimitOptions {
261    pub enabled: bool,
262    pub local_multiplier: u64,
263}
264
265impl Default for HybridRateLimitOptions {
266    fn default() -> Self {
267        Self {
268            enabled: false,
269            local_multiplier: 2,
270        }
271    }
272}
273
274impl HybridRateLimitOptions {
275    pub fn new() -> Self {
276        Self::default()
277    }
278
279    pub fn builder() -> Self {
280        Self::new()
281    }
282
283    pub fn enabled() -> Self {
284        Self {
285            enabled: true,
286            ..Self::default()
287        }
288    }
289
290    pub fn disabled() -> Self {
291        Self::default()
292    }
293
294    #[must_use]
295    pub fn set_enabled(mut self, enabled: bool) -> Self {
296        self.enabled = enabled;
297        self
298    }
299
300    #[must_use]
301    pub fn local_multiplier(mut self, multiplier: u64) -> Self {
302        self.local_multiplier = multiplier;
303        self
304    }
305}
306
307#[derive(Debug, Clone, PartialEq, Eq)]
308pub struct RateLimitPathRule {
309    pub path: String,
310    pub rule: Option<RateLimitRule>,
311}
312
313pub trait RateLimitRuleProvider: Send + Sync + 'static {
314    fn resolve(
315        &self,
316        request: &Request<Vec<u8>>,
317        current_rule: &RateLimitRule,
318    ) -> Result<Option<RateLimitRule>, RustAuthError>;
319}
320
321impl<F> RateLimitRuleProvider for F
322where
323    F: Fn(&Request<Vec<u8>>, &RateLimitRule) -> Result<Option<RateLimitRule>, RustAuthError>
324        + Send
325        + Sync
326        + 'static,
327{
328    fn resolve(
329        &self,
330        request: &Request<Vec<u8>>,
331        current_rule: &RateLimitRule,
332    ) -> Result<Option<RateLimitRule>, RustAuthError> {
333        self(request, current_rule)
334    }
335}
336
337#[derive(Clone)]
338pub struct DynamicRateLimitPathRule {
339    pub path: String,
340    pub provider: Arc<dyn RateLimitRuleProvider>,
341}
342
343impl DynamicRateLimitPathRule {
344    pub fn new<P>(path: impl Into<String>, provider: P) -> Self
345    where
346        P: RateLimitRuleProvider,
347    {
348        Self {
349            path: path.into(),
350            provider: Arc::new(provider),
351        }
352    }
353}
354
355impl fmt::Debug for DynamicRateLimitPathRule {
356    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
357        formatter
358            .debug_struct("DynamicRateLimitPathRule")
359            .field("path", &self.path)
360            .field("provider", &"<request-aware>")
361            .finish()
362    }
363}
364
365/// Rate limit storage record.
366#[derive(Debug, Clone, PartialEq, Eq)]
367pub struct RateLimitRecord {
368    pub key: String,
369    pub count: u64,
370    pub last_request: i64,
371}
372
373#[derive(Debug, Clone, PartialEq, Eq)]
374pub struct RateLimitConsumeInput {
375    pub key: String,
376    pub rule: RateLimitRule,
377    /// Current time as Unix epoch **milliseconds**.
378    pub now_ms: i64,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq)]
382pub struct RateLimitDecision {
383    pub permitted: bool,
384    /// Seconds until the client may retry when `permitted` is false.
385    pub retry_after: u64,
386    pub limit: u64,
387    pub remaining: u64,
388    pub reset_after: u64,
389}
390
391pub type RateLimitFuture<'a> =
392    Pin<Box<dyn Future<Output = Result<RateLimitDecision, RustAuthError>> + Send + 'a>>;
393
394/// Atomic rate limit storage contract.
395///
396/// Implementations must make the check-and-increment decision in one atomic
397/// operation when used for cross-process or distributed enforcement.
398pub trait RateLimitStore: Send + Sync + 'static {
399    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a>;
400}
401
402/// Synchronous storage contract for router-level rate limiting.
403///
404/// This legacy contract is preserved for compatibility. It is not atomic across
405/// multiple processes unless the implementation makes `get`/`set` externally
406/// serializable.
407pub trait RateLimitStorage: Send + Sync + 'static {
408    fn get(&self, key: &str) -> Result<Option<RateLimitRecord>, RustAuthError>;
409    fn set(
410        &self,
411        key: &str,
412        value: RateLimitRecord,
413        ttl_seconds: u64,
414        update: bool,
415    ) -> Result<(), RustAuthError>;
416}
417
418/// Rate limit storage selector.
419#[derive(Debug, Clone, Copy, PartialEq, Eq)]
420pub enum RateLimitStorageOption {
421    Memory,
422    Database,
423    SecondaryStorage,
424}
425
426/// Policy applied when rate limiting is enabled but no client IP can be
427/// resolved for a request.
428///
429/// This guards against a production deployment that enables rate limiting but
430/// fails to inject [`RequestClientIp`](crate::rate_limit::RequestClientIp) or
431/// configure a trusted IP header, which would otherwise silently disable
432/// rate limiting on auth endpoints. The policy is only applied when IP
433/// tracking is enabled; if `advanced.ip_address.disable_ip_tracking` is set,
434/// per-IP limiting is intentionally skipped.
435#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
436pub enum MissingIpPolicy {
437    /// Reject the request (fail closed). Secure default.
438    #[default]
439    Deny,
440    /// Rate limit every IP-less request together under a shared anonymous
441    /// bucket instead of a per-IP bucket.
442    SharedBucket,
443    /// Skip rate limiting when no client IP can be resolved (legacy fail-open).
444    Allow,
445}