Skip to main content

structured_proxy/shield/
store.rs

1//! Rate-limit counter storage.
2//!
3//! [`RateLimitStore`] abstracts where per-key counters live. [`MemoryStore`] is
4//! the default and keeps counters in-process (per replica). [`RedisStore`]
5//! (behind the `redis` feature) shares counters across replicas, which is what
6//! a multi-instance deployment behind a load balancer needs for correct global
7//! limits.
8
9use std::time::Duration;
10
11use super::rate::Rate;
12
13/// Outcome of recording one request against a key.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct Decision {
16    /// Whether the request is within the limit.
17    pub allowed: bool,
18    /// The configured limit (for the `X-RateLimit-Limit` header).
19    pub limit: u64,
20    /// Requests remaining in the current window (0 once exceeded).
21    pub remaining: u64,
22    /// How long until the window resets, when the request is rejected.
23    pub retry_after: Option<Duration>,
24}
25
26/// A backend that records request hits and decides whether each is allowed.
27#[async_trait::async_trait]
28pub trait RateLimitStore: Send + Sync {
29    /// Record one hit for `key` and return the limiting decision for `rate`.
30    ///
31    /// A store that cannot reach its backend should fail open (allow the
32    /// request) rather than reject legitimate traffic.
33    async fn hit(&self, key: &str, rate: &Rate) -> Decision;
34}
35
36/// In-process fixed-window counter store (per replica).
37///
38/// Counters are not shared between replicas, so global limits only hold for a
39/// single instance. Use [`RedisStore`] for multi-instance deployments.
40#[derive(Debug)]
41pub struct MemoryStore {
42    // no-std: caller-provided Clock + spin/hashbrown map.
43    windows: dashmap::DashMap<String, Window>,
44    base: std::time::Instant,
45    last_sweep_ms: std::sync::atomic::AtomicU64,
46}
47
48#[derive(Debug, Clone, Copy)]
49struct Window {
50    expires_at: std::time::Instant,
51    count: u64,
52}
53
54/// Run eviction of expired entries at most once per this interval.
55const SWEEP_INTERVAL_MS: u64 = 60_000;
56
57impl Default for MemoryStore {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl MemoryStore {
64    /// Create an empty store.
65    pub fn new() -> Self {
66        Self {
67            windows: dashmap::DashMap::new(),
68            base: std::time::Instant::now(),
69            last_sweep_ms: std::sync::atomic::AtomicU64::new(0),
70        }
71    }
72
73    /// Drop every entry whose window has elapsed.
74    ///
75    /// Entries are reset lazily on access, but keys that are never hit again
76    /// would otherwise linger forever; with client-controlled key cardinality
77    /// (IP / identifier) that is an unbounded-growth / OOM risk.
78    fn evict_expired(&self, now: std::time::Instant) {
79        self.windows.retain(|_, w| w.expires_at > now);
80    }
81
82    /// Evict expired entries at most once per [`SWEEP_INTERVAL_MS`]; the first
83    /// thread past the interval claims the sweep so it stays O(n) infrequently.
84    fn maybe_sweep(&self, now: std::time::Instant) {
85        use std::sync::atomic::Ordering;
86        let now_ms = now.duration_since(self.base).as_millis() as u64;
87        let last = self.last_sweep_ms.load(Ordering::Relaxed);
88        if now_ms.saturating_sub(last) < SWEEP_INTERVAL_MS {
89            return;
90        }
91        if self
92            .last_sweep_ms
93            .compare_exchange(last, now_ms, Ordering::Relaxed, Ordering::Relaxed)
94            .is_ok()
95        {
96            self.evict_expired(now);
97        }
98    }
99}
100
101#[async_trait::async_trait]
102impl RateLimitStore for MemoryStore {
103    async fn hit(&self, key: &str, rate: &Rate) -> Decision {
104        let now = std::time::Instant::now();
105        self.maybe_sweep(now);
106
107        let mut entry = self.windows.entry(key.to_string()).or_insert(Window {
108            expires_at: now + rate.window,
109            count: 0,
110        });
111
112        // Reset the counter once the current window has elapsed.
113        if now >= entry.expires_at {
114            entry.expires_at = now + rate.window;
115            entry.count = 0;
116        }
117        entry.count += 1;
118
119        let count = entry.count;
120        let expires_at = entry.expires_at;
121        drop(entry);
122
123        let allowed = count <= rate.limit;
124        Decision {
125            allowed,
126            limit: rate.limit,
127            remaining: rate.limit.saturating_sub(count),
128            retry_after: (!allowed).then(|| expires_at.saturating_duration_since(now)),
129        }
130    }
131}
132
133/// Redis-backed fixed-window counter store, shared across replicas.
134///
135/// The client is opened eagerly (URL validation) but the multiplexed
136/// connection is established lazily on first use, so construction stays
137/// synchronous and a Redis that is briefly unavailable at startup does not
138/// block the proxy from booting.
139#[cfg(feature = "redis")]
140pub struct RedisStore {
141    client: redis::Client,
142    conn: tokio::sync::OnceCell<redis::aio::MultiplexedConnection>,
143}
144
145#[cfg(feature = "redis")]
146impl RedisStore {
147    /// Open a Redis client for `url` (e.g. `redis://127.0.0.1/`).
148    ///
149    /// # Errors
150    /// Returns the underlying Redis error when the URL is invalid.
151    pub fn open(url: &str) -> redis::RedisResult<Self> {
152        Ok(Self {
153            client: redis::Client::open(url)?,
154            conn: tokio::sync::OnceCell::new(),
155        })
156    }
157
158    /// Get the shared multiplexed connection, establishing it on first call.
159    async fn connection(&self) -> redis::RedisResult<redis::aio::MultiplexedConnection> {
160        self.conn
161            .get_or_try_init(|| self.client.get_multiplexed_async_connection())
162            .await
163            .cloned()
164    }
165
166    /// Allow the request when Redis is unreachable: an outage must not take the
167    /// proxy down.
168    fn fail_open(rate: &Rate, err: redis::RedisError) -> Decision {
169        tracing::warn!("rate-limit store unavailable, allowing request: {err}");
170        Decision {
171            allowed: true,
172            limit: rate.limit,
173            remaining: rate.limit,
174            retry_after: None,
175        }
176    }
177}
178
179#[cfg(feature = "redis")]
180#[async_trait::async_trait]
181impl RateLimitStore for RedisStore {
182    async fn hit(&self, key: &str, rate: &Rate) -> Decision {
183        let window_ms = rate.window.as_millis().max(1) as u64;
184        let mut conn = match self.connection().await {
185            Ok(c) => c,
186            Err(e) => return Self::fail_open(rate, e),
187        };
188
189        // INCR and (re)set the TTL atomically in one round trip. Doing them in
190        // separate commands risks an immortal key if INCR lands but EXPIRE does
191        // not; the PTTL<0 guard also re-arms the TTL on any key that somehow
192        // lost it, so a key can never accumulate increments forever.
193        let script = redis::Script::new(
194            r"local c = redis.call('INCR', KEYS[1])
195              if redis.call('PTTL', KEYS[1]) < 0 then
196                redis.call('PEXPIRE', KEYS[1], ARGV[1])
197              end
198              return {c, redis.call('PTTL', KEYS[1])}",
199        );
200        let (count, pttl): (i64, i64) =
201            match script.key(key).arg(window_ms).invoke_async(&mut conn).await {
202                Ok(v) => v,
203                Err(e) => return Self::fail_open(rate, e),
204            };
205
206        let count = count.max(0) as u64;
207        let allowed = count <= rate.limit;
208        Decision {
209            allowed,
210            limit: rate.limit,
211            remaining: rate.limit.saturating_sub(count),
212            retry_after: (!allowed).then(|| Duration::from_millis(pttl.max(0) as u64)),
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[tokio::test]
222    async fn memory_store_evicts_expired_entries() {
223        let store = MemoryStore::new();
224        let rate = Rate {
225            limit: 5,
226            window: Duration::from_millis(20),
227        };
228        store.hit("a", &rate).await;
229        store.hit("b", &rate).await;
230        assert_eq!(store.windows.len(), 2);
231
232        tokio::time::sleep(Duration::from_millis(30)).await;
233        // Both windows have elapsed; eviction reclaims them so the map can't
234        // grow without bound from keys that are never hit again.
235        store.evict_expired(std::time::Instant::now());
236        assert_eq!(store.windows.len(), 0);
237    }
238
239    #[tokio::test]
240    async fn memory_store_allows_then_blocks_within_window() {
241        let store = MemoryStore::new();
242        let rate = Rate {
243            limit: 2,
244            window: Duration::from_secs(60),
245        };
246
247        let d1 = store.hit("k", &rate).await;
248        assert!(d1.allowed && d1.remaining == 1);
249        let d2 = store.hit("k", &rate).await;
250        assert!(d2.allowed && d2.remaining == 0);
251        let d3 = store.hit("k", &rate).await;
252        assert!(!d3.allowed);
253        assert_eq!(d3.remaining, 0);
254        assert!(d3.retry_after.is_some());
255    }
256
257    #[tokio::test]
258    async fn memory_store_resets_after_window() {
259        let store = MemoryStore::new();
260        let rate = Rate {
261            limit: 1,
262            window: Duration::from_millis(50),
263        };
264
265        assert!(store.hit("k", &rate).await.allowed);
266        assert!(!store.hit("k", &rate).await.allowed);
267        tokio::time::sleep(Duration::from_millis(60)).await;
268        // New window: the counter has reset.
269        assert!(store.hit("k", &rate).await.allowed);
270    }
271
272    #[tokio::test]
273    async fn memory_store_isolates_keys() {
274        let store = MemoryStore::new();
275        let rate = Rate {
276            limit: 1,
277            window: Duration::from_secs(60),
278        };
279        assert!(store.hit("a", &rate).await.allowed);
280        // A different key has its own independent counter.
281        assert!(store.hit("b", &rate).await.allowed);
282        assert!(!store.hit("a", &rate).await.allowed);
283    }
284}