Skip to main content

trojan_auth/store/
auth.rs

1//! Generic store-based authentication wrapper.
2//!
3//! [`StoreAuth<S>`] wraps any [`UserStore`] implementation and provides:
4//! - Validation logic (enabled → expired → traffic check)
5//! - Optional result caching via [`AuthCache`], with in-memory traffic deltas
6//! - Stale-while-revalidate: stale cache entries are served immediately
7//!   while being revalidated in the background
8//! - Negative caching for invalid hashes (prevents DB flooding)
9//! - Optional batched traffic recording via [`TrafficRecorder`]
10
11use std::sync::Arc;
12#[cfg(feature = "tokio-runtime")]
13use std::time::Duration;
14use std::time::{Instant, SystemTime, UNIX_EPOCH};
15
16use async_trait::async_trait;
17
18use crate::error::AuthError;
19use crate::result::{AuthMetadata, AuthResult};
20use crate::traits::AuthBackend;
21
22use super::cache::{AuthCache, CacheLookup, CacheStats, CachedUser};
23use super::config::{StoreAuthConfig, TrafficRecordingMode};
24use super::record::UserRecord;
25use super::traits::UserStore;
26
27#[cfg(feature = "batched-traffic")]
28use super::traffic::TrafficRecorder;
29
30/// Generic authentication backend that wraps a [`UserStore`].
31///
32/// Provides shared validation, caching, and traffic batching logic.
33/// New backends only need to implement [`UserStore`] (data access).
34///
35/// # Type parameter
36///
37/// - `S` — the underlying data store (e.g. `SqlStore`)
38pub struct StoreAuth<S: UserStore> {
39    store: Arc<S>,
40    auth_cache: Option<Arc<AuthCache>>,
41    #[cfg(feature = "batched-traffic")]
42    traffic_recorder: Option<TrafficRecorder>,
43    traffic_mode: TrafficRecordingMode,
44}
45
46impl<S: UserStore> StoreAuth<S> {
47    /// Create a new `StoreAuth` wrapping the given store.
48    ///
49    /// For backends that need batched traffic recording, use
50    /// [`with_traffic_recorder`](Self::with_traffic_recorder) after construction.
51    pub fn new(store: S, config: &StoreAuthConfig) -> Self {
52        let auth_cache = if config.cache_enabled {
53            Some(Arc::new(AuthCache::new(
54                config.cache_ttl,
55                config.stale_ttl,
56                config.neg_cache_ttl,
57            )))
58        } else {
59            None
60        };
61
62        Self {
63            store: Arc::new(store),
64            auth_cache,
65            #[cfg(feature = "batched-traffic")]
66            traffic_recorder: None,
67            traffic_mode: config.traffic_mode,
68        }
69    }
70
71    /// Attach a [`TrafficRecorder`] for batched traffic writes.
72    #[cfg(feature = "batched-traffic")]
73    pub fn with_traffic_recorder(mut self, recorder: TrafficRecorder) -> Self {
74        self.traffic_recorder = Some(recorder);
75        self
76    }
77
78    /// Get a reference to the underlying store.
79    pub fn store(&self) -> &S {
80        &self.store
81    }
82
83    /// Check if caching is enabled.
84    pub fn cache_enabled(&self) -> bool {
85        self.auth_cache.is_some()
86    }
87
88    /// Get cache statistics. Returns `None` if caching is disabled.
89    pub fn cache_stats(&self) -> Option<CacheStats> {
90        self.auth_cache.as_ref().map(|c| c.stats())
91    }
92
93    /// Invalidate cache entry by password hash.
94    ///
95    /// Also removes the hash from the negative cache so that a
96    /// newly-added user is not blocked by a stale negative entry.
97    pub fn cache_invalidate(&self, hash: &str) {
98        if let Some(ref cache) = self.auth_cache {
99            cache.remove(hash);
100            cache.remove_negative(hash);
101        }
102    }
103
104    /// Invalidate all cache entries for a user.
105    pub fn cache_invalidate_user(&self, user_id: &str) {
106        if let Some(ref cache) = self.auth_cache {
107            cache.invalidate_user(user_id);
108        }
109    }
110
111    /// Clear all cache entries.
112    pub fn cache_clear(&self) {
113        if let Some(ref cache) = self.auth_cache {
114            cache.clear();
115        }
116    }
117
118    /// Get current unix timestamp.
119    #[inline]
120    #[allow(
121        clippy::cast_possible_wrap,
122        clippy::cast_sign_loss,
123        clippy::cast_possible_truncation
124    )]
125    fn now_unix() -> i64 {
126        SystemTime::now()
127            .duration_since(UNIX_EPOCH)
128            .map(|d| d.as_secs() as i64)
129            .unwrap_or(0)
130    }
131
132    /// Validate a [`UserRecord`] against business rules.
133    ///
134    /// Checks: enabled → expired → traffic exceeded.
135    fn validate_record(record: &UserRecord) -> Result<(), AuthError> {
136        if !record.enabled {
137            return Err(AuthError::Disabled);
138        }
139
140        let now = Self::now_unix();
141        if record.expires_at > 0 && now >= record.expires_at {
142            return Err(AuthError::Expired);
143        }
144
145        if record.traffic_limit > 0 && record.traffic_used >= record.traffic_limit {
146            return Err(AuthError::TrafficExceeded);
147        }
148
149        Ok(())
150    }
151
152    /// Validate a cached entry with delta-adjusted traffic.
153    ///
154    /// Returns `Ok(AuthResult)` if valid, removes stale entries from cache
155    /// and returns an error otherwise.
156    fn validate_cached(
157        cache: &AuthCache,
158        hash: &str,
159        cached: CachedUser,
160    ) -> Result<AuthResult, AuthError> {
161        let delta = cached
162            .user_id
163            .as_deref()
164            .map(|uid| cache.get_traffic_delta(uid))
165            .unwrap_or(0);
166
167        let mut record = UserRecord::from(cached);
168        record.traffic_used += delta;
169
170        if let Err(e) = Self::validate_record(&record) {
171            if matches!(e, AuthError::Expired) {
172                cache.remove(hash);
173            }
174            return Err(e);
175        }
176
177        Ok(Self::record_to_result(&record))
178    }
179
180    /// Build an [`AuthResult`] from a validated [`UserRecord`].
181    #[allow(
182        clippy::cast_possible_wrap,
183        clippy::cast_sign_loss,
184        clippy::cast_possible_truncation
185    )]
186    fn record_to_result(record: &UserRecord) -> AuthResult {
187        let metadata = AuthMetadata {
188            traffic_limit: record.traffic_limit as u64,
189            traffic_used: record.traffic_used as u64,
190            expires_at: record.expires_at as u64,
191            enabled: record.enabled,
192        };
193
194        AuthResult {
195            user_id: record.user_id.clone(),
196            metadata: Some(metadata),
197        }
198    }
199}
200
201/// Background revalidation for stale cache entries (requires tokio runtime).
202#[cfg(feature = "tokio-runtime")]
203impl<S: UserStore + 'static> StoreAuth<S> {
204    /// Spawn a background task to revalidate a stale cache entry.
205    fn spawn_revalidation(&self, cache: &Arc<AuthCache>, hash: &str) {
206        // Single-flight: avoid spawning duplicate revalidation tasks per hash.
207        if !cache.start_revalidation(hash) {
208            return;
209        }
210        let store = Arc::clone(&self.store);
211        let cache = Arc::clone(cache);
212        let hash = hash.to_string();
213        tokio::spawn(async move {
214            Self::revalidate(store, Arc::clone(&cache), hash.clone()).await;
215            cache.finish_revalidation(&hash);
216        });
217    }
218
219    /// Start a background task that periodically removes expired cache entries.
220    ///
221    /// Call this after construction when a tokio runtime is available.
222    /// Without this, expired entries are only lazily evicted on access.
223    pub fn start_cache_cleanup(&self, interval: Duration) {
224        if let Some(ref cache) = self.auth_cache {
225            let cache = Arc::clone(cache);
226            tokio::spawn(async move {
227                let mut ticker = tokio::time::interval(interval);
228                ticker.tick().await;
229                loop {
230                    ticker.tick().await;
231                    cache.cleanup_expired();
232                }
233            });
234        }
235    }
236
237    /// Re-fetch a user from the store and update the cache.
238    async fn revalidate(store: Arc<S>, cache: Arc<AuthCache>, hash: String) {
239        match store.find_by_hash(&hash).await {
240            Ok(Some(record)) => {
241                if let Some(ref uid) = record.user_id {
242                    cache.clear_traffic_delta(uid);
243                }
244                let cached_user = CachedUser {
245                    user_id: record.user_id.clone(),
246                    traffic_limit: record.traffic_limit,
247                    traffic_used: record.traffic_used,
248                    expires_at: record.expires_at,
249                    enabled: record.enabled,
250                    cached_at: Instant::now(),
251                };
252                cache.insert(hash, cached_user);
253            }
254            Ok(None) => {
255                cache.remove(&hash);
256                cache.insert_negative(&hash);
257            }
258            Err(e) => {
259                tracing::warn!(hash = %hash, error = %e, "background revalidation failed");
260                cache.remove(&hash);
261            }
262        }
263    }
264}
265
266#[async_trait]
267impl<S: UserStore + 'static> AuthBackend for StoreAuth<S> {
268    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
269        if let Some(ref cache) = self.auth_cache {
270            // 1. Negative cache — reject known-invalid hashes without store query
271            if cache.is_negative(hash) {
272                return Err(AuthError::Invalid);
273            }
274
275            // 2. Cache lookup with stale-while-revalidate support
276            match cache.lookup(hash) {
277                CacheLookup::Fresh(cached) => {
278                    return Self::validate_cached(cache, hash, cached);
279                }
280                CacheLookup::Stale(cached) => {
281                    let result = Self::validate_cached(cache, hash, cached);
282                    // Spawn background revalidation for stale hits, including
283                    // validation failures (e.g. disabled/traffic exceeded), so
284                    // cache state can recover quickly when backend state changes.
285                    #[cfg(feature = "tokio-runtime")]
286                    self.spawn_revalidation(cache, hash);
287                    return result;
288                }
289                CacheLookup::Miss => { /* fall through to store query */ }
290            }
291        }
292
293        // 3. Cache miss — query the store
294        let record = match self.store.find_by_hash(hash).await? {
295            Some(record) => record,
296            None => {
297                // Insert into negative cache
298                if let Some(ref cache) = self.auth_cache {
299                    cache.insert_negative(hash);
300                }
301                return Err(AuthError::Invalid);
302            }
303        };
304
305        // Validate business rules
306        Self::validate_record(&record)?;
307
308        // Cache successful result and reset traffic delta
309        if let Some(ref cache) = self.auth_cache {
310            if let Some(ref uid) = record.user_id {
311                cache.clear_traffic_delta(uid);
312            }
313            let cached_user = CachedUser {
314                user_id: record.user_id.clone(),
315                traffic_limit: record.traffic_limit,
316                traffic_used: record.traffic_used,
317                expires_at: record.expires_at,
318                enabled: record.enabled,
319                cached_at: Instant::now(),
320            };
321            cache.insert(hash.to_string(), cached_user);
322        }
323
324        Ok(Self::record_to_result(&record))
325    }
326
327    async fn record_traffic(&self, user_id: &str, bytes: u64) -> Result<(), AuthError> {
328        // Update in-memory traffic delta so cache hits reflect accumulated traffic
329        if let Some(ref cache) = self.auth_cache {
330            cache.add_traffic_delta(user_id, bytes);
331        }
332
333        // Persist to backend
334        match self.traffic_mode {
335            TrafficRecordingMode::Immediate => self.store.add_traffic(user_id, bytes).await,
336            TrafficRecordingMode::Batched => {
337                #[cfg(feature = "batched-traffic")]
338                if let Some(ref recorder) = self.traffic_recorder {
339                    recorder.record(user_id.to_string(), bytes);
340                }
341                Ok(())
342            }
343            TrafficRecordingMode::Disabled => Ok(()),
344        }
345    }
346}
347
348impl<S: UserStore + std::fmt::Debug> std::fmt::Debug for StoreAuth<S> {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        f.debug_struct("StoreAuth")
351            .field("store", &self.store)
352            .field("traffic_mode", &self.traffic_mode)
353            .field("cache_enabled", &self.auth_cache.is_some())
354            .finish_non_exhaustive()
355    }
356}