Skip to main content

trojan_auth/store/
auth.rs

1//! Generic store-based authentication wrapper.
2//!
3//! [`StoreAuth<S>`] wraps any [`UserStore`] implementation and provides:
4//! - Validation logic (enabled → expired → traffic check)
5//! - Optional result caching via [`AuthCache`], with in-memory traffic deltas
6//! - Negative caching for invalid hashes (prevents DB flooding)
7//! - Optional batched traffic recording via [`TrafficRecorder`]
8
9use std::time::{Instant, SystemTime, UNIX_EPOCH};
10
11use async_trait::async_trait;
12
13use crate::error::AuthError;
14use crate::result::{AuthMetadata, AuthResult};
15use crate::traits::AuthBackend;
16
17use super::cache::{AuthCache, CacheStats, CachedUser};
18use super::config::{StoreAuthConfig, TrafficRecordingMode};
19use super::record::UserRecord;
20use super::traits::UserStore;
21
22#[cfg(feature = "batched-traffic")]
23use super::traffic::TrafficRecorder;
24
25/// Generic authentication backend that wraps a [`UserStore`].
26///
27/// Provides shared validation, caching, and traffic batching logic.
28/// New backends only need to implement [`UserStore`] (data access).
29///
30/// # Type parameter
31///
32/// - `S` — the underlying data store (e.g. `SqlStore`)
33pub struct StoreAuth<S: UserStore> {
34    store: S,
35    auth_cache: Option<AuthCache>,
36    #[cfg(feature = "batched-traffic")]
37    traffic_recorder: Option<TrafficRecorder>,
38    traffic_mode: TrafficRecordingMode,
39}
40
41impl<S: UserStore> StoreAuth<S> {
42    /// Create a new `StoreAuth` wrapping the given store.
43    ///
44    /// For backends that need batched traffic recording, use
45    /// [`with_traffic_recorder`](Self::with_traffic_recorder) after construction.
46    pub fn new(store: S, config: &StoreAuthConfig) -> Self {
47        let auth_cache = if config.cache_enabled {
48            Some(AuthCache::new(config.cache_ttl, config.neg_cache_ttl))
49        } else {
50            None
51        };
52
53        Self {
54            store,
55            auth_cache,
56            #[cfg(feature = "batched-traffic")]
57            traffic_recorder: None,
58            traffic_mode: config.traffic_mode,
59        }
60    }
61
62    /// Attach a [`TrafficRecorder`] for batched traffic writes.
63    #[cfg(feature = "batched-traffic")]
64    pub fn with_traffic_recorder(mut self, recorder: TrafficRecorder) -> Self {
65        self.traffic_recorder = Some(recorder);
66        self
67    }
68
69    /// Get a reference to the underlying store.
70    pub fn store(&self) -> &S {
71        &self.store
72    }
73
74    /// Check if caching is enabled.
75    pub fn cache_enabled(&self) -> bool {
76        self.auth_cache.is_some()
77    }
78
79    /// Get cache statistics. Returns `None` if caching is disabled.
80    pub fn cache_stats(&self) -> Option<CacheStats> {
81        self.auth_cache.as_ref().map(|c| c.stats())
82    }
83
84    /// Invalidate cache entry by password hash.
85    ///
86    /// Also removes the hash from the negative cache so that a
87    /// newly-added user is not blocked by a stale negative entry.
88    pub fn cache_invalidate(&self, hash: &str) {
89        if let Some(ref cache) = self.auth_cache {
90            cache.remove(hash);
91            cache.remove_negative(hash);
92        }
93    }
94
95    /// Invalidate all cache entries for a user.
96    pub fn cache_invalidate_user(&self, user_id: &str) {
97        if let Some(ref cache) = self.auth_cache {
98            cache.invalidate_user(user_id);
99        }
100    }
101
102    /// Clear all cache entries.
103    pub fn cache_clear(&self) {
104        if let Some(ref cache) = self.auth_cache {
105            cache.clear();
106        }
107    }
108
109    /// Get current unix timestamp.
110    #[inline]
111    #[allow(
112        clippy::cast_possible_wrap,
113        clippy::cast_sign_loss,
114        clippy::cast_possible_truncation
115    )]
116    fn now_unix() -> i64 {
117        SystemTime::now()
118            .duration_since(UNIX_EPOCH)
119            .map(|d| d.as_secs() as i64)
120            .unwrap_or(0)
121    }
122
123    /// Validate a [`UserRecord`] against business rules.
124    ///
125    /// Checks: enabled → expired → traffic exceeded.
126    fn validate_record(record: &UserRecord) -> Result<(), AuthError> {
127        if !record.enabled {
128            return Err(AuthError::Disabled);
129        }
130
131        let now = Self::now_unix();
132        if record.expires_at > 0 && now >= record.expires_at {
133            return Err(AuthError::Expired);
134        }
135
136        if record.traffic_limit > 0 && record.traffic_used >= record.traffic_limit {
137            return Err(AuthError::TrafficExceeded);
138        }
139
140        Ok(())
141    }
142
143    /// Build an [`AuthResult`] from a validated [`UserRecord`].
144    #[allow(
145        clippy::cast_possible_wrap,
146        clippy::cast_sign_loss,
147        clippy::cast_possible_truncation
148    )]
149    fn record_to_result(record: &UserRecord) -> AuthResult {
150        let metadata = AuthMetadata {
151            traffic_limit: record.traffic_limit as u64,
152            traffic_used: record.traffic_used as u64,
153            expires_at: record.expires_at as u64,
154            enabled: record.enabled,
155        };
156
157        AuthResult {
158            user_id: record.user_id.clone(),
159            metadata: Some(metadata),
160        }
161    }
162}
163
164#[async_trait]
165impl<S: UserStore> AuthBackend for StoreAuth<S> {
166    async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
167        if let Some(ref cache) = self.auth_cache {
168            // 1. Negative cache — reject known-invalid hashes without DB query
169            if cache.is_negative(hash) {
170                return Err(AuthError::Invalid);
171            }
172
173            // 2. Positive cache hit — validate with delta-adjusted traffic
174            if let Some(cached) = cache.get(hash) {
175                let delta = cached
176                    .user_id
177                    .as_deref()
178                    .map(|uid| cache.get_traffic_delta(uid))
179                    .unwrap_or(0);
180
181                let mut record = UserRecord::from(cached);
182                record.traffic_used += delta;
183
184                if let Err(e) = Self::validate_record(&record) {
185                    // Remove expired entries from positive cache
186                    if matches!(e, AuthError::Expired) {
187                        cache.remove(hash);
188                    }
189                    return Err(e);
190                }
191
192                return Ok(Self::record_to_result(&record));
193            }
194        }
195
196        // 3. Cache miss — query the store
197        let record = match self.store.find_by_hash(hash).await? {
198            Some(record) => record,
199            None => {
200                // Insert into negative cache
201                if let Some(ref cache) = self.auth_cache {
202                    cache.insert_negative(hash);
203                }
204                return Err(AuthError::Invalid);
205            }
206        };
207
208        // Validate business rules
209        Self::validate_record(&record)?;
210
211        // Cache successful result and reset traffic delta
212        if let Some(ref cache) = self.auth_cache {
213            if let Some(ref uid) = record.user_id {
214                cache.clear_traffic_delta(uid);
215            }
216            let cached_user = CachedUser {
217                user_id: record.user_id.clone(),
218                traffic_limit: record.traffic_limit,
219                traffic_used: record.traffic_used,
220                expires_at: record.expires_at,
221                enabled: record.enabled,
222                cached_at: Instant::now(),
223            };
224            cache.insert(hash.to_string(), cached_user);
225        }
226
227        Ok(Self::record_to_result(&record))
228    }
229
230    async fn record_traffic(&self, user_id: &str, bytes: u64) -> Result<(), AuthError> {
231        // Update in-memory traffic delta so cache hits reflect accumulated traffic
232        if let Some(ref cache) = self.auth_cache {
233            cache.add_traffic_delta(user_id, bytes);
234        }
235
236        // Persist to backend
237        match self.traffic_mode {
238            TrafficRecordingMode::Immediate => self.store.add_traffic(user_id, bytes).await,
239            TrafficRecordingMode::Batched => {
240                #[cfg(feature = "batched-traffic")]
241                if let Some(ref recorder) = self.traffic_recorder {
242                    recorder.record(user_id.to_string(), bytes);
243                }
244                Ok(())
245            }
246            TrafficRecordingMode::Disabled => Ok(()),
247        }
248    }
249}
250
251impl<S: UserStore + std::fmt::Debug> std::fmt::Debug for StoreAuth<S> {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.debug_struct("StoreAuth")
254            .field("store", &self.store)
255            .field("traffic_mode", &self.traffic_mode)
256            .field("cache_enabled", &self.auth_cache.is_some())
257            .finish_non_exhaustive()
258    }
259}