Skip to main content

trojan_auth/store/
cache.rs

1//! Authentication cache with TTL support.
2//!
3//! Caches successful authentication results to reduce database queries.
4//! Also provides:
5//! - **Traffic deltas**: in-memory tracking so cache hits reflect accumulated traffic
6//! - **Negative caching**: short-lived entries for invalid hashes to prevent DB flooding
7
8use std::collections::HashMap;
9#[cfg(feature = "tokio-runtime")]
10use std::collections::HashSet;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Duration, Instant};
13
14use parking_lot::RwLock;
15
16/// Result of a cache lookup with stale-while-revalidate support.
17#[derive(Debug)]
18pub enum CacheLookup {
19    /// Entry is within TTL — use directly.
20    Fresh(CachedUser),
21    /// Entry past TTL but within stale window — use but revalidate in background.
22    Stale(CachedUser),
23    /// No entry or fully expired.
24    Miss,
25}
26
27/// Cached user data.
28#[derive(Clone, Debug)]
29pub struct CachedUser {
30    /// User ID (optional identifier).
31    pub user_id: Option<String>,
32    /// Traffic limit in bytes (0 = unlimited).
33    pub traffic_limit: i64,
34    /// Traffic used in bytes.
35    pub traffic_used: i64,
36    /// Expiration timestamp (0 = never).
37    pub expires_at: i64,
38    /// Whether the user is enabled.
39    pub enabled: bool,
40    /// When this cache entry was created.
41    pub cached_at: Instant,
42}
43
44/// Cache entry with expiration.
45#[derive(Debug)]
46struct CacheEntry {
47    user: CachedUser,
48    expires_at: Instant,
49}
50
51/// Authentication cache with configurable TTL.
52///
53/// Beyond basic positive caching, this provides:
54/// - **Traffic deltas** (`user_id → i64`): accumulated traffic bytes since
55///   the last DB fetch, applied on cache hit so traffic-limit checks are
56///   accurate within a single cache window.
57/// - **Negative cache** (`hash → expiry`): short-lived entries for hashes
58///   that returned no DB row, preventing repeated SELECT storms from
59///   invalid/attack traffic.
60#[derive(Debug)]
61pub struct AuthCache {
62    /// Positive cache: hash → user data.
63    cache: RwLock<HashMap<String, CacheEntry>>,
64    /// TTL for positive cache entries.
65    ttl: Duration,
66    /// Stale-while-revalidate window beyond TTL.
67    ///
68    /// When an entry is past `ttl` but within `ttl + stale_ttl`, it is
69    /// considered stale: still usable, but should be revalidated in the
70    /// background. `Duration::ZERO` disables SWR.
71    stale_ttl: Duration,
72
73    /// Accumulated traffic bytes since last DB fetch, keyed by user_id.
74    traffic_deltas: RwLock<HashMap<String, i64>>,
75
76    /// Negative cache: hash → expiry instant.
77    neg_cache: RwLock<HashMap<String, Instant>>,
78    /// TTL for negative cache entries (Duration::ZERO = disabled).
79    neg_ttl: Duration,
80
81    /// Cache hit counter.
82    hits: AtomicU64,
83    /// Cache miss counter.
84    misses: AtomicU64,
85
86    /// In-flight stale revalidations (hashes currently being refreshed).
87    #[cfg(feature = "tokio-runtime")]
88    revalidating: RwLock<HashSet<String>>,
89}
90
91impl AuthCache {
92    /// Create a new auth cache.
93    ///
94    /// - `ttl` — positive cache entry lifetime
95    /// - `stale_ttl` — stale-while-revalidate window (`Duration::ZERO` to disable)
96    /// - `neg_ttl` — negative cache entry lifetime (`Duration::ZERO` to disable)
97    pub fn new(ttl: Duration, stale_ttl: Duration, neg_ttl: Duration) -> Self {
98        Self {
99            cache: RwLock::new(HashMap::new()),
100            ttl,
101            stale_ttl,
102            traffic_deltas: RwLock::new(HashMap::new()),
103            neg_cache: RwLock::new(HashMap::new()),
104            neg_ttl,
105            hits: AtomicU64::new(0),
106            misses: AtomicU64::new(0),
107            #[cfg(feature = "tokio-runtime")]
108            revalidating: RwLock::new(HashSet::new()),
109        }
110    }
111
112    // ── Positive cache ──────────────────────────────────────────
113
114    /// Get a cached user by password hash.
115    ///
116    /// Returns `Some(CachedUser)` if found and **fresh** (within TTL),
117    /// `None` otherwise. Stale entries are not returned — use
118    /// [`lookup`](Self::lookup) for stale-while-revalidate semantics.
119    pub fn get(&self, hash: &str) -> Option<CachedUser> {
120        let cache = self.cache.read();
121        if let Some(entry) = cache.get(hash)
122            && Instant::now() < entry.expires_at
123        {
124            self.hits.fetch_add(1, Ordering::Relaxed);
125            return Some(entry.user.clone());
126        }
127        drop(cache);
128
129        self.misses.fetch_add(1, Ordering::Relaxed);
130        None
131    }
132
133    /// Look up a cached user with stale-while-revalidate support.
134    ///
135    /// Returns:
136    /// - [`CacheLookup::Fresh`] — entry is within TTL, use directly
137    /// - [`CacheLookup::Stale`] — entry past TTL but within stale window,
138    ///   use but revalidate in background
139    /// - [`CacheLookup::Miss`] — no entry or fully expired
140    pub fn lookup(&self, hash: &str) -> CacheLookup {
141        let cache = self.cache.read();
142        if let Some(entry) = cache.get(hash) {
143            let now = Instant::now();
144            if now < entry.expires_at {
145                self.hits.fetch_add(1, Ordering::Relaxed);
146                return CacheLookup::Fresh(entry.user.clone());
147            }
148            // Past TTL — check stale window
149            if self.stale_ttl > Duration::ZERO && now < entry.expires_at + self.stale_ttl {
150                self.hits.fetch_add(1, Ordering::Relaxed);
151                return CacheLookup::Stale(entry.user.clone());
152            }
153        }
154        drop(cache);
155
156        self.misses.fetch_add(1, Ordering::Relaxed);
157        CacheLookup::Miss
158    }
159
160    /// Insert a user into the cache.
161    pub fn insert(&self, hash: String, user: CachedUser) {
162        let entry = CacheEntry {
163            user,
164            expires_at: Instant::now() + self.ttl,
165        };
166        self.cache.write().insert(hash, entry);
167    }
168
169    /// Remove a user from the cache.
170    pub fn remove(&self, hash: &str) {
171        self.cache.write().remove(hash);
172    }
173
174    /// Invalidate a user by user_id.
175    ///
176    /// Removes all positive cache entries with matching user_id
177    /// and clears the traffic delta for that user.
178    pub fn invalidate_user(&self, user_id: &str) {
179        self.cache
180            .write()
181            .retain(|_, entry| entry.user.user_id.as_deref() != Some(user_id));
182        self.traffic_deltas.write().remove(user_id);
183    }
184
185    /// Clear all cache entries (positive, negative, and traffic deltas).
186    pub fn clear(&self) {
187        self.cache.write().clear();
188        self.traffic_deltas.write().clear();
189        self.neg_cache.write().clear();
190        #[cfg(feature = "tokio-runtime")]
191        self.revalidating.write().clear();
192    }
193
194    /// Remove expired entries from positive and negative caches.
195    ///
196    /// Positive cache entries are kept until their stale window also expires
197    /// (i.e. `expires_at + stale_ttl`).
198    pub fn cleanup_expired(&self) {
199        let now = Instant::now();
200        let stale = self.stale_ttl;
201        self.cache
202            .write()
203            .retain(|_, entry| entry.expires_at + stale > now);
204        self.neg_cache.write().retain(|_, &mut exp| exp > now);
205    }
206
207    // ── Traffic deltas ──────────────────────────────────────────
208
209    /// Increment the in-memory traffic delta for a user.
210    ///
211    /// Called by `StoreAuth::record_traffic()` so that subsequent
212    /// cache hits reflect the accumulated traffic.
213    #[allow(clippy::cast_possible_wrap)]
214    pub fn add_traffic_delta(&self, user_id: &str, bytes: u64) {
215        *self
216            .traffic_deltas
217            .write()
218            .entry(user_id.to_string())
219            .or_insert(0) += bytes as i64;
220    }
221
222    /// Read the accumulated traffic delta for a user.
223    ///
224    /// Returns 0 if no delta is tracked (user never had traffic recorded
225    /// since the last DB fetch).
226    pub fn get_traffic_delta(&self, user_id: &str) -> i64 {
227        self.traffic_deltas
228            .read()
229            .get(user_id)
230            .copied()
231            .unwrap_or(0)
232    }
233
234    /// Clear the traffic delta for a user.
235    ///
236    /// Called when the cache re-fetches from DB, so the delta restarts
237    /// from zero (the DB value is now authoritative).
238    pub fn clear_traffic_delta(&self, user_id: &str) {
239        self.traffic_deltas.write().remove(user_id);
240    }
241
242    // ── Negative cache ──────────────────────────────────────────
243
244    /// Record a hash as "not found" in the negative cache.
245    ///
246    /// Subsequent lookups within `neg_ttl` will return `true` from
247    /// [`is_negative`](Self::is_negative), skipping the DB query.
248    pub fn insert_negative(&self, hash: &str) {
249        if self.neg_ttl > Duration::ZERO {
250            self.neg_cache
251                .write()
252                .insert(hash.to_string(), Instant::now() + self.neg_ttl);
253        }
254    }
255
256    /// Check if a hash is in the negative cache (known invalid).
257    ///
258    /// Returns `true` if the hash was recently looked up and not found.
259    /// Expired entries are lazily removed.
260    pub fn is_negative(&self, hash: &str) -> bool {
261        if self.neg_ttl == Duration::ZERO {
262            return false;
263        }
264        let cache = self.neg_cache.read();
265        if let Some(&exp) = cache.get(hash)
266            && Instant::now() < exp
267        {
268            return true;
269        }
270        false
271    }
272
273    /// Remove a hash from the negative cache.
274    ///
275    /// Called after cache invalidation so that a newly-added user
276    /// is not blocked by a stale negative entry.
277    pub fn remove_negative(&self, hash: &str) {
278        self.neg_cache.write().remove(hash);
279    }
280
281    /// Mark a hash as revalidating; returns `true` if caller should proceed.
282    ///
283    /// Returns `false` when another task is already revalidating this hash.
284    #[cfg(feature = "tokio-runtime")]
285    pub(crate) fn start_revalidation(&self, hash: &str) -> bool {
286        self.revalidating.write().insert(hash.to_string())
287    }
288
289    /// Clear revalidation marker for a hash.
290    #[cfg(feature = "tokio-runtime")]
291    pub(crate) fn finish_revalidation(&self, hash: &str) {
292        self.revalidating.write().remove(hash);
293    }
294
295    // ── Statistics ──────────────────────────────────────────────
296
297    /// Get cache statistics.
298    pub fn stats(&self) -> CacheStats {
299        let cache = self.cache.read();
300        CacheStats {
301            size: cache.len(),
302            neg_size: self.neg_cache.read().len(),
303            hits: self.hits.load(Ordering::Relaxed),
304            misses: self.misses.load(Ordering::Relaxed),
305            ttl: self.ttl,
306        }
307    }
308}
309
310/// Cache statistics.
311#[derive(Debug, Clone)]
312pub struct CacheStats {
313    /// Number of positive cache entries.
314    pub size: usize,
315    /// Number of negative cache entries.
316    pub neg_size: usize,
317    /// Number of cache hits.
318    pub hits: u64,
319    /// Number of cache misses.
320    pub misses: u64,
321    /// Cache TTL.
322    pub ttl: Duration,
323}
324
325impl CacheStats {
326    /// Calculate hit rate (0.0 to 1.0).
327    pub fn hit_rate(&self) -> f64 {
328        let total = self.hits + self.misses;
329        if total == 0 {
330            0.0
331        } else {
332            self.hits as f64 / total as f64
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    fn make_cache() -> AuthCache {
342        AuthCache::new(
343            Duration::from_secs(60),
344            Duration::ZERO,
345            Duration::from_secs(5),
346        )
347    }
348
349    fn make_user(user_id: &str, traffic_limit: i64, traffic_used: i64) -> CachedUser {
350        CachedUser {
351            user_id: Some(user_id.to_string()),
352            traffic_limit,
353            traffic_used,
354            expires_at: 0,
355            enabled: true,
356            cached_at: Instant::now(),
357        }
358    }
359
360    #[test]
361    fn test_cache_basic() {
362        let cache = make_cache();
363        let user = make_user("user1", 1000, 100);
364
365        cache.insert("hash1".to_string(), user);
366        let cached = cache.get("hash1").unwrap();
367        assert_eq!(cached.user_id, Some("user1".to_string()));
368        assert_eq!(cached.traffic_limit, 1000);
369
370        assert!(cache.get("hash2").is_none());
371    }
372
373    #[test]
374    fn test_cache_expiration() {
375        let cache = AuthCache::new(Duration::from_millis(10), Duration::ZERO, Duration::ZERO);
376        let user = make_user("user1", 0, 0);
377
378        cache.insert("hash1".to_string(), user);
379        assert!(cache.get("hash1").is_some());
380
381        std::thread::sleep(Duration::from_millis(20));
382        assert!(cache.get("hash1").is_none());
383    }
384
385    #[test]
386    fn test_cache_invalidate_user() {
387        let cache = make_cache();
388
389        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
390        cache.insert("hash2".to_string(), make_user("user2", 0, 0));
391
392        // Also add a traffic delta for user1
393        cache.add_traffic_delta("user1", 500);
394
395        cache.invalidate_user("user1");
396
397        assert!(cache.get("hash1").is_none());
398        assert!(cache.get("hash2").is_some());
399        // Delta should also be cleared
400        assert_eq!(cache.get_traffic_delta("user1"), 0);
401    }
402
403    #[test]
404    fn test_cache_stats() {
405        let cache = make_cache();
406        let user = CachedUser {
407            user_id: None,
408            traffic_limit: 0,
409            traffic_used: 0,
410            expires_at: 0,
411            enabled: true,
412            cached_at: Instant::now(),
413        };
414
415        cache.insert("hash1".to_string(), user);
416
417        cache.get("hash1"); // hit
418        cache.get("hash1"); // hit
419        cache.get("hash2"); // miss
420
421        let stats = cache.stats();
422        assert_eq!(stats.size, 1);
423        assert_eq!(stats.hits, 2);
424        assert_eq!(stats.misses, 1);
425        assert!((stats.hit_rate() - 0.666).abs() < 0.01);
426    }
427
428    // ── Traffic delta tests ─────────────────────────────────────
429
430    #[test]
431    fn test_traffic_delta_accumulates() {
432        let cache = make_cache();
433
434        cache.add_traffic_delta("user1", 100);
435        cache.add_traffic_delta("user1", 200);
436        cache.add_traffic_delta("user1", 300);
437
438        assert_eq!(cache.get_traffic_delta("user1"), 600);
439        assert_eq!(cache.get_traffic_delta("user2"), 0); // no delta
440    }
441
442    #[test]
443    fn test_traffic_delta_clear() {
444        let cache = make_cache();
445
446        cache.add_traffic_delta("user1", 500);
447        assert_eq!(cache.get_traffic_delta("user1"), 500);
448
449        cache.clear_traffic_delta("user1");
450        assert_eq!(cache.get_traffic_delta("user1"), 0);
451    }
452
453    #[test]
454    fn test_clear_resets_everything() {
455        let cache = make_cache();
456
457        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
458        cache.add_traffic_delta("user1", 100);
459        cache.insert_negative("bad_hash");
460
461        cache.clear();
462
463        assert!(cache.get("hash1").is_none());
464        assert_eq!(cache.get_traffic_delta("user1"), 0);
465        assert!(!cache.is_negative("bad_hash"));
466    }
467
468    // ── Negative cache tests ────────────────────────────────────
469
470    #[test]
471    fn test_negative_cache_basic() {
472        let cache = make_cache();
473
474        assert!(!cache.is_negative("bad_hash"));
475
476        cache.insert_negative("bad_hash");
477        assert!(cache.is_negative("bad_hash"));
478        assert!(!cache.is_negative("other_hash"));
479    }
480
481    #[test]
482    fn test_negative_cache_expiration() {
483        let cache = AuthCache::new(
484            Duration::from_secs(60),
485            Duration::ZERO,
486            Duration::from_millis(10),
487        );
488
489        cache.insert_negative("bad_hash");
490        assert!(cache.is_negative("bad_hash"));
491
492        std::thread::sleep(Duration::from_millis(20));
493        assert!(!cache.is_negative("bad_hash"));
494    }
495
496    #[test]
497    fn test_negative_cache_disabled_when_zero_ttl() {
498        let cache = AuthCache::new(Duration::from_secs(60), Duration::ZERO, Duration::ZERO);
499
500        cache.insert_negative("bad_hash");
501        assert!(!cache.is_negative("bad_hash"));
502    }
503
504    #[test]
505    fn test_negative_cache_remove() {
506        let cache = make_cache();
507
508        cache.insert_negative("bad_hash");
509        assert!(cache.is_negative("bad_hash"));
510
511        cache.remove_negative("bad_hash");
512        assert!(!cache.is_negative("bad_hash"));
513    }
514
515    #[test]
516    fn test_negative_cache_in_stats() {
517        let cache = make_cache();
518
519        cache.insert_negative("hash1");
520        cache.insert_negative("hash2");
521
522        let stats = cache.stats();
523        assert_eq!(stats.neg_size, 2);
524    }
525
526    #[test]
527    fn test_cleanup_expired_cleans_both() {
528        let cache = AuthCache::new(
529            Duration::from_millis(10),
530            Duration::ZERO,
531            Duration::from_millis(10),
532        );
533
534        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
535        cache.insert_negative("bad_hash");
536
537        std::thread::sleep(Duration::from_millis(20));
538        cache.cleanup_expired();
539
540        let stats = cache.stats();
541        assert_eq!(stats.size, 0);
542        assert_eq!(stats.neg_size, 0);
543    }
544
545    // ── Stale-while-revalidate tests ────────────────────────────
546
547    #[test]
548    fn test_cache_stale_lookup() {
549        // Use wide margins to avoid flakiness under heavy CPU load,
550        // where thread::sleep can overshoot significantly.
551        let cache = AuthCache::new(
552            Duration::from_millis(50),  // TTL
553            Duration::from_millis(500), // stale window
554            Duration::ZERO,             // neg TTL
555        );
556        let user = make_user("user1", 1000, 100);
557        cache.insert("hash1".to_string(), user);
558
559        // Should be Fresh
560        assert!(matches!(cache.lookup("hash1"), CacheLookup::Fresh(_)));
561
562        // Wait past TTL (100ms margin)
563        std::thread::sleep(Duration::from_millis(150));
564
565        // Should be Stale (past TTL but within stale window)
566        assert!(matches!(cache.lookup("hash1"), CacheLookup::Stale(_)));
567
568        // get() should return None for stale entries
569        assert!(cache.get("hash1").is_none());
570
571        // Wait past stale window
572        std::thread::sleep(Duration::from_millis(500));
573
574        // Should be Miss
575        assert!(matches!(cache.lookup("hash1"), CacheLookup::Miss));
576    }
577
578    #[test]
579    fn test_cache_stale_disabled_when_zero() {
580        // When stale_ttl is ZERO, stale lookup should be Miss
581        let cache = AuthCache::new(Duration::from_millis(50), Duration::ZERO, Duration::ZERO);
582        let user = make_user("user1", 0, 0);
583        cache.insert("hash1".to_string(), user);
584
585        std::thread::sleep(Duration::from_millis(150));
586        assert!(matches!(cache.lookup("hash1"), CacheLookup::Miss));
587    }
588
589    #[test]
590    fn test_cleanup_respects_stale_window() {
591        let cache = AuthCache::new(
592            Duration::from_millis(50),  // TTL
593            Duration::from_millis(500), // stale window
594            Duration::ZERO,
595        );
596        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
597
598        // Past TTL but within stale window — should NOT be cleaned up
599        std::thread::sleep(Duration::from_millis(150));
600        cache.cleanup_expired();
601        assert_eq!(cache.stats().size, 1);
602
603        // Past stale window — should be cleaned up
604        std::thread::sleep(Duration::from_millis(500));
605        cache.cleanup_expired();
606        assert_eq!(cache.stats().size, 0);
607    }
608
609    #[cfg(feature = "tokio-runtime")]
610    #[test]
611    fn test_revalidation_marker_deduplicates() {
612        let cache = make_cache();
613        assert!(cache.start_revalidation("hash1"));
614        assert!(!cache.start_revalidation("hash1"));
615        cache.finish_revalidation("hash1");
616        assert!(cache.start_revalidation("hash1"));
617    }
618}