Skip to main content

tork_core/throttle/
store.rs

1//! The pluggable counter backend for rate limiting.
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant};
6
7use crate::error::Result;
8use crate::router::BoxFuture;
9
10/// A backend that counts requests per key within a time window.
11///
12/// The throttler buckets each key by window and asks the store to count hits.
13/// An in-memory store ([`MemoryThrottleStore`]) suits a single instance; a Redis
14/// store (behind the `redis` feature) shares the count across instances.
15pub trait ThrottleStore: Send + Sync + 'static {
16    /// Atomically increments the counter at `key`, setting it to expire after
17    /// `ttl` when first created, and returns the new count.
18    fn incr(&self, key: String, ttl: Duration) -> BoxFuture<'_, Result<u64>>;
19
20    /// Returns the current count at `key` without changing it (`0` if absent or
21    /// expired). Used by the sliding-window estimate to read the previous window.
22    fn count(&self, key: String) -> BoxFuture<'_, Result<u64>>;
23}
24
25/// Number of entries past which a sweep of expired entries runs, bounding growth.
26const SWEEP_THRESHOLD: usize = 4096;
27
28/// One counter and the moment it expires.
29struct Entry {
30    count: u64,
31    expires_at: Instant,
32}
33
34/// An in-memory [`ThrottleStore`] backed by a map of counters.
35///
36/// Counters expire lazily on access, and the whole map is swept once it grows
37/// past a threshold, so a server with many distinct keys stays bounded.
38#[derive(Clone, Default)]
39pub struct MemoryThrottleStore {
40    inner: Arc<Mutex<HashMap<String, Entry>>>,
41}
42
43impl MemoryThrottleStore {
44    /// Creates an empty store.
45    pub fn new() -> Self {
46        Self::default()
47    }
48}
49
50impl ThrottleStore for MemoryThrottleStore {
51    fn incr(&self, key: String, ttl: Duration) -> BoxFuture<'_, Result<u64>> {
52        Box::pin(async move {
53            let now = Instant::now();
54            let mut map = self
55                .inner
56                .lock()
57                .unwrap_or_else(|poisoned| poisoned.into_inner());
58
59            if map.len() > SWEEP_THRESHOLD {
60                map.retain(|_, entry| entry.expires_at > now);
61            }
62
63            let entry = map.entry(key).or_insert(Entry {
64                count: 0,
65                expires_at: now + ttl,
66            });
67            // A new window starts when the previous one has expired.
68            if entry.expires_at <= now {
69                entry.count = 0;
70                entry.expires_at = now + ttl;
71            }
72            entry.count += 1;
73            Ok(entry.count)
74        })
75    }
76
77    fn count(&self, key: String) -> BoxFuture<'_, Result<u64>> {
78        Box::pin(async move {
79            let now = Instant::now();
80            let map = self
81                .inner
82                .lock()
83                .unwrap_or_else(|poisoned| poisoned.into_inner());
84            let count = map
85                .get(&key)
86                .filter(|entry| entry.expires_at > now)
87                .map(|entry| entry.count)
88                .unwrap_or(0);
89            Ok(count)
90        })
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[tokio::test]
99    async fn counts_within_a_window_then_resets_after_it() {
100        let store = MemoryThrottleStore::new();
101        let ttl = Duration::from_millis(80);
102
103        assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 1);
104        assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 2);
105        assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 3);
106
107        // After the window elapses the counter starts over.
108        tokio::time::sleep(Duration::from_millis(120)).await;
109        assert_eq!(store.incr("k".into(), ttl).await.unwrap(), 1);
110    }
111
112    #[tokio::test]
113    async fn distinct_keys_count_independently() {
114        let store = MemoryThrottleStore::new();
115        let ttl = Duration::from_secs(60);
116        assert_eq!(store.incr("a".into(), ttl).await.unwrap(), 1);
117        assert_eq!(store.incr("b".into(), ttl).await.unwrap(), 1);
118        assert_eq!(store.incr("a".into(), ttl).await.unwrap(), 2);
119    }
120}