1use std::time::Duration;
10
11use super::rate::Rate;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct Decision {
16 pub allowed: bool,
18 pub limit: u64,
20 pub remaining: u64,
22 pub retry_after: Option<Duration>,
24}
25
26#[async_trait::async_trait]
28pub trait RateLimitStore: Send + Sync {
29 async fn hit(&self, key: &str, rate: &Rate) -> Decision;
34}
35
36#[derive(Debug)]
41pub struct MemoryStore {
42 windows: dashmap::DashMap<String, Window>,
44 base: std::time::Instant,
45 last_sweep_ms: std::sync::atomic::AtomicU64,
46}
47
48#[derive(Debug, Clone, Copy)]
49struct Window {
50 expires_at: std::time::Instant,
51 count: u64,
52}
53
54const SWEEP_INTERVAL_MS: u64 = 60_000;
56
57impl Default for MemoryStore {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl MemoryStore {
64 pub fn new() -> Self {
66 Self {
67 windows: dashmap::DashMap::new(),
68 base: std::time::Instant::now(),
69 last_sweep_ms: std::sync::atomic::AtomicU64::new(0),
70 }
71 }
72
73 fn evict_expired(&self, now: std::time::Instant) {
79 self.windows.retain(|_, w| w.expires_at > now);
80 }
81
82 fn maybe_sweep(&self, now: std::time::Instant) {
85 use std::sync::atomic::Ordering;
86 let now_ms = now.duration_since(self.base).as_millis() as u64;
87 let last = self.last_sweep_ms.load(Ordering::Relaxed);
88 if now_ms.saturating_sub(last) < SWEEP_INTERVAL_MS {
89 return;
90 }
91 if self
92 .last_sweep_ms
93 .compare_exchange(last, now_ms, Ordering::Relaxed, Ordering::Relaxed)
94 .is_ok()
95 {
96 self.evict_expired(now);
97 }
98 }
99}
100
101#[async_trait::async_trait]
102impl RateLimitStore for MemoryStore {
103 async fn hit(&self, key: &str, rate: &Rate) -> Decision {
104 let now = std::time::Instant::now();
105 self.maybe_sweep(now);
106
107 let mut entry = self.windows.entry(key.to_string()).or_insert(Window {
108 expires_at: now + rate.window,
109 count: 0,
110 });
111
112 if now >= entry.expires_at {
114 entry.expires_at = now + rate.window;
115 entry.count = 0;
116 }
117 entry.count += 1;
118
119 let count = entry.count;
120 let expires_at = entry.expires_at;
121 drop(entry);
122
123 let allowed = count <= rate.limit;
124 Decision {
125 allowed,
126 limit: rate.limit,
127 remaining: rate.limit.saturating_sub(count),
128 retry_after: (!allowed).then(|| expires_at.saturating_duration_since(now)),
129 }
130 }
131}
132
133#[cfg(feature = "redis")]
140pub struct RedisStore {
141 client: redis::Client,
142 conn: tokio::sync::OnceCell<redis::aio::MultiplexedConnection>,
143}
144
145#[cfg(feature = "redis")]
146impl RedisStore {
147 pub fn open(url: &str) -> redis::RedisResult<Self> {
152 Ok(Self {
153 client: redis::Client::open(url)?,
154 conn: tokio::sync::OnceCell::new(),
155 })
156 }
157
158 async fn connection(&self) -> redis::RedisResult<redis::aio::MultiplexedConnection> {
160 self.conn
161 .get_or_try_init(|| self.client.get_multiplexed_async_connection())
162 .await
163 .cloned()
164 }
165
166 fn fail_open(rate: &Rate, err: redis::RedisError) -> Decision {
169 tracing::warn!("rate-limit store unavailable, allowing request: {err}");
170 Decision {
171 allowed: true,
172 limit: rate.limit,
173 remaining: rate.limit,
174 retry_after: None,
175 }
176 }
177}
178
179#[cfg(feature = "redis")]
180#[async_trait::async_trait]
181impl RateLimitStore for RedisStore {
182 async fn hit(&self, key: &str, rate: &Rate) -> Decision {
183 let window_ms = rate.window.as_millis().max(1) as u64;
184 let mut conn = match self.connection().await {
185 Ok(c) => c,
186 Err(e) => return Self::fail_open(rate, e),
187 };
188
189 let script = redis::Script::new(
194 r"local c = redis.call('INCR', KEYS[1])
195 if redis.call('PTTL', KEYS[1]) < 0 then
196 redis.call('PEXPIRE', KEYS[1], ARGV[1])
197 end
198 return {c, redis.call('PTTL', KEYS[1])}",
199 );
200 let (count, pttl): (i64, i64) =
201 match script.key(key).arg(window_ms).invoke_async(&mut conn).await {
202 Ok(v) => v,
203 Err(e) => return Self::fail_open(rate, e),
204 };
205
206 let count = count.max(0) as u64;
207 let allowed = count <= rate.limit;
208 Decision {
209 allowed,
210 limit: rate.limit,
211 remaining: rate.limit.saturating_sub(count),
212 retry_after: (!allowed).then(|| Duration::from_millis(pttl.max(0) as u64)),
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[tokio::test]
222 async fn memory_store_evicts_expired_entries() {
223 let store = MemoryStore::new();
224 let rate = Rate {
225 limit: 5,
226 window: Duration::from_millis(20),
227 };
228 store.hit("a", &rate).await;
229 store.hit("b", &rate).await;
230 assert_eq!(store.windows.len(), 2);
231
232 tokio::time::sleep(Duration::from_millis(30)).await;
233 store.evict_expired(std::time::Instant::now());
236 assert_eq!(store.windows.len(), 0);
237 }
238
239 #[tokio::test]
240 async fn memory_store_allows_then_blocks_within_window() {
241 let store = MemoryStore::new();
242 let rate = Rate {
243 limit: 2,
244 window: Duration::from_secs(60),
245 };
246
247 let d1 = store.hit("k", &rate).await;
248 assert!(d1.allowed && d1.remaining == 1);
249 let d2 = store.hit("k", &rate).await;
250 assert!(d2.allowed && d2.remaining == 0);
251 let d3 = store.hit("k", &rate).await;
252 assert!(!d3.allowed);
253 assert_eq!(d3.remaining, 0);
254 assert!(d3.retry_after.is_some());
255 }
256
257 #[tokio::test]
258 async fn memory_store_resets_after_window() {
259 let store = MemoryStore::new();
260 let rate = Rate {
261 limit: 1,
262 window: Duration::from_millis(50),
263 };
264
265 assert!(store.hit("k", &rate).await.allowed);
266 assert!(!store.hit("k", &rate).await.allowed);
267 tokio::time::sleep(Duration::from_millis(60)).await;
268 assert!(store.hit("k", &rate).await.allowed);
270 }
271
272 #[tokio::test]
273 async fn memory_store_isolates_keys() {
274 let store = MemoryStore::new();
275 let rate = Rate {
276 limit: 1,
277 window: Duration::from_secs(60),
278 };
279 assert!(store.hit("a", &rate).await.allowed);
280 assert!(store.hit("b", &rate).await.allowed);
282 assert!(!store.hit("a", &rate).await.allowed);
283 }
284}