1use core::hash::Hash;
5use core::time::Duration;
6use std::sync::Arc;
7
8use clock_lib::Clock;
9
10use crate::decision::Decision;
11#[cfg(feature = "runtime")]
12use crate::error::ThrottleError;
13use crate::limiter::{KeyedLimiter, Limiter};
14use crate::perkey::PerKey;
15
16pub struct Layered<K, E = K>
52where
53 K: Eq + Hash + Clone + Send + Sync + 'static,
54 E: Eq + Hash + Clone + Send + Sync + 'static,
55{
56 global: Option<Arc<dyn Limiter>>,
57 per_key: Option<Arc<dyn KeyedLimiter<K>>>,
58 per_endpoint: Option<Arc<dyn KeyedLimiter<E>>>,
59}
60
61impl<K, E> Layered<K, E>
62where
63 K: Eq + Hash + Clone + Send + Sync + 'static,
64 E: Eq + Hash + Clone + Send + Sync + 'static,
65{
66 #[must_use]
68 pub fn builder() -> LayeredBuilder<K, E> {
69 LayeredBuilder {
70 global: None,
71 per_key: None,
72 per_endpoint: None,
73 }
74 }
75
76 fn peek_scopes(&self, key: &K, endpoint: &E, cost: u32) -> Decision {
79 let mut wait: Option<Duration> = None;
80 let peeks = [
81 self.global.as_ref().map(|g| g.peek(cost)),
82 self.per_key.as_ref().map(|pk| pk.peek(key, cost)),
83 self.per_endpoint.as_ref().map(|pe| pe.peek(endpoint, cost)),
84 ];
85 for decision in peeks.into_iter().flatten() {
86 match decision {
87 Decision::Acquired => {}
88 Decision::Retry { after } => {
89 wait = Some(wait.map_or(after, |w| w.max(after)));
90 }
91 Decision::Impossible => return Decision::Impossible,
92 }
93 }
94 wait.map_or(Decision::Acquired, |after| Decision::Retry { after })
95 }
96
97 fn commit_scopes(&self, key: &K, endpoint: &E, cost: u32) -> bool {
100 if let Some(global) = &self.global {
101 if !global.acquire_cost(cost).is_acquired() {
102 return false;
103 }
104 }
105 if let Some(per_key) = &self.per_key {
106 if !per_key.try_acquire_with_cost(key, cost) {
107 return false;
108 }
109 }
110 if let Some(per_endpoint) = &self.per_endpoint {
111 if !per_endpoint.try_acquire_with_cost(endpoint, cost) {
112 return false;
113 }
114 }
115 true
116 }
117
118 fn decide(&self, key: &K, endpoint: &E, cost: u32) -> Decision {
123 match self.peek_scopes(key, endpoint, cost) {
124 Decision::Acquired => {}
125 other => return other,
126 }
127 if self.commit_scopes(key, endpoint, cost) {
128 return Decision::Acquired;
129 }
130 match self.peek_scopes(key, endpoint, cost) {
133 Decision::Acquired => Decision::Retry {
134 after: Duration::ZERO,
135 },
136 other => other,
137 }
138 }
139
140 #[must_use]
143 pub fn capacity(&self) -> u32 {
144 let caps = [
145 self.global.as_ref().map(|g| g.capacity()),
146 self.per_key.as_ref().map(|pk| pk.capacity()),
147 self.per_endpoint.as_ref().map(|pe| pe.capacity()),
148 ];
149 caps.into_iter().flatten().min().unwrap_or(u32::MAX)
150 }
151
152 #[inline]
170 #[must_use]
171 pub fn try_acquire(&self, key: &K, endpoint: &E) -> bool {
172 self.try_acquire_with_cost(key, endpoint, 1)
173 }
174
175 #[inline]
178 #[must_use]
179 pub fn try_acquire_with_cost(&self, key: &K, endpoint: &E, cost: u32) -> bool {
180 self.decide(key, endpoint, cost).is_acquired()
181 }
182
183 #[inline]
186 #[must_use]
187 pub fn peek(&self, key: &K, endpoint: &E, cost: u32) -> Decision {
188 self.peek_scopes(key, endpoint, cost)
189 }
190}
191
192#[cfg(feature = "runtime")]
193#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
194impl<K, E> Layered<K, E>
195where
196 K: Eq + Hash + Clone + Send + Sync + 'static,
197 E: Eq + Hash + Clone + Send + Sync + 'static,
198{
199 pub async fn acquire(&self, key: &K, endpoint: &E) -> Result<(), ThrottleError> {
206 self.acquire_with_cost(key, endpoint, 1).await
207 }
208
209 pub async fn acquire_with_cost(
217 &self,
218 key: &K,
219 endpoint: &E,
220 cost: u32,
221 ) -> Result<(), ThrottleError> {
222 loop {
223 match self.decide(key, endpoint, cost) {
224 Decision::Acquired => return Ok(()),
225 Decision::Impossible => {
226 return Err(ThrottleError::CostExceedsCapacity {
227 cost,
228 capacity: self.capacity(),
229 });
230 }
231 Decision::Retry { after } => crate::rt::sleep(after).await,
232 }
233 }
234 }
235}
236
237pub struct LayeredBuilder<K, E = K>
253where
254 K: Eq + Hash + Clone + Send + Sync + 'static,
255 E: Eq + Hash + Clone + Send + Sync + 'static,
256{
257 global: Option<Arc<dyn Limiter>>,
258 per_key: Option<Arc<dyn KeyedLimiter<K>>>,
259 per_endpoint: Option<Arc<dyn KeyedLimiter<E>>>,
260}
261
262impl<K, E> LayeredBuilder<K, E>
263where
264 K: Eq + Hash + Clone + Send + Sync + 'static,
265 E: Eq + Hash + Clone + Send + Sync + 'static,
266{
267 #[must_use]
271 pub fn global(mut self, limiter: impl Limiter + 'static) -> Self {
272 self.global = Some(Arc::new(limiter));
273 self
274 }
275
276 #[must_use]
279 pub fn per_key<C>(mut self, limiter: PerKey<K, C>) -> Self
280 where
281 C: Clock + Clone + 'static,
282 {
283 self.per_key = Some(Arc::new(limiter));
284 self
285 }
286
287 #[must_use]
290 pub fn per_endpoint<C>(mut self, limiter: PerKey<E, C>) -> Self
291 where
292 C: Clock + Clone + 'static,
293 {
294 self.per_endpoint = Some(Arc::new(limiter));
295 self
296 }
297
298 #[must_use]
300 pub fn build(self) -> Layered<K, E> {
301 Layered {
302 global: self.global,
303 per_key: self.per_key,
304 per_endpoint: self.per_endpoint,
305 }
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 #![allow(clippy::unwrap_used)]
312
313 use super::Layered;
314 use crate::perkey::PerKey;
315 use crate::throttle::Throttle;
316
317 fn assert_send_sync<T: Send + Sync>() {}
318
319 #[test]
320 fn test_layered_is_send_sync() {
321 assert_send_sync::<Layered<String>>();
322 assert_send_sync::<Layered<u64, String>>();
323 }
324
325 #[test]
326 fn test_request_must_clear_all_three_scopes() {
327 let layered = Layered::<&str>::builder()
328 .global(Throttle::per_second(100))
329 .per_key(PerKey::per_second(2))
330 .per_endpoint(PerKey::per_second(100))
331 .build();
332
333 assert!(layered.try_acquire(&"tenant", &"/x"));
334 assert!(layered.try_acquire(&"tenant", &"/x"));
335 assert!(!layered.try_acquire(&"tenant", &"/x"));
338 }
339
340 #[test]
341 fn test_keys_and_endpoints_are_independent() {
342 let layered = Layered::<&str>::builder()
343 .per_key(PerKey::per_second(1))
344 .per_endpoint(PerKey::per_second(1))
345 .build();
346
347 assert!(layered.try_acquire(&"a", &"/x"));
348 assert!(!layered.try_acquire(&"a", &"/x"));
350 assert!(!layered.try_acquire(&"b", &"/x"));
352 assert!(layered.try_acquire(&"b", &"/y"));
354 }
355
356 #[test]
357 fn test_global_scope_binds_across_keys() {
358 let layered = Layered::<&str>::builder()
359 .global(Throttle::per_second(2))
360 .per_key(PerKey::per_second(100))
361 .build();
362
363 assert!(layered.try_acquire(&"a", &"/x"));
365 assert!(layered.try_acquire(&"b", &"/x"));
366 assert!(!layered.try_acquire(&"c", &"/x"));
367 }
368
369 #[test]
370 fn test_no_scope_admits_everything() {
371 let layered = Layered::<&str>::builder().build();
372 assert!(layered.try_acquire(&"anything", &"/anywhere"));
373 assert_eq!(layered.capacity(), u32::MAX);
374 }
375
376 #[test]
377 fn test_no_token_spent_in_one_scope_when_another_blocks() {
378 let layered = Layered::<&str>::builder()
379 .global(Throttle::per_second(100))
380 .per_key(PerKey::per_second(1))
381 .build();
382
383 assert!(layered.try_acquire(&"a", &"/x")); assert!(!layered.try_acquire(&"a", &"/x"));
387 assert!(layered.try_acquire(&"b", &"/x"));
389 }
390
391 #[test]
392 fn test_capacity_is_the_smallest_scope() {
393 let layered = Layered::<&str>::builder()
394 .global(Throttle::per_second(1000))
395 .per_key(PerKey::per_second(100))
396 .per_endpoint(PerKey::per_second(25))
397 .build();
398 assert_eq!(layered.capacity(), 25);
399 }
400
401 #[cfg(feature = "runtime")]
402 #[tokio::test]
403 async fn test_acquire_errors_when_a_scope_can_never_admit() {
404 use crate::error::ThrottleError;
405
406 let layered = Layered::<&str>::builder()
407 .global(Throttle::per_second(1000))
408 .per_key(PerKey::per_second(5))
409 .build();
410 let err = layered.acquire_with_cost(&"a", &"/x", 9).await.unwrap_err();
411 assert!(matches!(
412 err,
413 ThrottleError::CostExceedsCapacity {
414 cost: 9,
415 capacity: 5
416 }
417 ));
418 }
419}