trojan_auth/store/
auth.rs1use std::sync::Arc;
12use std::time::{Instant, SystemTime, UNIX_EPOCH};
13
14use async_trait::async_trait;
15
16use crate::error::AuthError;
17use crate::result::{AuthMetadata, AuthResult};
18use crate::traits::AuthBackend;
19
20use super::cache::{AuthCache, CacheLookup, CacheStats, CachedUser};
21use super::config::{StoreAuthConfig, TrafficRecordingMode};
22use super::record::UserRecord;
23use super::traits::UserStore;
24
25#[cfg(feature = "batched-traffic")]
26use super::traffic::TrafficRecorder;
27
28pub struct StoreAuth<S: UserStore> {
37 store: Arc<S>,
38 auth_cache: Option<Arc<AuthCache>>,
39 #[cfg(feature = "batched-traffic")]
40 traffic_recorder: Option<TrafficRecorder>,
41 traffic_mode: TrafficRecordingMode,
42}
43
44impl<S: UserStore> StoreAuth<S> {
45 pub fn new(store: S, config: &StoreAuthConfig) -> Self {
50 let auth_cache = if config.cache_enabled {
51 Some(Arc::new(AuthCache::new(
52 config.cache_ttl,
53 config.stale_ttl,
54 config.neg_cache_ttl,
55 )))
56 } else {
57 None
58 };
59
60 Self {
61 store: Arc::new(store),
62 auth_cache,
63 #[cfg(feature = "batched-traffic")]
64 traffic_recorder: None,
65 traffic_mode: config.traffic_mode,
66 }
67 }
68
69 #[cfg(feature = "batched-traffic")]
71 pub fn with_traffic_recorder(mut self, recorder: TrafficRecorder) -> Self {
72 self.traffic_recorder = Some(recorder);
73 self
74 }
75
76 pub fn store(&self) -> &S {
78 &self.store
79 }
80
81 pub fn cache_enabled(&self) -> bool {
83 self.auth_cache.is_some()
84 }
85
86 pub fn cache_stats(&self) -> Option<CacheStats> {
88 self.auth_cache.as_ref().map(|c| c.stats())
89 }
90
91 pub fn cache_invalidate(&self, hash: &str) {
96 if let Some(ref cache) = self.auth_cache {
97 cache.remove(hash);
98 cache.remove_negative(hash);
99 }
100 }
101
102 pub fn cache_invalidate_user(&self, user_id: &str) {
104 if let Some(ref cache) = self.auth_cache {
105 cache.invalidate_user(user_id);
106 }
107 }
108
109 pub fn cache_clear(&self) {
111 if let Some(ref cache) = self.auth_cache {
112 cache.clear();
113 }
114 }
115
116 #[inline]
118 #[allow(
119 clippy::cast_possible_wrap,
120 clippy::cast_sign_loss,
121 clippy::cast_possible_truncation
122 )]
123 fn now_unix() -> i64 {
124 SystemTime::now()
125 .duration_since(UNIX_EPOCH)
126 .map(|d| d.as_secs() as i64)
127 .unwrap_or(0)
128 }
129
130 fn validate_record(record: &UserRecord) -> Result<(), AuthError> {
134 if !record.enabled {
135 return Err(AuthError::Disabled);
136 }
137
138 let now = Self::now_unix();
139 if record.expires_at > 0 && now >= record.expires_at {
140 return Err(AuthError::Expired);
141 }
142
143 if record.traffic_limit > 0 && record.traffic_used >= record.traffic_limit {
144 return Err(AuthError::TrafficExceeded);
145 }
146
147 Ok(())
148 }
149
150 fn validate_cached(
155 cache: &AuthCache,
156 hash: &str,
157 cached: CachedUser,
158 ) -> Result<AuthResult, AuthError> {
159 let delta = cached
160 .user_id
161 .as_deref()
162 .map(|uid| cache.get_traffic_delta(uid))
163 .unwrap_or(0);
164
165 let mut record = UserRecord::from(cached);
166 record.traffic_used += delta;
167
168 if let Err(e) = Self::validate_record(&record) {
169 if matches!(e, AuthError::Expired) {
170 cache.remove(hash);
171 }
172 return Err(e);
173 }
174
175 Ok(Self::record_to_result(&record))
176 }
177
178 #[allow(
180 clippy::cast_possible_wrap,
181 clippy::cast_sign_loss,
182 clippy::cast_possible_truncation
183 )]
184 fn record_to_result(record: &UserRecord) -> AuthResult {
185 let metadata = AuthMetadata {
186 traffic_limit: record.traffic_limit as u64,
187 traffic_used: record.traffic_used as u64,
188 expires_at: record.expires_at as u64,
189 enabled: record.enabled,
190 };
191
192 AuthResult {
193 user_id: record.user_id.clone(),
194 metadata: Some(metadata),
195 }
196 }
197}
198
199#[cfg(feature = "tokio-runtime")]
201impl<S: UserStore + 'static> StoreAuth<S> {
202 fn spawn_revalidation(&self, cache: &Arc<AuthCache>, hash: &str) {
204 if !cache.start_revalidation(hash) {
206 return;
207 }
208 let store = Arc::clone(&self.store);
209 let cache = Arc::clone(cache);
210 let hash = hash.to_string();
211 tokio::spawn(async move {
212 Self::revalidate(store, Arc::clone(&cache), hash.clone()).await;
213 cache.finish_revalidation(&hash);
214 });
215 }
216
217 async fn revalidate(store: Arc<S>, cache: Arc<AuthCache>, hash: String) {
219 match store.find_by_hash(&hash).await {
220 Ok(Some(record)) => {
221 if let Some(ref uid) = record.user_id {
222 cache.clear_traffic_delta(uid);
223 }
224 let cached_user = CachedUser {
225 user_id: record.user_id.clone(),
226 traffic_limit: record.traffic_limit,
227 traffic_used: record.traffic_used,
228 expires_at: record.expires_at,
229 enabled: record.enabled,
230 cached_at: Instant::now(),
231 };
232 cache.insert(hash, cached_user);
233 }
234 Ok(None) => {
235 cache.remove(&hash);
236 cache.insert_negative(&hash);
237 }
238 Err(e) => {
239 tracing::warn!(hash = %hash, error = %e, "background revalidation failed");
240 cache.remove(&hash);
241 }
242 }
243 }
244}
245
246#[async_trait]
247impl<S: UserStore + 'static> AuthBackend for StoreAuth<S> {
248 async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
249 if let Some(ref cache) = self.auth_cache {
250 if cache.is_negative(hash) {
252 return Err(AuthError::Invalid);
253 }
254
255 match cache.lookup(hash) {
257 CacheLookup::Fresh(cached) => {
258 return Self::validate_cached(cache, hash, cached);
259 }
260 CacheLookup::Stale(cached) => {
261 let result = Self::validate_cached(cache, hash, cached);
262 #[cfg(feature = "tokio-runtime")]
266 self.spawn_revalidation(cache, hash);
267 return result;
268 }
269 CacheLookup::Miss => { }
270 }
271 }
272
273 let record = match self.store.find_by_hash(hash).await? {
275 Some(record) => record,
276 None => {
277 if let Some(ref cache) = self.auth_cache {
279 cache.insert_negative(hash);
280 }
281 return Err(AuthError::Invalid);
282 }
283 };
284
285 Self::validate_record(&record)?;
287
288 if let Some(ref cache) = self.auth_cache {
290 if let Some(ref uid) = record.user_id {
291 cache.clear_traffic_delta(uid);
292 }
293 let cached_user = CachedUser {
294 user_id: record.user_id.clone(),
295 traffic_limit: record.traffic_limit,
296 traffic_used: record.traffic_used,
297 expires_at: record.expires_at,
298 enabled: record.enabled,
299 cached_at: Instant::now(),
300 };
301 cache.insert(hash.to_string(), cached_user);
302 }
303
304 Ok(Self::record_to_result(&record))
305 }
306
307 async fn record_traffic(&self, user_id: &str, bytes: u64) -> Result<(), AuthError> {
308 if let Some(ref cache) = self.auth_cache {
310 cache.add_traffic_delta(user_id, bytes);
311 }
312
313 match self.traffic_mode {
315 TrafficRecordingMode::Immediate => self.store.add_traffic(user_id, bytes).await,
316 TrafficRecordingMode::Batched => {
317 #[cfg(feature = "batched-traffic")]
318 if let Some(ref recorder) = self.traffic_recorder {
319 recorder.record(user_id.to_string(), bytes);
320 }
321 Ok(())
322 }
323 TrafficRecordingMode::Disabled => Ok(()),
324 }
325 }
326}
327
328impl<S: UserStore + std::fmt::Debug> std::fmt::Debug for StoreAuth<S> {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 f.debug_struct("StoreAuth")
331 .field("store", &self.store)
332 .field("traffic_mode", &self.traffic_mode)
333 .field("cache_enabled", &self.auth_cache.is_some())
334 .finish_non_exhaustive()
335 }
336}