1use dashmap::DashMap;
20use std::sync::Arc;
21use tracing::{debug, trace, warn};
22
23use sentinel_common::ids::Scope;
24use sentinel_common::limits::Limits;
25use sentinel_config::FlattenedConfig;
26
27use crate::rate_limit::{
28 HeaderAccessor, RateLimitConfig, RateLimitManager, RateLimitResult, RateLimiterPool,
29};
30
31pub struct ScopedRateLimitManager {
36 scope_managers: DashMap<Scope, Arc<RateLimitManager>>,
38
39 scope_limits: DashMap<Scope, Limits>,
41
42 fallback_manager: Arc<RateLimitManager>,
44}
45
46impl ScopedRateLimitManager {
47 pub fn new() -> Self {
49 Self {
50 scope_managers: DashMap::new(),
51 scope_limits: DashMap::new(),
52 fallback_manager: Arc::new(RateLimitManager::new()),
53 }
54 }
55
56 pub fn from_flattened(config: &FlattenedConfig) -> Self {
58 let manager = Self::new();
59
60 for (scope, limits) in &config.scope_limits {
62 manager.set_scope_limits(scope.clone(), limits.clone());
63 }
64
65 manager
66 }
67
68 pub fn set_scope_limits(&self, scope: Scope, limits: Limits) {
70 if let Some(max_rps) = limits.max_requests_per_second_global {
74 let burst = max_rps * 10;
76 let scope_manager = RateLimitManager::with_global_limit(max_rps, burst);
77
78 debug!(
79 scope = ?scope,
80 max_rps = max_rps,
81 burst = burst,
82 "Configured rate limit for scope"
83 );
84
85 self.scope_managers
86 .insert(scope.clone(), Arc::new(scope_manager));
87 }
88
89 self.scope_limits.insert(scope, limits);
90 }
91
92 pub fn register_route(&self, scope: &Scope, route_id: &str, config: RateLimitConfig) {
94 let manager = self
95 .scope_managers
96 .entry(scope.clone())
97 .or_insert_with(|| Arc::new(RateLimitManager::new()));
98
99 manager.register_route(route_id, config);
100
101 trace!(
102 scope = ?scope,
103 route_id = route_id,
104 "Registered route rate limiter in scope"
105 );
106 }
107
108 pub fn check(
113 &self,
114 scope: &Scope,
115 route_id: &str,
116 client_ip: &str,
117 path: &str,
118 headers: Option<&impl HeaderAccessor>,
119 ) -> ScopedRateLimitResult {
120 for s in scope.chain() {
122 if let Some(manager) = self.scope_managers.get(&s) {
123 let result = manager.check(route_id, client_ip, path, headers);
124
125 if !result.allowed {
126 return ScopedRateLimitResult {
127 inner: result,
128 scope: s,
129 scope_limited: true,
130 };
131 }
132
133 if result.limit > 0 {
135 return ScopedRateLimitResult {
136 inner: result,
137 scope: s,
138 scope_limited: false,
139 };
140 }
141 }
142 }
143
144 let result = self
146 .fallback_manager
147 .check(route_id, client_ip, path, headers);
148
149 ScopedRateLimitResult {
150 inner: result,
151 scope: Scope::Global,
152 scope_limited: false,
153 }
154 }
155
156 pub fn is_enabled_for_scope(&self, scope: &Scope) -> bool {
158 for s in scope.chain() {
159 if let Some(manager) = self.scope_managers.get(&s) {
160 if manager.is_enabled() {
161 return true;
162 }
163 }
164 }
165 self.fallback_manager.is_enabled()
166 }
167
168 pub fn get_effective_limits(&self, scope: &Scope) -> Option<Limits> {
172 for s in scope.chain() {
173 if let Some(limits) = self.scope_limits.get(&s) {
174 return Some(limits.clone());
175 }
176 }
177 None
178 }
179
180 pub fn cleanup(&self) {
182 for entry in self.scope_managers.iter() {
183 entry.value().cleanup();
184 }
185 self.fallback_manager.cleanup();
186 }
187
188 pub fn scope_count(&self) -> usize {
190 self.scope_managers.len()
191 }
192
193 pub fn clear(&self) {
195 self.scope_managers.clear();
196 self.scope_limits.clear();
197 }
198
199 pub fn reload(&self, config: &FlattenedConfig) {
201 self.clear();
202
203 for (scope, limits) in &config.scope_limits {
204 self.set_scope_limits(scope.clone(), limits.clone());
205 }
206
207 debug!(
208 scope_count = self.scope_count(),
209 "Reloaded scoped rate limit configuration"
210 );
211 }
212}
213
214impl Default for ScopedRateLimitManager {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220#[derive(Debug, Clone)]
222pub struct ScopedRateLimitResult {
223 pub inner: RateLimitResult,
225
226 pub scope: Scope,
228
229 pub scope_limited: bool,
231}
232
233impl ScopedRateLimitResult {
234 pub fn allowed(&self) -> bool {
236 self.inner.allowed
237 }
238
239 pub fn namespace(&self) -> Option<&str> {
241 match &self.scope {
242 Scope::Global => None,
243 Scope::Namespace(ns) => Some(ns),
244 Scope::Service { namespace, .. } => Some(namespace),
245 }
246 }
247
248 pub fn service(&self) -> Option<&str> {
250 match &self.scope {
251 Scope::Service { service, .. } => Some(service),
252 _ => None,
253 }
254 }
255}
256
257#[cfg(test)]
262mod tests {
263 use super::*;
264
265 fn test_limits_with_rate_limit(rps: u32, _burst: u32) -> Limits {
266 let mut limits = Limits::default();
267 limits.max_requests_per_second_global = Some(rps);
269 limits
270 }
271
272 struct NoHeaders;
273 impl HeaderAccessor for NoHeaders {
274 fn get_header(&self, _name: &str) -> Option<String> {
275 None
276 }
277 }
278
279 #[test]
280 fn test_scope_isolation() {
281 let manager = ScopedRateLimitManager::new();
282
283 manager.set_scope_limits(Scope::Global, test_limits_with_rate_limit(10, 5));
285 manager.set_scope_limits(
286 Scope::Namespace("api".to_string()),
287 test_limits_with_rate_limit(5, 2),
288 );
289
290 let ns_scope = Scope::Namespace("api".to_string());
292 for _ in 0..5 {
293 let result = manager.check(&ns_scope, "route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
294 assert!(result.allowed());
295 }
296
297 let result = manager.check(&ns_scope, "route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
299 assert!(!result.allowed());
300 assert!(matches!(result.scope, Scope::Namespace(_)));
301
302 let other_ns = Scope::Namespace("other".to_string());
304 let result = manager.check(&other_ns, "route", "127.0.0.2", "/", Option::<&NoHeaders>::None);
305 assert!(result.allowed());
306 }
307
308 #[test]
309 fn test_scope_chain_fallback() {
310 let manager = ScopedRateLimitManager::new();
311
312 manager.set_scope_limits(Scope::Global, test_limits_with_rate_limit(3, 1));
314
315 let svc_scope = Scope::Service {
317 namespace: "api".to_string(),
318 service: "payments".to_string(),
319 };
320
321 for _ in 0..3 {
322 let result = manager.check(&svc_scope, "route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
323 assert!(result.allowed());
324 }
325
326 let result = manager.check(&svc_scope, "route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
328 assert!(!result.allowed());
329 assert_eq!(result.scope, Scope::Global);
330 }
331
332 #[test]
333 fn test_service_scope_limits() {
334 let manager = ScopedRateLimitManager::new();
335
336 let svc_scope = Scope::Service {
338 namespace: "api".to_string(),
339 service: "payments".to_string(),
340 };
341 manager.set_scope_limits(svc_scope.clone(), test_limits_with_rate_limit(2, 1));
342
343 let result1 = manager.check(&svc_scope, "route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
345 let result2 = manager.check(&svc_scope, "route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
346 assert!(result1.allowed());
347 assert!(result2.allowed());
348
349 let result3 = manager.check(&svc_scope, "route", "127.0.0.1", "/", Option::<&NoHeaders>::None);
351 assert!(!result3.allowed());
352 assert!(matches!(result3.scope, Scope::Service { .. }));
353 }
354
355 #[test]
356 fn test_effective_limits() {
357 let manager = ScopedRateLimitManager::new();
358
359 manager.set_scope_limits(Scope::Global, test_limits_with_rate_limit(100, 50));
360 manager.set_scope_limits(
361 Scope::Namespace("api".to_string()),
362 test_limits_with_rate_limit(50, 25),
363 );
364
365 let svc_scope = Scope::Service {
367 namespace: "api".to_string(),
368 service: "payments".to_string(),
369 };
370 let limits = manager.get_effective_limits(&svc_scope).unwrap();
371 assert_eq!(limits.max_requests_per_second_global.unwrap(), 50);
372
373 let other_ns = Scope::Namespace("other".to_string());
375 let limits = manager.get_effective_limits(&other_ns).unwrap();
376 assert_eq!(limits.max_requests_per_second_global.unwrap(), 100);
377 }
378
379 #[test]
380 fn test_is_enabled_for_scope() {
381 let manager = ScopedRateLimitManager::new();
382
383 assert!(!manager.is_enabled_for_scope(&Scope::Global));
385
386 manager.set_scope_limits(
388 Scope::Namespace("api".to_string()),
389 test_limits_with_rate_limit(10, 5),
390 );
391
392 assert!(manager.is_enabled_for_scope(&Scope::Namespace("api".to_string())));
394 assert!(manager.is_enabled_for_scope(&Scope::Service {
395 namespace: "api".to_string(),
396 service: "payments".to_string(),
397 }));
398
399 assert!(!manager.is_enabled_for_scope(&Scope::Namespace("other".to_string())));
401 }
402
403 #[test]
404 fn test_reload() {
405 let manager = ScopedRateLimitManager::new();
406 manager.set_scope_limits(Scope::Global, test_limits_with_rate_limit(10, 5));
407
408 assert_eq!(manager.scope_count(), 1);
409
410 let mut new_config = FlattenedConfig::new();
412 new_config.scope_limits.insert(
413 Scope::Namespace("api".to_string()),
414 test_limits_with_rate_limit(20, 10),
415 );
416 new_config.scope_limits.insert(
417 Scope::Namespace("web".to_string()),
418 test_limits_with_rate_limit(30, 15),
419 );
420
421 manager.reload(&new_config);
422
423 assert_eq!(manager.scope_count(), 2);
424 assert!(manager.is_enabled_for_scope(&Scope::Namespace("api".to_string())));
425 assert!(manager.is_enabled_for_scope(&Scope::Namespace("web".to_string())));
426 }
427}