Skip to main content

trojan_auth/store/
cache.rs

1//! Authentication cache with TTL support.
2//!
3//! Caches successful authentication results to reduce database queries.
4//! Also provides:
5//! - **Traffic deltas**: in-memory tracking so cache hits reflect accumulated traffic
6//! - **Negative caching**: short-lived entries for invalid hashes to prevent DB flooding
7
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use parking_lot::RwLock;
13
14/// Cached user data.
15#[derive(Clone, Debug)]
16pub struct CachedUser {
17    /// User ID (optional identifier).
18    pub user_id: Option<String>,
19    /// Traffic limit in bytes (0 = unlimited).
20    pub traffic_limit: i64,
21    /// Traffic used in bytes.
22    pub traffic_used: i64,
23    /// Expiration timestamp (0 = never).
24    pub expires_at: i64,
25    /// Whether the user is enabled.
26    pub enabled: bool,
27    /// When this cache entry was created.
28    pub cached_at: Instant,
29}
30
31/// Cache entry with expiration.
32#[derive(Debug)]
33struct CacheEntry {
34    user: CachedUser,
35    expires_at: Instant,
36}
37
38/// Authentication cache with configurable TTL.
39///
40/// Beyond basic positive caching, this provides:
41/// - **Traffic deltas** (`user_id → i64`): accumulated traffic bytes since
42///   the last DB fetch, applied on cache hit so traffic-limit checks are
43///   accurate within a single cache window.
44/// - **Negative cache** (`hash → expiry`): short-lived entries for hashes
45///   that returned no DB row, preventing repeated SELECT storms from
46///   invalid/attack traffic.
47#[derive(Debug)]
48pub struct AuthCache {
49    /// Positive cache: hash → user data.
50    cache: RwLock<HashMap<String, CacheEntry>>,
51    /// TTL for positive cache entries.
52    ttl: Duration,
53
54    /// Accumulated traffic bytes since last DB fetch, keyed by user_id.
55    traffic_deltas: RwLock<HashMap<String, i64>>,
56
57    /// Negative cache: hash → expiry instant.
58    neg_cache: RwLock<HashMap<String, Instant>>,
59    /// TTL for negative cache entries (Duration::ZERO = disabled).
60    neg_ttl: Duration,
61
62    /// Cache hit counter.
63    hits: AtomicU64,
64    /// Cache miss counter.
65    misses: AtomicU64,
66}
67
68impl AuthCache {
69    /// Create a new auth cache.
70    ///
71    /// - `ttl` — positive cache entry lifetime
72    /// - `neg_ttl` — negative cache entry lifetime (`Duration::ZERO` to disable)
73    pub fn new(ttl: Duration, neg_ttl: Duration) -> Self {
74        Self {
75            cache: RwLock::new(HashMap::new()),
76            ttl,
77            traffic_deltas: RwLock::new(HashMap::new()),
78            neg_cache: RwLock::new(HashMap::new()),
79            neg_ttl,
80            hits: AtomicU64::new(0),
81            misses: AtomicU64::new(0),
82        }
83    }
84
85    // ── Positive cache ──────────────────────────────────────────
86
87    /// Get a cached user by password hash.
88    ///
89    /// Returns `Some(CachedUser)` if found and not expired, `None` otherwise.
90    pub fn get(&self, hash: &str) -> Option<CachedUser> {
91        let cache = self.cache.read();
92        if let Some(entry) = cache.get(hash)
93            && Instant::now() < entry.expires_at
94        {
95            self.hits.fetch_add(1, Ordering::Relaxed);
96            return Some(entry.user.clone());
97        }
98        drop(cache);
99
100        self.misses.fetch_add(1, Ordering::Relaxed);
101        None
102    }
103
104    /// Insert a user into the cache.
105    pub fn insert(&self, hash: String, user: CachedUser) {
106        let entry = CacheEntry {
107            user,
108            expires_at: Instant::now() + self.ttl,
109        };
110        self.cache.write().insert(hash, entry);
111    }
112
113    /// Remove a user from the cache.
114    pub fn remove(&self, hash: &str) {
115        self.cache.write().remove(hash);
116    }
117
118    /// Invalidate a user by user_id.
119    ///
120    /// Removes all positive cache entries with matching user_id
121    /// and clears the traffic delta for that user.
122    pub fn invalidate_user(&self, user_id: &str) {
123        self.cache
124            .write()
125            .retain(|_, entry| entry.user.user_id.as_deref() != Some(user_id));
126        self.traffic_deltas.write().remove(user_id);
127    }
128
129    /// Clear all cache entries (positive, negative, and traffic deltas).
130    pub fn clear(&self) {
131        self.cache.write().clear();
132        self.traffic_deltas.write().clear();
133        self.neg_cache.write().clear();
134    }
135
136    /// Remove expired entries from positive and negative caches.
137    pub fn cleanup_expired(&self) {
138        let now = Instant::now();
139        self.cache.write().retain(|_, entry| entry.expires_at > now);
140        self.neg_cache.write().retain(|_, &mut exp| exp > now);
141    }
142
143    // ── Traffic deltas ──────────────────────────────────────────
144
145    /// Increment the in-memory traffic delta for a user.
146    ///
147    /// Called by `StoreAuth::record_traffic()` so that subsequent
148    /// cache hits reflect the accumulated traffic.
149    #[allow(clippy::cast_possible_wrap)]
150    pub fn add_traffic_delta(&self, user_id: &str, bytes: u64) {
151        *self
152            .traffic_deltas
153            .write()
154            .entry(user_id.to_string())
155            .or_insert(0) += bytes as i64;
156    }
157
158    /// Read the accumulated traffic delta for a user.
159    ///
160    /// Returns 0 if no delta is tracked (user never had traffic recorded
161    /// since the last DB fetch).
162    pub fn get_traffic_delta(&self, user_id: &str) -> i64 {
163        self.traffic_deltas
164            .read()
165            .get(user_id)
166            .copied()
167            .unwrap_or(0)
168    }
169
170    /// Clear the traffic delta for a user.
171    ///
172    /// Called when the cache re-fetches from DB, so the delta restarts
173    /// from zero (the DB value is now authoritative).
174    pub fn clear_traffic_delta(&self, user_id: &str) {
175        self.traffic_deltas.write().remove(user_id);
176    }
177
178    // ── Negative cache ──────────────────────────────────────────
179
180    /// Record a hash as "not found" in the negative cache.
181    ///
182    /// Subsequent lookups within `neg_ttl` will return `true` from
183    /// [`is_negative`](Self::is_negative), skipping the DB query.
184    pub fn insert_negative(&self, hash: &str) {
185        if self.neg_ttl > Duration::ZERO {
186            self.neg_cache
187                .write()
188                .insert(hash.to_string(), Instant::now() + self.neg_ttl);
189        }
190    }
191
192    /// Check if a hash is in the negative cache (known invalid).
193    ///
194    /// Returns `true` if the hash was recently looked up and not found.
195    /// Expired entries are lazily removed.
196    pub fn is_negative(&self, hash: &str) -> bool {
197        if self.neg_ttl == Duration::ZERO {
198            return false;
199        }
200        let cache = self.neg_cache.read();
201        if let Some(&exp) = cache.get(hash)
202            && Instant::now() < exp
203        {
204            return true;
205        }
206        false
207    }
208
209    /// Remove a hash from the negative cache.
210    ///
211    /// Called after cache invalidation so that a newly-added user
212    /// is not blocked by a stale negative entry.
213    pub fn remove_negative(&self, hash: &str) {
214        self.neg_cache.write().remove(hash);
215    }
216
217    // ── Statistics ──────────────────────────────────────────────
218
219    /// Get cache statistics.
220    pub fn stats(&self) -> CacheStats {
221        let cache = self.cache.read();
222        CacheStats {
223            size: cache.len(),
224            neg_size: self.neg_cache.read().len(),
225            hits: self.hits.load(Ordering::Relaxed),
226            misses: self.misses.load(Ordering::Relaxed),
227            ttl: self.ttl,
228        }
229    }
230}
231
232/// Cache statistics.
233#[derive(Debug, Clone)]
234pub struct CacheStats {
235    /// Number of positive cache entries.
236    pub size: usize,
237    /// Number of negative cache entries.
238    pub neg_size: usize,
239    /// Number of cache hits.
240    pub hits: u64,
241    /// Number of cache misses.
242    pub misses: u64,
243    /// Cache TTL.
244    pub ttl: Duration,
245}
246
247impl CacheStats {
248    /// Calculate hit rate (0.0 to 1.0).
249    pub fn hit_rate(&self) -> f64 {
250        let total = self.hits + self.misses;
251        if total == 0 {
252            0.0
253        } else {
254            self.hits as f64 / total as f64
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    fn make_cache() -> AuthCache {
264        AuthCache::new(Duration::from_secs(60), Duration::from_secs(5))
265    }
266
267    fn make_user(user_id: &str, traffic_limit: i64, traffic_used: i64) -> CachedUser {
268        CachedUser {
269            user_id: Some(user_id.to_string()),
270            traffic_limit,
271            traffic_used,
272            expires_at: 0,
273            enabled: true,
274            cached_at: Instant::now(),
275        }
276    }
277
278    #[test]
279    fn test_cache_basic() {
280        let cache = make_cache();
281        let user = make_user("user1", 1000, 100);
282
283        cache.insert("hash1".to_string(), user);
284        let cached = cache.get("hash1").unwrap();
285        assert_eq!(cached.user_id, Some("user1".to_string()));
286        assert_eq!(cached.traffic_limit, 1000);
287
288        assert!(cache.get("hash2").is_none());
289    }
290
291    #[test]
292    fn test_cache_expiration() {
293        let cache = AuthCache::new(Duration::from_millis(10), Duration::ZERO);
294        let user = make_user("user1", 0, 0);
295
296        cache.insert("hash1".to_string(), user);
297        assert!(cache.get("hash1").is_some());
298
299        std::thread::sleep(Duration::from_millis(20));
300        assert!(cache.get("hash1").is_none());
301    }
302
303    #[test]
304    fn test_cache_invalidate_user() {
305        let cache = make_cache();
306
307        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
308        cache.insert("hash2".to_string(), make_user("user2", 0, 0));
309
310        // Also add a traffic delta for user1
311        cache.add_traffic_delta("user1", 500);
312
313        cache.invalidate_user("user1");
314
315        assert!(cache.get("hash1").is_none());
316        assert!(cache.get("hash2").is_some());
317        // Delta should also be cleared
318        assert_eq!(cache.get_traffic_delta("user1"), 0);
319    }
320
321    #[test]
322    fn test_cache_stats() {
323        let cache = make_cache();
324        let user = CachedUser {
325            user_id: None,
326            traffic_limit: 0,
327            traffic_used: 0,
328            expires_at: 0,
329            enabled: true,
330            cached_at: Instant::now(),
331        };
332
333        cache.insert("hash1".to_string(), user);
334
335        cache.get("hash1"); // hit
336        cache.get("hash1"); // hit
337        cache.get("hash2"); // miss
338
339        let stats = cache.stats();
340        assert_eq!(stats.size, 1);
341        assert_eq!(stats.hits, 2);
342        assert_eq!(stats.misses, 1);
343        assert!((stats.hit_rate() - 0.666).abs() < 0.01);
344    }
345
346    // ── Traffic delta tests ─────────────────────────────────────
347
348    #[test]
349    fn test_traffic_delta_accumulates() {
350        let cache = make_cache();
351
352        cache.add_traffic_delta("user1", 100);
353        cache.add_traffic_delta("user1", 200);
354        cache.add_traffic_delta("user1", 300);
355
356        assert_eq!(cache.get_traffic_delta("user1"), 600);
357        assert_eq!(cache.get_traffic_delta("user2"), 0); // no delta
358    }
359
360    #[test]
361    fn test_traffic_delta_clear() {
362        let cache = make_cache();
363
364        cache.add_traffic_delta("user1", 500);
365        assert_eq!(cache.get_traffic_delta("user1"), 500);
366
367        cache.clear_traffic_delta("user1");
368        assert_eq!(cache.get_traffic_delta("user1"), 0);
369    }
370
371    #[test]
372    fn test_clear_resets_everything() {
373        let cache = make_cache();
374
375        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
376        cache.add_traffic_delta("user1", 100);
377        cache.insert_negative("bad_hash");
378
379        cache.clear();
380
381        assert!(cache.get("hash1").is_none());
382        assert_eq!(cache.get_traffic_delta("user1"), 0);
383        assert!(!cache.is_negative("bad_hash"));
384    }
385
386    // ── Negative cache tests ────────────────────────────────────
387
388    #[test]
389    fn test_negative_cache_basic() {
390        let cache = make_cache();
391
392        assert!(!cache.is_negative("bad_hash"));
393
394        cache.insert_negative("bad_hash");
395        assert!(cache.is_negative("bad_hash"));
396        assert!(!cache.is_negative("other_hash"));
397    }
398
399    #[test]
400    fn test_negative_cache_expiration() {
401        let cache = AuthCache::new(Duration::from_secs(60), Duration::from_millis(10));
402
403        cache.insert_negative("bad_hash");
404        assert!(cache.is_negative("bad_hash"));
405
406        std::thread::sleep(Duration::from_millis(20));
407        assert!(!cache.is_negative("bad_hash"));
408    }
409
410    #[test]
411    fn test_negative_cache_disabled_when_zero_ttl() {
412        let cache = AuthCache::new(Duration::from_secs(60), Duration::ZERO);
413
414        cache.insert_negative("bad_hash");
415        assert!(!cache.is_negative("bad_hash"));
416    }
417
418    #[test]
419    fn test_negative_cache_remove() {
420        let cache = make_cache();
421
422        cache.insert_negative("bad_hash");
423        assert!(cache.is_negative("bad_hash"));
424
425        cache.remove_negative("bad_hash");
426        assert!(!cache.is_negative("bad_hash"));
427    }
428
429    #[test]
430    fn test_negative_cache_in_stats() {
431        let cache = make_cache();
432
433        cache.insert_negative("hash1");
434        cache.insert_negative("hash2");
435
436        let stats = cache.stats();
437        assert_eq!(stats.neg_size, 2);
438    }
439
440    #[test]
441    fn test_cleanup_expired_cleans_both() {
442        let cache = AuthCache::new(Duration::from_millis(10), Duration::from_millis(10));
443
444        cache.insert("hash1".to_string(), make_user("user1", 0, 0));
445        cache.insert_negative("bad_hash");
446
447        std::thread::sleep(Duration::from_millis(20));
448        cache.cleanup_expired();
449
450        let stats = cache.stats();
451        assert_eq!(stats.size, 0);
452        assert_eq!(stats.neg_size, 0);
453    }
454}