1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6use uvb_storage_api::{
7 RateLimitConfig, RateLimitError, RateLimitResult, RateLimitScope, RateLimitStore,
8};
9
10#[derive(Clone, Debug)]
11struct RateLimitEntry {
12 count: u32,
13 window_start: i64,
14 penalty_expires_at: Option<i64>,
15}
16
17pub struct InMemoryRateLimitStore {
22 entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
23}
24
25impl InMemoryRateLimitStore {
26 pub fn new() -> Self {
27 Self {
28 entries: Arc::new(RwLock::new(HashMap::new())),
29 }
30 }
31
32 fn make_key(&self, scope: &RateLimitScope) -> String {
33 scope.to_string()
34 }
35
36 async fn cleanup_expired(&self) {
37 let mut entries = self.entries.write().await;
38 let now = chrono::Utc::now().timestamp();
39
40 entries.retain(|_, entry| {
41 if let Some(penalty_expires) = entry.penalty_expires_at {
43 if penalty_expires > now {
44 return true;
45 }
46 }
47 now - entry.window_start < 3600
49 });
50 }
51}
52
53impl Default for InMemoryRateLimitStore {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59#[async_trait]
60impl RateLimitStore for InMemoryRateLimitStore {
61 async fn check_and_increment(
62 &self,
63 scope: &RateLimitScope,
64 config: &RateLimitConfig,
65 ) -> Result<RateLimitResult, RateLimitError> {
66 if rand::random::<f32>() < 0.1 {
68 self.cleanup_expired().await;
69 }
70
71 let key = self.make_key(scope);
72 let now = chrono::Utc::now().timestamp();
73 let mut entries = self.entries.write().await;
74
75 let entry = entries
76 .entry(key.clone())
77 .or_insert_with(|| RateLimitEntry {
78 count: 0,
79 window_start: now,
80 penalty_expires_at: None,
81 });
82
83 if let Some(penalty_expires) = entry.penalty_expires_at {
85 if penalty_expires > now {
86 return Ok(RateLimitResult::denied(
87 entry.count,
88 config.max_attempts,
89 entry.window_start + config.window_secs as i64,
90 )
91 .with_penalty(penalty_expires));
92 } else {
93 entry.penalty_expires_at = None;
95 }
96 }
97
98 let window_elapsed = now - entry.window_start;
100 if window_elapsed >= config.window_secs as i64 {
101 entry.count = 0;
103 entry.window_start = now;
104 }
105
106 entry.count += 1;
108 let current_count = entry.count;
109 let reset_at = entry.window_start + config.window_secs as i64;
110
111 if current_count > config.max_attempts {
112 if let Some(penalty_secs) = config.penalty_secs {
114 entry.penalty_expires_at = Some(now + penalty_secs as i64);
115 }
116
117 Ok(RateLimitResult::denied(
118 current_count,
119 config.max_attempts,
120 reset_at,
121 ))
122 } else {
123 Ok(RateLimitResult::allowed(
124 current_count,
125 config.max_attempts,
126 reset_at,
127 ))
128 }
129 }
130
131 async fn check(
132 &self,
133 scope: &RateLimitScope,
134 config: &RateLimitConfig,
135 ) -> Result<RateLimitResult, RateLimitError> {
136 let key = self.make_key(scope);
137 let now = chrono::Utc::now().timestamp();
138 let entries = self.entries.read().await;
139
140 let entry = match entries.get(&key) {
141 Some(entry) => entry,
142 None => {
143 return Ok(RateLimitResult::allowed(
145 0,
146 config.max_attempts,
147 now + config.window_secs as i64,
148 ));
149 }
150 };
151
152 if let Some(penalty_expires) = entry.penalty_expires_at {
154 if penalty_expires > now {
155 return Ok(RateLimitResult::denied(
156 entry.count,
157 config.max_attempts,
158 entry.window_start + config.window_secs as i64,
159 )
160 .with_penalty(penalty_expires));
161 }
162 }
163
164 let window_elapsed = now - entry.window_start;
166 let current_count = if window_elapsed >= config.window_secs as i64 {
167 0 } else {
169 entry.count
170 };
171
172 let reset_at = entry.window_start + config.window_secs as i64;
173
174 if current_count >= config.max_attempts {
175 Ok(RateLimitResult::denied(
176 current_count,
177 config.max_attempts,
178 reset_at,
179 ))
180 } else {
181 Ok(RateLimitResult::allowed(
182 current_count,
183 config.max_attempts,
184 reset_at,
185 ))
186 }
187 }
188
189 async fn reset(&self, scope: &RateLimitScope) -> Result<(), RateLimitError> {
190 let key = self.make_key(scope);
191 let mut entries = self.entries.write().await;
192 entries.remove(&key);
193 Ok(())
194 }
195
196 async fn apply_penalty(
197 &self,
198 scope: &RateLimitScope,
199 penalty_secs: u64,
200 ) -> Result<(), RateLimitError> {
201 let key = self.make_key(scope);
202 let now = chrono::Utc::now().timestamp();
203 let mut entries = self.entries.write().await;
204
205 let entry = entries.entry(key).or_insert_with(|| RateLimitEntry {
206 count: 0,
207 window_start: now,
208 penalty_expires_at: None,
209 });
210
211 entry.penalty_expires_at = Some(now + penalty_secs as i64);
212 Ok(())
213 }
214
215 async fn is_penalized(&self, scope: &RateLimitScope) -> Result<bool, RateLimitError> {
216 let key = self.make_key(scope);
217 let now = chrono::Utc::now().timestamp();
218 let entries = self.entries.read().await;
219
220 if let Some(entry) = entries.get(&key) {
221 if let Some(penalty_expires) = entry.penalty_expires_at {
222 return Ok(penalty_expires > now);
223 }
224 }
225
226 Ok(false)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use uvb_core::TenantId;
234
235 #[tokio::test]
236 async fn test_rate_limit_basic() {
237 let store = InMemoryRateLimitStore::new();
238 let scope = RateLimitScope::Subject {
239 user_id: "user_1".to_string(),
240 tenant_id: TenantId::new("tenant_a"),
241 };
242 let config = RateLimitConfig::new(3, 60);
243
244 for i in 1..=3 {
246 let result = store.check_and_increment(&scope, &config).await.unwrap();
247 assert!(result.allowed, "attempt {} should be allowed", i);
248 assert_eq!(result.current_attempts, i);
249 }
250
251 let result = store.check_and_increment(&scope, &config).await.unwrap();
253 assert!(!result.allowed);
254 assert_eq!(result.current_attempts, 4);
255 }
256
257 #[tokio::test]
258 async fn test_rate_limit_window_reset() {
259 let store = InMemoryRateLimitStore::new();
260 let scope = RateLimitScope::IpAddress {
261 ip: "203.0.113.1".to_string(),
262 };
263 let config = RateLimitConfig::new(2, 1); let result1 = store.check_and_increment(&scope, &config).await.unwrap();
267 assert!(result1.allowed);
268
269 let result2 = store.check_and_increment(&scope, &config).await.unwrap();
270 assert!(result2.allowed);
271
272 let result3 = store.check_and_increment(&scope, &config).await.unwrap();
273 assert!(!result3.allowed);
274
275 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
277
278 let result4 = store.check_and_increment(&scope, &config).await.unwrap();
280 assert!(result4.allowed);
281 assert_eq!(result4.current_attempts, 1);
282 }
283
284 #[tokio::test]
285 async fn test_penalty() {
286 let store = InMemoryRateLimitStore::new();
287 let scope = RateLimitScope::FactorAttempt {
288 user_id: "user_1".to_string(),
289 tenant_id: TenantId::new("tenant_a"),
290 factor_id: "totp".to_string(),
291 };
292
293 store.apply_penalty(&scope, 5).await.unwrap();
295
296 assert!(store.is_penalized(&scope).await.unwrap());
298
299 let config = RateLimitConfig::new(3, 60);
301 let result = store.check(&scope, &config).await.unwrap();
302 assert!(!result.allowed);
303 assert!(result.penalty_expires_at.is_some());
304 }
305
306 #[tokio::test]
307 async fn test_reset() {
308 let store = InMemoryRateLimitStore::new();
309 let scope = RateLimitScope::Subject {
310 user_id: "user_1".to_string(),
311 tenant_id: TenantId::new("tenant_a"),
312 };
313 let config = RateLimitConfig::new(3, 60);
314
315 for _ in 0..4 {
317 let _ = store.check_and_increment(&scope, &config).await;
318 }
319
320 let result = store.check(&scope, &config).await.unwrap();
322 assert!(!result.allowed);
323
324 store.reset(&scope).await.unwrap();
326
327 let result = store.check(&scope, &config).await.unwrap();
329 assert!(result.allowed);
330 assert_eq!(result.current_attempts, 0);
331 }
332}