1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::time::Duration;
6use time::Duration as TimeDuration;
7
8use http::Request;
9
10use super::model_schema::ModelSchemaOptions;
11use crate::error::RustAuthError;
12
13#[derive(Clone)]
15pub struct RateLimitOptions {
16 pub schema: ModelSchemaOptions,
17 pub enabled: Option<bool>,
18 pub window: TimeDuration,
19 pub max: u64,
20 pub storage: RateLimitStorageOption,
21 pub custom_rules: Vec<RateLimitPathRule>,
22 pub dynamic_rules: Vec<DynamicRateLimitPathRule>,
23 pub custom_store: Option<Arc<dyn RateLimitStore>>,
24 pub custom_storage: Option<Arc<dyn RateLimitStorage>>,
25 pub hybrid: HybridRateLimitOptions,
26 pub memory_cleanup_interval: Option<Duration>,
27 pub missing_ip_policy: MissingIpPolicy,
28}
29
30impl Default for RateLimitOptions {
31 fn default() -> Self {
32 Self {
33 schema: ModelSchemaOptions::default(),
34 enabled: None,
35 window: TimeDuration::seconds(10),
36 max: 100,
37 storage: RateLimitStorageOption::Memory,
38 custom_rules: Vec::new(),
39 dynamic_rules: Vec::new(),
40 custom_store: None,
41 custom_storage: None,
42 hybrid: HybridRateLimitOptions::default(),
43 memory_cleanup_interval: Some(Duration::from_secs(60 * 60)),
44 missing_ip_policy: MissingIpPolicy::default(),
45 }
46 }
47}
48
49impl RateLimitOptions {
50 pub fn new() -> Self {
51 Self::default()
52 }
53
54 pub fn builder() -> Self {
55 Self::new()
56 }
57
58 #[must_use]
59 pub fn schema(mut self, schema: ModelSchemaOptions) -> Self {
60 self.schema = schema;
61 self
62 }
63
64 pub fn memory() -> Self {
65 Self {
66 storage: RateLimitStorageOption::Memory,
67 ..Self::default()
68 }
69 }
70
71 pub fn database<S>(store: S) -> Self
72 where
73 S: RateLimitStore,
74 {
75 Self::database_arc(Arc::new(store))
76 }
77
78 pub fn database_arc(store: Arc<dyn RateLimitStore>) -> Self {
79 Self {
80 storage: RateLimitStorageOption::Database,
81 custom_store: Some(store),
82 ..Self::default()
83 }
84 }
85
86 pub fn secondary_storage<S>(store: S) -> Self
87 where
88 S: RateLimitStore,
89 {
90 Self::secondary_storage_arc(Arc::new(store))
91 }
92
93 pub fn secondary_storage_arc(store: Arc<dyn RateLimitStore>) -> Self {
94 Self {
95 storage: RateLimitStorageOption::SecondaryStorage,
96 custom_store: Some(store),
97 ..Self::default()
98 }
99 }
100
101 #[must_use]
102 pub fn enabled(mut self, enabled: bool) -> Self {
103 self.enabled = Some(enabled);
104 self
105 }
106
107 #[must_use]
108 pub fn window(mut self, window: TimeDuration) -> Self {
109 self.window = window;
110 self
111 }
112
113 #[must_use]
114 pub fn max(mut self, max: u64) -> Self {
115 self.max = max;
116 self
117 }
118
119 #[must_use]
120 pub fn storage(mut self, storage: RateLimitStorageOption) -> Self {
121 self.storage = storage;
122 self
123 }
124
125 #[must_use]
126 pub fn custom_store<S>(mut self, store: S) -> Self
127 where
128 S: RateLimitStore,
129 {
130 self.custom_store = Some(Arc::new(store));
131 self
132 }
133
134 #[must_use]
135 pub fn custom_store_arc(mut self, store: Arc<dyn RateLimitStore>) -> Self {
136 self.custom_store = Some(store);
137 self
138 }
139
140 #[must_use]
141 pub fn custom_storage(mut self, storage: Arc<dyn RateLimitStorage>) -> Self {
142 self.custom_storage = Some(storage);
143 self
144 }
145
146 #[must_use]
147 pub fn custom_rule(mut self, path: impl Into<String>, rule: RateLimitRule) -> Self {
148 self.custom_rules.push(RateLimitPathRule {
149 path: path.into(),
150 rule: Some(rule),
151 });
152 self
153 }
154
155 #[must_use]
156 pub fn disabled_path(mut self, path: impl Into<String>) -> Self {
157 self.custom_rules.push(RateLimitPathRule {
158 path: path.into(),
159 rule: None,
160 });
161 self
162 }
163
164 #[must_use]
165 pub fn dynamic_rule<P>(mut self, path: impl Into<String>, provider: P) -> Self
166 where
167 P: RateLimitRuleProvider,
168 {
169 self.dynamic_rules
170 .push(DynamicRateLimitPathRule::new(path, provider));
171 self
172 }
173
174 #[must_use]
175 pub fn hybrid(mut self, hybrid: HybridRateLimitOptions) -> Self {
176 self.hybrid = hybrid;
177 self
178 }
179
180 #[must_use]
181 pub fn memory_cleanup_interval(mut self, interval: Option<Duration>) -> Self {
182 self.memory_cleanup_interval = interval;
183 self
184 }
185
186 #[must_use]
187 pub fn missing_ip_policy(mut self, policy: MissingIpPolicy) -> Self {
188 self.missing_ip_policy = policy;
189 self
190 }
191}
192
193impl fmt::Debug for RateLimitOptions {
194 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
195 formatter
196 .debug_struct("RateLimitOptions")
197 .field("enabled", &self.enabled)
198 .field("window", &self.window)
199 .field("max", &self.max)
200 .field("storage", &self.storage)
201 .field("custom_rules", &self.custom_rules)
202 .field("dynamic_rules", &self.dynamic_rules)
203 .field(
204 "custom_store",
205 &self.custom_store.as_ref().map(|_| "<custom-store>"),
206 )
207 .field(
208 "custom_storage",
209 &self.custom_storage.as_ref().map(|_| "<custom-storage>"),
210 )
211 .field("hybrid", &self.hybrid)
212 .field("memory_cleanup_interval", &self.memory_cleanup_interval)
213 .field("missing_ip_policy", &self.missing_ip_policy)
214 .finish()
215 }
216}
217
218#[derive(Debug, Clone, PartialEq, Eq)]
222pub struct RateLimitRule {
223 pub window: TimeDuration,
225 pub max: u64,
226}
227
228impl RateLimitRule {
229 pub fn new(window: TimeDuration, max: u64) -> Self {
230 Self { window, max }
231 }
232}
233
234pub fn validate_rate_limit_rule(rule: &RateLimitRule) -> Result<i64, RustAuthError> {
236 if rule.window.is_zero() {
237 return Err(RustAuthError::InvalidConfig(
238 "rate limit window must be greater than zero".to_owned(),
239 ));
240 }
241 if rule.max == 0 {
242 return Err(RustAuthError::InvalidConfig(
243 "rate limit max must be greater than zero".to_owned(),
244 ));
245 }
246 let milliseconds = rule.window.whole_milliseconds();
247 if milliseconds <= 0 {
248 return Err(RustAuthError::InvalidConfig(
249 "rate limit window must be greater than zero".to_owned(),
250 ));
251 }
252 let window_ms = i64::try_from(milliseconds)
253 .map_err(|_| RustAuthError::InvalidConfig("rate limit window is too large".to_owned()))?;
254 i64::try_from(rule.max)
255 .map_err(|_| RustAuthError::InvalidConfig("rate limit max must fit in i64".to_owned()))?;
256 Ok(window_ms)
257}
258
259#[derive(Debug, Clone, PartialEq, Eq)]
260pub struct HybridRateLimitOptions {
261 pub enabled: bool,
262 pub local_multiplier: u64,
263}
264
265impl Default for HybridRateLimitOptions {
266 fn default() -> Self {
267 Self {
268 enabled: false,
269 local_multiplier: 2,
270 }
271 }
272}
273
274impl HybridRateLimitOptions {
275 pub fn new() -> Self {
276 Self::default()
277 }
278
279 pub fn builder() -> Self {
280 Self::new()
281 }
282
283 pub fn enabled() -> Self {
284 Self {
285 enabled: true,
286 ..Self::default()
287 }
288 }
289
290 pub fn disabled() -> Self {
291 Self::default()
292 }
293
294 #[must_use]
295 pub fn set_enabled(mut self, enabled: bool) -> Self {
296 self.enabled = enabled;
297 self
298 }
299
300 #[must_use]
301 pub fn local_multiplier(mut self, multiplier: u64) -> Self {
302 self.local_multiplier = multiplier;
303 self
304 }
305}
306
307#[derive(Debug, Clone, PartialEq, Eq)]
308pub struct RateLimitPathRule {
309 pub path: String,
310 pub rule: Option<RateLimitRule>,
311}
312
313pub trait RateLimitRuleProvider: Send + Sync + 'static {
314 fn resolve(
315 &self,
316 request: &Request<Vec<u8>>,
317 current_rule: &RateLimitRule,
318 ) -> Result<Option<RateLimitRule>, RustAuthError>;
319}
320
321impl<F> RateLimitRuleProvider for F
322where
323 F: Fn(&Request<Vec<u8>>, &RateLimitRule) -> Result<Option<RateLimitRule>, RustAuthError>
324 + Send
325 + Sync
326 + 'static,
327{
328 fn resolve(
329 &self,
330 request: &Request<Vec<u8>>,
331 current_rule: &RateLimitRule,
332 ) -> Result<Option<RateLimitRule>, RustAuthError> {
333 self(request, current_rule)
334 }
335}
336
337#[derive(Clone)]
338pub struct DynamicRateLimitPathRule {
339 pub path: String,
340 pub provider: Arc<dyn RateLimitRuleProvider>,
341}
342
343impl DynamicRateLimitPathRule {
344 pub fn new<P>(path: impl Into<String>, provider: P) -> Self
345 where
346 P: RateLimitRuleProvider,
347 {
348 Self {
349 path: path.into(),
350 provider: Arc::new(provider),
351 }
352 }
353}
354
355impl fmt::Debug for DynamicRateLimitPathRule {
356 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
357 formatter
358 .debug_struct("DynamicRateLimitPathRule")
359 .field("path", &self.path)
360 .field("provider", &"<request-aware>")
361 .finish()
362 }
363}
364
365#[derive(Debug, Clone, PartialEq, Eq)]
367pub struct RateLimitRecord {
368 pub key: String,
369 pub count: u64,
370 pub last_request: i64,
371}
372
373#[derive(Debug, Clone, PartialEq, Eq)]
374pub struct RateLimitConsumeInput {
375 pub key: String,
376 pub rule: RateLimitRule,
377 pub now_ms: i64,
379}
380
381#[derive(Debug, Clone, PartialEq, Eq)]
382pub struct RateLimitDecision {
383 pub permitted: bool,
384 pub retry_after: u64,
386 pub limit: u64,
387 pub remaining: u64,
388 pub reset_after: u64,
389}
390
391pub type RateLimitFuture<'a> =
392 Pin<Box<dyn Future<Output = Result<RateLimitDecision, RustAuthError>> + Send + 'a>>;
393
394pub trait RateLimitStore: Send + Sync + 'static {
399 fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a>;
400}
401
402pub trait RateLimitStorage: Send + Sync + 'static {
408 fn get(&self, key: &str) -> Result<Option<RateLimitRecord>, RustAuthError>;
409 fn set(
410 &self,
411 key: &str,
412 value: RateLimitRecord,
413 ttl_seconds: u64,
414 update: bool,
415 ) -> Result<(), RustAuthError>;
416}
417
418#[derive(Debug, Clone, Copy, PartialEq, Eq)]
420pub enum RateLimitStorageOption {
421 Memory,
422 Database,
423 SecondaryStorage,
424}
425
426#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
436pub enum MissingIpPolicy {
437 #[default]
439 Deny,
440 SharedBucket,
443 Allow,
445}