1use anyhow::{Context, Result};
9use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
10use ring::{hmac, pbkdf2, rand as ring_rand};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::num::NonZeroU32;
14use std::path::Path;
15use std::sync::Arc;
16
17const PBKDF2_ITERATIONS: u32 = 100_000;
19
20const SALT_LENGTH: usize = 16;
22
23const CREDENTIAL_LENGTH: usize = 32;
25
26const DEFAULT_TOKEN_EXPIRY_SECS: u64 = 86400;
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct UserCredential {
32 pub username: String,
34 pub password_hash: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40struct TokenPayload {
41 pub sub: String,
43 pub iat: u64,
45 pub exp: u64,
47 pub jti: String,
49}
50
51#[derive(Clone)]
53pub struct AuthService {
54 credentials: Arc<HashMap<String, UserCredential>>,
56 signing_key: Arc<hmac::Key>,
58 token_expiry_secs: u64,
60 enabled: bool,
62}
63
64impl std::fmt::Debug for AuthService {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("AuthService")
67 .field("credentials_count", &self.credentials.len())
68 .field("token_expiry_secs", &self.token_expiry_secs)
69 .field("enabled", &self.enabled)
70 .finish()
71 }
72}
73
74impl AuthService {
75 pub fn new(
83 credentials_file: Option<&Path>,
84 token_secret: Option<&str>,
85 token_expiry_secs: Option<u64>,
86 enabled: bool,
87 ) -> Result<Self> {
88 let credentials = if let Some(path) = credentials_file {
90 Self::load_credentials(path)?
91 } else {
92 HashMap::new()
93 };
94
95 let signing_key = if let Some(secret) = token_secret {
97 hmac::Key::new(hmac::HMAC_SHA256, secret.as_bytes())
98 } else {
99 let rng = ring_rand::SystemRandom::new();
100 hmac::Key::generate(hmac::HMAC_SHA256, &rng)
101 .map_err(|_| anyhow::anyhow!("Failed to generate signing key"))?
102 };
103
104 Ok(Self {
105 credentials: Arc::new(credentials),
106 signing_key: Arc::new(signing_key),
107 token_expiry_secs: token_expiry_secs.unwrap_or(DEFAULT_TOKEN_EXPIRY_SECS),
108 enabled,
109 })
110 }
111
112 pub fn disabled() -> Self {
114 Self {
115 credentials: Arc::new(HashMap::new()),
116 signing_key: Arc::new(hmac::Key::new(hmac::HMAC_SHA256, b"disabled-auth-not-used")),
117 token_expiry_secs: DEFAULT_TOKEN_EXPIRY_SECS,
118 enabled: false,
119 }
120 }
121
122 pub fn is_enabled(&self) -> bool {
124 self.enabled
125 }
126
127 fn load_credentials(path: &Path) -> Result<HashMap<String, UserCredential>> {
129 let content = std::fs::read_to_string(path)
130 .with_context(|| format!("Failed to read credentials file: {}", path.display()))?;
131
132 let credentials: Vec<UserCredential> = serde_json::from_str(&content)
133 .with_context(|| format!("Failed to parse credentials file: {}", path.display()))?;
134
135 let mut map = HashMap::new();
136 for cred in credentials {
137 map.insert(cred.username.clone(), cred);
138 }
139
140 tracing::info!("Loaded {} user credentials", map.len());
141 Ok(map)
142 }
143
144 pub fn verify_password(&self, username: &str, password: &str) -> bool {
146 let Some(credential) = self.credentials.get(username) else {
147 return false;
148 };
149
150 let Ok(stored_bytes) = URL_SAFE_NO_PAD.decode(&credential.password_hash) else {
152 tracing::warn!("Invalid base64 in password hash for user: {}", username);
153 return false;
154 };
155
156 if stored_bytes.len() != SALT_LENGTH + CREDENTIAL_LENGTH {
157 tracing::warn!("Invalid password hash length for user: {}", username);
158 return false;
159 }
160
161 let (salt, stored_hash) = stored_bytes.split_at(SALT_LENGTH);
162
163 pbkdf2::verify(
165 pbkdf2::PBKDF2_HMAC_SHA256,
166 NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
167 salt,
168 password.as_bytes(),
169 stored_hash,
170 )
171 .is_ok()
172 }
173
174 pub fn generate_token(&self, username: &str) -> Result<String> {
176 let now = std::time::SystemTime::now()
177 .duration_since(std::time::UNIX_EPOCH)
178 .context("System time before Unix epoch")?
179 .as_secs();
180
181 let payload = TokenPayload {
182 sub: username.to_string(),
183 iat: now,
184 exp: now + self.token_expiry_secs,
185 jti: uuid::Uuid::new_v4().to_string(),
186 };
187
188 let payload_json = serde_json::to_string(&payload)?;
190 let payload_b64 = URL_SAFE_NO_PAD.encode(payload_json.as_bytes());
191
192 let signature = hmac::sign(&self.signing_key, payload_b64.as_bytes());
194 let signature_b64 = URL_SAFE_NO_PAD.encode(signature.as_ref());
195
196 Ok(format!("{}.{}", payload_b64, signature_b64))
198 }
199
200 pub fn validate_token(&self, token: &str) -> Option<String> {
202 let parts: Vec<&str> = token.split('.').collect();
203 if parts.len() != 2 {
204 return None;
205 }
206
207 let payload_b64 = parts[0];
208 let signature_b64 = parts[1];
209
210 let Ok(signature_bytes) = URL_SAFE_NO_PAD.decode(signature_b64) else {
212 return None;
213 };
214
215 if hmac::verify(&self.signing_key, payload_b64.as_bytes(), &signature_bytes).is_err() {
216 return None;
217 }
218
219 let Ok(payload_json) = URL_SAFE_NO_PAD.decode(payload_b64) else {
221 return None;
222 };
223
224 let Ok(payload): Result<TokenPayload, _> = serde_json::from_slice(&payload_json) else {
225 return None;
226 };
227
228 let now = std::time::SystemTime::now()
230 .duration_since(std::time::UNIX_EPOCH)
231 .ok()?
232 .as_secs();
233
234 if now > payload.exp {
235 return None;
236 }
237
238 Some(payload.sub)
239 }
240
241 pub fn hash_password(password: &str) -> Result<String> {
244 let rng = ring_rand::SystemRandom::new();
245
246 let mut salt = [0u8; SALT_LENGTH];
248 ring_rand::SecureRandom::fill(&rng, &mut salt)
249 .map_err(|_| anyhow::anyhow!("Failed to generate salt"))?;
250
251 let mut derived_key = [0u8; CREDENTIAL_LENGTH];
253 pbkdf2::derive(
254 pbkdf2::PBKDF2_HMAC_SHA256,
255 NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
256 &salt,
257 password.as_bytes(),
258 &mut derived_key,
259 );
260
261 let mut combined = Vec::with_capacity(SALT_LENGTH + CREDENTIAL_LENGTH);
263 combined.extend_from_slice(&salt);
264 combined.extend_from_slice(&derived_key);
265
266 Ok(URL_SAFE_NO_PAD.encode(&combined))
267 }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct TokenRequest {
273 pub username: String,
275 pub password: String,
277}
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct TokenResponse {
282 pub token: String,
284 pub token_type: String,
286 pub expires_in: u64,
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use std::io::Write;
294 use tempfile::NamedTempFile;
295
296 #[test]
297 fn test_password_hashing() {
298 let password = "my_secret_password";
299 let hash = AuthService::hash_password(password).unwrap();
300
301 let decoded = URL_SAFE_NO_PAD.decode(&hash).unwrap();
303 assert_eq!(decoded.len(), SALT_LENGTH + CREDENTIAL_LENGTH);
304 }
305
306 #[test]
307 fn test_password_verification() {
308 let password = "test_password_123";
309 let hash = AuthService::hash_password(password).unwrap();
310
311 let credentials = vec![UserCredential {
313 username: "testuser".to_string(),
314 password_hash: hash,
315 }];
316
317 let mut file = NamedTempFile::new().unwrap();
318 write!(file, "{}", serde_json::to_string(&credentials).unwrap()).unwrap();
319
320 let auth =
321 AuthService::new(Some(file.path()), Some("test_secret"), Some(3600), true).unwrap();
322
323 assert!(auth.verify_password("testuser", password));
325
326 assert!(!auth.verify_password("testuser", "wrong_password"));
328
329 assert!(!auth.verify_password("unknown", password));
331 }
332
333 #[test]
334 fn test_token_generation_and_validation() {
335 let auth = AuthService::new(None, Some("test_secret"), Some(3600), true).unwrap();
336
337 let token = auth.generate_token("testuser").unwrap();
338
339 let username = auth.validate_token(&token);
341 assert_eq!(username, Some("testuser".to_string()));
342
343 assert!(auth.validate_token("invalid.token").is_none());
345 assert!(auth.validate_token("notavalidtoken").is_none());
346 }
347
348 #[test]
349 fn test_expired_token() {
350 let auth = AuthService::new(None, Some("test_secret"), Some(0), true).unwrap();
352
353 let token = auth.generate_token("testuser").unwrap();
354
355 std::thread::sleep(std::time::Duration::from_millis(1100));
358
359 assert!(auth.validate_token(&token).is_none());
360 }
361
362 #[test]
363 fn test_disabled_auth() {
364 let auth = AuthService::disabled();
365 assert!(!auth.is_enabled());
366 }
367
368 #[test]
369 fn test_token_tampering() {
370 let auth = AuthService::new(None, Some("test_secret"), Some(3600), true).unwrap();
371
372 let token = auth.generate_token("testuser").unwrap();
373 let parts: Vec<&str> = token.split('.').collect();
374
375 let tampered_payload = URL_SAFE_NO_PAD
377 .encode(b"{\"sub\":\"admin\",\"iat\":0,\"exp\":9999999999,\"jti\":\"fake\"}");
378 let tampered_token = format!("{}.{}", tampered_payload, parts[1]);
379
380 assert!(auth.validate_token(&tampered_token).is_none());
381 }
382}