1use std::collections::HashMap;
9#[cfg(feature = "tokio-runtime")]
10use std::collections::HashSet;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::RwLock;
15
16#[derive(Debug)]
18pub enum CacheLookup {
19 Fresh(CachedUser),
21 Stale(CachedUser),
23 Miss,
25}
26
27#[derive(Clone, Debug)]
29pub struct CachedUser {
30 pub user_id: Option<String>,
32 pub traffic_limit: i64,
34 pub traffic_used: i64,
36 pub expires_at: i64,
38 pub enabled: bool,
40 pub cached_at: Instant,
42}
43
44#[derive(Debug)]
46struct CacheEntry {
47 user: CachedUser,
48 expires_at: Instant,
49}
50
51#[derive(Debug)]
61pub struct AuthCache {
62 cache: RwLock<HashMap<String, CacheEntry>>,
64 ttl: Duration,
66 stale_ttl: Duration,
72
73 traffic_deltas: RwLock<HashMap<String, i64>>,
75
76 neg_cache: RwLock<HashMap<String, Instant>>,
78 neg_ttl: Duration,
80
81 hits: AtomicU64,
83 misses: AtomicU64,
85
86 #[cfg(feature = "tokio-runtime")]
88 revalidating: RwLock<HashSet<String>>,
89}
90
91impl AuthCache {
92 pub fn new(ttl: Duration, stale_ttl: Duration, neg_ttl: Duration) -> Self {
98 Self {
99 cache: RwLock::new(HashMap::new()),
100 ttl,
101 stale_ttl,
102 traffic_deltas: RwLock::new(HashMap::new()),
103 neg_cache: RwLock::new(HashMap::new()),
104 neg_ttl,
105 hits: AtomicU64::new(0),
106 misses: AtomicU64::new(0),
107 #[cfg(feature = "tokio-runtime")]
108 revalidating: RwLock::new(HashSet::new()),
109 }
110 }
111
112 pub fn get(&self, hash: &str) -> Option<CachedUser> {
120 let cache = self.cache.read();
121 if let Some(entry) = cache.get(hash)
122 && Instant::now() < entry.expires_at
123 {
124 self.hits.fetch_add(1, Ordering::Relaxed);
125 return Some(entry.user.clone());
126 }
127 drop(cache);
128
129 self.misses.fetch_add(1, Ordering::Relaxed);
130 None
131 }
132
133 pub fn lookup(&self, hash: &str) -> CacheLookup {
141 let cache = self.cache.read();
142 if let Some(entry) = cache.get(hash) {
143 let now = Instant::now();
144 if now < entry.expires_at {
145 self.hits.fetch_add(1, Ordering::Relaxed);
146 return CacheLookup::Fresh(entry.user.clone());
147 }
148 if self.stale_ttl > Duration::ZERO && now < entry.expires_at + self.stale_ttl {
150 self.hits.fetch_add(1, Ordering::Relaxed);
151 return CacheLookup::Stale(entry.user.clone());
152 }
153 }
154 drop(cache);
155
156 self.misses.fetch_add(1, Ordering::Relaxed);
157 CacheLookup::Miss
158 }
159
160 pub fn insert(&self, hash: String, user: CachedUser) {
162 let entry = CacheEntry {
163 user,
164 expires_at: Instant::now() + self.ttl,
165 };
166 self.cache.write().insert(hash, entry);
167 }
168
169 pub fn remove(&self, hash: &str) {
171 self.cache.write().remove(hash);
172 }
173
174 pub fn invalidate_user(&self, user_id: &str) {
179 self.cache
180 .write()
181 .retain(|_, entry| entry.user.user_id.as_deref() != Some(user_id));
182 self.traffic_deltas.write().remove(user_id);
183 }
184
185 pub fn clear(&self) {
187 self.cache.write().clear();
188 self.traffic_deltas.write().clear();
189 self.neg_cache.write().clear();
190 #[cfg(feature = "tokio-runtime")]
191 self.revalidating.write().clear();
192 }
193
194 pub fn cleanup_expired(&self) {
199 let now = Instant::now();
200 let stale = self.stale_ttl;
201 self.cache
202 .write()
203 .retain(|_, entry| entry.expires_at + stale > now);
204 self.neg_cache.write().retain(|_, &mut exp| exp > now);
205 }
206
207 #[allow(clippy::cast_possible_wrap)]
214 pub fn add_traffic_delta(&self, user_id: &str, bytes: u64) {
215 *self
216 .traffic_deltas
217 .write()
218 .entry(user_id.to_string())
219 .or_insert(0) += bytes as i64;
220 }
221
222 pub fn get_traffic_delta(&self, user_id: &str) -> i64 {
227 self.traffic_deltas
228 .read()
229 .get(user_id)
230 .copied()
231 .unwrap_or(0)
232 }
233
234 pub fn clear_traffic_delta(&self, user_id: &str) {
239 self.traffic_deltas.write().remove(user_id);
240 }
241
242 pub fn insert_negative(&self, hash: &str) {
249 if self.neg_ttl > Duration::ZERO {
250 self.neg_cache
251 .write()
252 .insert(hash.to_string(), Instant::now() + self.neg_ttl);
253 }
254 }
255
256 pub fn is_negative(&self, hash: &str) -> bool {
261 if self.neg_ttl == Duration::ZERO {
262 return false;
263 }
264 let cache = self.neg_cache.read();
265 if let Some(&exp) = cache.get(hash)
266 && Instant::now() < exp
267 {
268 return true;
269 }
270 false
271 }
272
273 pub fn remove_negative(&self, hash: &str) {
278 self.neg_cache.write().remove(hash);
279 }
280
281 #[cfg(feature = "tokio-runtime")]
285 pub(crate) fn start_revalidation(&self, hash: &str) -> bool {
286 self.revalidating.write().insert(hash.to_string())
287 }
288
289 #[cfg(feature = "tokio-runtime")]
291 pub(crate) fn finish_revalidation(&self, hash: &str) {
292 self.revalidating.write().remove(hash);
293 }
294
295 pub fn stats(&self) -> CacheStats {
299 let cache = self.cache.read();
300 CacheStats {
301 size: cache.len(),
302 neg_size: self.neg_cache.read().len(),
303 hits: self.hits.load(Ordering::Relaxed),
304 misses: self.misses.load(Ordering::Relaxed),
305 ttl: self.ttl,
306 }
307 }
308}
309
310#[derive(Debug, Clone)]
312pub struct CacheStats {
313 pub size: usize,
315 pub neg_size: usize,
317 pub hits: u64,
319 pub misses: u64,
321 pub ttl: Duration,
323}
324
325impl CacheStats {
326 pub fn hit_rate(&self) -> f64 {
328 let total = self.hits + self.misses;
329 if total == 0 {
330 0.0
331 } else {
332 self.hits as f64 / total as f64
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 fn make_cache() -> AuthCache {
342 AuthCache::new(
343 Duration::from_secs(60),
344 Duration::ZERO,
345 Duration::from_secs(5),
346 )
347 }
348
349 fn make_user(user_id: &str, traffic_limit: i64, traffic_used: i64) -> CachedUser {
350 CachedUser {
351 user_id: Some(user_id.to_string()),
352 traffic_limit,
353 traffic_used,
354 expires_at: 0,
355 enabled: true,
356 cached_at: Instant::now(),
357 }
358 }
359
360 #[test]
361 fn test_cache_basic() {
362 let cache = make_cache();
363 let user = make_user("user1", 1000, 100);
364
365 cache.insert("hash1".to_string(), user);
366 let cached = cache.get("hash1").unwrap();
367 assert_eq!(cached.user_id, Some("user1".to_string()));
368 assert_eq!(cached.traffic_limit, 1000);
369
370 assert!(cache.get("hash2").is_none());
371 }
372
373 #[test]
374 fn test_cache_expiration() {
375 let cache = AuthCache::new(Duration::from_millis(10), Duration::ZERO, Duration::ZERO);
376 let user = make_user("user1", 0, 0);
377
378 cache.insert("hash1".to_string(), user);
379 assert!(cache.get("hash1").is_some());
380
381 std::thread::sleep(Duration::from_millis(20));
382 assert!(cache.get("hash1").is_none());
383 }
384
385 #[test]
386 fn test_cache_invalidate_user() {
387 let cache = make_cache();
388
389 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
390 cache.insert("hash2".to_string(), make_user("user2", 0, 0));
391
392 cache.add_traffic_delta("user1", 500);
394
395 cache.invalidate_user("user1");
396
397 assert!(cache.get("hash1").is_none());
398 assert!(cache.get("hash2").is_some());
399 assert_eq!(cache.get_traffic_delta("user1"), 0);
401 }
402
403 #[test]
404 fn test_cache_stats() {
405 let cache = make_cache();
406 let user = CachedUser {
407 user_id: None,
408 traffic_limit: 0,
409 traffic_used: 0,
410 expires_at: 0,
411 enabled: true,
412 cached_at: Instant::now(),
413 };
414
415 cache.insert("hash1".to_string(), user);
416
417 cache.get("hash1"); cache.get("hash1"); cache.get("hash2"); let stats = cache.stats();
422 assert_eq!(stats.size, 1);
423 assert_eq!(stats.hits, 2);
424 assert_eq!(stats.misses, 1);
425 assert!((stats.hit_rate() - 0.666).abs() < 0.01);
426 }
427
428 #[test]
431 fn test_traffic_delta_accumulates() {
432 let cache = make_cache();
433
434 cache.add_traffic_delta("user1", 100);
435 cache.add_traffic_delta("user1", 200);
436 cache.add_traffic_delta("user1", 300);
437
438 assert_eq!(cache.get_traffic_delta("user1"), 600);
439 assert_eq!(cache.get_traffic_delta("user2"), 0); }
441
442 #[test]
443 fn test_traffic_delta_clear() {
444 let cache = make_cache();
445
446 cache.add_traffic_delta("user1", 500);
447 assert_eq!(cache.get_traffic_delta("user1"), 500);
448
449 cache.clear_traffic_delta("user1");
450 assert_eq!(cache.get_traffic_delta("user1"), 0);
451 }
452
453 #[test]
454 fn test_clear_resets_everything() {
455 let cache = make_cache();
456
457 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
458 cache.add_traffic_delta("user1", 100);
459 cache.insert_negative("bad_hash");
460
461 cache.clear();
462
463 assert!(cache.get("hash1").is_none());
464 assert_eq!(cache.get_traffic_delta("user1"), 0);
465 assert!(!cache.is_negative("bad_hash"));
466 }
467
468 #[test]
471 fn test_negative_cache_basic() {
472 let cache = make_cache();
473
474 assert!(!cache.is_negative("bad_hash"));
475
476 cache.insert_negative("bad_hash");
477 assert!(cache.is_negative("bad_hash"));
478 assert!(!cache.is_negative("other_hash"));
479 }
480
481 #[test]
482 fn test_negative_cache_expiration() {
483 let cache = AuthCache::new(
484 Duration::from_secs(60),
485 Duration::ZERO,
486 Duration::from_millis(10),
487 );
488
489 cache.insert_negative("bad_hash");
490 assert!(cache.is_negative("bad_hash"));
491
492 std::thread::sleep(Duration::from_millis(20));
493 assert!(!cache.is_negative("bad_hash"));
494 }
495
496 #[test]
497 fn test_negative_cache_disabled_when_zero_ttl() {
498 let cache = AuthCache::new(Duration::from_secs(60), Duration::ZERO, Duration::ZERO);
499
500 cache.insert_negative("bad_hash");
501 assert!(!cache.is_negative("bad_hash"));
502 }
503
504 #[test]
505 fn test_negative_cache_remove() {
506 let cache = make_cache();
507
508 cache.insert_negative("bad_hash");
509 assert!(cache.is_negative("bad_hash"));
510
511 cache.remove_negative("bad_hash");
512 assert!(!cache.is_negative("bad_hash"));
513 }
514
515 #[test]
516 fn test_negative_cache_in_stats() {
517 let cache = make_cache();
518
519 cache.insert_negative("hash1");
520 cache.insert_negative("hash2");
521
522 let stats = cache.stats();
523 assert_eq!(stats.neg_size, 2);
524 }
525
526 #[test]
527 fn test_cleanup_expired_cleans_both() {
528 let cache = AuthCache::new(
529 Duration::from_millis(10),
530 Duration::ZERO,
531 Duration::from_millis(10),
532 );
533
534 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
535 cache.insert_negative("bad_hash");
536
537 std::thread::sleep(Duration::from_millis(20));
538 cache.cleanup_expired();
539
540 let stats = cache.stats();
541 assert_eq!(stats.size, 0);
542 assert_eq!(stats.neg_size, 0);
543 }
544
545 #[test]
548 fn test_cache_stale_lookup() {
549 let cache = AuthCache::new(
552 Duration::from_millis(50), Duration::from_millis(500), Duration::ZERO, );
556 let user = make_user("user1", 1000, 100);
557 cache.insert("hash1".to_string(), user);
558
559 assert!(matches!(cache.lookup("hash1"), CacheLookup::Fresh(_)));
561
562 std::thread::sleep(Duration::from_millis(150));
564
565 assert!(matches!(cache.lookup("hash1"), CacheLookup::Stale(_)));
567
568 assert!(cache.get("hash1").is_none());
570
571 std::thread::sleep(Duration::from_millis(500));
573
574 assert!(matches!(cache.lookup("hash1"), CacheLookup::Miss));
576 }
577
578 #[test]
579 fn test_cache_stale_disabled_when_zero() {
580 let cache = AuthCache::new(Duration::from_millis(50), Duration::ZERO, Duration::ZERO);
582 let user = make_user("user1", 0, 0);
583 cache.insert("hash1".to_string(), user);
584
585 std::thread::sleep(Duration::from_millis(150));
586 assert!(matches!(cache.lookup("hash1"), CacheLookup::Miss));
587 }
588
589 #[test]
590 fn test_cleanup_respects_stale_window() {
591 let cache = AuthCache::new(
592 Duration::from_millis(50), Duration::from_millis(500), Duration::ZERO,
595 );
596 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
597
598 std::thread::sleep(Duration::from_millis(150));
600 cache.cleanup_expired();
601 assert_eq!(cache.stats().size, 1);
602
603 std::thread::sleep(Duration::from_millis(500));
605 cache.cleanup_expired();
606 assert_eq!(cache.stats().size, 0);
607 }
608
609 #[cfg(feature = "tokio-runtime")]
610 #[test]
611 fn test_revalidation_marker_deduplicates() {
612 let cache = make_cache();
613 assert!(cache.start_revalidation("hash1"));
614 assert!(!cache.start_revalidation("hash1"));
615 cache.finish_revalidation("hash1");
616 assert!(cache.start_revalidation("hash1"));
617 }
618}