Skip to main content

trojan_auth/store/
cache.rs

1//! Authentication cache with TTL support.
2//!
3//! Caches successful authentication results to reduce database queries.
4//! Also provides:
5//! - **Traffic deltas**: in-memory tracking so cache hits reflect accumulated traffic
6//! - **Negative caching**: short-lived entries for invalid hashes to prevent DB flooding
7
8use std::collections::HashMap;
9#[cfg(feature = "tokio-runtime")]
10use std::collections::HashSet;
11use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::RwLock;
15
16/// Result of a cache lookup with stale-while-revalidate support.
17#[derive(Debug)]
18pub enum CacheLookup {
19    /// Entry is within TTL — use directly.
20    Fresh(CachedUser),
21    /// Entry past TTL but within stale window — use but revalidate in background.
22    Stale(CachedUser),
23    /// No entry or fully expired.
24    Miss,
25}
26
27/// Cached user data.
28#[derive(Clone, Debug)]
29pub struct CachedUser {
30    /// User ID (optional identifier).
31    pub user_id: Option<String>,
32    /// Traffic limit in bytes (0 = unlimited).
33    pub traffic_limit: i64,
34    /// Traffic used in bytes.
35    pub traffic_used: i64,
36    /// Expiration timestamp (0 = never).
37    pub expires_at: i64,
38    /// Whether the user is enabled.
39    pub enabled: bool,
40    /// When this cache entry was created.
41    pub cached_at: Instant,
42}
43
44/// Cache entry with expiration.
45#[derive(Debug)]
46struct CacheEntry {
47    user: CachedUser,
48    expires_at: Instant,
49}
50
51/// Authentication cache with configurable TTL.
52///
53/// Beyond basic positive caching, this provides:
54/// - **Traffic deltas** (`user_id → i64`): accumulated traffic bytes since
55///   the last DB fetch, applied on cache hit so traffic-limit checks are
56///   accurate within a single cache window.
57/// - **Negative cache** (`hash → expiry`): short-lived entries for hashes
58///   that returned no DB row, preventing repeated SELECT storms from
59///   invalid/attack traffic.
60#[derive(Debug)]
61pub struct AuthCache {
62    /// Positive cache: hash → user data.
63    cache: RwLock<HashMap<String, CacheEntry>>,
64    /// TTL for positive cache entries.
65    ttl: Duration,
66    /// Stale-while-revalidate window beyond TTL.
67    ///
68    /// When an entry is past `ttl` but within `ttl + stale_ttl`, it is
69    /// considered stale: still usable, but should be revalidated in the
70    /// background. `Duration::ZERO` disables SWR.
71    stale_ttl: Duration,
72
73    /// Accumulated traffic bytes since last DB fetch, keyed by user_id.
74    traffic_deltas: RwLock<HashMap<String, AtomicI64>>,
75
76    /// Negative cache: hash → expiry instant.
77    neg_cache: RwLock<HashMap<String, Instant>>,
78    /// TTL for negative cache entries (Duration::ZERO = disabled).
79    neg_ttl: Duration,
80
81    /// Cache hit counter.
82    hits: AtomicU64,
83    /// Cache miss counter.
84    misses: AtomicU64,
85
86    /// In-flight stale revalidations (hashes currently being refreshed).
87    #[cfg(feature = "tokio-runtime")]
88    revalidating: RwLock<HashSet<String>>,
89}
90
91impl AuthCache {
92    /// Create a new auth cache.
93    ///
94    /// - `ttl` — positive cache entry lifetime
95    /// - `stale_ttl` — stale-while-revalidate window (`Duration::ZERO` to disable)
96    /// - `neg_ttl` — negative cache entry lifetime (`Duration::ZERO` to disable)
97    pub fn new(ttl: Duration, stale_ttl: Duration, neg_ttl: Duration) -> Self {
98        Self {
99            cache: RwLock::new(HashMap::new()),
100            ttl,
101            stale_ttl,
102            traffic_deltas: RwLock::new(HashMap::new()),
103            neg_cache: RwLock::new(HashMap::new()),
104            neg_ttl,
105            hits: AtomicU64::new(0),
106            misses: AtomicU64::new(0),
107            #[cfg(feature = "tokio-runtime")]
108            revalidating: RwLock::new(HashSet::new()),
109        }
110    }
111
112    // ── Positive cache ──────────────────────────────────────────
113
114    /// Get a cached user by password hash.
115    ///
116    /// Returns `Some(CachedUser)` if found and **fresh** (within TTL),
117    /// `None` otherwise. Stale entries are not returned — use
118    /// [`lookup`](Self::lookup) for stale-while-revalidate semantics.
119    pub fn get(&self, hash: &str) -> Option<CachedUser> {
120        let cache = self.cache.read();
121        if let Some(entry) = cache.get(hash)
122            && Instant::now() < entry.expires_at
123        {
124            self.hits.fetch_add(1, Ordering::Relaxed);
125            return Some(entry.user.clone());
126        }
127        drop(cache);
128
129        self.misses.fetch_add(1, Ordering::Relaxed);
130        None
131    }
132
133    /// Look up a cached user with stale-while-revalidate support.
134    ///
135    /// Returns:
136    /// - [`CacheLookup::Fresh`] — entry is within TTL, use directly
137    /// - [`CacheLookup::Stale`] — entry past TTL but within stale window,
138    ///   use but revalidate in background
139    /// - [`CacheLookup::Miss`] — no entry or fully expired
140    pub fn lookup(&self, hash: &str) -> CacheLookup {
141        let cache = self.cache.read();
142        if let Some(entry) = cache.get(hash) {
143            let now = Instant::now();
144            if now < entry.expires_at {
145                self.hits.fetch_add(1, Ordering::Relaxed);
146                return CacheLookup::Fresh(entry.user.clone());
147            }
148            // Past TTL — check stale window
149            if self.stale_ttl > Duration::ZERO && now < entry.expires_at + self.stale_ttl {
150                self.hits.fetch_add(1, Ordering::Relaxed);
151                return CacheLookup::Stale(entry.user.clone());
152            }
153        }
154        drop(cache);
155
156        self.misses.fetch_add(1, Ordering::Relaxed);
157        CacheLookup::Miss
158    }
159
160    /// Insert a user into the cache.
161    pub fn insert(&self, hash: String, user: CachedUser) {
162        let entry = CacheEntry {
163            user,
164            expires_at: Instant::now() + self.ttl,
165        };
166        self.cache.write().insert(hash, entry);
167    }
168
169    /// Remove a user from the cache.
170    pub fn remove(&self, hash: &str) {
171        self.cache.write().remove(hash);
172    }
173
174    /// Invalidate a user by user_id.
175    ///
176    /// Removes all positive cache entries with matching user_id
177    /// and clears the traffic delta for that user.
178    pub fn invalidate_user(&self, user_id: &str) {
179        self.cache
180            .write()
181            .retain(|_, entry| entry.user.user_id.as_deref() != Some(user_id));
182        self.traffic_deltas.write().remove(user_id);
183    }
184
185    /// Clear all cache entries (positive, negative, and traffic deltas).
186    pub fn clear(&self) {
187        self.cache.write().clear();
188        self.traffic_deltas.write().clear();
189        self.neg_cache.write().clear();
190        #[cfg(feature = "tokio-runtime")]
191        self.revalidating.write().clear();
192    }
193
194    /// Remove expired entries from positive and negative caches.
195    ///
196    /// Positive cache entries are kept until their stale window also expires
197    /// (i.e. `expires_at + stale_ttl`).
198    pub fn cleanup_expired(&self) {
199        let now = Instant::now();
200        let stale = self.stale_ttl;
201        self.cache
202            .write()
203            .retain(|_, entry| entry.expires_at + stale > now);
204        self.neg_cache.write().retain(|_, &mut exp| exp > now);
205    }
206
207    // ── Traffic deltas ──────────────────────────────────────────
208
209    /// Increment the in-memory traffic delta for a user.
210    ///
211    /// Called by `StoreAuth::record_traffic()` so that subsequent
212    /// cache hits reflect the accumulated traffic.
213    #[allow(clippy::cast_possible_wrap)]
214    pub fn add_traffic_delta(&self, user_id: &str, bytes: u64) {
215        // Fast path: read lock + atomic add (no write lock needed)
216        let deltas = self.traffic_deltas.read();
217        if let Some(delta) = deltas.get(user_id) {
218            delta.fetch_add(bytes as i64, Ordering::Relaxed);
219            return;
220        }
221        drop(deltas);
222        // Slow path: write lock to insert new entry
223        self.traffic_deltas
224            .write()
225            .entry(user_id.to_string())
226            .or_insert_with(|| AtomicI64::new(0))
227            .fetch_add(bytes as i64, Ordering::Relaxed);
228    }
229
230    /// Read the accumulated traffic delta for a user.
231    ///
232    /// Returns 0 if no delta is tracked (user never had traffic recorded
233    /// since the last DB fetch).
234    pub fn get_traffic_delta(&self, user_id: &str) -> i64 {
235        self.traffic_deltas
236            .read()
237            .get(user_id)
238            .map(|d| d.load(Ordering::Relaxed))
239            .unwrap_or(0)
240    }
241
242    /// Clear the traffic delta for a user.
243    ///
244    /// Called when the cache re-fetches from DB, so the delta restarts
245    /// from zero (the DB value is now authoritative).
246    pub fn clear_traffic_delta(&self, user_id: &str) {
247        self.traffic_deltas.write().remove(user_id);
248    }
249
250    // ── Negative cache ──────────────────────────────────────────
251
252    /// Record a hash as "not found" in the negative cache.
253    ///
254    /// Subsequent lookups within `neg_ttl` will return `true` from
255    /// [`is_negative`](Self::is_negative), skipping the DB query.
256    pub fn insert_negative(&self, hash: &str) {
257        if self.neg_ttl > Duration::ZERO {
258            self.neg_cache
259                .write()
260                .insert(hash.to_string(), Instant::now() + self.neg_ttl);
261        }
262    }
263
264    /// Check if a hash is in the negative cache (known invalid).
265    ///
266    /// Returns `true` if the hash was recently looked up and not found.
267    /// Expired entries are lazily removed.
268    pub fn is_negative(&self, hash: &str) -> bool {
269        if self.neg_ttl == Duration::ZERO {
270            return false;
271        }
272        let cache = self.neg_cache.read();
273        if let Some(&exp) = cache.get(hash)
274            && Instant::now() < exp
275        {
276            return true;
277        }
278        false
279    }
280
281    /// Remove a hash from the negative cache.
282    ///
283    /// Called after cache invalidation so that a newly-added user
284    /// is not blocked by a stale negative entry.
285    pub fn remove_negative(&self, hash: &str) {
286        self.neg_cache.write().remove(hash);
287    }
288
289    /// Mark a hash as revalidating; returns `true` if caller should proceed.
290    ///
291    /// Returns `false` when another task is already revalidating this hash.
292    #[cfg(feature = "tokio-runtime")]
293    pub(crate) fn start_revalidation(&self, hash: &str) -> bool {
294        self.revalidating.write().insert(hash.to_string())
295    }
296
297    /// Clear revalidation marker for a hash.
298    #[cfg(feature = "tokio-runtime")]
299    pub(crate) fn finish_revalidation(&self, hash: &str) {
300        self.revalidating.write().remove(hash);
301    }
302
303    // ── Statistics ──────────────────────────────────────────────
304
305    /// Get cache statistics.
306    pub fn stats(&self) -> CacheStats {
307        let cache = self.cache.read();
308        CacheStats {
309            size: cache.len(),
310            neg_size: self.neg_cache.read().len(),
311            hits: self.hits.load(Ordering::Relaxed),
312            misses: self.misses.load(Ordering::Relaxed),
313            ttl: self.ttl,
314        }
315    }
316}
317
318/// Cache statistics.
319#[derive(Debug, Clone)]
320pub struct CacheStats {
321    /// Number of positive cache entries.
322    pub size: usize,
323    /// Number of negative cache entries.
324    pub neg_size: usize,
325    /// Number of cache hits.
326    pub hits: u64,
327    /// Number of cache misses.
328    pub misses: u64,
329    /// Cache TTL.
330    pub ttl: Duration,
331}
332
333impl CacheStats {
334    /// Calculate hit rate (0.0 to 1.0).
335    pub fn hit_rate(&self) -> f64 {
336        let total = self.hits + self.misses;
337        if total == 0 {
338            0.0
339        } else {
340            self.hits as f64 / total as f64
341        }
342    }
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    fn make_cache() -> AuthCache {
350        AuthCache::new(
351            Duration::from_secs(60),
352            Duration::ZERO,
353            Duration::from_secs(5),
354        )
355    }
356
357    fn make_user(user_id: &str, traffic_limit: i64, traffic_used: i64) -> CachedUser {
358        CachedUser {
359            user_id: Some(user_id.to_string()),
360            traffic_limit,
361            traffic_used,
362            expires_at: 0,
363            enabled: true,
364            cached_at: Instant::now(),
365        }
366    }
367
368    #[test]
369    fn test_cache_basic() {
370        let cache = make_cache();
371        let user = make_user("user1", 1000, 100);
372
373        cache.insert("hash1".to_string(), user);
374        let cached = cache.get("hash1").unwrap();
375        assert_eq!(cached.user_id, Some("user1".to_string()));
376        assert_eq!(cached.traffic_limit, 1000);
377
378        assert!(cache.get("hash2").is_none());
379    }
380
381    #[test]
382    fn test_cache_expiration() {
383        let cache = AuthCache::new(Duration::from_millis(10), Duration::ZERO, Duration::ZERO);
384        let user = make_user("user1", 0, 0);
385
386        cache.insert("hash1".to_string(), user);
387        assert!(cache.get("hash1").is_some());
388
389        std::thread::sleep(Duration::from_millis(20));
390        assert!(cache.get("hash1").is_none());
391    }
392
393    #[test]
394    fn test_cache_invalidate_user() {
395        let cache = make_cache();
396
397        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
398        cache.insert("hash2".to_string(), make_user("user2", 0, 0));
399
400        // Also add a traffic delta for user1
401        cache.add_traffic_delta("user1", 500);
402
403        cache.invalidate_user("user1");
404
405        assert!(cache.get("hash1").is_none());
406        assert!(cache.get("hash2").is_some());
407        // Delta should also be cleared
408        assert_eq!(cache.get_traffic_delta("user1"), 0);
409    }
410
411    #[test]
412    fn test_cache_stats() {
413        let cache = make_cache();
414        let user = CachedUser {
415            user_id: None,
416            traffic_limit: 0,
417            traffic_used: 0,
418            expires_at: 0,
419            enabled: true,
420            cached_at: Instant::now(),
421        };
422
423        cache.insert("hash1".to_string(), user);
424
425        cache.get("hash1"); // hit
426        cache.get("hash1"); // hit
427        cache.get("hash2"); // miss
428
429        let stats = cache.stats();
430        assert_eq!(stats.size, 1);
431        assert_eq!(stats.hits, 2);
432        assert_eq!(stats.misses, 1);
433        assert!((stats.hit_rate() - 0.666).abs() < 0.01);
434    }
435
436    // ── Traffic delta tests ─────────────────────────────────────
437
438    #[test]
439    fn test_traffic_delta_accumulates() {
440        let cache = make_cache();
441
442        cache.add_traffic_delta("user1", 100);
443        cache.add_traffic_delta("user1", 200);
444        cache.add_traffic_delta("user1", 300);
445
446        assert_eq!(cache.get_traffic_delta("user1"), 600);
447        assert_eq!(cache.get_traffic_delta("user2"), 0); // no delta
448    }
449
450    #[test]
451    fn test_traffic_delta_clear() {
452        let cache = make_cache();
453
454        cache.add_traffic_delta("user1", 500);
455        assert_eq!(cache.get_traffic_delta("user1"), 500);
456
457        cache.clear_traffic_delta("user1");
458        assert_eq!(cache.get_traffic_delta("user1"), 0);
459    }
460
461    #[test]
462    fn test_clear_resets_everything() {
463        let cache = make_cache();
464
465        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
466        cache.add_traffic_delta("user1", 100);
467        cache.insert_negative("bad_hash");
468
469        cache.clear();
470
471        assert!(cache.get("hash1").is_none());
472        assert_eq!(cache.get_traffic_delta("user1"), 0);
473        assert!(!cache.is_negative("bad_hash"));
474    }
475
476    // ── Negative cache tests ────────────────────────────────────
477
478    #[test]
479    fn test_negative_cache_basic() {
480        let cache = make_cache();
481
482        assert!(!cache.is_negative("bad_hash"));
483
484        cache.insert_negative("bad_hash");
485        assert!(cache.is_negative("bad_hash"));
486        assert!(!cache.is_negative("other_hash"));
487    }
488
489    #[test]
490    fn test_negative_cache_expiration() {
491        let cache = AuthCache::new(
492            Duration::from_secs(60),
493            Duration::ZERO,
494            Duration::from_millis(10),
495        );
496
497        cache.insert_negative("bad_hash");
498        assert!(cache.is_negative("bad_hash"));
499
500        std::thread::sleep(Duration::from_millis(20));
501        assert!(!cache.is_negative("bad_hash"));
502    }
503
504    #[test]
505    fn test_negative_cache_disabled_when_zero_ttl() {
506        let cache = AuthCache::new(Duration::from_secs(60), Duration::ZERO, Duration::ZERO);
507
508        cache.insert_negative("bad_hash");
509        assert!(!cache.is_negative("bad_hash"));
510    }
511
512    #[test]
513    fn test_negative_cache_remove() {
514        let cache = make_cache();
515
516        cache.insert_negative("bad_hash");
517        assert!(cache.is_negative("bad_hash"));
518
519        cache.remove_negative("bad_hash");
520        assert!(!cache.is_negative("bad_hash"));
521    }
522
523    #[test]
524    fn test_negative_cache_in_stats() {
525        let cache = make_cache();
526
527        cache.insert_negative("hash1");
528        cache.insert_negative("hash2");
529
530        let stats = cache.stats();
531        assert_eq!(stats.neg_size, 2);
532    }
533
534    #[test]
535    fn test_cleanup_expired_cleans_both() {
536        let cache = AuthCache::new(
537            Duration::from_millis(10),
538            Duration::ZERO,
539            Duration::from_millis(10),
540        );
541
542        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
543        cache.insert_negative("bad_hash");
544
545        std::thread::sleep(Duration::from_millis(20));
546        cache.cleanup_expired();
547
548        let stats = cache.stats();
549        assert_eq!(stats.size, 0);
550        assert_eq!(stats.neg_size, 0);
551    }
552
553    // ── Stale-while-revalidate tests ────────────────────────────
554
555    #[test]
556    fn test_cache_stale_lookup() {
557        // Use wide margins to avoid flakiness under heavy CPU load,
558        // where thread::sleep can overshoot significantly.
559        let cache = AuthCache::new(
560            Duration::from_millis(50),  // TTL
561            Duration::from_millis(500), // stale window
562            Duration::ZERO,             // neg TTL
563        );
564        let user = make_user("user1", 1000, 100);
565        cache.insert("hash1".to_string(), user);
566
567        // Should be Fresh
568        assert!(matches!(cache.lookup("hash1"), CacheLookup::Fresh(_)));
569
570        // Wait past TTL (100ms margin)
571        std::thread::sleep(Duration::from_millis(150));
572
573        // Should be Stale (past TTL but within stale window)
574        assert!(matches!(cache.lookup("hash1"), CacheLookup::Stale(_)));
575
576        // get() should return None for stale entries
577        assert!(cache.get("hash1").is_none());
578
579        // Wait past stale window
580        std::thread::sleep(Duration::from_millis(500));
581
582        // Should be Miss
583        assert!(matches!(cache.lookup("hash1"), CacheLookup::Miss));
584    }
585
586    #[test]
587    fn test_cache_stale_disabled_when_zero() {
588        // When stale_ttl is ZERO, stale lookup should be Miss
589        let cache = AuthCache::new(Duration::from_millis(50), Duration::ZERO, Duration::ZERO);
590        let user = make_user("user1", 0, 0);
591        cache.insert("hash1".to_string(), user);
592
593        std::thread::sleep(Duration::from_millis(150));
594        assert!(matches!(cache.lookup("hash1"), CacheLookup::Miss));
595    }
596
597    #[test]
598    fn test_cleanup_respects_stale_window() {
599        let cache = AuthCache::new(
600            Duration::from_millis(50),  // TTL
601            Duration::from_millis(500), // stale window
602            Duration::ZERO,
603        );
604        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
605
606        // Past TTL but within stale window — should NOT be cleaned up
607        std::thread::sleep(Duration::from_millis(150));
608        cache.cleanup_expired();
609        assert_eq!(cache.stats().size, 1);
610
611        // Past stale window — should be cleaned up
612        std::thread::sleep(Duration::from_millis(500));
613        cache.cleanup_expired();
614        assert_eq!(cache.stats().size, 0);
615    }
616
617    #[cfg(feature = "tokio-runtime")]
618    #[test]
619    fn test_revalidation_marker_deduplicates() {
620        let cache = make_cache();
621        assert!(cache.start_revalidation("hash1"));
622        assert!(!cache.start_revalidation("hash1"));
623        cache.finish_revalidation("hash1");
624        assert!(cache.start_revalidation("hash1"));
625    }
626}