1use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use parking_lot::RwLock;
13
14#[derive(Clone, Debug)]
16pub struct CachedUser {
17 pub user_id: Option<String>,
19 pub traffic_limit: i64,
21 pub traffic_used: i64,
23 pub expires_at: i64,
25 pub enabled: bool,
27 pub cached_at: Instant,
29}
30
31#[derive(Debug)]
33struct CacheEntry {
34 user: CachedUser,
35 expires_at: Instant,
36}
37
38#[derive(Debug)]
48pub struct AuthCache {
49 cache: RwLock<HashMap<String, CacheEntry>>,
51 ttl: Duration,
53
54 traffic_deltas: RwLock<HashMap<String, i64>>,
56
57 neg_cache: RwLock<HashMap<String, Instant>>,
59 neg_ttl: Duration,
61
62 hits: AtomicU64,
64 misses: AtomicU64,
66}
67
68impl AuthCache {
69 pub fn new(ttl: Duration, neg_ttl: Duration) -> Self {
74 Self {
75 cache: RwLock::new(HashMap::new()),
76 ttl,
77 traffic_deltas: RwLock::new(HashMap::new()),
78 neg_cache: RwLock::new(HashMap::new()),
79 neg_ttl,
80 hits: AtomicU64::new(0),
81 misses: AtomicU64::new(0),
82 }
83 }
84
85 pub fn get(&self, hash: &str) -> Option<CachedUser> {
91 let cache = self.cache.read();
92 if let Some(entry) = cache.get(hash)
93 && Instant::now() < entry.expires_at
94 {
95 self.hits.fetch_add(1, Ordering::Relaxed);
96 return Some(entry.user.clone());
97 }
98 drop(cache);
99
100 self.misses.fetch_add(1, Ordering::Relaxed);
101 None
102 }
103
104 pub fn insert(&self, hash: String, user: CachedUser) {
106 let entry = CacheEntry {
107 user,
108 expires_at: Instant::now() + self.ttl,
109 };
110 self.cache.write().insert(hash, entry);
111 }
112
113 pub fn remove(&self, hash: &str) {
115 self.cache.write().remove(hash);
116 }
117
118 pub fn invalidate_user(&self, user_id: &str) {
123 self.cache
124 .write()
125 .retain(|_, entry| entry.user.user_id.as_deref() != Some(user_id));
126 self.traffic_deltas.write().remove(user_id);
127 }
128
129 pub fn clear(&self) {
131 self.cache.write().clear();
132 self.traffic_deltas.write().clear();
133 self.neg_cache.write().clear();
134 }
135
136 pub fn cleanup_expired(&self) {
138 let now = Instant::now();
139 self.cache.write().retain(|_, entry| entry.expires_at > now);
140 self.neg_cache.write().retain(|_, &mut exp| exp > now);
141 }
142
143 #[allow(clippy::cast_possible_wrap)]
150 pub fn add_traffic_delta(&self, user_id: &str, bytes: u64) {
151 *self
152 .traffic_deltas
153 .write()
154 .entry(user_id.to_string())
155 .or_insert(0) += bytes as i64;
156 }
157
158 pub fn get_traffic_delta(&self, user_id: &str) -> i64 {
163 self.traffic_deltas
164 .read()
165 .get(user_id)
166 .copied()
167 .unwrap_or(0)
168 }
169
170 pub fn clear_traffic_delta(&self, user_id: &str) {
175 self.traffic_deltas.write().remove(user_id);
176 }
177
178 pub fn insert_negative(&self, hash: &str) {
185 if self.neg_ttl > Duration::ZERO {
186 self.neg_cache
187 .write()
188 .insert(hash.to_string(), Instant::now() + self.neg_ttl);
189 }
190 }
191
192 pub fn is_negative(&self, hash: &str) -> bool {
197 if self.neg_ttl == Duration::ZERO {
198 return false;
199 }
200 let cache = self.neg_cache.read();
201 if let Some(&exp) = cache.get(hash)
202 && Instant::now() < exp
203 {
204 return true;
205 }
206 false
207 }
208
209 pub fn remove_negative(&self, hash: &str) {
214 self.neg_cache.write().remove(hash);
215 }
216
217 pub fn stats(&self) -> CacheStats {
221 let cache = self.cache.read();
222 CacheStats {
223 size: cache.len(),
224 neg_size: self.neg_cache.read().len(),
225 hits: self.hits.load(Ordering::Relaxed),
226 misses: self.misses.load(Ordering::Relaxed),
227 ttl: self.ttl,
228 }
229 }
230}
231
232#[derive(Debug, Clone)]
234pub struct CacheStats {
235 pub size: usize,
237 pub neg_size: usize,
239 pub hits: u64,
241 pub misses: u64,
243 pub ttl: Duration,
245}
246
247impl CacheStats {
248 pub fn hit_rate(&self) -> f64 {
250 let total = self.hits + self.misses;
251 if total == 0 {
252 0.0
253 } else {
254 self.hits as f64 / total as f64
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 fn make_cache() -> AuthCache {
264 AuthCache::new(Duration::from_secs(60), Duration::from_secs(5))
265 }
266
267 fn make_user(user_id: &str, traffic_limit: i64, traffic_used: i64) -> CachedUser {
268 CachedUser {
269 user_id: Some(user_id.to_string()),
270 traffic_limit,
271 traffic_used,
272 expires_at: 0,
273 enabled: true,
274 cached_at: Instant::now(),
275 }
276 }
277
278 #[test]
279 fn test_cache_basic() {
280 let cache = make_cache();
281 let user = make_user("user1", 1000, 100);
282
283 cache.insert("hash1".to_string(), user);
284 let cached = cache.get("hash1").unwrap();
285 assert_eq!(cached.user_id, Some("user1".to_string()));
286 assert_eq!(cached.traffic_limit, 1000);
287
288 assert!(cache.get("hash2").is_none());
289 }
290
291 #[test]
292 fn test_cache_expiration() {
293 let cache = AuthCache::new(Duration::from_millis(10), Duration::ZERO);
294 let user = make_user("user1", 0, 0);
295
296 cache.insert("hash1".to_string(), user);
297 assert!(cache.get("hash1").is_some());
298
299 std::thread::sleep(Duration::from_millis(20));
300 assert!(cache.get("hash1").is_none());
301 }
302
303 #[test]
304 fn test_cache_invalidate_user() {
305 let cache = make_cache();
306
307 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
308 cache.insert("hash2".to_string(), make_user("user2", 0, 0));
309
310 cache.add_traffic_delta("user1", 500);
312
313 cache.invalidate_user("user1");
314
315 assert!(cache.get("hash1").is_none());
316 assert!(cache.get("hash2").is_some());
317 assert_eq!(cache.get_traffic_delta("user1"), 0);
319 }
320
321 #[test]
322 fn test_cache_stats() {
323 let cache = make_cache();
324 let user = CachedUser {
325 user_id: None,
326 traffic_limit: 0,
327 traffic_used: 0,
328 expires_at: 0,
329 enabled: true,
330 cached_at: Instant::now(),
331 };
332
333 cache.insert("hash1".to_string(), user);
334
335 cache.get("hash1"); cache.get("hash1"); cache.get("hash2"); let stats = cache.stats();
340 assert_eq!(stats.size, 1);
341 assert_eq!(stats.hits, 2);
342 assert_eq!(stats.misses, 1);
343 assert!((stats.hit_rate() - 0.666).abs() < 0.01);
344 }
345
346 #[test]
349 fn test_traffic_delta_accumulates() {
350 let cache = make_cache();
351
352 cache.add_traffic_delta("user1", 100);
353 cache.add_traffic_delta("user1", 200);
354 cache.add_traffic_delta("user1", 300);
355
356 assert_eq!(cache.get_traffic_delta("user1"), 600);
357 assert_eq!(cache.get_traffic_delta("user2"), 0); }
359
360 #[test]
361 fn test_traffic_delta_clear() {
362 let cache = make_cache();
363
364 cache.add_traffic_delta("user1", 500);
365 assert_eq!(cache.get_traffic_delta("user1"), 500);
366
367 cache.clear_traffic_delta("user1");
368 assert_eq!(cache.get_traffic_delta("user1"), 0);
369 }
370
371 #[test]
372 fn test_clear_resets_everything() {
373 let cache = make_cache();
374
375 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
376 cache.add_traffic_delta("user1", 100);
377 cache.insert_negative("bad_hash");
378
379 cache.clear();
380
381 assert!(cache.get("hash1").is_none());
382 assert_eq!(cache.get_traffic_delta("user1"), 0);
383 assert!(!cache.is_negative("bad_hash"));
384 }
385
386 #[test]
389 fn test_negative_cache_basic() {
390 let cache = make_cache();
391
392 assert!(!cache.is_negative("bad_hash"));
393
394 cache.insert_negative("bad_hash");
395 assert!(cache.is_negative("bad_hash"));
396 assert!(!cache.is_negative("other_hash"));
397 }
398
399 #[test]
400 fn test_negative_cache_expiration() {
401 let cache = AuthCache::new(Duration::from_secs(60), Duration::from_millis(10));
402
403 cache.insert_negative("bad_hash");
404 assert!(cache.is_negative("bad_hash"));
405
406 std::thread::sleep(Duration::from_millis(20));
407 assert!(!cache.is_negative("bad_hash"));
408 }
409
410 #[test]
411 fn test_negative_cache_disabled_when_zero_ttl() {
412 let cache = AuthCache::new(Duration::from_secs(60), Duration::ZERO);
413
414 cache.insert_negative("bad_hash");
415 assert!(!cache.is_negative("bad_hash"));
416 }
417
418 #[test]
419 fn test_negative_cache_remove() {
420 let cache = make_cache();
421
422 cache.insert_negative("bad_hash");
423 assert!(cache.is_negative("bad_hash"));
424
425 cache.remove_negative("bad_hash");
426 assert!(!cache.is_negative("bad_hash"));
427 }
428
429 #[test]
430 fn test_negative_cache_in_stats() {
431 let cache = make_cache();
432
433 cache.insert_negative("hash1");
434 cache.insert_negative("hash2");
435
436 let stats = cache.stats();
437 assert_eq!(stats.neg_size, 2);
438 }
439
440 #[test]
441 fn test_cleanup_expired_cleans_both() {
442 let cache = AuthCache::new(Duration::from_millis(10), Duration::from_millis(10));
443
444 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
445 cache.insert_negative("bad_hash");
446
447 std::thread::sleep(Duration::from_millis(20));
448 cache.cleanup_expired();
449
450 let stats = cache.stats();
451 assert_eq!(stats.size, 0);
452 assert_eq!(stats.neg_size, 0);
453 }
454}