1use std::collections::HashMap;
9#[cfg(feature = "tokio-runtime")]
10use std::collections::HashSet;
11use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::RwLock;
15
16#[derive(Debug)]
18pub enum CacheLookup {
19 Fresh(CachedUser),
21 Stale(CachedUser),
23 Miss,
25}
26
27#[derive(Clone, Debug)]
29pub struct CachedUser {
30 pub user_id: Option<String>,
32 pub traffic_limit: i64,
34 pub traffic_used: i64,
36 pub expires_at: i64,
38 pub enabled: bool,
40 pub cached_at: Instant,
42}
43
44#[derive(Debug)]
46struct CacheEntry {
47 user: CachedUser,
48 expires_at: Instant,
49}
50
51#[derive(Debug)]
61pub struct AuthCache {
62 cache: RwLock<HashMap<String, CacheEntry>>,
64 ttl: Duration,
66 stale_ttl: Duration,
72
73 traffic_deltas: RwLock<HashMap<String, AtomicI64>>,
75
76 neg_cache: RwLock<HashMap<String, Instant>>,
78 neg_ttl: Duration,
80
81 hits: AtomicU64,
83 misses: AtomicU64,
85
86 #[cfg(feature = "tokio-runtime")]
88 revalidating: RwLock<HashSet<String>>,
89}
90
91impl AuthCache {
92 pub fn new(ttl: Duration, stale_ttl: Duration, neg_ttl: Duration) -> Self {
98 Self {
99 cache: RwLock::new(HashMap::new()),
100 ttl,
101 stale_ttl,
102 traffic_deltas: RwLock::new(HashMap::new()),
103 neg_cache: RwLock::new(HashMap::new()),
104 neg_ttl,
105 hits: AtomicU64::new(0),
106 misses: AtomicU64::new(0),
107 #[cfg(feature = "tokio-runtime")]
108 revalidating: RwLock::new(HashSet::new()),
109 }
110 }
111
112 pub fn get(&self, hash: &str) -> Option<CachedUser> {
120 let cache = self.cache.read();
121 if let Some(entry) = cache.get(hash)
122 && Instant::now() < entry.expires_at
123 {
124 self.hits.fetch_add(1, Ordering::Relaxed);
125 return Some(entry.user.clone());
126 }
127 drop(cache);
128
129 self.misses.fetch_add(1, Ordering::Relaxed);
130 None
131 }
132
133 pub fn lookup(&self, hash: &str) -> CacheLookup {
141 let cache = self.cache.read();
142 if let Some(entry) = cache.get(hash) {
143 let now = Instant::now();
144 if now < entry.expires_at {
145 self.hits.fetch_add(1, Ordering::Relaxed);
146 return CacheLookup::Fresh(entry.user.clone());
147 }
148 if self.stale_ttl > Duration::ZERO && now < entry.expires_at + self.stale_ttl {
150 self.hits.fetch_add(1, Ordering::Relaxed);
151 return CacheLookup::Stale(entry.user.clone());
152 }
153 }
154 drop(cache);
155
156 self.misses.fetch_add(1, Ordering::Relaxed);
157 CacheLookup::Miss
158 }
159
160 pub fn insert(&self, hash: String, user: CachedUser) {
162 let entry = CacheEntry {
163 user,
164 expires_at: Instant::now() + self.ttl,
165 };
166 self.cache.write().insert(hash, entry);
167 }
168
169 pub fn remove(&self, hash: &str) {
171 self.cache.write().remove(hash);
172 }
173
174 pub fn invalidate_user(&self, user_id: &str) {
179 self.cache
180 .write()
181 .retain(|_, entry| entry.user.user_id.as_deref() != Some(user_id));
182 self.traffic_deltas.write().remove(user_id);
183 }
184
185 pub fn clear(&self) {
187 self.cache.write().clear();
188 self.traffic_deltas.write().clear();
189 self.neg_cache.write().clear();
190 #[cfg(feature = "tokio-runtime")]
191 self.revalidating.write().clear();
192 }
193
194 pub fn cleanup_expired(&self) {
199 let now = Instant::now();
200 let stale = self.stale_ttl;
201 self.cache
202 .write()
203 .retain(|_, entry| entry.expires_at + stale > now);
204 self.neg_cache.write().retain(|_, &mut exp| exp > now);
205 }
206
207 #[allow(clippy::cast_possible_wrap)]
214 pub fn add_traffic_delta(&self, user_id: &str, bytes: u64) {
215 let deltas = self.traffic_deltas.read();
217 if let Some(delta) = deltas.get(user_id) {
218 delta.fetch_add(bytes as i64, Ordering::Relaxed);
219 return;
220 }
221 drop(deltas);
222 self.traffic_deltas
224 .write()
225 .entry(user_id.to_string())
226 .or_insert_with(|| AtomicI64::new(0))
227 .fetch_add(bytes as i64, Ordering::Relaxed);
228 }
229
230 pub fn get_traffic_delta(&self, user_id: &str) -> i64 {
235 self.traffic_deltas
236 .read()
237 .get(user_id)
238 .map(|d| d.load(Ordering::Relaxed))
239 .unwrap_or(0)
240 }
241
242 pub fn clear_traffic_delta(&self, user_id: &str) {
247 self.traffic_deltas.write().remove(user_id);
248 }
249
250 pub fn insert_negative(&self, hash: &str) {
257 if self.neg_ttl > Duration::ZERO {
258 self.neg_cache
259 .write()
260 .insert(hash.to_string(), Instant::now() + self.neg_ttl);
261 }
262 }
263
264 pub fn is_negative(&self, hash: &str) -> bool {
269 if self.neg_ttl == Duration::ZERO {
270 return false;
271 }
272 let cache = self.neg_cache.read();
273 if let Some(&exp) = cache.get(hash)
274 && Instant::now() < exp
275 {
276 return true;
277 }
278 false
279 }
280
281 pub fn remove_negative(&self, hash: &str) {
286 self.neg_cache.write().remove(hash);
287 }
288
289 #[cfg(feature = "tokio-runtime")]
293 pub(crate) fn start_revalidation(&self, hash: &str) -> bool {
294 self.revalidating.write().insert(hash.to_string())
295 }
296
297 #[cfg(feature = "tokio-runtime")]
299 pub(crate) fn finish_revalidation(&self, hash: &str) {
300 self.revalidating.write().remove(hash);
301 }
302
303 pub fn stats(&self) -> CacheStats {
307 let cache = self.cache.read();
308 CacheStats {
309 size: cache.len(),
310 neg_size: self.neg_cache.read().len(),
311 hits: self.hits.load(Ordering::Relaxed),
312 misses: self.misses.load(Ordering::Relaxed),
313 ttl: self.ttl,
314 }
315 }
316}
317
318#[derive(Debug, Clone)]
320pub struct CacheStats {
321 pub size: usize,
323 pub neg_size: usize,
325 pub hits: u64,
327 pub misses: u64,
329 pub ttl: Duration,
331}
332
333impl CacheStats {
334 pub fn hit_rate(&self) -> f64 {
336 let total = self.hits + self.misses;
337 if total == 0 {
338 0.0
339 } else {
340 self.hits as f64 / total as f64
341 }
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 fn make_cache() -> AuthCache {
350 AuthCache::new(
351 Duration::from_secs(60),
352 Duration::ZERO,
353 Duration::from_secs(5),
354 )
355 }
356
357 fn make_user(user_id: &str, traffic_limit: i64, traffic_used: i64) -> CachedUser {
358 CachedUser {
359 user_id: Some(user_id.to_string()),
360 traffic_limit,
361 traffic_used,
362 expires_at: 0,
363 enabled: true,
364 cached_at: Instant::now(),
365 }
366 }
367
368 #[test]
369 fn test_cache_basic() {
370 let cache = make_cache();
371 let user = make_user("user1", 1000, 100);
372
373 cache.insert("hash1".to_string(), user);
374 let cached = cache.get("hash1").unwrap();
375 assert_eq!(cached.user_id, Some("user1".to_string()));
376 assert_eq!(cached.traffic_limit, 1000);
377
378 assert!(cache.get("hash2").is_none());
379 }
380
381 #[test]
382 fn test_cache_expiration() {
383 let cache = AuthCache::new(Duration::from_millis(10), Duration::ZERO, Duration::ZERO);
384 let user = make_user("user1", 0, 0);
385
386 cache.insert("hash1".to_string(), user);
387 assert!(cache.get("hash1").is_some());
388
389 std::thread::sleep(Duration::from_millis(20));
390 assert!(cache.get("hash1").is_none());
391 }
392
393 #[test]
394 fn test_cache_invalidate_user() {
395 let cache = make_cache();
396
397 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
398 cache.insert("hash2".to_string(), make_user("user2", 0, 0));
399
400 cache.add_traffic_delta("user1", 500);
402
403 cache.invalidate_user("user1");
404
405 assert!(cache.get("hash1").is_none());
406 assert!(cache.get("hash2").is_some());
407 assert_eq!(cache.get_traffic_delta("user1"), 0);
409 }
410
411 #[test]
412 fn test_cache_stats() {
413 let cache = make_cache();
414 let user = CachedUser {
415 user_id: None,
416 traffic_limit: 0,
417 traffic_used: 0,
418 expires_at: 0,
419 enabled: true,
420 cached_at: Instant::now(),
421 };
422
423 cache.insert("hash1".to_string(), user);
424
425 cache.get("hash1"); cache.get("hash1"); cache.get("hash2"); let stats = cache.stats();
430 assert_eq!(stats.size, 1);
431 assert_eq!(stats.hits, 2);
432 assert_eq!(stats.misses, 1);
433 assert!((stats.hit_rate() - 0.666).abs() < 0.01);
434 }
435
436 #[test]
439 fn test_traffic_delta_accumulates() {
440 let cache = make_cache();
441
442 cache.add_traffic_delta("user1", 100);
443 cache.add_traffic_delta("user1", 200);
444 cache.add_traffic_delta("user1", 300);
445
446 assert_eq!(cache.get_traffic_delta("user1"), 600);
447 assert_eq!(cache.get_traffic_delta("user2"), 0); }
449
450 #[test]
451 fn test_traffic_delta_clear() {
452 let cache = make_cache();
453
454 cache.add_traffic_delta("user1", 500);
455 assert_eq!(cache.get_traffic_delta("user1"), 500);
456
457 cache.clear_traffic_delta("user1");
458 assert_eq!(cache.get_traffic_delta("user1"), 0);
459 }
460
461 #[test]
462 fn test_clear_resets_everything() {
463 let cache = make_cache();
464
465 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
466 cache.add_traffic_delta("user1", 100);
467 cache.insert_negative("bad_hash");
468
469 cache.clear();
470
471 assert!(cache.get("hash1").is_none());
472 assert_eq!(cache.get_traffic_delta("user1"), 0);
473 assert!(!cache.is_negative("bad_hash"));
474 }
475
476 #[test]
479 fn test_negative_cache_basic() {
480 let cache = make_cache();
481
482 assert!(!cache.is_negative("bad_hash"));
483
484 cache.insert_negative("bad_hash");
485 assert!(cache.is_negative("bad_hash"));
486 assert!(!cache.is_negative("other_hash"));
487 }
488
489 #[test]
490 fn test_negative_cache_expiration() {
491 let cache = AuthCache::new(
492 Duration::from_secs(60),
493 Duration::ZERO,
494 Duration::from_millis(10),
495 );
496
497 cache.insert_negative("bad_hash");
498 assert!(cache.is_negative("bad_hash"));
499
500 std::thread::sleep(Duration::from_millis(20));
501 assert!(!cache.is_negative("bad_hash"));
502 }
503
504 #[test]
505 fn test_negative_cache_disabled_when_zero_ttl() {
506 let cache = AuthCache::new(Duration::from_secs(60), Duration::ZERO, Duration::ZERO);
507
508 cache.insert_negative("bad_hash");
509 assert!(!cache.is_negative("bad_hash"));
510 }
511
512 #[test]
513 fn test_negative_cache_remove() {
514 let cache = make_cache();
515
516 cache.insert_negative("bad_hash");
517 assert!(cache.is_negative("bad_hash"));
518
519 cache.remove_negative("bad_hash");
520 assert!(!cache.is_negative("bad_hash"));
521 }
522
523 #[test]
524 fn test_negative_cache_in_stats() {
525 let cache = make_cache();
526
527 cache.insert_negative("hash1");
528 cache.insert_negative("hash2");
529
530 let stats = cache.stats();
531 assert_eq!(stats.neg_size, 2);
532 }
533
534 #[test]
535 fn test_cleanup_expired_cleans_both() {
536 let cache = AuthCache::new(
537 Duration::from_millis(10),
538 Duration::ZERO,
539 Duration::from_millis(10),
540 );
541
542 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
543 cache.insert_negative("bad_hash");
544
545 std::thread::sleep(Duration::from_millis(20));
546 cache.cleanup_expired();
547
548 let stats = cache.stats();
549 assert_eq!(stats.size, 0);
550 assert_eq!(stats.neg_size, 0);
551 }
552
553 #[test]
556 fn test_cache_stale_lookup() {
557 let cache = AuthCache::new(
560 Duration::from_millis(50), Duration::from_millis(500), Duration::ZERO, );
564 let user = make_user("user1", 1000, 100);
565 cache.insert("hash1".to_string(), user);
566
567 assert!(matches!(cache.lookup("hash1"), CacheLookup::Fresh(_)));
569
570 std::thread::sleep(Duration::from_millis(150));
572
573 assert!(matches!(cache.lookup("hash1"), CacheLookup::Stale(_)));
575
576 assert!(cache.get("hash1").is_none());
578
579 std::thread::sleep(Duration::from_millis(500));
581
582 assert!(matches!(cache.lookup("hash1"), CacheLookup::Miss));
584 }
585
586 #[test]
587 fn test_cache_stale_disabled_when_zero() {
588 let cache = AuthCache::new(Duration::from_millis(50), Duration::ZERO, Duration::ZERO);
590 let user = make_user("user1", 0, 0);
591 cache.insert("hash1".to_string(), user);
592
593 std::thread::sleep(Duration::from_millis(150));
594 assert!(matches!(cache.lookup("hash1"), CacheLookup::Miss));
595 }
596
597 #[test]
598 fn test_cleanup_respects_stale_window() {
599 let cache = AuthCache::new(
600 Duration::from_millis(50), Duration::from_millis(500), Duration::ZERO,
603 );
604 cache.insert("hash1".to_string(), make_user("user1", 0, 0));
605
606 std::thread::sleep(Duration::from_millis(150));
608 cache.cleanup_expired();
609 assert_eq!(cache.stats().size, 1);
610
611 std::thread::sleep(Duration::from_millis(500));
613 cache.cleanup_expired();
614 assert_eq!(cache.stats().size, 0);
615 }
616
617 #[cfg(feature = "tokio-runtime")]
618 #[test]
619 fn test_revalidation_marker_deduplicates() {
620 let cache = make_cache();
621 assert!(cache.start_revalidation("hash1"));
622 assert!(!cache.start_revalidation("hash1"));
623 cache.finish_revalidation("hash1");
624 assert!(cache.start_revalidation("hash1"));
625 }
626}