Skip to main content

throttle_net/
layered.rs

1//! Layered scopes: a request must clear a global, a per-key, and a per-endpoint
2//! limit, in that order.
3
4use core::hash::Hash;
5use core::time::Duration;
6use std::sync::Arc;
7
8use clock_lib::Clock;
9
10use crate::decision::Decision;
11#[cfg(feature = "runtime")]
12use crate::error::ThrottleError;
13use crate::limiter::{KeyedLimiter, Limiter};
14use crate::perkey::PerKey;
15
16/// Several scopes of limiting stacked so a request must clear every one.
17///
18/// A real service limits at more than one granularity at once: an overall
19/// ceiling for the whole process (the *global* scope), a fair share per caller
20/// (the *per-key* scope, keyed by tenant or token), and a ceiling per route (the
21/// *per-endpoint* scope). `Layered` checks the scopes that are configured and
22/// admits a request only when all of them can afford it — applied atomically by
23/// the same peek-then-commit rule the other composites use, so a request never
24/// spends in one scope when another blocks it.
25///
26/// The two key types are independent: a numeric tenant id and a string endpoint,
27/// say. They default to the same type for the common all-string case.
28///
29/// Build one with [`Layered::builder`]. Every scope is optional; a builder with
30/// none admits everything.
31///
32/// # Examples
33///
34/// ```
35/// # async fn run() -> Result<(), throttle_net::ThrottleError> {
36/// use throttle_net::{Layered, PerKey, Throttle};
37///
38/// // 1000/s overall, 100/s per tenant, 50/s per endpoint.
39/// let layered = Layered::<String>::builder()
40///     .global(Throttle::per_second(1000))
41///     .per_key(PerKey::per_second(100))
42///     .per_endpoint(PerKey::per_second(50))
43///     .build();
44///
45/// layered
46///     .acquire(&"tenant:42".to_string(), &"/v1/chat".to_string())
47///     .await?;
48/// # Ok(())
49/// # }
50/// ```
51pub struct Layered<K, E = K>
52where
53    K: Eq + Hash + Clone + Send + Sync + 'static,
54    E: Eq + Hash + Clone + Send + Sync + 'static,
55{
56    global: Option<Arc<dyn Limiter>>,
57    per_key: Option<Arc<dyn KeyedLimiter<K>>>,
58    per_endpoint: Option<Arc<dyn KeyedLimiter<E>>>,
59}
60
61impl<K, E> Layered<K, E>
62where
63    K: Eq + Hash + Clone + Send + Sync + 'static,
64    E: Eq + Hash + Clone + Send + Sync + 'static,
65{
66    /// Starts building a layered limiter.
67    #[must_use]
68    pub fn builder() -> LayeredBuilder<K, E> {
69        LayeredBuilder {
70            global: None,
71            per_key: None,
72            per_endpoint: None,
73        }
74    }
75
76    /// Aggregate, non-consuming peek across every configured scope: the longest
77    /// wait, or [`Decision::Impossible`] if any scope can never grant `cost`.
78    fn peek_scopes(&self, key: &K, endpoint: &E, cost: u32) -> Decision {
79        let mut wait: Option<Duration> = None;
80        let peeks = [
81            self.global.as_ref().map(|g| g.peek(cost)),
82            self.per_key.as_ref().map(|pk| pk.peek(key, cost)),
83            self.per_endpoint.as_ref().map(|pe| pe.peek(endpoint, cost)),
84        ];
85        for decision in peeks.into_iter().flatten() {
86            match decision {
87                Decision::Acquired => {}
88                Decision::Retry { after } => {
89                    wait = Some(wait.map_or(after, |w| w.max(after)));
90                }
91                Decision::Impossible => return Decision::Impossible,
92            }
93        }
94        wait.map_or(Decision::Acquired, |after| Decision::Retry { after })
95    }
96
97    /// Commits `cost` to each configured scope in order, short-circuiting on the
98    /// first refusal. Returns whether every scope granted.
99    fn commit_scopes(&self, key: &K, endpoint: &E, cost: u32) -> bool {
100        if let Some(global) = &self.global {
101            if !global.acquire_cost(cost).is_acquired() {
102                return false;
103            }
104        }
105        if let Some(per_key) = &self.per_key {
106            if !per_key.try_acquire_with_cost(key, cost) {
107                return false;
108            }
109        }
110        if let Some(per_endpoint) = &self.per_endpoint {
111            if !per_endpoint.try_acquire_with_cost(endpoint, cost) {
112                return false;
113            }
114        }
115        true
116    }
117
118    /// The consuming core: peek every scope, and only if all would grant, commit
119    /// each. A commit that loses a race after the peek leaves the
120    /// already-committed scopes debited and reports a retry (never an
121    /// over-admission), exactly as the other composites do.
122    fn decide(&self, key: &K, endpoint: &E, cost: u32) -> Decision {
123        match self.peek_scopes(key, endpoint, cost) {
124            Decision::Acquired => {}
125            other => return other,
126        }
127        if self.commit_scopes(key, endpoint, cost) {
128            return Decision::Acquired;
129        }
130        // Lost a race after the peek. Re-peek for an accurate wait; if it now
131        // reads grantable again, nudge the caller to retry immediately.
132        match self.peek_scopes(key, endpoint, cost) {
133            Decision::Acquired => Decision::Retry {
134                after: Duration::ZERO,
135            },
136            other => other,
137        }
138    }
139
140    /// The binding capacity across configured scopes: the smallest, since that is
141    /// the first ceiling a request hits. No scopes means unbounded ([`u32::MAX`]).
142    #[must_use]
143    pub fn capacity(&self) -> u32 {
144        let caps = [
145            self.global.as_ref().map(|g| g.capacity()),
146            self.per_key.as_ref().map(|pk| pk.capacity()),
147            self.per_endpoint.as_ref().map(|pe| pe.capacity()),
148        ];
149        caps.into_iter().flatten().min().unwrap_or(u32::MAX)
150    }
151
152    /// Attempts to admit one request for `(key, endpoint)` without waiting.
153    ///
154    /// # Examples
155    ///
156    /// ```
157    /// use throttle_net::{Layered, PerKey, Throttle};
158    ///
159    /// let layered = Layered::<&str>::builder()
160    ///     .global(Throttle::per_second(2))
161    ///     .per_key(PerKey::per_second(1))
162    ///     .build();
163    ///
164    /// assert!(layered.try_acquire(&"a", &"/x"));
165    /// // The per-key scope for "a" is now empty, even though the global has room.
166    /// assert!(!layered.try_acquire(&"a", &"/x"));
167    /// assert!(layered.try_acquire(&"b", &"/x")); // a different key is independent
168    /// ```
169    #[inline]
170    #[must_use]
171    pub fn try_acquire(&self, key: &K, endpoint: &E) -> bool {
172        self.try_acquire_with_cost(key, endpoint, 1)
173    }
174
175    /// Attempts to admit a request of weight `cost` for `(key, endpoint)` without
176    /// waiting.
177    #[inline]
178    #[must_use]
179    pub fn try_acquire_with_cost(&self, key: &K, endpoint: &E, cost: u32) -> bool {
180        self.decide(key, endpoint, cost).is_acquired()
181    }
182
183    /// Reports whether a request for `(key, endpoint)` would be admitted now,
184    /// without taking anything.
185    #[inline]
186    #[must_use]
187    pub fn peek(&self, key: &K, endpoint: &E, cost: u32) -> Decision {
188        self.peek_scopes(key, endpoint, cost)
189    }
190}
191
192#[cfg(feature = "runtime")]
193#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
194impl<K, E> Layered<K, E>
195where
196    K: Eq + Hash + Clone + Send + Sync + 'static,
197    E: Eq + Hash + Clone + Send + Sync + 'static,
198{
199    /// Admits one request for `(key, endpoint)`, waiting until every scope can.
200    ///
201    /// # Errors
202    ///
203    /// Returns [`ThrottleError::CostExceedsCapacity`] if some scope's capacity is
204    /// too small to ever admit the request.
205    pub async fn acquire(&self, key: &K, endpoint: &E) -> Result<(), ThrottleError> {
206        self.acquire_with_cost(key, endpoint, 1).await
207    }
208
209    /// Admits a request of weight `cost` for `(key, endpoint)`, waiting until
210    /// every scope can.
211    ///
212    /// # Errors
213    ///
214    /// Returns [`ThrottleError::CostExceedsCapacity`] if some scope can never
215    /// admit `cost`.
216    pub async fn acquire_with_cost(
217        &self,
218        key: &K,
219        endpoint: &E,
220        cost: u32,
221    ) -> Result<(), ThrottleError> {
222        loop {
223            match self.decide(key, endpoint, cost) {
224                Decision::Acquired => return Ok(()),
225                Decision::Impossible => {
226                    return Err(ThrottleError::CostExceedsCapacity {
227                        cost,
228                        capacity: self.capacity(),
229                    });
230                }
231                Decision::Retry { after } => crate::rt::sleep(after).await,
232            }
233        }
234    }
235}
236
237/// Builder for a [`Layered`] limiter.
238///
239/// Set any subset of the three scopes; omitted scopes simply do not constrain.
240///
241/// # Examples
242///
243/// ```
244/// use throttle_net::{Layered, PerKey, Throttle};
245///
246/// let layered = Layered::<u64, String>::builder()
247///     .global(Throttle::per_second(1000))
248///     .per_key(PerKey::per_second(100))   // keyed by numeric tenant id
249///     .build();
250/// # let _ = layered;
251/// ```
252pub struct LayeredBuilder<K, E = K>
253where
254    K: Eq + Hash + Clone + Send + Sync + 'static,
255    E: Eq + Hash + Clone + Send + Sync + 'static,
256{
257    global: Option<Arc<dyn Limiter>>,
258    per_key: Option<Arc<dyn KeyedLimiter<K>>>,
259    per_endpoint: Option<Arc<dyn KeyedLimiter<E>>>,
260}
261
262impl<K, E> LayeredBuilder<K, E>
263where
264    K: Eq + Hash + Clone + Send + Sync + 'static,
265    E: Eq + Hash + Clone + Send + Sync + 'static,
266{
267    /// Sets the global scope: one limiter shared by every request. Any
268    /// [`Limiter`] works, so the global ceiling can itself be a
269    /// [`Hybrid`](crate::Hybrid).
270    #[must_use]
271    pub fn global(mut self, limiter: impl Limiter + 'static) -> Self {
272        self.global = Some(Arc::new(limiter));
273        self
274    }
275
276    /// Sets the per-key scope: independent state per caller key. Accepts a
277    /// [`PerKey`] built on any clock.
278    #[must_use]
279    pub fn per_key<C>(mut self, limiter: PerKey<K, C>) -> Self
280    where
281        C: Clock + Clone + 'static,
282    {
283        self.per_key = Some(Arc::new(limiter));
284        self
285    }
286
287    /// Sets the per-endpoint scope: independent state per endpoint. Accepts a
288    /// [`PerKey`] built on any clock.
289    #[must_use]
290    pub fn per_endpoint<C>(mut self, limiter: PerKey<E, C>) -> Self
291    where
292        C: Clock + Clone + 'static,
293    {
294        self.per_endpoint = Some(Arc::new(limiter));
295        self
296    }
297
298    /// Builds the [`Layered`] limiter.
299    #[must_use]
300    pub fn build(self) -> Layered<K, E> {
301        Layered {
302            global: self.global,
303            per_key: self.per_key,
304            per_endpoint: self.per_endpoint,
305        }
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    #![allow(clippy::unwrap_used)]
312
313    use super::Layered;
314    use crate::perkey::PerKey;
315    use crate::throttle::Throttle;
316
317    fn assert_send_sync<T: Send + Sync>() {}
318
319    #[test]
320    fn test_layered_is_send_sync() {
321        assert_send_sync::<Layered<String>>();
322        assert_send_sync::<Layered<u64, String>>();
323    }
324
325    #[test]
326    fn test_request_must_clear_all_three_scopes() {
327        let layered = Layered::<&str>::builder()
328            .global(Throttle::per_second(100))
329            .per_key(PerKey::per_second(2))
330            .per_endpoint(PerKey::per_second(100))
331            .build();
332
333        assert!(layered.try_acquire(&"tenant", &"/x"));
334        assert!(layered.try_acquire(&"tenant", &"/x"));
335        // The per-key scope (2/s) is exhausted though global and endpoint have
336        // room, so the layered limiter refuses.
337        assert!(!layered.try_acquire(&"tenant", &"/x"));
338    }
339
340    #[test]
341    fn test_keys_and_endpoints_are_independent() {
342        let layered = Layered::<&str>::builder()
343            .per_key(PerKey::per_second(1))
344            .per_endpoint(PerKey::per_second(1))
345            .build();
346
347        assert!(layered.try_acquire(&"a", &"/x"));
348        // Same key, same endpoint: both scopes exhausted.
349        assert!(!layered.try_acquire(&"a", &"/x"));
350        // A different key on the same endpoint is blocked by the endpoint scope.
351        assert!(!layered.try_acquire(&"b", &"/x"));
352        // A different key on a different endpoint clears both.
353        assert!(layered.try_acquire(&"b", &"/y"));
354    }
355
356    #[test]
357    fn test_global_scope_binds_across_keys() {
358        let layered = Layered::<&str>::builder()
359            .global(Throttle::per_second(2))
360            .per_key(PerKey::per_second(100))
361            .build();
362
363        // The global ceiling of 2 is spent across two different keys.
364        assert!(layered.try_acquire(&"a", &"/x"));
365        assert!(layered.try_acquire(&"b", &"/x"));
366        assert!(!layered.try_acquire(&"c", &"/x"));
367    }
368
369    #[test]
370    fn test_no_scope_admits_everything() {
371        let layered = Layered::<&str>::builder().build();
372        assert!(layered.try_acquire(&"anything", &"/anywhere"));
373        assert_eq!(layered.capacity(), u32::MAX);
374    }
375
376    #[test]
377    fn test_no_token_spent_in_one_scope_when_another_blocks() {
378        let layered = Layered::<&str>::builder()
379            .global(Throttle::per_second(100))
380            .per_key(PerKey::per_second(1))
381            .build();
382
383        assert!(layered.try_acquire(&"a", &"/x")); // global: 99 left, key a: 0 left
384        // Key "a" is blocked; the global scope must not be charged for the
385        // refused request.
386        assert!(!layered.try_acquire(&"a", &"/x"));
387        // Global still has room for other keys: 99 - 1 (for b) succeeds.
388        assert!(layered.try_acquire(&"b", &"/x"));
389    }
390
391    #[test]
392    fn test_capacity_is_the_smallest_scope() {
393        let layered = Layered::<&str>::builder()
394            .global(Throttle::per_second(1000))
395            .per_key(PerKey::per_second(100))
396            .per_endpoint(PerKey::per_second(25))
397            .build();
398        assert_eq!(layered.capacity(), 25);
399    }
400
401    #[cfg(feature = "runtime")]
402    #[tokio::test]
403    async fn test_acquire_errors_when_a_scope_can_never_admit() {
404        use crate::error::ThrottleError;
405
406        let layered = Layered::<&str>::builder()
407            .global(Throttle::per_second(1000))
408            .per_key(PerKey::per_second(5))
409            .build();
410        let err = layered.acquire_with_cost(&"a", &"/x", 9).await.unwrap_err();
411        assert!(matches!(
412            err,
413            ThrottleError::CostExceedsCapacity {
414                cost: 9,
415                capacity: 5
416            }
417        ));
418    }
419}