Skip to main content

throttle_net/
perkey.rs

1//! Independent throttling per key, with sharded state and bounded memory.
2
3use core::time::Duration;
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{PoisonError, RwLock, RwLockReadGuard, RwLockWriteGuard};
8
9use ahash::RandomState;
10use clock_lib::{Clock, Monotonic, SystemClock};
11
12use crate::decision::Decision;
13#[cfg(feature = "runtime")]
14use crate::error::ThrottleError;
15use crate::eviction::Eviction;
16use crate::limiter::Limiter;
17use crate::throttle::Throttle;
18
19/// Default shard count, before rounding up to a power of two. A handful of shards
20/// keeps unrelated keys from serialising without wasting memory on a small store.
21const DEFAULT_SHARDS: usize = 16;
22
23/// Per-key state: that key's throttle, plus a "last seen" stamp for eviction.
24///
25/// The stamp is monotonic milliseconds since the store's epoch when an idle TTL
26/// is configured (so idle expiry can compare against real time), and a per-shard
27/// logical sequence number otherwise — same least-recently-seen *ordering* for
28/// capacity eviction without a clock read on every access.
29struct Entry<C: Clock> {
30    throttle: Throttle<C>,
31    last_seen: AtomicU64,
32}
33
34/// One shard: an independently locked slice of the key space, hashed with
35/// `ahash` (fast, and collision-attack resistant via its random seed).
36struct Shard<K, C: Clock> {
37    map: RwLock<HashMap<K, Entry<C>, RandomState>>,
38    /// Per-shard counter handing out "last seen" stamps when no TTL is set.
39    /// Per-shard so unrelated shards never contend on it.
40    seq: AtomicU64,
41}
42
43impl<K, C: Clock> Shard<K, C> {
44    fn new() -> Self {
45        Self {
46            map: RwLock::new(HashMap::default()),
47            seq: AtomicU64::new(0),
48        }
49    }
50}
51
52/// A throttle that keeps independent state per key.
53///
54/// Each distinct key — a tenant, a user, an API token — gets its own token
55/// bucket with the same configured rate, so one noisy key cannot spend another's
56/// budget. State lives in a **sharded** concurrent map: keys are spread across
57/// shards by hash, each shard has its own lock, and an existing key's acquire
58/// takes only a shard *read* lock plus the bucket's own atomic accounting, so
59/// unrelated keys never contend and throughput scales with cores.
60///
61/// Memory is **bounded by eviction** (see [`Eviction`]): idle keys expire and a
62/// hard cap bounds the total, so a flood of unique keys reaches a ceiling instead
63/// of growing without limit. Eviction is lazy and per-shard — it runs while
64/// inserting a new key, never on a background thread or the steady-state path.
65/// The default policy is bounded ([`Eviction::default`]).
66///
67/// Like [`Throttle`], the headline [`acquire`](Self::acquire) **waits**; the
68/// `try_*` variants do not.
69///
70/// # Examples
71///
72/// ```
73/// # async fn run() -> Result<(), throttle_net::ThrottleError> {
74/// use throttle_net::PerKey;
75///
76/// // 100 requests per second, per tenant.
77/// let limiter: PerKey<String> = PerKey::per_second(100);
78/// limiter.acquire(&"tenant:42".to_string()).await?;
79/// # Ok(())
80/// # }
81/// ```
82pub struct PerKey<K, C = SystemClock>
83where
84    C: Clock,
85{
86    shards: Box<[Shard<K, C>]>,
87    /// `shard_count - 1`; the count is a power of two, so this masks a hash to a
88    /// shard index without a division.
89    shard_mask: u64,
90    hasher: RandomState,
91    eviction: Eviction,
92    amount: u32,
93    period: Duration,
94    clock: C,
95    epoch: Monotonic,
96}
97
98impl<K> PerKey<K, SystemClock>
99where
100    K: Eq + Hash + Clone + Send + Sync + 'static,
101{
102    /// Creates a per-key limiter giving every key `rate` units per second,
103    /// driven by the OS monotonic clock and the default [`Eviction`] policy.
104    ///
105    /// # Examples
106    ///
107    /// ```
108    /// use throttle_net::PerKey;
109    ///
110    /// let limiter: PerKey<u64> = PerKey::per_second(10);
111    /// assert!(limiter.try_acquire(&42));
112    /// ```
113    #[must_use]
114    pub fn per_second(rate: u32) -> Self {
115        Self::build(
116            rate,
117            Duration::from_secs(1),
118            SystemClock::new(),
119            DEFAULT_SHARDS,
120            Eviction::default(),
121        )
122    }
123
124    /// Creates a per-key limiter giving every key `amount` units every `period`.
125    ///
126    /// # Examples
127    ///
128    /// ```
129    /// use std::time::Duration;
130    /// use throttle_net::PerKey;
131    ///
132    /// // 1000 units per minute, per key.
133    /// let limiter: PerKey<String> = PerKey::per_duration(1000, Duration::from_secs(60));
134    /// # let _ = limiter;
135    /// ```
136    #[must_use]
137    pub fn per_duration(amount: u32, period: Duration) -> Self {
138        Self::build(
139            amount,
140            period,
141            SystemClock::new(),
142            DEFAULT_SHARDS,
143            Eviction::default(),
144        )
145    }
146}
147
148impl<K, C> PerKey<K, C>
149where
150    K: Eq + Hash + Clone + Send + Sync + 'static,
151    C: Clock + Clone,
152{
153    fn build(amount: u32, period: Duration, clock: C, shards: usize, eviction: Eviction) -> Self {
154        let shard_count = shards.max(1).next_power_of_two();
155        let shards = (0..shard_count)
156            .map(|_| Shard::new())
157            .collect::<Vec<_>>()
158            .into_boxed_slice();
159        let epoch = clock.now();
160        Self {
161            shards,
162            shard_mask: shard_count as u64 - 1,
163            hasher: RandomState::new(),
164            eviction,
165            amount,
166            period,
167            clock,
168            epoch,
169        }
170    }
171
172    /// Replaces the time source, for deterministic tests with a
173    /// [`ManualClock`](clock_lib::ManualClock). The store is rebuilt empty around
174    /// the new clock.
175    ///
176    /// # Examples
177    ///
178    /// ```
179    /// use std::sync::Arc;
180    /// use std::time::Duration;
181    /// use clock_lib::ManualClock;
182    /// use throttle_net::PerKey;
183    ///
184    /// let clock = Arc::new(ManualClock::new());
185    /// let limiter = PerKey::<&str>::per_second(1).with_clock(clock.clone());
186    ///
187    /// assert!(limiter.try_acquire(&"k"));
188    /// assert!(!limiter.try_acquire(&"k"));
189    /// clock.advance(Duration::from_secs(1));
190    /// assert!(limiter.try_acquire(&"k"));
191    /// ```
192    #[must_use]
193    pub fn with_clock<C2>(self, clock: C2) -> PerKey<K, C2>
194    where
195        C2: Clock + Clone,
196    {
197        PerKey::build(
198            self.amount,
199            self.period,
200            clock,
201            self.shards.len(),
202            self.eviction,
203        )
204    }
205
206    /// Sets the memory-bound policy (idle TTL and/or hard key cap).
207    ///
208    /// # Examples
209    ///
210    /// ```
211    /// use std::time::Duration;
212    /// use throttle_net::{Eviction, PerKey};
213    ///
214    /// let limiter: PerKey<String> = PerKey::per_second(100)
215    ///     .with_eviction(Eviction::capacity(50_000).with_idle(Duration::from_secs(300)));
216    /// # let _ = limiter;
217    /// ```
218    #[must_use]
219    pub fn with_eviction(mut self, eviction: Eviction) -> Self {
220        self.eviction = eviction;
221        self
222    }
223
224    /// Sets the shard count (rounded up to a power of two, at least one).
225    ///
226    /// More shards reduce contention between unrelated keys at the cost of a
227    /// little memory. The store is rebuilt empty.
228    ///
229    /// # Examples
230    ///
231    /// ```
232    /// use throttle_net::PerKey;
233    ///
234    /// let limiter: PerKey<u64> = PerKey::per_second(100).with_shards(64);
235    /// assert_eq!(limiter.shard_count(), 64);
236    /// ```
237    #[must_use]
238    pub fn with_shards(self, shards: usize) -> Self {
239        PerKey::build(self.amount, self.period, self.clock, shards, self.eviction)
240    }
241
242    /// The per-key capacity (burst ceiling): the configured `amount`.
243    #[inline]
244    #[must_use]
245    pub fn capacity(&self) -> u32 {
246        self.amount
247    }
248
249    /// The number of shards (a power of two).
250    #[inline]
251    #[must_use]
252    pub fn shard_count(&self) -> usize {
253        self.shards.len()
254    }
255
256    /// The number of keys with live state across all shards.
257    ///
258    /// A momentary, advisory snapshot — useful for tests and metrics, not a
259    /// synchronization point.
260    #[must_use]
261    pub fn len(&self) -> usize {
262        self.shards
263            .iter()
264            .map(|shard| read_guard(&shard.map).len())
265            .sum()
266    }
267
268    /// Returns `true` if no key currently has live state.
269    #[must_use]
270    pub fn is_empty(&self) -> bool {
271        self.shards
272            .iter()
273            .all(|shard| read_guard(&shard.map).is_empty())
274    }
275
276    /// Attempts to take one token for `key` without waiting.
277    ///
278    /// # Examples
279    ///
280    /// ```
281    /// use throttle_net::PerKey;
282    ///
283    /// let limiter: PerKey<&str> = PerKey::per_second(1);
284    /// assert!(limiter.try_acquire(&"a"));
285    /// assert!(!limiter.try_acquire(&"a"));
286    /// assert!(limiter.try_acquire(&"b")); // a different key is independent
287    /// ```
288    #[inline]
289    #[must_use]
290    pub fn try_acquire(&self, key: &K) -> bool {
291        self.try_acquire_with_cost(key, 1)
292    }
293
294    /// Attempts to take `cost` tokens for `key` without waiting.
295    #[inline]
296    #[must_use]
297    pub fn try_acquire_with_cost(&self, key: &K, cost: u32) -> bool {
298        self.decide(key, cost).is_acquired()
299    }
300
301    /// Reports whether `cost` tokens would be granted for `key` now, without
302    /// taking them — and without creating state for an unseen key.
303    #[inline]
304    #[must_use]
305    pub fn peek(&self, key: &K, cost: u32) -> Decision {
306        let shard = self.shard_for(key);
307        let guard = read_guard(&shard.map);
308        match guard.get(key) {
309            Some(entry) => entry.throttle.peek(cost),
310            // An unseen key would be a fresh, full bucket of capacity `amount`.
311            None if cost > self.amount => Decision::Impossible,
312            None => Decision::Acquired,
313        }
314    }
315
316    /// Current tokens available for `key`. An unseen key reports the full
317    /// capacity, since acquiring would create a fresh bucket.
318    #[must_use]
319    pub fn available(&self, key: &K) -> u32 {
320        let shard = self.shard_for(key);
321        let guard = read_guard(&shard.map);
322        guard
323            .get(key)
324            .map_or(self.amount, |entry| entry.throttle.available())
325    }
326
327    /// Builds a fresh throttle for a newly-seen key, sharing this store's clock.
328    #[inline]
329    fn make_throttle(&self) -> Throttle<C> {
330        Throttle::per_duration(self.amount, self.period).with_clock(self.clock.clone())
331    }
332
333    /// Milliseconds since the store's epoch, saturating.
334    #[inline]
335    fn now_ms(&self) -> u64 {
336        let elapsed = self.clock.now().saturating_duration_since(self.epoch);
337        u64::try_from(elapsed.as_millis()).unwrap_or(u64::MAX)
338    }
339
340    /// The "last seen" stamp for an access: real elapsed milliseconds when an
341    /// idle TTL is configured, otherwise a cheap per-shard sequence number.
342    #[inline]
343    fn stamp(&self, shard: &Shard<K, C>, now_ms: u64) -> u64 {
344        if self.eviction.idle_ttl().is_some() {
345            now_ms
346        } else {
347            shard.seq.fetch_add(1, Ordering::Relaxed)
348        }
349    }
350
351    #[inline]
352    fn shard_for(&self, key: &K) -> &Shard<K, C> {
353        let index = (self.hasher.hash_one(key) & self.shard_mask) as usize;
354        &self.shards[index]
355    }
356
357    /// The consuming core: acquire `cost` for `key`, creating its state on first
358    /// sight. Deducts on success.
359    fn decide(&self, key: &K, cost: u32) -> Decision {
360        let now_ms = self.now_ms();
361        let shard = self.shard_for(key);
362
363        // Fast path: a shared read lock is enough for an existing key. The
364        // bucket does its own atomic accounting, so concurrent acquires — of this
365        // key or any other in the shard — proceed without serialising.
366        {
367            let guard = read_guard(&shard.map);
368            if let Some(entry) = guard.get(key) {
369                entry
370                    .last_seen
371                    .store(self.stamp(shard, now_ms), Ordering::Relaxed);
372                return entry.throttle.acquire_cost(cost);
373            }
374        }
375
376        // Slow path: first-seen key. Take the write lock, re-check (another
377        // thread may have inserted in the gap), evict to make room, insert.
378        let mut guard = write_guard(&shard.map);
379        if let Some(entry) = guard.get(key) {
380            entry
381                .last_seen
382                .store(self.stamp(shard, now_ms), Ordering::Relaxed);
383            return entry.throttle.acquire_cost(cost);
384        }
385
386        let stamp = self.stamp(shard, now_ms);
387        self.evict_for_insert(&mut guard, now_ms);
388        let throttle = self.make_throttle();
389        let outcome = throttle.acquire_cost(cost);
390        let _ = guard.insert(
391            key.clone(),
392            Entry {
393                throttle,
394                last_seen: AtomicU64::new(stamp),
395            },
396        );
397        outcome
398    }
399
400    /// Makes room in a shard about to receive a new key: drop idle-expired keys,
401    /// then, if still at capacity, evict the least-recently-seen one. Runs under
402    /// the caller's write lock and touches only this shard.
403    fn evict_for_insert(&self, map: &mut HashMap<K, Entry<C>, RandomState>, now_ms: u64) {
404        if let Some(ttl) = self.eviction.idle_ttl() {
405            let ttl_ms = u64::try_from(ttl.as_millis()).unwrap_or(u64::MAX);
406            map.retain(|_, entry| {
407                now_ms.saturating_sub(entry.last_seen.load(Ordering::Relaxed)) < ttl_ms
408            });
409        }
410
411        if let Some(max) = self.eviction.max_keys() {
412            let per_shard_cap = max.div_ceil(self.shards.len()).max(1);
413            while map.len() >= per_shard_cap {
414                let victim = map
415                    .iter()
416                    .min_by_key(|(_, entry)| entry.last_seen.load(Ordering::Relaxed))
417                    .map(|(key, _)| key.clone());
418                match victim {
419                    Some(key) => {
420                        let _ = map.remove(&key);
421                    }
422                    None => break,
423                }
424            }
425        }
426    }
427}
428
429#[cfg(feature = "runtime")]
430#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
431impl<K, C> PerKey<K, C>
432where
433    K: Eq + Hash + Clone + Send + Sync + 'static,
434    C: Clock + Clone,
435{
436    /// Takes one token for `key`, waiting until one is available.
437    ///
438    /// # Errors
439    ///
440    /// Returns [`ThrottleError::CostExceedsCapacity`] when the per-key capacity
441    /// is zero.
442    ///
443    /// # Examples
444    ///
445    /// ```
446    /// # async fn run() -> Result<(), throttle_net::ThrottleError> {
447    /// use throttle_net::PerKey;
448    ///
449    /// let limiter: PerKey<String> = PerKey::per_second(100);
450    /// limiter.acquire(&"tenant:7".to_string()).await?;
451    /// # Ok(())
452    /// # }
453    /// ```
454    pub async fn acquire(&self, key: &K) -> Result<(), ThrottleError> {
455        self.acquire_with_cost(key, 1).await
456    }
457
458    /// Takes `cost` tokens for `key`, waiting until they are available.
459    ///
460    /// # Errors
461    ///
462    /// Returns [`ThrottleError::CostExceedsCapacity`] when `cost` exceeds the
463    /// per-key capacity, so the request can never be granted.
464    pub async fn acquire_with_cost(&self, key: &K, cost: u32) -> Result<(), ThrottleError> {
465        loop {
466            match self.decide(key, cost) {
467                Decision::Acquired => return Ok(()),
468                Decision::Impossible => {
469                    return Err(ThrottleError::CostExceedsCapacity {
470                        cost,
471                        capacity: self.amount,
472                    });
473                }
474                Decision::Retry { after } => crate::rt::sleep(after).await,
475            }
476        }
477    }
478}
479
480impl<K, C> crate::limiter::KeyedLimiter<K> for PerKey<K, C>
481where
482    K: Eq + Hash + Clone + Send + Sync + 'static,
483    C: Clock + Clone + 'static,
484{
485    #[inline]
486    fn peek(&self, key: &K, cost: u32) -> Decision {
487        PerKey::peek(self, key, cost)
488    }
489
490    #[inline]
491    fn try_acquire_with_cost(&self, key: &K, cost: u32) -> bool {
492        PerKey::try_acquire_with_cost(self, key, cost)
493    }
494
495    #[inline]
496    fn capacity(&self) -> u32 {
497        PerKey::capacity(self)
498    }
499}
500
501/// Recovers a read guard even if a previous holder panicked: a poisoned shard
502/// should keep limiting, not propagate a panic into every caller.
503fn read_guard<T>(lock: &RwLock<T>) -> RwLockReadGuard<'_, T> {
504    lock.read().unwrap_or_else(PoisonError::into_inner)
505}
506
507/// Recovers a write guard even if a previous holder panicked. See [`read_guard`].
508fn write_guard<T>(lock: &RwLock<T>) -> RwLockWriteGuard<'_, T> {
509    lock.write().unwrap_or_else(PoisonError::into_inner)
510}
511
512#[cfg(test)]
513mod tests {
514    #![allow(clippy::unwrap_used)]
515
516    use super::PerKey;
517    use crate::eviction::Eviction;
518    use clock_lib::ManualClock;
519    use core::time::Duration;
520    use std::sync::Arc;
521
522    fn assert_send_sync<T: Send + Sync>() {}
523
524    #[test]
525    fn test_perkey_is_send_sync() {
526        assert_send_sync::<PerKey<String>>();
527        assert_send_sync::<PerKey<u64>>();
528    }
529
530    #[test]
531    fn test_keys_are_independent() {
532        let limiter: PerKey<&str> = PerKey::per_second(1);
533        assert!(limiter.try_acquire(&"a"));
534        assert!(!limiter.try_acquire(&"a")); // a is spent
535        assert!(limiter.try_acquire(&"b")); // b is untouched
536    }
537
538    #[test]
539    fn test_first_acquire_creates_exactly_one_key() {
540        let limiter: PerKey<&str> = PerKey::per_second(10);
541        assert_eq!(limiter.len(), 0);
542        assert!(limiter.try_acquire(&"a"));
543        assert_eq!(limiter.len(), 1);
544        assert!(limiter.try_acquire(&"a"));
545        assert_eq!(limiter.len(), 1);
546    }
547
548    #[test]
549    fn test_shard_count_rounds_up_to_power_of_two() {
550        assert_eq!(PerKey::<u64>::per_second(1).with_shards(5).shard_count(), 8);
551        assert_eq!(
552            PerKey::<u64>::per_second(1).with_shards(16).shard_count(),
553            16
554        );
555        assert_eq!(PerKey::<u64>::per_second(1).with_shards(0).shard_count(), 1);
556    }
557
558    #[test]
559    fn test_peek_does_not_create_state() {
560        let limiter: PerKey<&str> = PerKey::per_second(5);
561        assert!(limiter.peek(&"ghost", 1).is_acquired());
562        assert_eq!(limiter.len(), 0, "peek must not insert a key");
563    }
564
565    #[test]
566    fn test_available_reports_full_capacity_for_unseen_key() {
567        let limiter: PerKey<&str> = PerKey::per_second(7);
568        assert_eq!(limiter.available(&"unseen"), 7);
569        assert!(limiter.try_acquire_with_cost(&"seen", 3));
570        assert_eq!(limiter.available(&"seen"), 4);
571    }
572
573    #[test]
574    fn test_refill_under_manual_clock() {
575        let clock = Arc::new(ManualClock::new());
576        let limiter = PerKey::<&str>::per_second(2).with_clock(clock.clone());
577
578        assert!(limiter.try_acquire(&"k"));
579        assert!(limiter.try_acquire(&"k"));
580        assert!(!limiter.try_acquire(&"k"));
581
582        clock.advance(Duration::from_secs(1));
583        assert!(limiter.try_acquire(&"k"));
584    }
585
586    #[test]
587    fn test_capacity_bounds_total_keys_under_unique_flood() {
588        let shards = 8;
589        let cap = 100usize;
590        let limiter: PerKey<u64> = PerKey::per_second(10)
591            .with_shards(shards)
592            .with_eviction(Eviction::capacity(cap));
593
594        for k in 0..10_000u64 {
595            let _ = limiter.try_acquire(&k);
596        }
597
598        let per_shard_cap = cap.div_ceil(shards).max(1);
599        let bound = per_shard_cap * shards;
600        assert!(
601            limiter.len() <= bound,
602            "flood grew to {} keys, bound {bound}",
603            limiter.len()
604        );
605    }
606
607    #[test]
608    fn test_ttl_reclaims_idle_keys_on_later_insert() {
609        let clock = Arc::new(ManualClock::new());
610        let limiter = PerKey::<&str>::per_second(10)
611            .with_clock(clock.clone())
612            .with_eviction(Eviction::idle(Duration::from_millis(1000)).with_capacity(1))
613            .with_shards(1);
614
615        assert!(limiter.try_acquire(&"idle"));
616        assert_eq!(limiter.len(), 1);
617
618        clock.advance(Duration::from_millis(2000));
619        // Inserting a fresh key reclaims the idle one.
620        assert!(limiter.try_acquire(&"fresh"));
621        assert_eq!(limiter.len(), 1, "the idle key should have been reclaimed");
622    }
623
624    #[test]
625    fn test_recently_seen_key_survives_eviction_pressure() {
626        let limiter: PerKey<String> = PerKey::per_second(1_000)
627            .with_shards(1)
628            .with_eviction(Eviction::capacity(4));
629
630        for round in 0..50u64 {
631            assert!(limiter.try_acquire(&"hot".to_string()));
632            let _ = limiter.try_acquire(&round.to_string());
633        }
634        // The hot key was touched every round, so it is never the eviction victim.
635        assert!(limiter.try_acquire(&"hot".to_string()));
636    }
637
638    #[cfg(feature = "runtime")]
639    #[tokio::test]
640    async fn test_acquire_errors_when_cost_exceeds_capacity() {
641        use crate::error::ThrottleError;
642
643        let limiter: PerKey<&str> = PerKey::per_second(5);
644        let err = limiter.acquire_with_cost(&"k", 9).await.unwrap_err();
645        assert_eq!(
646            err,
647            ThrottleError::CostExceedsCapacity {
648                cost: 9,
649                capacity: 5,
650            }
651        );
652    }
653
654    #[cfg(feature = "runtime")]
655    #[tokio::test]
656    async fn test_acquire_waits_then_succeeds() {
657        let limiter: PerKey<&str> = PerKey::per_second(1000);
658        for _ in 0..1000 {
659            assert!(limiter.try_acquire(&"k"));
660        }
661        assert!(!limiter.try_acquire(&"k"));
662        assert!(limiter.acquire(&"k").await.is_ok());
663    }
664}