Skip to main content

uvb_storage_memory/
ratelimit.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6use uvb_storage_api::{
7    RateLimitConfig, RateLimitError, RateLimitResult, RateLimitScope, RateLimitStore,
8};
9
10#[derive(Clone, Debug)]
11struct RateLimitEntry {
12    count: u32,
13    window_start: i64,
14    penalty_expires_at: Option<i64>,
15}
16
17/// In-memory rate limit store for testing
18///
19/// Not recommended for production use across multiple instances.
20/// Use Redis for distributed rate limiting.
21pub struct InMemoryRateLimitStore {
22    entries: Arc<RwLock<HashMap<String, RateLimitEntry>>>,
23}
24
25impl InMemoryRateLimitStore {
26    pub fn new() -> Self {
27        Self {
28            entries: Arc::new(RwLock::new(HashMap::new())),
29        }
30    }
31
32    fn make_key(&self, scope: &RateLimitScope) -> String {
33        scope.to_string()
34    }
35
36    async fn cleanup_expired(&self) {
37        let mut entries = self.entries.write().await;
38        let now = chrono::Utc::now().timestamp();
39
40        entries.retain(|_, entry| {
41            // Keep if penalty is still active
42            if let Some(penalty_expires) = entry.penalty_expires_at {
43                if penalty_expires > now {
44                    return true;
45                }
46            }
47            // Keep if within window (allow some grace period)
48            now - entry.window_start < 3600
49        });
50    }
51}
52
53impl Default for InMemoryRateLimitStore {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59#[async_trait]
60impl RateLimitStore for InMemoryRateLimitStore {
61    async fn check_and_increment(
62        &self,
63        scope: &RateLimitScope,
64        config: &RateLimitConfig,
65    ) -> Result<RateLimitResult, RateLimitError> {
66        // Periodic cleanup
67        if rand::random::<f32>() < 0.1 {
68            self.cleanup_expired().await;
69        }
70
71        let key = self.make_key(scope);
72        let now = chrono::Utc::now().timestamp();
73        let mut entries = self.entries.write().await;
74
75        let entry = entries
76            .entry(key.clone())
77            .or_insert_with(|| RateLimitEntry {
78                count: 0,
79                window_start: now,
80                penalty_expires_at: None,
81            });
82
83        // Check penalty
84        if let Some(penalty_expires) = entry.penalty_expires_at {
85            if penalty_expires > now {
86                return Ok(RateLimitResult::denied(
87                    entry.count,
88                    config.max_attempts,
89                    entry.window_start + config.window_secs as i64,
90                )
91                .with_penalty(penalty_expires));
92            } else {
93                // Penalty expired, clear it
94                entry.penalty_expires_at = None;
95            }
96        }
97
98        // Check if window expired
99        let window_elapsed = now - entry.window_start;
100        if window_elapsed >= config.window_secs as i64 {
101            // Reset window
102            entry.count = 0;
103            entry.window_start = now;
104        }
105
106        // Increment
107        entry.count += 1;
108        let current_count = entry.count;
109        let reset_at = entry.window_start + config.window_secs as i64;
110
111        if current_count > config.max_attempts {
112            // Apply penalty if configured
113            if let Some(penalty_secs) = config.penalty_secs {
114                entry.penalty_expires_at = Some(now + penalty_secs as i64);
115            }
116
117            Ok(RateLimitResult::denied(
118                current_count,
119                config.max_attempts,
120                reset_at,
121            ))
122        } else {
123            Ok(RateLimitResult::allowed(
124                current_count,
125                config.max_attempts,
126                reset_at,
127            ))
128        }
129    }
130
131    async fn check(
132        &self,
133        scope: &RateLimitScope,
134        config: &RateLimitConfig,
135    ) -> Result<RateLimitResult, RateLimitError> {
136        let key = self.make_key(scope);
137        let now = chrono::Utc::now().timestamp();
138        let entries = self.entries.read().await;
139
140        let entry = match entries.get(&key) {
141            Some(entry) => entry,
142            None => {
143                // No entry means no attempts yet
144                return Ok(RateLimitResult::allowed(
145                    0,
146                    config.max_attempts,
147                    now + config.window_secs as i64,
148                ));
149            }
150        };
151
152        // Check penalty
153        if let Some(penalty_expires) = entry.penalty_expires_at {
154            if penalty_expires > now {
155                return Ok(RateLimitResult::denied(
156                    entry.count,
157                    config.max_attempts,
158                    entry.window_start + config.window_secs as i64,
159                )
160                .with_penalty(penalty_expires));
161            }
162        }
163
164        // Check if window expired
165        let window_elapsed = now - entry.window_start;
166        let current_count = if window_elapsed >= config.window_secs as i64 {
167            0 // Window expired, count is effectively 0
168        } else {
169            entry.count
170        };
171
172        let reset_at = entry.window_start + config.window_secs as i64;
173
174        if current_count >= config.max_attempts {
175            Ok(RateLimitResult::denied(
176                current_count,
177                config.max_attempts,
178                reset_at,
179            ))
180        } else {
181            Ok(RateLimitResult::allowed(
182                current_count,
183                config.max_attempts,
184                reset_at,
185            ))
186        }
187    }
188
189    async fn reset(&self, scope: &RateLimitScope) -> Result<(), RateLimitError> {
190        let key = self.make_key(scope);
191        let mut entries = self.entries.write().await;
192        entries.remove(&key);
193        Ok(())
194    }
195
196    async fn apply_penalty(
197        &self,
198        scope: &RateLimitScope,
199        penalty_secs: u64,
200    ) -> Result<(), RateLimitError> {
201        let key = self.make_key(scope);
202        let now = chrono::Utc::now().timestamp();
203        let mut entries = self.entries.write().await;
204
205        let entry = entries.entry(key).or_insert_with(|| RateLimitEntry {
206            count: 0,
207            window_start: now,
208            penalty_expires_at: None,
209        });
210
211        entry.penalty_expires_at = Some(now + penalty_secs as i64);
212        Ok(())
213    }
214
215    async fn is_penalized(&self, scope: &RateLimitScope) -> Result<bool, RateLimitError> {
216        let key = self.make_key(scope);
217        let now = chrono::Utc::now().timestamp();
218        let entries = self.entries.read().await;
219
220        if let Some(entry) = entries.get(&key) {
221            if let Some(penalty_expires) = entry.penalty_expires_at {
222                return Ok(penalty_expires > now);
223            }
224        }
225
226        Ok(false)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use uvb_core::TenantId;
234
235    #[tokio::test]
236    async fn test_rate_limit_basic() {
237        let store = InMemoryRateLimitStore::new();
238        let scope = RateLimitScope::Subject {
239            user_id: "user_1".to_string(),
240            tenant_id: TenantId::new("tenant_a"),
241        };
242        let config = RateLimitConfig::new(3, 60);
243
244        // First 3 attempts should succeed
245        for i in 1..=3 {
246            let result = store.check_and_increment(&scope, &config).await.unwrap();
247            assert!(result.allowed, "attempt {} should be allowed", i);
248            assert_eq!(result.current_attempts, i);
249        }
250
251        // 4th attempt should be denied
252        let result = store.check_and_increment(&scope, &config).await.unwrap();
253        assert!(!result.allowed);
254        assert_eq!(result.current_attempts, 4);
255    }
256
257    #[tokio::test]
258    async fn test_rate_limit_window_reset() {
259        let store = InMemoryRateLimitStore::new();
260        let scope = RateLimitScope::IpAddress {
261            ip: "203.0.113.1".to_string(),
262        };
263        let config = RateLimitConfig::new(2, 1); // 2 attempts per second
264
265        // Use up the limit
266        let result1 = store.check_and_increment(&scope, &config).await.unwrap();
267        assert!(result1.allowed);
268
269        let result2 = store.check_and_increment(&scope, &config).await.unwrap();
270        assert!(result2.allowed);
271
272        let result3 = store.check_and_increment(&scope, &config).await.unwrap();
273        assert!(!result3.allowed);
274
275        // Wait for window to expire
276        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
277
278        // Should be allowed again
279        let result4 = store.check_and_increment(&scope, &config).await.unwrap();
280        assert!(result4.allowed);
281        assert_eq!(result4.current_attempts, 1);
282    }
283
284    #[tokio::test]
285    async fn test_penalty() {
286        let store = InMemoryRateLimitStore::new();
287        let scope = RateLimitScope::FactorAttempt {
288            user_id: "user_1".to_string(),
289            tenant_id: TenantId::new("tenant_a"),
290            factor_id: "totp".to_string(),
291        };
292
293        // Apply penalty
294        store.apply_penalty(&scope, 5).await.unwrap();
295
296        // Should be penalized
297        assert!(store.is_penalized(&scope).await.unwrap());
298
299        // Check should return denied with penalty
300        let config = RateLimitConfig::new(3, 60);
301        let result = store.check(&scope, &config).await.unwrap();
302        assert!(!result.allowed);
303        assert!(result.penalty_expires_at.is_some());
304    }
305
306    #[tokio::test]
307    async fn test_reset() {
308        let store = InMemoryRateLimitStore::new();
309        let scope = RateLimitScope::Subject {
310            user_id: "user_1".to_string(),
311            tenant_id: TenantId::new("tenant_a"),
312        };
313        let config = RateLimitConfig::new(3, 60);
314
315        // Use up the limit
316        for _ in 0..4 {
317            let _ = store.check_and_increment(&scope, &config).await;
318        }
319
320        // Should be denied
321        let result = store.check(&scope, &config).await.unwrap();
322        assert!(!result.allowed);
323
324        // Reset
325        store.reset(&scope).await.unwrap();
326
327        // Should be allowed again
328        let result = store.check(&scope, &config).await.unwrap();
329        assert!(result.allowed);
330        assert_eq!(result.current_attempts, 0);
331    }
332}