trojan_auth/store/
auth.rs1use std::time::{Instant, SystemTime, UNIX_EPOCH};
10
11use async_trait::async_trait;
12
13use crate::error::AuthError;
14use crate::result::{AuthMetadata, AuthResult};
15use crate::traits::AuthBackend;
16
17use super::cache::{AuthCache, CacheStats, CachedUser};
18use super::config::{StoreAuthConfig, TrafficRecordingMode};
19use super::record::UserRecord;
20use super::traits::UserStore;
21
22#[cfg(feature = "batched-traffic")]
23use super::traffic::TrafficRecorder;
24
25pub struct StoreAuth<S: UserStore> {
34 store: S,
35 auth_cache: Option<AuthCache>,
36 #[cfg(feature = "batched-traffic")]
37 traffic_recorder: Option<TrafficRecorder>,
38 traffic_mode: TrafficRecordingMode,
39}
40
41impl<S: UserStore> StoreAuth<S> {
42 pub fn new(store: S, config: &StoreAuthConfig) -> Self {
47 let auth_cache = if config.cache_enabled {
48 Some(AuthCache::new(config.cache_ttl, config.neg_cache_ttl))
49 } else {
50 None
51 };
52
53 Self {
54 store,
55 auth_cache,
56 #[cfg(feature = "batched-traffic")]
57 traffic_recorder: None,
58 traffic_mode: config.traffic_mode,
59 }
60 }
61
62 #[cfg(feature = "batched-traffic")]
64 pub fn with_traffic_recorder(mut self, recorder: TrafficRecorder) -> Self {
65 self.traffic_recorder = Some(recorder);
66 self
67 }
68
69 pub fn store(&self) -> &S {
71 &self.store
72 }
73
74 pub fn cache_enabled(&self) -> bool {
76 self.auth_cache.is_some()
77 }
78
79 pub fn cache_stats(&self) -> Option<CacheStats> {
81 self.auth_cache.as_ref().map(|c| c.stats())
82 }
83
84 pub fn cache_invalidate(&self, hash: &str) {
89 if let Some(ref cache) = self.auth_cache {
90 cache.remove(hash);
91 cache.remove_negative(hash);
92 }
93 }
94
95 pub fn cache_invalidate_user(&self, user_id: &str) {
97 if let Some(ref cache) = self.auth_cache {
98 cache.invalidate_user(user_id);
99 }
100 }
101
102 pub fn cache_clear(&self) {
104 if let Some(ref cache) = self.auth_cache {
105 cache.clear();
106 }
107 }
108
109 #[inline]
111 #[allow(
112 clippy::cast_possible_wrap,
113 clippy::cast_sign_loss,
114 clippy::cast_possible_truncation
115 )]
116 fn now_unix() -> i64 {
117 SystemTime::now()
118 .duration_since(UNIX_EPOCH)
119 .map(|d| d.as_secs() as i64)
120 .unwrap_or(0)
121 }
122
123 fn validate_record(record: &UserRecord) -> Result<(), AuthError> {
127 if !record.enabled {
128 return Err(AuthError::Disabled);
129 }
130
131 let now = Self::now_unix();
132 if record.expires_at > 0 && now >= record.expires_at {
133 return Err(AuthError::Expired);
134 }
135
136 if record.traffic_limit > 0 && record.traffic_used >= record.traffic_limit {
137 return Err(AuthError::TrafficExceeded);
138 }
139
140 Ok(())
141 }
142
143 #[allow(
145 clippy::cast_possible_wrap,
146 clippy::cast_sign_loss,
147 clippy::cast_possible_truncation
148 )]
149 fn record_to_result(record: &UserRecord) -> AuthResult {
150 let metadata = AuthMetadata {
151 traffic_limit: record.traffic_limit as u64,
152 traffic_used: record.traffic_used as u64,
153 expires_at: record.expires_at as u64,
154 enabled: record.enabled,
155 };
156
157 AuthResult {
158 user_id: record.user_id.clone(),
159 metadata: Some(metadata),
160 }
161 }
162}
163
164#[async_trait]
165impl<S: UserStore> AuthBackend for StoreAuth<S> {
166 async fn verify(&self, hash: &str) -> Result<AuthResult, AuthError> {
167 if let Some(ref cache) = self.auth_cache {
168 if cache.is_negative(hash) {
170 return Err(AuthError::Invalid);
171 }
172
173 if let Some(cached) = cache.get(hash) {
175 let delta = cached
176 .user_id
177 .as_deref()
178 .map(|uid| cache.get_traffic_delta(uid))
179 .unwrap_or(0);
180
181 let mut record = UserRecord::from(cached);
182 record.traffic_used += delta;
183
184 if let Err(e) = Self::validate_record(&record) {
185 if matches!(e, AuthError::Expired) {
187 cache.remove(hash);
188 }
189 return Err(e);
190 }
191
192 return Ok(Self::record_to_result(&record));
193 }
194 }
195
196 let record = match self.store.find_by_hash(hash).await? {
198 Some(record) => record,
199 None => {
200 if let Some(ref cache) = self.auth_cache {
202 cache.insert_negative(hash);
203 }
204 return Err(AuthError::Invalid);
205 }
206 };
207
208 Self::validate_record(&record)?;
210
211 if let Some(ref cache) = self.auth_cache {
213 if let Some(ref uid) = record.user_id {
214 cache.clear_traffic_delta(uid);
215 }
216 let cached_user = CachedUser {
217 user_id: record.user_id.clone(),
218 traffic_limit: record.traffic_limit,
219 traffic_used: record.traffic_used,
220 expires_at: record.expires_at,
221 enabled: record.enabled,
222 cached_at: Instant::now(),
223 };
224 cache.insert(hash.to_string(), cached_user);
225 }
226
227 Ok(Self::record_to_result(&record))
228 }
229
230 async fn record_traffic(&self, user_id: &str, bytes: u64) -> Result<(), AuthError> {
231 if let Some(ref cache) = self.auth_cache {
233 cache.add_traffic_delta(user_id, bytes);
234 }
235
236 match self.traffic_mode {
238 TrafficRecordingMode::Immediate => self.store.add_traffic(user_id, bytes).await,
239 TrafficRecordingMode::Batched => {
240 #[cfg(feature = "batched-traffic")]
241 if let Some(ref recorder) = self.traffic_recorder {
242 recorder.record(user_id.to_string(), bytes);
243 }
244 Ok(())
245 }
246 TrafficRecordingMode::Disabled => Ok(()),
247 }
248 }
249}
250
251impl<S: UserStore + std::fmt::Debug> std::fmt::Debug for StoreAuth<S> {
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 f.debug_struct("StoreAuth")
254 .field("store", &self.store)
255 .field("traffic_mode", &self.traffic_mode)
256 .field("cache_enabled", &self.auth_cache.is_some())
257 .finish_non_exhaustive()
258 }
259}