1use dashmap::DashMap;
20use std::sync::Arc;
21use tracing::{debug, trace, warn};
22
23use sentinel_common::ids::Scope;
24use sentinel_common::limits::Limits;
25use sentinel_config::FlattenedConfig;
26
27use crate::rate_limit::{
28 HeaderAccessor, RateLimitConfig, RateLimitManager, RateLimitResult, RateLimiterPool,
29};
30
31pub struct ScopedRateLimitManager {
36 scope_managers: DashMap<Scope, Arc<RateLimitManager>>,
38
39 scope_limits: DashMap<Scope, Limits>,
41
42 fallback_manager: Arc<RateLimitManager>,
44}
45
46impl ScopedRateLimitManager {
47 pub fn new() -> Self {
49 Self {
50 scope_managers: DashMap::new(),
51 scope_limits: DashMap::new(),
52 fallback_manager: Arc::new(RateLimitManager::new()),
53 }
54 }
55
56 pub fn from_flattened(config: &FlattenedConfig) -> Self {
58 let manager = Self::new();
59
60 for (scope, limits) in &config.scope_limits {
62 manager.set_scope_limits(scope.clone(), limits.clone());
63 }
64
65 manager
66 }
67
68 pub fn set_scope_limits(&self, scope: Scope, limits: Limits) {
70 if let Some(max_rps) = limits.max_requests_per_second_global {
74 let burst = max_rps * 10;
76 let scope_manager = RateLimitManager::with_global_limit(max_rps, burst);
77
78 debug!(
79 scope = ?scope,
80 max_rps = max_rps,
81 burst = burst,
82 "Configured rate limit for scope"
83 );
84
85 self.scope_managers
86 .insert(scope.clone(), Arc::new(scope_manager));
87 }
88
89 self.scope_limits.insert(scope, limits);
90 }
91
92 pub fn register_route(&self, scope: &Scope, route_id: &str, config: RateLimitConfig) {
94 let manager = self
95 .scope_managers
96 .entry(scope.clone())
97 .or_insert_with(|| Arc::new(RateLimitManager::new()));
98
99 manager.register_route(route_id, config);
100
101 trace!(
102 scope = ?scope,
103 route_id = route_id,
104 "Registered route rate limiter in scope"
105 );
106 }
107
108 pub fn check(
113 &self,
114 scope: &Scope,
115 route_id: &str,
116 client_ip: &str,
117 path: &str,
118 headers: Option<&impl HeaderAccessor>,
119 ) -> ScopedRateLimitResult {
120 for s in scope.chain() {
122 if let Some(manager) = self.scope_managers.get(&s) {
123 let result = manager.check(route_id, client_ip, path, headers);
124
125 if !result.allowed {
126 return ScopedRateLimitResult {
127 inner: result,
128 scope: s,
129 scope_limited: true,
130 };
131 }
132
133 if result.limit > 0 {
135 return ScopedRateLimitResult {
136 inner: result,
137 scope: s,
138 scope_limited: false,
139 };
140 }
141 }
142 }
143
144 let result = self
146 .fallback_manager
147 .check(route_id, client_ip, path, headers);
148
149 ScopedRateLimitResult {
150 inner: result,
151 scope: Scope::Global,
152 scope_limited: false,
153 }
154 }
155
156 pub fn is_enabled_for_scope(&self, scope: &Scope) -> bool {
158 for s in scope.chain() {
159 if let Some(manager) = self.scope_managers.get(&s) {
160 if manager.is_enabled() {
161 return true;
162 }
163 }
164 }
165 self.fallback_manager.is_enabled()
166 }
167
168 pub fn get_effective_limits(&self, scope: &Scope) -> Option<Limits> {
172 for s in scope.chain() {
173 if let Some(limits) = self.scope_limits.get(&s) {
174 return Some(limits.clone());
175 }
176 }
177 None
178 }
179
180 pub fn cleanup(&self) {
182 for entry in self.scope_managers.iter() {
183 entry.value().cleanup();
184 }
185 self.fallback_manager.cleanup();
186 }
187
188 pub fn scope_count(&self) -> usize {
190 self.scope_managers.len()
191 }
192
193 pub fn clear(&self) {
195 self.scope_managers.clear();
196 self.scope_limits.clear();
197 }
198
199 pub fn reload(&self, config: &FlattenedConfig) {
201 self.clear();
202
203 for (scope, limits) in &config.scope_limits {
204 self.set_scope_limits(scope.clone(), limits.clone());
205 }
206
207 debug!(
208 scope_count = self.scope_count(),
209 "Reloaded scoped rate limit configuration"
210 );
211 }
212}
213
214impl Default for ScopedRateLimitManager {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct ScopedRateLimitResult {
223 pub inner: RateLimitResult,
225
226 pub scope: Scope,
228
229 pub scope_limited: bool,
231}
232
233impl ScopedRateLimitResult {
234 pub fn allowed(&self) -> bool {
236 self.inner.allowed
237 }
238
239 pub fn namespace(&self) -> Option<&str> {
241 match &self.scope {
242 Scope::Global => None,
243 Scope::Namespace(ns) => Some(ns),
244 Scope::Service { namespace, .. } => Some(namespace),
245 }
246 }
247
248 pub fn service(&self) -> Option<&str> {
250 match &self.scope {
251 Scope::Service { service, .. } => Some(service),
252 _ => None,
253 }
254 }
255}
256
257#[cfg(test)]
262mod tests {
263 use super::*;
264
265 fn test_limits_with_rate_limit(rps: u32, _burst: u32) -> Limits {
266 Limits {
267 max_requests_per_second_global: Some(rps),
268 ..Limits::default()
269 }
270 }
271
272 struct NoHeaders;
273 impl HeaderAccessor for NoHeaders {
274 fn get_header(&self, _name: &str) -> Option<String> {
275 None
276 }
277 }
278
279 #[test]
280 fn test_scope_isolation() {
281 let manager = ScopedRateLimitManager::new();
282
283 manager.set_scope_limits(Scope::Global, test_limits_with_rate_limit(10, 5));
285 manager.set_scope_limits(
286 Scope::Namespace("api".to_string()),
287 test_limits_with_rate_limit(5, 2),
288 );
289
290 let ns_scope = Scope::Namespace("api".to_string());
292 for _ in 0..5 {
293 let result = manager.check(
294 &ns_scope,
295 "route",
296 "127.0.0.1",
297 "/",
298 Option::<&NoHeaders>::None,
299 );
300 assert!(result.allowed());
301 }
302
303 let result = manager.check(
305 &ns_scope,
306 "route",
307 "127.0.0.1",
308 "/",
309 Option::<&NoHeaders>::None,
310 );
311 assert!(!result.allowed());
312 assert!(matches!(result.scope, Scope::Namespace(_)));
313
314 let other_ns = Scope::Namespace("other".to_string());
316 let result = manager.check(
317 &other_ns,
318 "route",
319 "127.0.0.2",
320 "/",
321 Option::<&NoHeaders>::None,
322 );
323 assert!(result.allowed());
324 }
325
326 #[test]
327 fn test_scope_chain_fallback() {
328 let manager = ScopedRateLimitManager::new();
329
330 manager.set_scope_limits(Scope::Global, test_limits_with_rate_limit(3, 1));
332
333 let svc_scope = Scope::Service {
335 namespace: "api".to_string(),
336 service: "payments".to_string(),
337 };
338
339 for _ in 0..3 {
340 let result = manager.check(
341 &svc_scope,
342 "route",
343 "127.0.0.1",
344 "/",
345 Option::<&NoHeaders>::None,
346 );
347 assert!(result.allowed());
348 }
349
350 let result = manager.check(
352 &svc_scope,
353 "route",
354 "127.0.0.1",
355 "/",
356 Option::<&NoHeaders>::None,
357 );
358 assert!(!result.allowed());
359 assert_eq!(result.scope, Scope::Global);
360 }
361
362 #[test]
363 fn test_service_scope_limits() {
364 let manager = ScopedRateLimitManager::new();
365
366 let svc_scope = Scope::Service {
368 namespace: "api".to_string(),
369 service: "payments".to_string(),
370 };
371 manager.set_scope_limits(svc_scope.clone(), test_limits_with_rate_limit(2, 1));
372
373 let result1 = manager.check(
375 &svc_scope,
376 "route",
377 "127.0.0.1",
378 "/",
379 Option::<&NoHeaders>::None,
380 );
381 let result2 = manager.check(
382 &svc_scope,
383 "route",
384 "127.0.0.1",
385 "/",
386 Option::<&NoHeaders>::None,
387 );
388 assert!(result1.allowed());
389 assert!(result2.allowed());
390
391 let result3 = manager.check(
393 &svc_scope,
394 "route",
395 "127.0.0.1",
396 "/",
397 Option::<&NoHeaders>::None,
398 );
399 assert!(!result3.allowed());
400 assert!(matches!(result3.scope, Scope::Service { .. }));
401 }
402
403 #[test]
404 fn test_effective_limits() {
405 let manager = ScopedRateLimitManager::new();
406
407 manager.set_scope_limits(Scope::Global, test_limits_with_rate_limit(100, 50));
408 manager.set_scope_limits(
409 Scope::Namespace("api".to_string()),
410 test_limits_with_rate_limit(50, 25),
411 );
412
413 let svc_scope = Scope::Service {
415 namespace: "api".to_string(),
416 service: "payments".to_string(),
417 };
418 let limits = manager.get_effective_limits(&svc_scope).unwrap();
419 assert_eq!(limits.max_requests_per_second_global.unwrap(), 50);
420
421 let other_ns = Scope::Namespace("other".to_string());
423 let limits = manager.get_effective_limits(&other_ns).unwrap();
424 assert_eq!(limits.max_requests_per_second_global.unwrap(), 100);
425 }
426
427 #[test]
428 fn test_is_enabled_for_scope() {
429 let manager = ScopedRateLimitManager::new();
430
431 assert!(!manager.is_enabled_for_scope(&Scope::Global));
433
434 manager.set_scope_limits(
436 Scope::Namespace("api".to_string()),
437 test_limits_with_rate_limit(10, 5),
438 );
439
440 assert!(manager.is_enabled_for_scope(&Scope::Namespace("api".to_string())));
442 assert!(manager.is_enabled_for_scope(&Scope::Service {
443 namespace: "api".to_string(),
444 service: "payments".to_string(),
445 }));
446
447 assert!(!manager.is_enabled_for_scope(&Scope::Namespace("other".to_string())));
449 }
450
451 #[test]
452 fn test_reload() {
453 let manager = ScopedRateLimitManager::new();
454 manager.set_scope_limits(Scope::Global, test_limits_with_rate_limit(10, 5));
455
456 assert_eq!(manager.scope_count(), 1);
457
458 let mut new_config = FlattenedConfig::new();
460 new_config.scope_limits.insert(
461 Scope::Namespace("api".to_string()),
462 test_limits_with_rate_limit(20, 10),
463 );
464 new_config.scope_limits.insert(
465 Scope::Namespace("web".to_string()),
466 test_limits_with_rate_limit(30, 15),
467 );
468
469 manager.reload(&new_config);
470
471 assert_eq!(manager.scope_count(), 2);
472 assert!(manager.is_enabled_for_scope(&Scope::Namespace("api".to_string())));
473 assert!(manager.is_enabled_for_scope(&Scope::Namespace("web".to_string())));
474 }
475}