trojan_auth/store/
auth.rs1use std::sync::Arc;
12#[cfg(feature = "tokio-runtime")]
13use std::time::Duration;
14use std::time::{Instant, SystemTime, UNIX_EPOCH};
15
16use async_trait::async_trait;
17
18use crate::error::AuthError;
19use crate::result::{AuthMetadata, AuthResult};
20use crate::traits::AuthBackend;
21
22use super::cache::{AuthCache, CacheLookup, CacheStats, CachedUser};
23use super::config::{StoreAuthConfig, TrafficRecordingMode};
24use super::record::UserRecord;
25use super::traits::UserStore;
26
27#[cfg(feature = "batched-traffic")]
28use super::traffic::TrafficRecorder;
29
30pub struct StoreAuth<S: UserStore> {
39 store: Arc<S>,
40 auth_cache: Option<Arc<AuthCache>>,
41 #[cfg(feature = "batched-traffic")]
42 traffic_recorder: Option<TrafficRecorder>,
43 traffic_mode: TrafficRecordingMode,
44}
45
46impl<S: UserStore> StoreAuth<S> {
47 pub fn new(store: S, config: &StoreAuthConfig) -> Self {
52 let auth_cache = if config.cache_enabled {
53 Some(Arc::new(AuthCache::new(
54 config.cache_ttl,
55 config.stale_ttl,
56 config.neg_cache_ttl,
57 )))
58 } else {
59 None
60 };
61
62 Self {
63 store: Arc::new(store),
64 auth_cache,
65 #[cfg(feature = "batched-traffic")]
66 traffic_recorder: None,
67 traffic_mode: config.traffic_mode,
68 }
69 }
70
71 #[cfg(feature = "batched-traffic")]
73 pub fn with_traffic_recorder(mut self, recorder: TrafficRecorder) -> Self {
74 self.traffic_recorder = Some(recorder);
75 self
76 }
77
78 pub fn store(&self) -> &S {
80 &self.store
81 }
82
83 pub fn cache_enabled(&self) -> bool {
85 self.auth_cache.is_some()
86 }
87
88 pub fn cache_stats(&self) -> Option<CacheStats> {
90 self.auth_cache.as_ref().map(|c| c.stats())
91 }
92
93 pub fn cache_invalidate(&self, hash: &str) {
98 if let Some(ref cache) = self.auth_cache {
99 cache.remove(hash);
100 cache.remove_negative(hash);
101 }
102 }
103
104 pub fn cache_invalidate_user(&self, user_id: &str) {
106 if let Some(ref cache) = self.auth_cache {
107 cache.invalidate_user(user_id);
108 }
109 }
110
111 pub fn cache_clear(&self) {
113 if let Some(ref cache) = self.auth_cache {
114 cache.clear();
115 }
116 }
117
118 #[inline]
120 #[allow(
121 clippy::cast_possible_wrap,
122 clippy::cast_sign_loss,
123 clippy::cast_possible_truncation
124 )]
125 fn now_unix() -> i64 {
126 SystemTime::now()
127 .duration_since(UNIX_EPOCH)
128 .map(|d| d.as_secs() as i64)
129 .unwrap_or(0)
130 }
131
132 fn validate_record(record: &UserRecord) -> Result<(), AuthError> {
136 if !record.enabled {
137 return Err(AuthError::Disabled);
138 }
139
140 let now = Self::now_unix();
141 if record.expires_at > 0 && now >= record.expires_at {
142 return Err(AuthError::Expired);
143 }
144
145 if record.traffic_limit > 0 && record.traffic_used >= record.traffic_limit {
146 return Err(AuthError::TrafficExceeded);
147 }
148
149 Ok(())
150 }
151
152 fn validate_cached(
157 cache: &AuthCache,
158 hash: &str,
159 cached: CachedUser,
160 ) -> Result<AuthResult, AuthError> {
161 let delta = cached
162 .user_id
163 .as_deref()
164 .map(|uid| cache.get_traffic_delta(uid))
165 .unwrap_or(0);
166
167 let mut record = UserRecord::from(cached);
168 record.traffic_used += delta;
169
170 if let Err(e) = Self::validate_record(&record) {
171 if matches!(e, AuthError::Expired) {
172 cache.remove(hash);
173 }
174 return Err(e);
175 }
176
177 Ok(Self::record_to_result(&record))
178 }
179
180 #[allow(
182 clippy::cast_possible_wrap,
183 clippy::cast_sign_loss,
184 clippy::cast_possible_truncation
185 )]
186 fn record_to_result(record: &UserRecord) -> AuthResult {
187 let metadata = AuthMetadata {
188 traffic_limit: record.traffic_limit as u64,
189 traffic_used: record.traffic_used as u64,
190 expires_at: record.expires_at as u64,
191 enabled: record.enabled,
192 };
193
194 AuthResult {
195 user_id: record.user_id.clone(),
196 metadata: Some(metadata),
197 }
198 }
199}
200
201#[cfg(feature = "tokio-runtime")]
203impl<S: UserStore + 'static> StoreAuth<S> {
204 fn spawn_revalidation(&self, cache: &Arc<AuthCache>, hash: &str) {
206 if !cache.start_revalidation(hash) {
208 return;
209 }
210 let store = Arc::clone(&self.store);
211 let cache = Arc::clone(cache);
212 let hash = hash.to_string();
213 tokio::spawn(async move {
214 Self::revalidate(store, Arc::clone(&cache), hash.clone()).await;
215 cache.finish_revalidation(&hash);
216 });
217 }
218
219 pub fn start_cache_cleanup(&self, interval: Duration) {
224 if let Some(ref cache) = self.auth_cache {
225 let cache = Arc::clone(cache);
226 tokio::spawn(async move {
227 let mut ticker = tokio::time::interval(interval);
228 ticker.tick().await;
229 loop {
230 ticker.tick().await;
231 cache.cleanup_expired();
232 }
233 });
234 }
235 }
236
237 async fn revalidate(store: Arc<S>, cache: Arc<AuthCache>, hash: String) {
239 match store.find_by_hash(&hash).await {
240 Ok(Some(record)) => {
241 if let Some(ref uid) = record.user_id {
242 cache.clear_traffic_delta(uid);
243 }
244 let cached_user = CachedUser {
245 user_id: record.user_id.clone(),
246 traffic_limit: record.traffic_limit,
247 traffic_used: record.traffic_used,
248 expires_at: record.expires_at,
249 enabled: record.enabled,
250 cached_at: Instant::now(),
251 };
252 cache.insert(hash, cached_user);
253 }
254 Ok(None) => {
255 cache.remove(&hash);
256 cache.insert_negative(&hash);
257 }
258 Err(e) => {
259 tracing::warn!(hash = %hash, error = %e, "background revalidation failed");
260 cache.remove(&hash);
261 }
262 }
263 }
264}
265
266#[async_trait]
267impl<S: UserStore + 'static> AuthBackend for StoreAuth<S> {
268 async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
269 if let Some(ref cache) = self.auth_cache {
270 if cache.is_negative(hash) {
272 return Err(AuthError::Invalid);
273 }
274
275 match cache.lookup(hash) {
277 CacheLookup::Fresh(cached) => {
278 return Self::validate_cached(cache, hash, cached);
279 }
280 CacheLookup::Stale(cached) => {
281 let result = Self::validate_cached(cache, hash, cached);
282 #[cfg(feature = "tokio-runtime")]
286 self.spawn_revalidation(cache, hash);
287 return result;
288 }
289 CacheLookup::Miss => { }
290 }
291 }
292
293 let record = match self.store.find_by_hash(hash).await? {
295 Some(record) => record,
296 None => {
297 if let Some(ref cache) = self.auth_cache {
299 cache.insert_negative(hash);
300 }
301 return Err(AuthError::Invalid);
302 }
303 };
304
305 Self::validate_record(&record)?;
307
308 if let Some(ref cache) = self.auth_cache {
310 if let Some(ref uid) = record.user_id {
311 cache.clear_traffic_delta(uid);
312 }
313 let cached_user = CachedUser {
314 user_id: record.user_id.clone(),
315 traffic_limit: record.traffic_limit,
316 traffic_used: record.traffic_used,
317 expires_at: record.expires_at,
318 enabled: record.enabled,
319 cached_at: Instant::now(),
320 };
321 cache.insert(hash.to_string(), cached_user);
322 }
323
324 Ok(Self::record_to_result(&record))
325 }
326
327 async fn record_traffic(&self, user_id: &str, bytes: u64) -> Result<(), AuthError> {
328 if let Some(ref cache) = self.auth_cache {
330 cache.add_traffic_delta(user_id, bytes);
331 }
332
333 match self.traffic_mode {
335 TrafficRecordingMode::Immediate => self.store.add_traffic(user_id, bytes).await,
336 TrafficRecordingMode::Batched => {
337 #[cfg(feature = "batched-traffic")]
338 if let Some(ref recorder) = self.traffic_recorder {
339 recorder.record(user_id.to_string(), bytes);
340 }
341 Ok(())
342 }
343 TrafficRecordingMode::Disabled => Ok(()),
344 }
345 }
346}
347
348impl<S: UserStore + std::fmt::Debug> std::fmt::Debug for StoreAuth<S> {
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 f.debug_struct("StoreAuth")
351 .field("store", &self.store)
352 .field("traffic_mode", &self.traffic_mode)
353 .field("cache_enabled", &self.auth_cache.is_some())
354 .finish_non_exhaustive()
355 }
356}