1use std::collections::HashMap;
21use std::sync::Arc;
22
23use crate::algorithm::Algorithm;
24use crate::decision::Decision;
25use crate::error::Result;
26use crate::key::Key;
27use crate::quota::Quota;
28use crate::storage::Storage;
29
30#[derive(Debug, Clone)]
32pub struct RouteConfig {
33 pub quota: Quota,
35 pub key_suffix: Option<String>,
37}
38
39impl RouteConfig {
40 pub fn new(quota: Quota) -> Self {
42 Self {
43 quota,
44 key_suffix: None,
45 }
46 }
47
48 pub fn with_key_suffix(mut self, suffix: impl Into<String>) -> Self {
50 self.key_suffix = Some(suffix.into());
51 self
52 }
53}
54
55impl From<Quota> for RouteConfig {
56 fn from(quota: Quota) -> Self {
57 Self::new(quota)
58 }
59}
60
61pub struct RateLimitManager<A, S, K> {
66 algorithm: A,
67 storage: Arc<S>,
68 key_extractor: K,
69 default_quota: Option<Quota>,
70 routes: HashMap<String, RouteConfig>,
71 patterns: Vec<(String, RouteConfig)>,
72}
73
74impl<A, S, K> RateLimitManager<A, S, K>
75where
76 A: Algorithm,
77 S: Storage,
78{
79 pub fn builder() -> RateLimitManagerBuilder<K> {
81 RateLimitManagerBuilder::new()
82 }
83
84 pub async fn check_and_record<R>(&self, path: &str, request: &R) -> Result<Decision>
86 where
87 K: Key<R>,
88 {
89 let config = self.get_config(path);
90
91 let Some(quota) = config.map(|c| &c.quota).or(self.default_quota.as_ref()) else {
92 return Ok(Decision::allowed(crate::decision::RateLimitInfo::new(
94 u64::MAX,
95 u64::MAX,
96 std::time::Instant::now() + std::time::Duration::from_secs(3600),
97 std::time::Instant::now(),
98 )));
99 };
100
101 let base_key = self.key_extractor.extract(request).unwrap_or_else(|| "unknown".to_string());
103 let key = if let Some(suffix) = config.and_then(|c| c.key_suffix.as_ref()) {
104 format!("{}:{}", base_key, suffix)
105 } else {
106 format!("{}:{}", base_key, path)
107 };
108
109 self.algorithm
110 .check_and_record(&*self.storage, &key, quota)
111 .await
112 }
113
114 pub async fn check<R>(&self, path: &str, request: &R) -> Result<Decision>
116 where
117 K: Key<R>,
118 {
119 let config = self.get_config(path);
120
121 let Some(quota) = config.map(|c| &c.quota).or(self.default_quota.as_ref()) else {
122 return Ok(Decision::allowed(crate::decision::RateLimitInfo::new(
123 u64::MAX,
124 u64::MAX,
125 std::time::Instant::now() + std::time::Duration::from_secs(3600),
126 std::time::Instant::now(),
127 )));
128 };
129
130 let base_key = self.key_extractor.extract(request).unwrap_or_else(|| "unknown".to_string());
131 let key = if let Some(suffix) = config.and_then(|c| c.key_suffix.as_ref()) {
132 format!("{}:{}", base_key, suffix)
133 } else {
134 format!("{}:{}", base_key, path)
135 };
136
137 self.algorithm.check(&*self.storage, &key, quota).await
138 }
139
140 fn get_config(&self, path: &str) -> Option<&RouteConfig> {
142 if let Some(config) = self.routes.get(path) {
144 return Some(config);
145 }
146
147 for (pattern, config) in &self.patterns {
149 if pattern_matches(pattern, path) {
150 return Some(config);
151 }
152 }
153
154 None
155 }
156
157 pub async fn reset(&self, key: &str) -> Result<()> {
159 self.algorithm.reset(&*self.storage, key).await
160 }
161}
162
163fn pattern_matches(pattern: &str, path: &str) -> bool {
169 let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
170 let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
171
172 let mut pi = 0; let mut pa = 0; while pi < pattern_parts.len() && pa < path_parts.len() {
176 let p = pattern_parts[pi];
177
178 if p == "**" {
179 return true;
181 } else if p == "*" {
182 pi += 1;
184 pa += 1;
185 } else if p == path_parts[pa] {
186 pi += 1;
188 pa += 1;
189 } else {
190 return false;
191 }
192 }
193
194 pi == pattern_parts.len() && pa == path_parts.len()
196}
197
198pub struct RateLimitManagerBuilder<K> {
200 default_quota: Option<Quota>,
201 routes: HashMap<String, RouteConfig>,
202 patterns: Vec<(String, RouteConfig)>,
203 key_extractor: Option<K>,
204}
205
206impl<K> Default for RateLimitManagerBuilder<K> {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl<K> RateLimitManagerBuilder<K> {
213 pub fn new() -> Self {
215 Self {
216 default_quota: None,
217 routes: HashMap::new(),
218 patterns: Vec::new(),
219 key_extractor: None,
220 }
221 }
222
223 pub fn default_quota(mut self, quota: Quota) -> Self {
225 self.default_quota = Some(quota);
226 self
227 }
228
229 pub fn route(mut self, path: impl Into<String>, config: impl Into<RouteConfig>) -> Self {
231 self.routes.insert(path.into(), config.into());
232 self
233 }
234
235 pub fn route_pattern(
239 mut self,
240 pattern: impl Into<String>,
241 config: impl Into<RouteConfig>,
242 ) -> Self {
243 self.patterns.push((pattern.into(), config.into()));
244 self
245 }
246
247 pub fn key_extractor(mut self, extractor: K) -> Self {
249 self.key_extractor = Some(extractor);
250 self
251 }
252
253 pub fn build<A, S>(self, algorithm: A, storage: S) -> RateLimitManager<A, S, K>
255 where
256 K: Default,
257 {
258 RateLimitManager {
259 algorithm,
260 storage: Arc::new(storage),
261 key_extractor: self.key_extractor.unwrap_or_default(),
262 default_quota: self.default_quota,
263 routes: self.routes,
264 patterns: self.patterns,
265 }
266 }
267
268 pub fn build_with_key<A, S>(
270 self,
271 algorithm: A,
272 storage: S,
273 key_extractor: K,
274 ) -> RateLimitManager<A, S, K> {
275 RateLimitManager {
276 algorithm,
277 storage: Arc::new(storage),
278 key_extractor,
279 default_quota: self.default_quota,
280 routes: self.routes,
281 patterns: self.patterns,
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_pattern_matches_exact() {
292 assert!(pattern_matches("/api/users", "/api/users"));
293 assert!(!pattern_matches("/api/users", "/api/posts"));
294 }
295
296 #[test]
297 fn test_pattern_matches_single_wildcard() {
298 assert!(pattern_matches("/api/*/posts", "/api/users/posts"));
299 assert!(pattern_matches("/api/*/posts", "/api/admins/posts"));
300 assert!(!pattern_matches("/api/*/posts", "/api/users/comments"));
301 }
302
303 #[test]
304 fn test_pattern_matches_double_wildcard() {
305 assert!(pattern_matches("/api/**", "/api/users"));
306 assert!(pattern_matches("/api/**", "/api/users/123/posts"));
307 assert!(!pattern_matches("/api/**", "/v2/api/users"));
308 }
309
310 #[test]
311 fn test_route_config_from_quota() {
312 let config: RouteConfig = Quota::per_minute(60).into();
313 assert_eq!(config.quota.max_requests(), 60);
314 assert!(config.key_suffix.is_none());
315 }
316}