Skip to main content

trojan_auth/store/
auth.rs

1//! Generic store-based authentication wrapper.
2//!
3//! [`StoreAuth<S>`] wraps any [`UserStore`] implementation and provides:
4//! - Validation logic (enabled → expired → traffic check)
5//! - Optional result caching via [`AuthCache`], with in-memory traffic deltas
6//! - Stale-while-revalidate: stale cache entries are served immediately
7//!   while being revalidated in the background
8//! - Negative caching for invalid hashes (prevents DB flooding)
9//! - Optional batched traffic recording via [`TrafficRecorder`]
10
11use std::sync::Arc;
12use std::time::{Instant, SystemTime, UNIX_EPOCH};
13
14use async_trait::async_trait;
15
16use crate::error::AuthError;
17use crate::result::{AuthMetadata, AuthResult};
18use crate::traits::AuthBackend;
19
20use super::cache::{AuthCache, CacheLookup, CacheStats, CachedUser};
21use super::config::{StoreAuthConfig, TrafficRecordingMode};
22use super::record::UserRecord;
23use super::traits::UserStore;
24
25#[cfg(feature = "batched-traffic")]
26use super::traffic::TrafficRecorder;
27
28/// Generic authentication backend that wraps a [`UserStore`].
29///
30/// Provides shared validation, caching, and traffic batching logic.
31/// New backends only need to implement [`UserStore`] (data access).
32///
33/// # Type parameter
34///
35/// - `S` — the underlying data store (e.g. `SqlStore`)
36pub struct StoreAuth<S: UserStore> {
37    store: Arc<S>,
38    auth_cache: Option<Arc<AuthCache>>,
39    #[cfg(feature = "batched-traffic")]
40    traffic_recorder: Option<TrafficRecorder>,
41    traffic_mode: TrafficRecordingMode,
42}
43
44impl<S: UserStore> StoreAuth<S> {
45    /// Create a new `StoreAuth` wrapping the given store.
46    ///
47    /// For backends that need batched traffic recording, use
48    /// [`with_traffic_recorder`](Self::with_traffic_recorder) after construction.
49    pub fn new(store: S, config: &StoreAuthConfig) -> Self {
50        let auth_cache = if config.cache_enabled {
51            Some(Arc::new(AuthCache::new(
52                config.cache_ttl,
53                config.stale_ttl,
54                config.neg_cache_ttl,
55            )))
56        } else {
57            None
58        };
59
60        Self {
61            store: Arc::new(store),
62            auth_cache,
63            #[cfg(feature = "batched-traffic")]
64            traffic_recorder: None,
65            traffic_mode: config.traffic_mode,
66        }
67    }
68
69    /// Attach a [`TrafficRecorder`] for batched traffic writes.
70    #[cfg(feature = "batched-traffic")]
71    pub fn with_traffic_recorder(mut self, recorder: TrafficRecorder) -> Self {
72        self.traffic_recorder = Some(recorder);
73        self
74    }
75
76    /// Get a reference to the underlying store.
77    pub fn store(&self) -> &S {
78        &self.store
79    }
80
81    /// Check if caching is enabled.
82    pub fn cache_enabled(&self) -> bool {
83        self.auth_cache.is_some()
84    }
85
86    /// Get cache statistics. Returns `None` if caching is disabled.
87    pub fn cache_stats(&self) -> Option<CacheStats> {
88        self.auth_cache.as_ref().map(|c| c.stats())
89    }
90
91    /// Invalidate cache entry by password hash.
92    ///
93    /// Also removes the hash from the negative cache so that a
94    /// newly-added user is not blocked by a stale negative entry.
95    pub fn cache_invalidate(&self, hash: &str) {
96        if let Some(ref cache) = self.auth_cache {
97            cache.remove(hash);
98            cache.remove_negative(hash);
99        }
100    }
101
102    /// Invalidate all cache entries for a user.
103    pub fn cache_invalidate_user(&self, user_id: &str) {
104        if let Some(ref cache) = self.auth_cache {
105            cache.invalidate_user(user_id);
106        }
107    }
108
109    /// Clear all cache entries.
110    pub fn cache_clear(&self) {
111        if let Some(ref cache) = self.auth_cache {
112            cache.clear();
113        }
114    }
115
116    /// Get current unix timestamp.
117    #[inline]
118    #[allow(
119        clippy::cast_possible_wrap,
120        clippy::cast_sign_loss,
121        clippy::cast_possible_truncation
122    )]
123    fn now_unix() -> i64 {
124        SystemTime::now()
125            .duration_since(UNIX_EPOCH)
126            .map(|d| d.as_secs() as i64)
127            .unwrap_or(0)
128    }
129
130    /// Validate a [`UserRecord`] against business rules.
131    ///
132    /// Checks: enabled → expired → traffic exceeded.
133    fn validate_record(record: &UserRecord) -> Result<(), AuthError> {
134        if !record.enabled {
135            return Err(AuthError::Disabled);
136        }
137
138        let now = Self::now_unix();
139        if record.expires_at > 0 && now >= record.expires_at {
140            return Err(AuthError::Expired);
141        }
142
143        if record.traffic_limit > 0 && record.traffic_used >= record.traffic_limit {
144            return Err(AuthError::TrafficExceeded);
145        }
146
147        Ok(())
148    }
149
150    /// Validate a cached entry with delta-adjusted traffic.
151    ///
152    /// Returns `Ok(AuthResult)` if valid, removes stale entries from cache
153    /// and returns an error otherwise.
154    fn validate_cached(
155        cache: &AuthCache,
156        hash: &str,
157        cached: CachedUser,
158    ) -> Result<AuthResult, AuthError> {
159        let delta = cached
160            .user_id
161            .as_deref()
162            .map(|uid| cache.get_traffic_delta(uid))
163            .unwrap_or(0);
164
165        let mut record = UserRecord::from(cached);
166        record.traffic_used += delta;
167
168        if let Err(e) = Self::validate_record(&record) {
169            if matches!(e, AuthError::Expired) {
170                cache.remove(hash);
171            }
172            return Err(e);
173        }
174
175        Ok(Self::record_to_result(&record))
176    }
177
178    /// Build an [`AuthResult`] from a validated [`UserRecord`].
179    #[allow(
180        clippy::cast_possible_wrap,
181        clippy::cast_sign_loss,
182        clippy::cast_possible_truncation
183    )]
184    fn record_to_result(record: &UserRecord) -> AuthResult {
185        let metadata = AuthMetadata {
186            traffic_limit: record.traffic_limit as u64,
187            traffic_used: record.traffic_used as u64,
188            expires_at: record.expires_at as u64,
189            enabled: record.enabled,
190        };
191
192        AuthResult {
193            user_id: record.user_id.clone(),
194            metadata: Some(metadata),
195        }
196    }
197}
198
199/// Background revalidation for stale cache entries (requires tokio runtime).
200#[cfg(feature = "tokio-runtime")]
201impl<S: UserStore + 'static> StoreAuth<S> {
202    /// Spawn a background task to revalidate a stale cache entry.
203    fn spawn_revalidation(&self, cache: &Arc<AuthCache>, hash: &str) {
204        // Single-flight: avoid spawning duplicate revalidation tasks per hash.
205        if !cache.start_revalidation(hash) {
206            return;
207        }
208        let store = Arc::clone(&self.store);
209        let cache = Arc::clone(cache);
210        let hash = hash.to_string();
211        tokio::spawn(async move {
212            Self::revalidate(store, Arc::clone(&cache), hash.clone()).await;
213            cache.finish_revalidation(&hash);
214        });
215    }
216
217    /// Re-fetch a user from the store and update the cache.
218    async fn revalidate(store: Arc<S>, cache: Arc<AuthCache>, hash: String) {
219        match store.find_by_hash(&hash).await {
220            Ok(Some(record)) => {
221                if let Some(ref uid) = record.user_id {
222                    cache.clear_traffic_delta(uid);
223                }
224                let cached_user = CachedUser {
225                    user_id: record.user_id.clone(),
226                    traffic_limit: record.traffic_limit,
227                    traffic_used: record.traffic_used,
228                    expires_at: record.expires_at,
229                    enabled: record.enabled,
230                    cached_at: Instant::now(),
231                };
232                cache.insert(hash, cached_user);
233            }
234            Ok(None) => {
235                cache.remove(&hash);
236                cache.insert_negative(&hash);
237            }
238            Err(e) => {
239                tracing::warn!(hash = %hash, error = %e, "background revalidation failed");
240                cache.remove(&hash);
241            }
242        }
243    }
244}
245
246#[async_trait]
247impl<S: UserStore + 'static> AuthBackend for StoreAuth<S> {
248    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
249        if let Some(ref cache) = self.auth_cache {
250            // 1. Negative cache — reject known-invalid hashes without store query
251            if cache.is_negative(hash) {
252                return Err(AuthError::Invalid);
253            }
254
255            // 2. Cache lookup with stale-while-revalidate support
256            match cache.lookup(hash) {
257                CacheLookup::Fresh(cached) => {
258                    return Self::validate_cached(cache, hash, cached);
259                }
260                CacheLookup::Stale(cached) => {
261                    let result = Self::validate_cached(cache, hash, cached);
262                    // Spawn background revalidation for stale hits, including
263                    // validation failures (e.g. disabled/traffic exceeded), so
264                    // cache state can recover quickly when backend state changes.
265                    #[cfg(feature = "tokio-runtime")]
266                    self.spawn_revalidation(cache, hash);
267                    return result;
268                }
269                CacheLookup::Miss => { /* fall through to store query */ }
270            }
271        }
272
273        // 3. Cache miss — query the store
274        let record = match self.store.find_by_hash(hash).await? {
275            Some(record) => record,
276            None => {
277                // Insert into negative cache
278                if let Some(ref cache) = self.auth_cache {
279                    cache.insert_negative(hash);
280                }
281                return Err(AuthError::Invalid);
282            }
283        };
284
285        // Validate business rules
286        Self::validate_record(&record)?;
287
288        // Cache successful result and reset traffic delta
289        if let Some(ref cache) = self.auth_cache {
290            if let Some(ref uid) = record.user_id {
291                cache.clear_traffic_delta(uid);
292            }
293            let cached_user = CachedUser {
294                user_id: record.user_id.clone(),
295                traffic_limit: record.traffic_limit,
296                traffic_used: record.traffic_used,
297                expires_at: record.expires_at,
298                enabled: record.enabled,
299                cached_at: Instant::now(),
300            };
301            cache.insert(hash.to_string(), cached_user);
302        }
303
304        Ok(Self::record_to_result(&record))
305    }
306
307    async fn record_traffic(&self, user_id: &str, bytes: u64) -> Result<(), AuthError> {
308        // Update in-memory traffic delta so cache hits reflect accumulated traffic
309        if let Some(ref cache) = self.auth_cache {
310            cache.add_traffic_delta(user_id, bytes);
311        }
312
313        // Persist to backend
314        match self.traffic_mode {
315            TrafficRecordingMode::Immediate => self.store.add_traffic(user_id, bytes).await,
316            TrafficRecordingMode::Batched => {
317                #[cfg(feature = "batched-traffic")]
318                if let Some(ref recorder) = self.traffic_recorder {
319                    recorder.record(user_id.to_string(), bytes);
320                }
321                Ok(())
322            }
323            TrafficRecordingMode::Disabled => Ok(()),
324        }
325    }
326}
327
328impl<S: UserStore + std::fmt::Debug> std::fmt::Debug for StoreAuth<S> {
329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330        f.debug_struct("StoreAuth")
331            .field("store", &self.store)
332            .field("traffic_mode", &self.traffic_mode)
333            .field("cache_enabled", &self.auth_cache.is_some())
334            .finish_non_exhaustive()
335    }
336}